├── .github └── FUNDING.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── core ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── alphazero.py │ ├── baselines │ │ ├── __init__.py │ │ ├── baseline.py │ │ ├── best.py │ │ ├── greedy.py │ │ ├── greedy_mcts.py │ │ ├── lazy_greedy_mcts.py │ │ ├── lazy_rollout_mcts.py │ │ ├── random.py │ │ ├── rollout.py │ │ └── rollout_mcts.py │ ├── evaluator.py │ ├── lazy_mcts.py │ ├── lazyzero.py │ ├── load.py │ └── mcts.py ├── demo │ ├── __init__.py │ ├── demo.py │ ├── human.py │ └── load.py ├── env.py ├── resnet.py ├── test │ ├── __init__.py │ ├── tester.py │ └── tournament │ │ ├── __init__.py │ │ └── tournament.py ├── train │ ├── __init__.py │ ├── collector.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── custom_activations.py │ ├── heatmap.py │ ├── history.py │ ├── memory.py │ └── utils.py ├── envs ├── _2048 │ ├── __init__.py │ ├── collector.py │ ├── env.py │ ├── misc │ │ ├── board_representation.png │ │ ├── high_tile.png │ │ ├── improvement.png │ │ ├── reward_distribution.png │ │ └── slide_conv.png │ ├── tester.py │ ├── torchscripts.py │ └── trainer.py ├── __init__.py ├── connect_x │ ├── __init__.py │ ├── collector.py │ ├── demo.py │ ├── env.py │ ├── tester.py │ └── trainer.py ├── load.py └── othello │ ├── __init__.py │ ├── collector.py │ ├── demo.py │ ├── env.py │ ├── evaluators │ ├── __init__.py │ └── edax.py │ ├── misc │ ├── book.txt │ ├── convolution.png │ ├── example_kernel.png │ ├── example_kernel2.png │ ├── legal_starting_actions.png │ └── othello_sandwich_menu.png │ ├── tester.py │ ├── torchscripts.py │ └── trainer.py ├── example_configs ├── 2048_gpu.yaml ├── 2048_mini.yaml ├── 2048_test.yaml ├── 2048_tiny.yaml ├── othello_baselines.yaml ├── othello_demo.yaml ├── othello_gpu.yaml ├── othello_mini.yaml ├── othello_tiny.yaml ├── othello_tournament.yaml ├── test_tournament.yaml └── tournament.yaml ├── misc ├── 2048.gif ├── 2048env.png ├── benchmark.png ├── cpu_bad.png ├── heatmap.png ├── high_tile.png ├── one_lazyzero_iteration.png ├── othello_game.gif ├── slide_conv.png ├── train_accuracy.png ├── train_high_tile.png ├── train_moves.png └── workflow.png ├── notebooks ├── benchmark.ipynb ├── demo.ipynb ├── hello_world.ipynb ├── hello_world_colab.ipynb ├── test.ipynb ├── tournament.ipynb └── train.ipynb ├── poetry.lock ├── pyproject.toml └── turbozero.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [lowrollr] 4 | 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | .vscode/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | authors: 3 | - family-names: "Marshall" 4 | given-names: "Jacob" 5 | title: "TurboZero: Vectorized AlphaZero, MCTS, and Environments" 6 | abstract: vectorized implementations of RL algorithms and environments, including AlphaZero/MCTS 7 | url: "https://github.com/lowrollr/turbozero" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.float32 = torch.float32 -------------------------------------------------------------------------------- /core/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/algorithms/__init__.py -------------------------------------------------------------------------------- /core/algorithms/alphazero.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from dataclasses import dataclass 5 | from typing import Callable, Optional, Tuple 6 | import torch 7 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig, TrainableEvaluator 8 | from core.algorithms.mcts import MCTS, MCTSConfig 9 | from core.env import Env 10 | from core.utils.utils import rand_argmax_2d 11 | 12 | @dataclass 13 | class AlphaZeroConfig(MCTSConfig): 14 | temperature: float = 1.0 15 | 16 | 17 | class AlphaZero(MCTS, TrainableEvaluator): 18 | def __init__(self, env: Env, config: AlphaZeroConfig, model: torch.nn.Module) -> None: 19 | super().__init__(env, config, model) 20 | self.config: AlphaZeroConfig 21 | 22 | def choose_actions(self, visits: torch.Tensor) -> torch.Tensor: 23 | if self.config.temperature > 0: 24 | return torch.multinomial(torch.pow(visits, 1/self.config.temperature), 1, replacement=True).flatten() 25 | else: 26 | return rand_argmax_2d(visits).flatten() 27 | 28 | # see MCTS 29 | def evaluate(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 30 | evaluation_fn = lambda env: self.model(env.get_nn_input()) 31 | return super().evaluate(evaluation_fn) 32 | 33 | 34 | -------------------------------------------------------------------------------- /core/algorithms/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/algorithms/baselines/__init__.py -------------------------------------------------------------------------------- /core/algorithms/baselines/baseline.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import logging 4 | 5 | from dataclasses import dataclass 6 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig 7 | from core.env import Env 8 | from core.utils.history import Metric 9 | 10 | 11 | class Baseline(Evaluator): 12 | def __init__(self, env: Env, config: EvaluatorConfig, *args, **kwargs): 13 | super().__init__(env, config, *args, **kwargs) 14 | self.metrics_key = 'baseline' 15 | self.proper_name = 'Baseline' 16 | 17 | def add_metrics(self, history): 18 | if history.epoch_metrics.get(self.metrics_key) is None: 19 | history.epoch_metrics[self.metrics_key] = Metric( 20 | name=self.metrics_key, 21 | xlabel='Epoch', 22 | ylabel='Win Rate', 23 | maximize=True, 24 | alert_on_best=False, 25 | proper_name=f'Win Rate (Current Model vs. {self.proper_name})' 26 | ) 27 | 28 | def add_metrics_data(self, data, history, log=True): 29 | history.add_epoch_data({self.metrics_key: data}, log=log) 30 | 31 | 32 | -------------------------------------------------------------------------------- /core/algorithms/baselines/best.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from copy import deepcopy 4 | from dataclasses import dataclass 5 | from typing import Optional, Tuple 6 | import torch 7 | from core.algorithms.baselines.baseline import Baseline 8 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig 9 | from core.env import Env 10 | from core.utils.history import Metric 11 | 12 | class BestModelBaseline(Baseline): 13 | def __init__(self, 14 | env: Env, 15 | config: EvaluatorConfig, 16 | evaluator: Evaluator, 17 | best_model: torch.nn.Module, 18 | best_model_optimizer: torch.optim.Optimizer, 19 | metrics_key: str = 'win_rate_vs_best', 20 | proper_name: str = 'Best Model', 21 | *args, **kwargs 22 | ): 23 | super().__init__(env, config, *args, **kwargs) 24 | self.best_model = deepcopy(best_model) 25 | self.best_model_optimizer = deepcopy(best_model_optimizer.state_dict()) if best_model_optimizer is not None else None 26 | self.evaluator = evaluator.__class__(env, evaluator.config, self.best_model) 27 | self.metrics_key = metrics_key 28 | self.proper_name = proper_name 29 | 30 | def step_evaluator(self, actions, terminated) -> None: 31 | self.evaluator.step_evaluator(actions, terminated) 32 | 33 | def step(self, *args) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: 34 | return self.evaluator.step(self.best_model) -------------------------------------------------------------------------------- /core/algorithms/baselines/greedy.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | from dataclasses import dataclass 7 | import torch 8 | from core.algorithms.baselines.baseline import Baseline 9 | from core.algorithms.evaluator import EvaluatorConfig 10 | from core.env import Env 11 | 12 | @dataclass 13 | class GreedyConfig(EvaluatorConfig): 14 | heuristic: str 15 | 16 | 17 | class GreedyBaseline(Baseline): 18 | def __init__(self, env: Env, config: GreedyConfig, *args, **kwargs): 19 | super().__init__(env, config, *args, **kwargs) 20 | self.metrics_key = 'greedy_' + config.heuristic 21 | self.proper_name = 'Greedy_' + config.heuristic 22 | self.config: GreedyConfig 23 | 24 | def evaluate(self): 25 | saved = self.env.save_node() 26 | legal_actions = self.env.get_legal_actions().clone() 27 | starting_players = self.env.cur_players.clone() 28 | action_scores = torch.zeros( 29 | (self.env.parallel_envs, self.env.policy_shape[0]), 30 | dtype=torch.float32, 31 | device=self.device, 32 | requires_grad=False 33 | ) 34 | for action_idx in range(self.env.policy_shape[0]): 35 | terminated = self.env.step(torch.full((self.env.parallel_envs, ), action_idx, dtype=torch.long, device=self.device, requires_grad=False)) 36 | rewards = self.env.get_rewards(starting_players) 37 | greedy_rewards = self.env.get_greedy_rewards(starting_players) 38 | is_legal = legal_actions[:, action_idx] 39 | action_scores[:, action_idx] = ((rewards * terminated) + (greedy_rewards * (~terminated))) * is_legal 40 | 41 | self.env.load_node(torch.full_like(starting_players, True, dtype=torch.bool, device=self.device, requires_grad=False), saved=saved) 42 | 43 | return action_scores, (action_scores * legal_actions).max(dim=1) 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /core/algorithms/baselines/greedy_mcts.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | import torch 8 | from core.algorithms.baselines.baseline import Baseline 9 | from core.algorithms.evaluator import Evaluator 10 | from core.algorithms.mcts import MCTS, MCTSConfig 11 | from core.env import Env 12 | 13 | @dataclass 14 | class GreedyMCTSConfig(MCTSConfig): 15 | heuristic: str 16 | 17 | 18 | class GreedyMCTS(MCTS, Baseline): 19 | def __init__(self, env: Env, config: GreedyMCTSConfig, *args, **kwargs) -> None: 20 | super().__init__(env, config, *args, **kwargs) 21 | self.metrics_key = 'greedymcts_' + config.heuristic + '_' + str(config.num_iters) 22 | self.proper_name = 'GreedyMCTS_' + config.heuristic + '_' + str(config.num_iters) 23 | self.uniform_probabilities = torch.ones((self.env.parallel_envs, self.env.policy_shape[0]), device=self.device, requires_grad=False) / self.env.policy_shape[0] 24 | self.config: GreedyMCTSConfig 25 | 26 | def evaluate(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 27 | evaluation_fn = lambda env: (self.uniform_probabilities, env.get_greedy_rewards()) 28 | return super().evaluate(evaluation_fn) 29 | 30 | 31 | -------------------------------------------------------------------------------- /core/algorithms/baselines/lazy_greedy_mcts.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | import torch 8 | from core.algorithms.baselines.baseline import Baseline 9 | from core.algorithms.lazy_mcts import LazyMCTS, LazyMCTSConfig 10 | from core.env import Env 11 | 12 | @dataclass 13 | class LazyGreedyMCTSConfig(LazyMCTSConfig): 14 | heuristic: str 15 | 16 | 17 | class LazyGreedyMCTS(LazyMCTS, Baseline): 18 | def __init__(self, env: Env, config: LazyGreedyMCTSConfig, *args, **kwargs) -> None: 19 | super().__init__(env, config, *args, **kwargs) 20 | self.metrics_key = 'lazygreedymcts_' + config.heuristic + '_' + str(config.num_policy_rollouts) 21 | self.proper_name = 'LazyGreedyMCTS_' + config.heuristic + '_' + str(config.num_policy_rollouts) 22 | self.uniform_probabilities = torch.ones((self.env.parallel_envs, self.env.policy_shape[0]), device=self.device, requires_grad=False) / self.env.policy_shape[0] 23 | self.config: LazyGreedyMCTSConfig 24 | 25 | def evaluate(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 26 | evaluation_fn = lambda env: (self.uniform_probabilities, env.get_greedy_rewards()) 27 | return super().evaluate(evaluation_fn) 28 | 29 | 30 | -------------------------------------------------------------------------------- /core/algorithms/baselines/lazy_rollout_mcts.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | from dataclasses import dataclass 6 | 7 | import torch 8 | from core.algorithms.baselines.baseline import Baseline 9 | from core.algorithms.lazy_mcts import LazyMCTS, LazyMCTSConfig 10 | from core.env import Env 11 | 12 | 13 | @dataclass 14 | class RandomRolloutLazyMCTSConfig(LazyMCTSConfig): 15 | rollouts_per_leaf: int 16 | 17 | class RandomRolloutLazyMCTS(LazyMCTS, Baseline): 18 | def __init__(self, env: Env, config: RandomRolloutLazyMCTSConfig, *args, **kwargs): 19 | super().__init__(env, config, *args, **kwargs) 20 | self.config: RandomRolloutLazyMCTSConfig 21 | self.uniform_probabilities = torch.ones((self.env.parallel_envs, self.env.policy_shape[0]), device=self.device, requires_grad=False) / self.env.policy_shape[0] 22 | 23 | def evaluate(self): 24 | evaluation_fn = lambda env: (self.uniform_probabilities, env.random_rollout(num_rollouts=self.config.rollouts_per_leaf)) 25 | return super().evaluate(evaluation_fn) -------------------------------------------------------------------------------- /core/algorithms/baselines/random.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | from typing import Optional, Tuple 7 | import torch 8 | from core.algorithms.baselines.baseline import Baseline 9 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig 10 | from core.env import Env 11 | from core.utils.history import Metric 12 | 13 | 14 | class RandomBaseline(Baseline): 15 | def __init__(self, 16 | env: Env, 17 | config: EvaluatorConfig, 18 | metrics_key: str = 'random', 19 | proper_name: str = 'Random', 20 | *args, **kwargs 21 | ) -> None: 22 | super().__init__(env, config, *args, **kwargs) 23 | self.metrics_key = metrics_key 24 | self.proper_name = proper_name 25 | self.sample = torch.zeros(self.env.parallel_envs, self.env.policy_shape[0], device=self.device, requires_grad=False, dtype=torch.float32) 26 | 27 | def evaluate(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 28 | self.sample.uniform_(0, 1) 29 | return self.sample, None -------------------------------------------------------------------------------- /core/algorithms/baselines/rollout.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | from dataclasses import dataclass 8 | import torch 9 | from core.algorithms.baselines.baseline import Baseline 10 | from core.algorithms.evaluator import EvaluatorConfig 11 | from core.env import Env 12 | 13 | @dataclass 14 | class RolloutConfig(EvaluatorConfig): 15 | num_rollouts: int 16 | 17 | class Rollout(Baseline): 18 | def __init__(self, env: Env, config: RolloutConfig, *args, **kwargs): 19 | super().__init__(env, config, *args, **kwargs) 20 | self.config: RolloutConfig 21 | 22 | def evaluate(self): 23 | saved = self.env.save_node() 24 | legal_actions = self.env.get_legal_actions().clone() 25 | starting_players = self.env.cur_players.clone() 26 | action_scores = torch.zeros( 27 | (self.env.parallel_envs, self.env.policy_shape[0]), 28 | dtype=torch.float32, 29 | device=self.device, 30 | requires_grad=False 31 | ) 32 | for action_idx in range(self.env.policy_shape[0]): 33 | rewards = self.env.random_rollout(num_rollouts=self.config.num_rollouts) 34 | is_legal = legal_actions[:, action_idx] 35 | action_scores[:, action_idx] = rewards * is_legal 36 | 37 | self.env.load_node(torch.full_like(starting_players, True, dtype=torch.bool, device=self.device, requires_grad=False), saved=saved) 38 | 39 | return action_scores, None -------------------------------------------------------------------------------- /core/algorithms/baselines/rollout_mcts.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from core.algorithms.baselines.baseline import Baseline 10 | from core.algorithms.evaluator import EvaluatorConfig 11 | from core.algorithms.mcts import MCTS, MCTSConfig 12 | from core.env import Env 13 | from functools import partial 14 | 15 | @dataclass 16 | class RandomRolloutMCTSConfig(MCTSConfig): 17 | rollouts_per_leaf: int 18 | 19 | 20 | 21 | class RandomRolloutMCTS(MCTS, Baseline): 22 | def __init__(self, env: Env, config: RandomRolloutMCTSConfig, *args, **kwargs): 23 | super().__init__(env, config, *args, **kwargs) 24 | self.config: RandomRolloutMCTSConfig 25 | self.uniform_probabilities = torch.ones((self.env.parallel_envs, self.env.policy_shape[0]), device=self.device, requires_grad=False) / self.env.policy_shape[0] 26 | 27 | 28 | def evaluate(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 29 | evaluation_fn = lambda env: (self.uniform_probabilities, env.random_rollout(num_rollouts=self.config.rollouts_per_leaf)) 30 | return super().evaluate(evaluation_fn) 31 | 32 | 33 | -------------------------------------------------------------------------------- /core/algorithms/evaluator.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | import torch 5 | from core.env import Env 6 | from core.utils.utils import rand_argmax_2d 7 | 8 | @dataclass 9 | class EvaluatorConfig: 10 | name: str 11 | 12 | class Evaluator: 13 | def __init__(self, env: Env, config: EvaluatorConfig, *args, **kwargs): 14 | self.device = env.device 15 | self.env = env 16 | self.env.reset() 17 | self.config = config 18 | self.epsilon = 1e-8 19 | self.args = args 20 | self.kwargs = kwargs 21 | 22 | def reset(self, seed=None) -> int: 23 | return self.env.reset(seed=seed) 24 | 25 | def evaluate(self, *args) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 26 | # returns probability distribution over actions, and optionally the value of the current state 27 | raise NotImplementedError() 28 | 29 | def step_evaluator(self, actions, terminated) -> None: 30 | pass 31 | 32 | def step_env(self, actions) -> torch.Tensor: 33 | terminated = self.env.step(actions) 34 | self.step_evaluator(actions, terminated) 35 | return terminated 36 | 37 | def choose_actions(self, probs: torch.Tensor) -> torch.Tensor: 38 | legal_actions = self.env.get_legal_actions() 39 | return rand_argmax_2d((probs + self.epsilon) * legal_actions).flatten() 40 | 41 | def step(self, *args) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: 42 | initial_states = self.env.get_nn_input().clone() 43 | probs, values = self.evaluate(*args) 44 | actions = self.choose_actions(probs) 45 | terminated = self.step_env(actions) 46 | return initial_states, probs, values, actions, terminated 47 | 48 | def reset_evaluator_states(self, evals_to_reset: torch.Tensor) -> None: 49 | pass 50 | 51 | class TrainableEvaluator(Evaluator): 52 | def __init__(self, env: Env, config: EvaluatorConfig, model: torch.nn.Module): 53 | super().__init__(env, config) 54 | self.env = env 55 | self.config = config 56 | self._model = model 57 | 58 | @property 59 | def model(self) -> torch.nn.Module: 60 | return self._model -------------------------------------------------------------------------------- /core/algorithms/lazy_mcts.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Callable, Optional, Tuple 4 | import torch 5 | 6 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig 7 | from core.utils.utils import rand_argmax_2d 8 | from ..env import Env 9 | 10 | 11 | @dataclass 12 | class LazyMCTSConfig(EvaluatorConfig): 13 | num_policy_rollouts: int # number of policy rollouts to run per evaluation call 14 | rollout_depth: int # depth of each policy rollout, once this depth is reached, return the network's evaluation (value head) of the state 15 | puct_coeff: float # C-value in PUCT formula 16 | 17 | 18 | class LazyMCTS(Evaluator): 19 | def __init__(self, env: Env, config: LazyMCTSConfig, *args, **kwargs) -> None: 20 | super().__init__(env, config, *args, **kwargs) 21 | 22 | self.action_scores = torch.zeros( 23 | (self.env.parallel_envs, *self.env.policy_shape), 24 | dtype=torch.float32, 25 | device=self.env.device, 26 | requires_grad=False 27 | ) 28 | 29 | self.visit_counts = torch.zeros_like( 30 | self.action_scores, 31 | dtype=torch.float32, 32 | device=self.env.device, 33 | requires_grad=False 34 | ) 35 | 36 | self.puct_coeff = config.puct_coeff 37 | self.policy_rollouts = config.num_policy_rollouts 38 | self.rollout_depth = config.rollout_depth 39 | 40 | self.all_nodes = torch.ones(env.parallel_envs, dtype=torch.bool, device=self.device) 41 | 42 | def reset(self, seed=None) -> None: 43 | self.env.reset(seed=seed) 44 | self.reset_puct() 45 | 46 | def reset_puct(self) -> None: 47 | self.action_scores.zero_() 48 | self.visit_counts.zero_() 49 | 50 | def evaluate(self, evaluation_fn: Callable) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 51 | self.reset_puct() 52 | 53 | return self.explore_for_iters(evaluation_fn, self.policy_rollouts, self.rollout_depth) 54 | 55 | def choose_action_with_puct(self, probs: torch.Tensor, legal_actions: torch.Tensor) -> torch.Tensor: 56 | n_sum = torch.sum(self.visit_counts, dim=1, keepdim=True) 57 | zero_counts = self.visit_counts == 0 58 | visit_counts_augmented = self.visit_counts + zero_counts 59 | q_values = self.action_scores / visit_counts_augmented 60 | 61 | puct_scores = q_values + \ 62 | (self.puct_coeff * probs * torch.sqrt(n_sum + 1) / (1 + self.visit_counts)) 63 | 64 | puct_scores = (puct_scores * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions)) 65 | return torch.argmax(puct_scores, dim=1) 66 | 67 | def iterate(self, evaluation_fn: Callable, depth: int, rewards: torch.Tensor) -> torch.Tensor: # type: ignore 68 | while depth > 0: 69 | with torch.no_grad(): 70 | policy_logits, values = evaluation_fn(self.env) 71 | depth -= 1 72 | if depth == 0: 73 | rewards = self.env.get_rewards() 74 | final_values = values.flatten() * torch.logical_not(self.env.terminated) 75 | final_values += rewards * self.env.terminated 76 | return final_values 77 | else: 78 | legal_actions = self.env.get_legal_actions() 79 | policy_logits = (policy_logits * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions)) 80 | distribution = torch.nn.functional.softmax( 81 | policy_logits, dim=1) 82 | next_actions = torch.multinomial(distribution + self.env.is_terminal().unsqueeze(1), 1, replacement=True).flatten() 83 | self.env.step(next_actions) 84 | 85 | def explore_for_iters(self, evaluation_fn: Callable, iters: int, search_depth: int) -> Tuple[torch.Tensor, torch.Tensor]: 86 | legal_actions = self.env.get_legal_actions() 87 | with torch.no_grad(): 88 | policy_logits, initial_values = evaluation_fn(self.env) 89 | policy_logits = (policy_logits * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions)) 90 | policy_logits = torch.nn.functional.softmax(policy_logits, dim=1) 91 | saved = self.env.save_node() 92 | 93 | for _ in range(iters): 94 | actions = self.choose_action_with_puct( 95 | policy_logits, legal_actions) 96 | self.env.step(actions) 97 | rewards = self.env.get_rewards() 98 | values = self.iterate(evaluation_fn, search_depth, rewards) 99 | if search_depth % self.env.num_players: 100 | values = 1 - values 101 | self.visit_counts[self.env.env_indices, actions] += 1 102 | self.action_scores[self.env.env_indices, actions] += values 103 | self.env.load_node(self.all_nodes, saved) 104 | 105 | return self.visit_counts / self.visit_counts.sum(dim=1, keepdim=True), initial_values 106 | -------------------------------------------------------------------------------- /core/algorithms/lazyzero.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | import torch 5 | from core.algorithms.evaluator import TrainableEvaluator 6 | 7 | 8 | from core.algorithms.lazy_mcts import LazyMCTS, LazyMCTSConfig 9 | from core.env import Env 10 | from core.utils.utils import rand_argmax_2d 11 | 12 | @dataclass 13 | class LazyZeroConfig(LazyMCTSConfig): 14 | temperature: float 15 | 16 | 17 | 18 | class LazyZero(LazyMCTS, TrainableEvaluator): 19 | def __init__(self, env: Env, config: LazyZeroConfig, model: torch.nn.Module, *args, **kwargs): 20 | super().__init__(env, config, model, *args, **kwargs) 21 | self.config: LazyZeroConfig 22 | 23 | 24 | # all additional alphazero implementation details live in MCTS, for now 25 | def choose_actions(self, visits: torch.Tensor) -> torch.Tensor: 26 | if self.config.temperature > 0: 27 | return torch.multinomial(torch.pow(visits, 1/self.config.temperature), 1, replacement=True).flatten() 28 | else: 29 | return rand_argmax_2d(visits).flatten() 30 | 31 | def evaluate(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 32 | evaluation_fn = lambda env: self.model(env.get_nn_input()) 33 | return super().evaluate(evaluation_fn) -------------------------------------------------------------------------------- /core/algorithms/load.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from envs.othello.evaluators.edax import Edax, EdaxConfig 4 | import torch 5 | from core.algorithms.alphazero import AlphaZero, AlphaZeroConfig 6 | from core.algorithms.baselines.greedy import GreedyBaseline, GreedyConfig 7 | from core.algorithms.baselines.greedy_mcts import GreedyMCTS, GreedyMCTSConfig 8 | from core.algorithms.baselines.lazy_greedy_mcts import LazyGreedyMCTS, LazyGreedyMCTSConfig 9 | from core.algorithms.baselines.lazy_rollout_mcts import RandomRolloutLazyMCTS, RandomRolloutLazyMCTSConfig 10 | from core.algorithms.baselines.random import RandomBaseline 11 | from core.algorithms.baselines.rollout import Rollout, RolloutConfig 12 | from core.algorithms.baselines.rollout_mcts import RandomRolloutMCTS, RandomRolloutMCTSConfig 13 | from core.algorithms.evaluator import EvaluatorConfig 14 | from core.algorithms.lazyzero import LazyZero, LazyZeroConfig 15 | from core.demo.human import HumanEvaluator 16 | from core.env import Env 17 | 18 | 19 | def init_evaluator(algo_config: dict, env: Env, *args, **kwargs): 20 | algo_type = algo_config['name'] 21 | if algo_type == 'lazyzero': 22 | config = LazyZeroConfig(**algo_config) 23 | return LazyZero(env, config, *args, **kwargs) 24 | elif algo_type == 'alphazero': 25 | config = AlphaZeroConfig(**algo_config) 26 | return AlphaZero(env, config, *args, **kwargs) 27 | elif algo_type == 'random': 28 | config = EvaluatorConfig(**algo_config) 29 | return RandomBaseline(env, config, *args, **kwargs) 30 | elif algo_type == 'human': 31 | config = EvaluatorConfig(**algo_config) 32 | return HumanEvaluator(env, config, *args, **kwargs) 33 | elif algo_type == 'greedy_mcts': 34 | config = GreedyMCTSConfig(**algo_config) 35 | return GreedyMCTS(env, config, *args, **kwargs) 36 | elif algo_type == 'edax': 37 | config = EdaxConfig(**algo_config) 38 | return Edax(env, config, *args, **kwargs) 39 | elif algo_type == 'greedy': 40 | config = GreedyConfig(**algo_config) 41 | return GreedyBaseline(env, config, *args, **kwargs) 42 | elif algo_type == 'random_rollout_mcts': 43 | config = RandomRolloutMCTSConfig(**algo_config) 44 | return RandomRolloutMCTS(env, config, *args, **kwargs) 45 | elif algo_type == 'random_rollout_lazy_mcts': 46 | config = RandomRolloutLazyMCTSConfig(**algo_config) 47 | return RandomRolloutLazyMCTS(env, config, *args, **kwargs) 48 | elif algo_type == 'rollout': 49 | config = RolloutConfig(**algo_config) 50 | return Rollout(env, config, *args, **kwargs) 51 | elif algo_type == 'lazy_greedy_mcts': 52 | config = LazyGreedyMCTSConfig(**algo_config) 53 | return LazyGreedyMCTS(env, config, *args, **kwargs) 54 | else: 55 | raise NotImplementedError(f'Unknown evaluator type: {algo_type}') 56 | 57 | -------------------------------------------------------------------------------- /core/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/demo/__init__.py -------------------------------------------------------------------------------- /core/demo/demo.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import torch 4 | from core.algorithms.evaluator import Evaluator 5 | from IPython.display import clear_output 6 | import os 7 | 8 | class Demo: 9 | def __init__(self, 10 | evaluator: Evaluator, 11 | manual_step: bool = False 12 | ): 13 | self.manual_step = manual_step 14 | self.evaluator = evaluator 15 | assert self.evaluator.env.parallel_envs == 1 16 | 17 | def run(self, print_evaluation: bool = False, print_state: bool = True, interactive: bool =True): 18 | self.evaluator.reset() 19 | actions = None 20 | while True: 21 | if print_state: 22 | self.evaluator.env.print_state(actions.item() if actions is not None else None) 23 | if self.manual_step: 24 | input('Press any key to continue...') 25 | _, _, value, actions, terminated = self.evaluator.step() 26 | if interactive: 27 | clear_output(wait=True) 28 | else: 29 | os.system('clear') 30 | if print_evaluation and value is not None: 31 | print(f'Evaluation: {value[0]}') 32 | if terminated: 33 | print('Game over!') 34 | print('Final state:') 35 | print(self.evaluator.env) 36 | break 37 | 38 | class TwoPlayerDemo(Demo): 39 | def __init__(self, 40 | evaluator, 41 | evaluator2, 42 | manual_step: bool = False 43 | ) -> None: 44 | super().__init__(evaluator, manual_step) 45 | self.evaluator2 = evaluator2 46 | assert self.evaluator2.env.parallel_envs == 1 47 | 48 | def run(self, print_evaluation: bool = False, print_state: bool = True, interactive: bool =True): 49 | seed = random.randint(0, 2**32 - 1) 50 | 51 | self.evaluator.reset(seed) 52 | self.evaluator2.reset(seed) 53 | p1_turn = random.choice([True, False]) 54 | cur_player = self.evaluator.env.cur_players.item() 55 | p1_player_id = cur_player if p1_turn else 1 - cur_player 56 | p1_evaluation = 0.5 57 | p2_evaluation = 0.5 58 | actions = None 59 | while True: 60 | 61 | 62 | active_evaluator = self.evaluator if p1_turn else self.evaluator2 63 | other_evaluator = self.evaluator2 if p1_turn else self.evaluator 64 | if print_state: 65 | print(f'Player 1 (O): {self.evaluator.__class__.__name__ if p1_player_id == 0 else self.evaluator2.__class__.__name__}') 66 | print(f'Player 2 (X): {self.evaluator.__class__.__name__ if p1_player_id == 1 else self.evaluator2.__class__.__name__}') 67 | active_evaluator.env.print_state(int(actions.item()) if actions is not None else None) 68 | if self.manual_step: 69 | input('Press any key to continue...') 70 | _, _, value, actions, terminated = active_evaluator.step() 71 | if p1_turn: 72 | p1_evaluation = value[0] if value is not None else None 73 | else: 74 | p2_evaluation = value[0] if value is not None else None 75 | other_evaluator.step_evaluator(actions, terminated) 76 | if interactive: 77 | clear_output(wait=False) 78 | else: 79 | os.system('clear') 80 | if print_evaluation: 81 | if p1_evaluation is not None: 82 | print(f'{self.evaluator.__class__.__name__} Evaluation: {p1_evaluation}') 83 | if p2_evaluation is not None: 84 | print(f'{self.evaluator2.__class__.__name__} Evaluation: {p2_evaluation}') 85 | if terminated: 86 | print('Game over!') 87 | print('Final state:') 88 | active_evaluator.env.print_state(int(actions.item())) 89 | self.print_rewards(p1_player_id) 90 | break 91 | 92 | p1_turn = not p1_turn 93 | 94 | def print_rewards(self, p1_player_id): 95 | reward = self.evaluator.env.get_rewards(torch.tensor([p1_player_id]))[0] 96 | if reward == 1: 97 | print(f'Player 1 ({self.evaluator.__class__.__name__}) won!') 98 | elif reward == 0: 99 | print(f'Player 2 ({self.evaluator2.__class__.__name__}) won!') 100 | else: 101 | print('Draw!') -------------------------------------------------------------------------------- /core/demo/human.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | from core.algorithms.evaluator import Evaluator 6 | 7 | 8 | class HumanEvaluator(Evaluator): 9 | def __init__(self, env, config): 10 | super().__init__(env, config) 11 | assert self.env.parallel_envs == 1, 'HumanEvaluator only supports parallel_envs=1' 12 | 13 | def evaluate(self): 14 | legal_action_ids = [] 15 | for index, i in enumerate(self.env.get_legal_actions()[0]): 16 | if i: 17 | legal_action_ids.append(index) 18 | 19 | print('Legal actions:', legal_action_ids) 20 | 21 | while True: 22 | action = input('Enter action: ') 23 | try: 24 | action = int(action) 25 | if action in legal_action_ids: 26 | break 27 | else: 28 | print('Action not legal, choose a legal action') 29 | except: 30 | print('Invalid input') 31 | 32 | return torch.nn.functional.one_hot(torch.tensor([action]), num_classes=self.env.policy_shape[0]).float(), None 33 | 34 | -------------------------------------------------------------------------------- /core/demo/load.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from core.algorithms.load import init_evaluator 5 | 6 | from core.demo.demo import Demo 7 | from core.utils.checkpoint import load_checkpoint, load_model_and_optimizer_from_checkpoint 8 | from envs.connect_x.demo import ConnectXDemo 9 | from envs.load import init_env 10 | from envs.othello.demo import OthelloDemo 11 | 12 | 13 | def init_demo(env_config: dict, demo_config: dict, device: torch.device, *args, **kwargs) -> Demo: 14 | env = init_env(device, 1, env_config, debug=False) 15 | evaluator1_config = demo_config['evaluator1_config'] 16 | if evaluator1_config.get('checkpoint'): 17 | model, _ = load_model_and_optimizer_from_checkpoint(load_checkpoint(evaluator1_config['checkpoint']), env, device) 18 | evaluator1 = init_evaluator(evaluator1_config['algo_config'], env, model, *args, **kwargs) 19 | else: 20 | evaluator1 = init_evaluator(evaluator1_config['algo_config'], env, *args, **kwargs) 21 | if env_config['env_type'] == 'othello': 22 | evaluator2_config = demo_config['evaluator2_config'] 23 | if evaluator2_config.get('checkpoint'): 24 | model, _ = load_model_and_optimizer_from_checkpoint(load_checkpoint(evaluator2_config['checkpoint']), env, device) 25 | evaluator2 = init_evaluator(evaluator2_config['algo_config'], env, model, *args, **kwargs) 26 | else: 27 | evaluator2 = init_evaluator(evaluator2_config['algo_config'], env, *args, **kwargs) 28 | return OthelloDemo(evaluator1, evaluator2, demo_config['manual_step']) 29 | elif env_config['env_type'] == 'connect_x': 30 | evaluator2_config = demo_config['evaluator2_config'] 31 | if evaluator2_config.get('checkpoint'): 32 | model, _ = load_model_and_optimizer_from_checkpoint(load_checkpoint(evaluator2_config['checkpoint']), env, device) 33 | evaluator2 = init_evaluator(evaluator2_config['algo_config'], env, model, *args, **kwargs) 34 | else: 35 | evaluator2 = init_evaluator(evaluator2_config['algo_config'], env, *args, **kwargs) 36 | return ConnectXDemo(evaluator1, evaluator2, demo_config['manual_step']) 37 | elif env_config['env_type'] == '2048': 38 | return Demo(evaluator1, demo_config['manual_step']) 39 | else: 40 | return Demo(evaluator1, demo_config['manual_step']) -------------------------------------------------------------------------------- /core/env.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | import torch 4 | 5 | 6 | @dataclass 7 | class EnvConfig: 8 | env_type: str 9 | 10 | 11 | class Env: 12 | def __init__(self, 13 | parallel_envs: int, 14 | config: EnvConfig, 15 | device: torch.device, 16 | num_players: int, 17 | state_shape: torch.Size, 18 | policy_shape: torch.Size, 19 | value_shape: torch.Size, 20 | debug: bool = False 21 | ): 22 | self.config = config 23 | self.parallel_envs = parallel_envs 24 | self.state_shape = state_shape 25 | self.policy_shape = policy_shape 26 | self.value_shape = value_shape 27 | self.num_players = num_players 28 | self.debug = debug 29 | 30 | 31 | self.states = torch.zeros((self.parallel_envs, *state_shape), dtype=torch.float32, device=device, requires_grad=False) 32 | self.terminated = torch.zeros(self.parallel_envs, dtype=torch.bool, device=device, requires_grad=False) 33 | self.cur_players = torch.zeros((self.parallel_envs, ), dtype=torch.long, device=device, requires_grad=False) 34 | self.env_indices = torch.arange(self.parallel_envs, device=device, requires_grad=False) 35 | 36 | self.device = device 37 | 38 | def __str__(self): 39 | return str(self.states) 40 | 41 | def reset(self, seed: Optional[int] = None): 42 | raise NotImplementedError() 43 | 44 | def step(self, actions) -> torch.Tensor: 45 | self.push_actions(actions) 46 | self.next_turn() 47 | self.update_terminated() 48 | return self.terminated 49 | 50 | def update_terminated(self): 51 | self.terminated = self.is_terminal() 52 | 53 | def is_terminal(self): 54 | raise NotImplementedError() 55 | 56 | def push_actions(self, actions): 57 | raise NotImplementedError() 58 | 59 | def get_nn_input(self): 60 | return self.states 61 | 62 | def get_legal_actions(self): 63 | return torch.ones(self.parallel_envs, *self.policy_shape, dtype=torch.bool, device=self.device, requires_grad=False) 64 | 65 | def apply_stochastic_progressions(self, mask=None) -> None: 66 | progs, probs = self.get_stochastic_progressions() 67 | if mask is None: 68 | indices = torch.multinomial(probs, 1, replacement=True).flatten() 69 | else: 70 | indices = torch.multinomial(probs + (~(mask.unsqueeze(1))), 1, replacement=True).flatten() 71 | 72 | new_states = progs[(self.env_indices, indices)].unsqueeze(1) 73 | 74 | if mask is not None: 75 | mask = mask.view(self.parallel_envs, 1, 1, 1) 76 | self.states = self.states * (~mask) + new_states * mask 77 | else: 78 | self.states = new_states 79 | 80 | def get_stochastic_progressions(self) -> Tuple[torch.Tensor, torch.Tensor]: 81 | raise NotImplementedError() 82 | 83 | def reset_terminated_states(self, seed: Optional[int] = None): 84 | raise NotImplementedError() 85 | 86 | def next_player(self): 87 | self.cur_players = (self.cur_players + 1) % self.num_players 88 | 89 | def get_rewards(self, player_ids: Optional[torch.Tensor] = None): 90 | raise NotImplementedError() 91 | 92 | def next_turn(self): 93 | raise NotImplementedError() 94 | 95 | def save_node(self): 96 | raise NotImplementedError() 97 | 98 | def load_node(self, load_envs, saved): 99 | raise NotImplementedError() 100 | 101 | def get_greedy_rewards(self, player_ids: Optional[torch.Tensor] = None, heuristic: Optional[str] = None): 102 | # returns instantaneous reward, used in greedy algorithms 103 | raise NotImplementedError() 104 | 105 | def choose_random_legal_action(self) -> torch.Tensor: 106 | legal_actions = self.get_legal_actions() 107 | return torch.multinomial(legal_actions.float(), 1, replacement=True).flatten() 108 | 109 | def random_rollout(self, num_rollouts: int) -> torch.Tensor: 110 | saved = self.save_node() 111 | cumulative_rewards = torch.zeros(self.parallel_envs, dtype=torch.float32, device=self.device, requires_grad=False) 112 | for _ in range(num_rollouts): 113 | completed = torch.zeros(self.parallel_envs, dtype=torch.bool, device=self.device, requires_grad=False) 114 | starting_players = self.cur_players.clone() 115 | while not completed.all(): 116 | actions = self.choose_random_legal_action() 117 | terminated = self.step(actions) 118 | rewards = self.get_rewards(starting_players) 119 | rewards = ((self.cur_players == starting_players) * rewards) + ((self.cur_players != starting_players) * 1-rewards) 120 | cumulative_rewards += rewards * terminated * (~completed) 121 | completed = completed | terminated 122 | self.load_node(torch.full((self.parallel_envs,), True, dtype=torch.bool, device=self.device), saved) 123 | cumulative_rewards /= num_rollouts 124 | return cumulative_rewards 125 | 126 | def print_state(self, last_action: Optional[int] = None) -> None: 127 | pass -------------------------------------------------------------------------------- /core/resnet.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | from dataclasses import dataclass 7 | from typing import Callable 8 | 9 | from core.utils.custom_activations import load_activation 10 | 11 | 12 | @dataclass 13 | class ResNetConfig: 14 | res_channels: int 15 | res_blocks: int 16 | kernel_size: int 17 | value_fc_size: int = 32 18 | value_output_activation: str = '' 19 | 20 | def reset_model_weights(m): 21 | reset_parameters = getattr(m, "reset_parameters", None) 22 | if callable(reset_parameters): 23 | m.reset_parameters() 24 | 25 | class ResidualBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size, stride = 1): 27 | super(ResidualBlock, self).__init__() 28 | self.conv1 = nn.Sequential( 29 | nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 'same', bias=False), 30 | nn.BatchNorm2d(out_channels), 31 | nn.ReLU()) 32 | self.conv2 = nn.Sequential( 33 | nn.Conv2d(out_channels, out_channels, kernel_size = kernel_size, stride = 1, padding = 'same', bias=False), 34 | nn.BatchNorm2d(out_channels)) 35 | self.relu = nn.ReLU() 36 | self.out_channels = out_channels 37 | 38 | def forward(self, x): 39 | residual = x 40 | out = self.conv1(x) 41 | out = self.conv2(out) 42 | out += residual 43 | out = self.relu(out) 44 | return out 45 | 46 | def fuse(self): 47 | torch.quantization.fuse_modules(self.conv1, ['0', '1', '2'], inplace=True) 48 | torch.quantization.fuse_modules(self.conv2, ['0', '1'], inplace=True) 49 | 50 | 51 | class TurboZeroResnet(nn.Module): 52 | def __init__(self, config: ResNetConfig, input_shape: torch.Size, output_shape: torch.Size) -> None: 53 | super().__init__() 54 | assert len(input_shape) == 3 # (channels, height, width) 55 | self.value_head_activation: Optional[torch.nn.Module] = load_activation(config.value_output_activation) 56 | self.input_channels, self.input_height, self.input_width = input_shape 57 | 58 | self.input_block = nn.Sequential( 59 | nn.Conv2d(self.input_channels, config.res_channels, kernel_size = config.kernel_size, stride = 1, padding = 'same', bias=False), 60 | nn.BatchNorm2d(config.res_channels), 61 | nn.ReLU() 62 | ) 63 | 64 | self.res_blocks = nn.Sequential( 65 | *[ResidualBlock(config.res_channels, config.res_channels, config.kernel_size) \ 66 | for _ in range(config.res_blocks)] 67 | ) 68 | 69 | self.policy_head = nn.Sequential( 70 | nn.Conv2d(config.res_channels, 2, kernel_size = 1, stride = 1, padding = 0, bias=False), 71 | nn.BatchNorm2d(2), 72 | nn.ReLU(), 73 | nn.Flatten(start_dim=1), 74 | nn.Linear(2 * self.input_height * self.input_width, output_shape[0]) 75 | # we use cross entropy loss so no need for softmax 76 | ) 77 | 78 | self.value_head = nn.Sequential( 79 | nn.Conv2d(config.res_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = False), 80 | nn.BatchNorm2d(1), 81 | nn.ReLU(), 82 | nn.Flatten(start_dim=1), 83 | nn.Linear(self.input_height * self.input_width, config.value_fc_size), 84 | nn.ReLU(), 85 | nn.Linear(config.value_fc_size, 1) 86 | # value head activation handled in forward 87 | ) 88 | 89 | self.config = config 90 | 91 | def forward(self, x): 92 | x = self.input_block(x) 93 | x = self.res_blocks(x) 94 | policy = self.policy_head(x) 95 | value = self.value_head(x) 96 | return policy, self.value_head_activation(value) if self.value_head_activation is not None else value 97 | 98 | def fuse(self): 99 | torch.quantization.fuse_modules(self.input_block, ['0', '1', '2'], inplace=True) 100 | for b in self.res_blocks: 101 | if isinstance(b, ResidualBlock): 102 | b.fuse() 103 | for b in self.policy_head: 104 | if isinstance(b, ResidualBlock): 105 | b.fuse() 106 | for b in self.value_head: 107 | if isinstance(b, ResidualBlock): 108 | b.fuse() 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /core/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/test/__init__.py -------------------------------------------------------------------------------- /core/test/tester.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from copy import deepcopy 4 | import random 5 | from tqdm import tqdm 6 | import torch 7 | from typing import List, Optional 8 | from core.algorithms.baselines.baseline import Baseline 9 | from core.algorithms.baselines.best import BestModelBaseline 10 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig 11 | from core.algorithms.load import init_evaluator 12 | from core.train.collector import Collector 13 | from core.utils.history import TrainingMetrics 14 | import logging 15 | 16 | from dataclasses import dataclass 17 | 18 | @dataclass 19 | class TesterConfig: 20 | algo_config: EvaluatorConfig 21 | episodes_per_epoch: int 22 | 23 | class Tester: 24 | def __init__(self, 25 | collector: Collector, 26 | config: TesterConfig, 27 | model: torch.nn.Module, 28 | history: TrainingMetrics, 29 | optimizer: Optional[torch.optim.Optimizer] = None, 30 | log_results: bool = True, 31 | debug: bool = False 32 | ): 33 | self.config = config 34 | self.model = model 35 | self.optimizer = optimizer 36 | self.collector = collector 37 | self.device = self.collector.evaluator.env.device 38 | self.history = history 39 | self.log_results = log_results 40 | self.debug = debug 41 | 42 | def add_evaluation_metrics(self, episodes): 43 | raise NotImplementedError() 44 | 45 | def collect_test_batch(self): 46 | self.collector.reset() 47 | # we make a big assumption here that all evaluation episodes can be run in parallel 48 | # if the user wants to evaluate an obscene number of episodes this will become problematic 49 | # TODO: batch evaluation episodes where necessary 50 | completed_episodes = torch.zeros(self.config.episodes_per_epoch, dtype=torch.bool, device=self.collector.evaluator.env.device) 51 | while not completed_episodes.all(): 52 | episodes, termianted = self.collector.collect(inactive_mask=completed_episodes) 53 | self.add_evaluation_metrics(episodes) 54 | completed_episodes |= termianted 55 | 56 | def generate_plots(self): 57 | self.history.generate_plots(categories=['eval']) 58 | 59 | 60 | @dataclass 61 | class TwoPlayerTesterConfig(TesterConfig): 62 | baselines: List[dict] 63 | improvement_threshold_pct: float = 0.0 64 | 65 | 66 | class TwoPlayerTester(Tester): 67 | def __init__(self, 68 | collector: Collector, 69 | config: TwoPlayerTesterConfig, 70 | model: torch.nn.Module, 71 | history: TrainingMetrics, 72 | optimizer: Optional[torch.optim.Optimizer] = None, 73 | log_results: bool = True, 74 | debug: bool = False 75 | ): 76 | super().__init__(collector, config, model, history, optimizer, log_results, debug) 77 | self.config: TwoPlayerTesterConfig 78 | self.baselines = [] 79 | for baseline_config in config.baselines: 80 | baseline: Baseline = init_evaluator(baseline_config, collector.evaluator.env, evaluator=collector.evaluator, best_model=model, best_model_optimizer=optimizer) 81 | self.baselines.append(baseline) 82 | baseline.add_metrics(self.history) 83 | 84 | def collect_test_batch(self): 85 | for baseline in self.baselines: 86 | scores = collect_games(self.collector.evaluator, baseline, self.config.episodes_per_epoch, self.device, debug=self.debug) 87 | wins = (scores == 1).sum().cpu().clone() 88 | draws = (scores == 0.5).sum().cpu().clone() 89 | losses = (scores == 0).sum().cpu().clone() 90 | win_pct = wins / self.config.episodes_per_epoch 91 | if isinstance(baseline, BestModelBaseline): 92 | if win_pct > self.config.improvement_threshold_pct: 93 | baseline.best_model = deepcopy(self.model) 94 | baseline.best_model_optimizer = deepcopy(self.optimizer.state_dict()) if self.optimizer is not None else None 95 | logging.info('************ NEW BEST MODEL ************') 96 | if self.history: 97 | baseline.add_metrics_data(win_pct, self.history, log=self.log_results) 98 | logging.info(f'Epoch {self.history.cur_epoch} Current vs. {baseline.proper_name}:') 99 | logging.info(f'W/L/D: {wins}/{losses}/{draws}') 100 | 101 | 102 | 103 | def collect_games(evaluator1: Evaluator, evaluator2: Evaluator, num_games: int, device: torch.device, debug: bool) -> torch.Tensor: 104 | if not debug: 105 | progress_bar = tqdm(total=num_games, desc='Collecting games...', leave=True, position=0) 106 | seed = random.randint(0, 2**32 - 1) 107 | evaluator1.reset(seed) 108 | evaluator2.reset(seed) 109 | split = num_games // 2 110 | reset = torch.zeros(num_games, dtype=torch.bool, device=device, requires_grad=False) 111 | reset[:split] = True 112 | 113 | completed_episodes = torch.zeros(num_games, dtype=torch.bool, device=device, requires_grad=False) 114 | scores = torch.zeros(num_games, dtype=torch.float32, device=device, requires_grad=False) 115 | 116 | _, _, _, actions, terminated = evaluator1.step() 117 | 118 | envs_to_reset = terminated | reset 119 | 120 | evaluator1.env.terminated[:split] = True 121 | evaluator1.env.reset_terminated_states(seed) 122 | evaluator1.reset_evaluator_states(envs_to_reset) 123 | evaluator2.step_evaluator(actions, envs_to_reset) 124 | 125 | starting_players = (evaluator1.env.cur_players.clone() - 1) % 2 126 | use_second_evaluator = True 127 | while not completed_episodes.all(): 128 | if use_second_evaluator: 129 | _, _, _, actions, terminated = evaluator2.step() 130 | evaluator1.step_evaluator(actions, terminated) 131 | else: 132 | _, _, _, actions, terminated = evaluator1.step() 133 | evaluator2.step_evaluator(actions, terminated) 134 | rewards = evaluator1.env.get_rewards(starting_players) 135 | scores += rewards * terminated * (~completed_episodes) 136 | new_completed = (terminated & (~completed_episodes)).long().sum().item() 137 | completed_episodes |= terminated 138 | evaluator1.env.reset_terminated_states(seed) 139 | use_second_evaluator = not use_second_evaluator 140 | if not debug: 141 | progress_bar.update(new_completed) 142 | 143 | return scores 144 | -------------------------------------------------------------------------------- /core/test/tournament/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/test/tournament/__init__.py -------------------------------------------------------------------------------- /core/test/tournament/tournament.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from collections import defaultdict 5 | from dataclasses import dataclass 6 | import logging 7 | from random import shuffle 8 | import random 9 | from typing import Dict, List, Optional, Tuple 10 | 11 | import torch 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from core.algorithms.evaluator import Evaluator 15 | from itertools import combinations 16 | from core.algorithms.load import init_evaluator 17 | from core.env import Env 18 | from core.utils.checkpoint import load_checkpoint, load_model_and_optimizer_from_checkpoint 19 | from core.utils.heatmap import annotate_heatmap, heatmap 20 | from core.test.tester import collect_games 21 | 22 | from envs.load import init_env 23 | 24 | class TournamentPlayer: 25 | def __init__(self, name: str, evaluator: Evaluator, raw_config: dict, initial_rating=1500) -> None: 26 | self.rating: float = initial_rating 27 | self.initial_rating = initial_rating 28 | self.name = name 29 | self.evaluator = evaluator 30 | self.raw_config = raw_config 31 | 32 | def expected_result(self, opponent_rating: float): 33 | return 1 / (1 + 10 ** ((opponent_rating - self.rating) / 400)) 34 | 35 | def update_rating(self, opponent_rating: float, result: float): 36 | self.rating += 16 * (result - self.expected_result(opponent_rating)) 37 | 38 | def reset_rating(self): 39 | self.rating = self.initial_rating 40 | 41 | 42 | @dataclass 43 | class GameResult: 44 | player1_name: str 45 | player2_name: str 46 | player1_result: float 47 | player2_result: float 48 | 49 | 50 | class Tournament: 51 | def __init__(self, env: Env, n_games: int, n_tournaments: int, device: torch.device, name: str, debug: bool = False): 52 | self.competitors = [] 53 | self.competitors_dict = dict() 54 | self.env = env 55 | self.n_games = n_games 56 | self.n_tournaments = n_tournaments 57 | self.results: List[GameResult] = [] 58 | self.device = device 59 | self.name = name 60 | self.debug = debug 61 | 62 | def init_competitor(self, config: dict) -> TournamentPlayer: 63 | if config.get('checkpoint'): 64 | model, _ = load_model_and_optimizer_from_checkpoint(load_checkpoint(config['checkpoint']), self.env, self.device) 65 | evaluator = init_evaluator(config['algo_config'], self.env, model) 66 | else: 67 | evaluator = init_evaluator(config['algo_config'], self.env) 68 | return TournamentPlayer(config['name'], evaluator, config) 69 | 70 | def collect_games(self, new_competitor_config: dict): 71 | new_competitor = self.init_competitor(new_competitor_config) 72 | if new_competitor.name not in self.competitors_dict: 73 | for competitor in self.competitors: 74 | competitor.evaluator.env = self.env 75 | new_competitor.evaluator.env = self.env 76 | p1_scores = collect_games(competitor.evaluator, new_competitor.evaluator, self.n_games, self.device, self.debug) 77 | new_results = [] 78 | for p1_score in p1_scores: 79 | new_results.append(GameResult( 80 | player1_name=competitor.name, 81 | player2_name=new_competitor.name, 82 | player1_result=p1_score, 83 | player2_result=1 - p1_score 84 | )) 85 | logging.info(f'{competitor.name}: {sum([r.player1_result for r in new_results])}, {new_competitor.name}: {sum([r.player2_result for r in new_results])}') 86 | self.results.extend(new_results) 87 | self.competitors.append(new_competitor) 88 | self.competitors_dict[new_competitor.name] = new_competitor 89 | else: 90 | logging.warn(f'Already have data for competitor {new_competitor.name}, skipping...') 91 | 92 | def remove_competitor(self, name: str): 93 | self.competitors = [competitor for competitor in self.competitors if competitor.name != name] 94 | self.competitors_dict.pop(name) 95 | self.results = [result for result in self.results if result.player1_name != name and result.player2_name != name] 96 | 97 | def save(self, path: Optional[str] = '') -> None: 98 | if not path: 99 | path = f'{self.name}.pt' 100 | data = dict() 101 | data['competitor_configs'] = dict() 102 | for competitor in self.competitors: 103 | data['competitor_configs'][competitor.name] = competitor.raw_config 104 | data['env_config'] = self.env.config.__dict__ 105 | data['n_games'] = self.n_games 106 | data['n_tournaments'] = self.n_tournaments 107 | data['results'] = self.results 108 | torch.save(data, path) 109 | 110 | def simulate_elo(self, interactive: bool = True) -> Dict[str, int]: 111 | player_ratings = defaultdict(lambda: []) 112 | matchups: Dict[Tuple[str, str], float] = defaultdict(lambda: 0) 113 | for _ in range(self.n_tournaments): 114 | shuffle(self.results) 115 | for result in self.results: 116 | self.competitors_dict[result.player1_name].update_rating(self.competitors_dict[result.player2_name].rating, result.player1_result) 117 | self.competitors_dict[result.player2_name].update_rating(self.competitors_dict[result.player1_name].rating, result.player2_result) 118 | for competitor in self.competitors: 119 | player_ratings[competitor.name].append(competitor.rating) 120 | competitor.reset_rating() 121 | for result in self.results: 122 | matchups[(result.player1_name, result.player2_name)] += result.player1_result 123 | matchups[(result.player2_name, result.player1_name)] += result.player2_result 124 | for key, value in matchups.items(): 125 | matchups[key] = (value / self.n_games) * 100 126 | 127 | matchup_matrix = np.zeros((len(self.competitors), len(self.competitors))) 128 | 129 | final_ratings = {name: int(sum(ratings) / len(ratings)) for name, ratings in player_ratings.items()} 130 | 131 | sorted_competitors = sorted(self.competitors, key=lambda c: final_ratings[c.name]) 132 | for p1_idx in range(len(sorted_competitors)): 133 | for p2_idx in range(p1_idx+1, len(sorted_competitors)): 134 | p1_name = sorted_competitors[p1_idx].name 135 | p2_name = sorted_competitors[p2_idx].name 136 | matchup_matrix[p1_idx, p2_idx] = matchups[(p1_name, p2_name)] 137 | matchup_matrix[p2_idx, p1_idx] = matchups[(p2_name, p1_name)] 138 | 139 | 140 | logging.info(f'Final ratings: {final_ratings}') 141 | player_names = [c.name for c in sorted_competitors] 142 | player_names_elo = [f'{c.name} ({final_ratings[c.name]})' for c in sorted_competitors] 143 | 144 | 145 | if interactive: 146 | height = len(player_names) / 1.5 147 | width = len(player_names) * 1.5 148 | fig, ax = plt.subplots(figsize=(width, height), dpi=500) 149 | im, cbar = heatmap(matchup_matrix, player_names_elo, player_names, ax=ax, 150 | cmap="YlGn", cbarlabel="Head-to-Head Win Rate (%)") 151 | texts = annotate_heatmap(im, valfmt="{x:.1f}%") 152 | fig.tight_layout() 153 | plt.show() 154 | 155 | return final_ratings 156 | 157 | def run(self, competitors: List[dict], interactive: bool = True): 158 | for competitor in competitors: 159 | self.collect_games(competitor) 160 | self.save() 161 | return self.simulate_elo(interactive) 162 | 163 | def load_tournament(path: str, device: torch.device): 164 | tournament_data = torch.load(path) 165 | tournament = Tournament( 166 | init_env(device, tournament_data['n_games'], tournament_data['env_config'], False), 167 | tournament_data['n_games'], 168 | tournament_data['n_tournaments'], 169 | device, 170 | tournament_data.get('name', 'tournament') 171 | ) 172 | tournament.results = tournament_data['results'] 173 | for competitor_config in tournament_data['competitor_configs'].values(): 174 | tournament.competitors.append(tournament.init_competitor(competitor_config)) 175 | tournament.competitors_dict = {competitor.name: competitor for competitor in tournament.competitors} 176 | 177 | return tournament 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /core/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/train/__init__.py -------------------------------------------------------------------------------- /core/train/collector.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | from typing import Optional 7 | import torch 8 | from core.utils.memory import EpisodeMemory 9 | from core.algorithms.evaluator import Evaluator, TrainableEvaluator 10 | 11 | 12 | class Collector: 13 | def __init__(self, 14 | evaluator: TrainableEvaluator, 15 | episode_memory_device: torch.device 16 | ) -> None: 17 | parallel_envs = evaluator.env.parallel_envs 18 | self.evaluator = evaluator 19 | self.evaluator.reset() 20 | self.episode_memory = EpisodeMemory(parallel_envs, episode_memory_device) 21 | 22 | 23 | def collect(self, inactive_mask: Optional[torch.Tensor] = None): 24 | _, terminated = self.collect_step() 25 | terminated = terminated.clone() 26 | 27 | if inactive_mask is not None: 28 | terminated *= ~inactive_mask 29 | terminated_episodes = self.episode_memory.pop_terminated_episodes(terminated) 30 | 31 | terminated_episodes = self.assign_rewards(terminated_episodes, terminated) 32 | 33 | self.evaluator.env.reset_terminated_states() 34 | 35 | return terminated_episodes, terminated 36 | 37 | def collect_step(self): 38 | self.evaluator.model.eval() 39 | legal_actions = self.evaluator.env.get_legal_actions().clone() 40 | initial_states, probs, _, actions, terminated = self.evaluator.step() 41 | self.episode_memory.insert(initial_states, probs, legal_actions) 42 | return actions, terminated 43 | 44 | def assign_rewards(self, terminated_episodes, terminated): 45 | raise NotImplementedError() 46 | 47 | def postprocess(self, terminated_episodes): 48 | return terminated_episodes 49 | 50 | def reset(self): 51 | self.episode_memory = EpisodeMemory(self.evaluator.env.parallel_envs, self.episode_memory.device) 52 | self.evaluator.reset() 53 | -------------------------------------------------------------------------------- /core/train/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Type 3 | 4 | from tqdm import tqdm 5 | import torch 6 | import logging 7 | from pathlib import Path 8 | from core.algorithms.evaluator import EvaluatorConfig 9 | from core.resnet import TurboZeroResnet 10 | from core.test.tester import TesterConfig, Tester 11 | from core.train.collector import Collector 12 | from core.utils.history import Metric, TrainingMetrics 13 | from collections import deque 14 | 15 | from core.utils.memory import GameReplayMemory, ReplayMemory 16 | import time 17 | 18 | 19 | def init_history(log_results: bool = True): 20 | return TrainingMetrics( 21 | train_metrics=[ 22 | Metric(name='loss', xlabel='Step', ylabel='Loss', addons={'running_mean': 25}, maximize=False, alert_on_best=log_results), 23 | Metric(name='value_loss', xlabel='Step', ylabel='Loss', addons={'running_mean': 25}, maximize=False, alert_on_best=log_results, proper_name='Value Loss'), 24 | Metric(name='policy_loss', xlabel='Step', ylabel='Loss', addons={'running_mean': 25}, maximize=False, alert_on_best=log_results, proper_name='Policy Loss'), 25 | Metric(name='policy_accuracy', xlabel='Step', ylabel='Accuracy (%)', addons={'running_mean': 25}, maximize=True, alert_on_best=log_results, proper_name='Policy Accuracy'), 26 | Metric(name='replay_memory_similarity', xlabel='Step', ylabel='Cosine Centroid Similarity', addons={'running_mean': 25}, maximize=False, alert_on_best=log_results, proper_name='Replay Memory Similarity'), 27 | ], 28 | episode_metrics=[], 29 | eval_metrics=[], 30 | epoch_metrics=[] 31 | ) 32 | 33 | 34 | @dataclass 35 | class TrainerConfig: 36 | algo_config: EvaluatorConfig 37 | episodes_per_epoch: int 38 | episodes_per_minibatch: int 39 | minibatch_size: int 40 | learning_rate: float 41 | momentum: float 42 | c_reg: float 43 | lr_decay_gamma: float 44 | parallel_envs: int 45 | policy_factor: float 46 | replay_memory_min_size: int 47 | replay_memory_max_size: int 48 | test_config: TesterConfig 49 | replay_memory_sample_games: bool = True 50 | 51 | 52 | class Trainer: 53 | def __init__(self, 54 | config: TrainerConfig, 55 | collector: Collector, 56 | tester: Tester, 57 | model: torch.nn.Module, 58 | optimizer: torch.optim.Optimizer, 59 | device: torch.device, 60 | raw_train_config: dict, 61 | raw_env_config: dict, 62 | history: TrainingMetrics, 63 | log_results: bool = True, 64 | interactive: bool = True, 65 | run_tag: str = 'model', 66 | debug: bool = False 67 | ): 68 | self.collector = collector 69 | self.tester = tester 70 | self.parallel_envs = collector.evaluator.env.parallel_envs 71 | self.model = model 72 | self.optimizer = optimizer 73 | self.config = config 74 | self.device = device 75 | self.log_results = log_results 76 | self.interactive = interactive 77 | self.run_tag = run_tag 78 | self.raw_train_config = raw_train_config 79 | self.raw_env_config = raw_env_config 80 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.config.lr_decay_gamma) 81 | self.debug = debug 82 | self.collection_queue = deque() 83 | 84 | self.history = history 85 | 86 | if config.replay_memory_sample_games: 87 | self.replay_memory = GameReplayMemory( 88 | config.replay_memory_max_size 89 | ) 90 | else: 91 | self.replay_memory = ReplayMemory( 92 | config.replay_memory_max_size, 93 | ) 94 | 95 | def add_collection_metrics(self, episodes): 96 | raise NotImplementedError() 97 | 98 | def add_train_metrics(self, policy_loss, value_loss, policy_accuracy, loss): 99 | self.history.add_training_data({ 100 | 'policy_loss': policy_loss, 101 | 'value_loss': value_loss, 102 | 'policy_accuracy': policy_accuracy, 103 | 'loss': loss 104 | }, log=self.log_results) 105 | 106 | def add_epoch_metrics(self): 107 | raise NotImplementedError() 108 | 109 | def training_step(self): 110 | inputs, target_policy, target_value, legal_actions = zip(*self.replay_memory.sample(self.config.minibatch_size)) 111 | inputs = torch.stack(inputs).to(device=self.device) 112 | 113 | target_policy = torch.stack(target_policy).to(device=self.device) 114 | target_policy = self.policy_transform(target_policy) 115 | 116 | target_value = torch.stack(target_value).to(device=self.device) 117 | target_value = self.value_transform(target_value) 118 | 119 | legal_actions = torch.stack(legal_actions).to(device=self.device) 120 | 121 | self.optimizer.zero_grad() 122 | policy_logits, values = self.model(inputs) 123 | # multiply policy logits by legal actions mask, set illegal actions to smallest possible negative float32 124 | # consistent with open_spiel implementation 125 | policy_logits = (policy_logits * legal_actions) + (torch.finfo(torch.float32).min * (~legal_actions)) 126 | 127 | policy_loss = self.config.policy_factor * torch.nn.functional.cross_entropy(policy_logits, target_policy) 128 | # multiply by 2 since most other implementations have values rangeing from -1 to 1 whereas ours range from 0 to 1 129 | # this makes values loss a bit more comparable 130 | value_loss = torch.nn.functional.mse_loss(values.flatten() * 2, target_value * 2) 131 | loss = policy_loss + value_loss 132 | 133 | policy_accuracy = torch.eq(torch.argmax(target_policy, dim=1), torch.argmax(policy_logits, dim=1)).float().mean() 134 | 135 | loss.backward() 136 | self.optimizer.step() 137 | 138 | return policy_loss.item(), value_loss.item(), policy_accuracy.item(), loss.item() 139 | 140 | def policy_transform(self, policy): 141 | return policy 142 | 143 | def value_transform(self, value): 144 | return value 145 | 146 | def train_minibatch(self): 147 | self.model.train() 148 | memory_size = self.replay_memory.size() 149 | if memory_size >= self.config.replay_memory_min_size: 150 | policy_loss, value_loss, polcy_accuracy, loss = self.training_step() 151 | self.add_train_metrics(policy_loss, value_loss, polcy_accuracy, loss) 152 | else: 153 | logging.info(f'Replay memory samples ({memory_size}) <= min samples ({self.config.replay_memory_min_size}), skipping training step') 154 | 155 | def train_epoch(self): 156 | new_episodes = 0 157 | threshold_for_training_step = self.config.episodes_per_minibatch 158 | if not self.debug: 159 | progress_bar = tqdm(total=self.config.episodes_per_epoch, desc='Collecting self-play episodes...', leave=True, position=0) 160 | while new_episodes < self.config.episodes_per_epoch: 161 | if self.collection_queue: 162 | episode = self.collection_queue.popleft() 163 | self.replay_memory.insert(episode) 164 | if not self.debug: 165 | progress_bar.update(1) 166 | new_episodes += 1 167 | if new_episodes >= threshold_for_training_step: 168 | threshold_for_training_step += self.config.episodes_per_minibatch 169 | self.train_minibatch() 170 | else: 171 | finished_episodes, _ = self.collector.collect() 172 | if finished_episodes: 173 | for episode in finished_episodes: 174 | episode = self.collector.postprocess(episode) 175 | if new_episodes >= self.config.episodes_per_epoch: 176 | self.collection_queue.append(episode) 177 | else: 178 | self.replay_memory.insert(episode) 179 | if not self.debug: 180 | progress_bar.update(1) 181 | new_episodes += 1 182 | if new_episodes >= threshold_for_training_step: 183 | threshold_for_training_step += self.config.episodes_per_minibatch 184 | self.train_minibatch() 185 | 186 | self.add_collection_metrics(finished_episodes) 187 | 188 | def fill_replay_memory(self): 189 | if not self.debug: 190 | progress_bar = tqdm(total=self.config.replay_memory_min_size, desc='Populating Replay Memory...', leave=True, position=0) 191 | while self.replay_memory.size() < self.config.replay_memory_min_size: 192 | finished_episodes, _ = self.collector.collect() 193 | if finished_episodes: 194 | for episode in finished_episodes: 195 | episode = self.collector.postprocess(episode) 196 | self.replay_memory.insert(episode) 197 | if not self.debug: 198 | progress_bar.update(1) 199 | 200 | 201 | def training_loop(self, epochs: Optional[int] = None): 202 | total_epochs = self.history.cur_epoch + epochs if epochs is not None else None 203 | if self.history.cur_epoch == 0: 204 | # run initial test batch with untrained model 205 | if self.tester.config.episodes_per_epoch > 0: 206 | self.tester.collect_test_batch() 207 | self.save_checkpoint() 208 | 209 | if self.replay_memory.size() <= self.config.replay_memory_min_size: 210 | logging.info('Populating replay memory...') 211 | self.fill_replay_memory() 212 | 213 | while self.history.cur_epoch < total_epochs if total_epochs is not None else True: 214 | self.history.start_new_epoch() 215 | self.train_epoch() 216 | 217 | if self.tester.config.episodes_per_epoch > 0: 218 | self.tester.collect_test_batch() 219 | self.add_epoch_metrics() 220 | 221 | if self.interactive: 222 | self.history.generate_plots() 223 | 224 | self.scheduler.step() 225 | self.save_checkpoint() 226 | 227 | 228 | 229 | def save_checkpoint(self, custom_name: Optional[str] = None) -> None: 230 | directory = f'./checkpoints/{self.run_tag}/' 231 | Path(directory).mkdir(parents=True, exist_ok=True) 232 | filename = custom_name if custom_name is not None else str(self.history.cur_epoch) 233 | filepath = directory + f'{filename}.pt' 234 | torch.save({ 235 | 'model_arch_params': self.model.config, 236 | 'model_state_dict': self.model.state_dict(), 237 | 'optimizer_state_dict': self.optimizer.state_dict(), 238 | 'history': self.history, 239 | 'run_tag': self.run_tag, 240 | 'raw_train_config': self.raw_train_config, 241 | 'raw_env_config': self.raw_env_config 242 | }, filepath) 243 | 244 | def benchmark_collection_step(self): 245 | start = time.time() 246 | self.collector.collect() 247 | tottime = time.time() - start 248 | time_per_env_step = tottime / self.config.parallel_envs 249 | print(f'Stepped {self.config.parallel_envs} envs in {tottime:.4f} seconds ({time_per_env_step:.4f} seconds per step)') 250 | 251 | 252 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | 2 | from core.env import Env 3 | from core.resnet import TurboZeroResnet 4 | import torch 5 | 6 | 7 | def load_checkpoint(path: str): 8 | return torch.load(path, map_location=torch.device('cpu')) 9 | 10 | def load_model_and_optimizer_from_checkpoint(checkpoint: dict, env: Env, device: torch.device): 11 | model = TurboZeroResnet(checkpoint['model_arch_params'], env.state_shape, env.policy_shape) 12 | model.load_state_dict(checkpoint['model_state_dict']) 13 | model = model.to(device) 14 | model.eval() 15 | optimizer = torch.optim.SGD(model.parameters(), 16 | lr = checkpoint['raw_train_config']['learning_rate'], 17 | momentum = checkpoint['raw_train_config']['momentum'], 18 | weight_decay = checkpoint['raw_train_config']['c_reg']) 19 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 20 | return model, optimizer 21 | -------------------------------------------------------------------------------- /core/utils/custom_activations.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from typing import Optional 5 | import torch 6 | import logging 7 | 8 | 9 | class Tanh0to1(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.tanh = torch.nn.Tanh() 13 | 14 | def forward(self, x): 15 | return (self.tanh(x) + 1) / 2 16 | 17 | 18 | def load_activation(activation: str) -> Optional[torch.nn.Module]: 19 | if activation == 'tanh0to1': 20 | return Tanh0to1() 21 | elif activation == 'tanh': 22 | return torch.nn.Tanh() 23 | elif activation == 'relu': 24 | return torch.nn.ReLU() 25 | elif activation == 'sigmoid': 26 | return torch.nn.Sigmoid() 27 | else: 28 | if activation != '': 29 | logging.warn(f'Warning: activation {activation} not found') 30 | logging.warn(f'No activation will be applied to value head') 31 | return None -------------------------------------------------------------------------------- /core/utils/heatmap.py: -------------------------------------------------------------------------------- 1 | # ripped straight from matplotlib docs 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | 7 | def heatmap(data, row_labels, col_labels, ax=None, 8 | cbar_kw=None, cbarlabel="", **kwargs): 9 | """ 10 | Create a heatmap from a numpy array and two lists of labels. 11 | 12 | Parameters 13 | ---------- 14 | data 15 | A 2D numpy array of shape (M, N). 16 | row_labels 17 | A list or array of length M with the labels for the rows. 18 | col_labels 19 | A list or array of length N with the labels for the columns. 20 | ax 21 | A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 22 | not provided, use current axes or create a new one. Optional. 23 | cbar_kw 24 | A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 25 | cbarlabel 26 | The label for the colorbar. Optional. 27 | **kwargs 28 | All other arguments are forwarded to `imshow`. 29 | """ 30 | 31 | if ax is None: 32 | ax = plt.gca() 33 | 34 | if cbar_kw is None: 35 | cbar_kw = {} 36 | 37 | # Plot the heatmap 38 | im = ax.imshow(data, **kwargs) 39 | 40 | # Create colorbar 41 | cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) 42 | cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 43 | 44 | # Show all ticks and label them with the respective list entries. 45 | ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) 46 | ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) 47 | 48 | # Let the horizontal axes labeling appear on top. 49 | ax.tick_params(top=True, bottom=False, 50 | labeltop=True, labelbottom=False) 51 | 52 | # Rotate the tick labels and set their alignment. 53 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", 54 | rotation_mode="anchor") 55 | 56 | # Turn spines off and create white grid. 57 | ax.spines[:].set_visible(False) 58 | 59 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 60 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 61 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 62 | ax.tick_params(which="minor", bottom=False, left=False) 63 | 64 | return im, cbar 65 | 66 | 67 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 68 | textcolors=("black", "white"), 69 | threshold=None, **textkw): 70 | """ 71 | A function to annotate a heatmap. 72 | 73 | Parameters 74 | ---------- 75 | im 76 | The AxesImage to be labeled. 77 | data 78 | Data used to annotate. If None, the image's data is used. Optional. 79 | valfmt 80 | The format of the annotations inside the heatmap. This should either 81 | use the string format method, e.g. "$ {x:.2f}", or be a 82 | `matplotlib.ticker.Formatter`. Optional. 83 | textcolors 84 | A pair of colors. The first is used for values below a threshold, 85 | the second for those above. Optional. 86 | threshold 87 | Value in data units according to which the colors from textcolors are 88 | applied. If None (the default) uses the middle of the colormap as 89 | separation. Optional. 90 | **kwargs 91 | All other arguments are forwarded to each call to `text` used to create 92 | the text labels. 93 | """ 94 | 95 | if not isinstance(data, (list, np.ndarray)): 96 | data = im.get_array() 97 | 98 | # Normalize the threshold to the images color range. 99 | if threshold is not None: 100 | threshold = im.norm(threshold) 101 | else: 102 | threshold = im.norm(data.max())/2. 103 | 104 | # Set default alignment to center, but allow it to be 105 | # overwritten by textkw. 106 | kw = dict(horizontalalignment="center", 107 | verticalalignment="center") 108 | kw.update(textkw) 109 | 110 | # Get the formatter in case a string is supplied 111 | if isinstance(valfmt, str): 112 | valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 113 | 114 | # Loop over the data and create a `Text` for each "pixel". 115 | # Change the text's color depending on the data. 116 | texts = [] 117 | for i in range(data.shape[0]): 118 | for j in range(data.shape[1]): 119 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 120 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 121 | texts.append(text) 122 | 123 | return texts -------------------------------------------------------------------------------- /core/utils/history.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import bottleneck 3 | 4 | from collections import OrderedDict 5 | from matplotlib import pyplot as plt 6 | import IPython.display as display 7 | import numpy as np 8 | import logging 9 | 10 | class Metric: 11 | def __init__(self, name, xlabel, ylabel, pl_type: str ='plot', addons={}, maximize=True, alert_on_best=False, proper_name=None, best=None) -> None: 12 | self.name = name 13 | self.ts = [] 14 | self.data = [] 15 | self.plot = plt.figure() 16 | self.xlabel = xlabel 17 | self.ylabel = ylabel 18 | self.pl_type = pl_type 19 | self.addons = addons 20 | self.maximize = maximize 21 | if best is None: 22 | self.best = float('-inf') if maximize else float('inf') 23 | else: 24 | self.best = best 25 | self.alert_on_best = alert_on_best 26 | self.proper_name = name.capitalize() if proper_name is None else proper_name 27 | 28 | def add_data(self, ts, data): 29 | self.ts.append(ts) 30 | self.data.append(data) 31 | if self.maximize: 32 | if data > self.best: 33 | self.best = data 34 | if self.alert_on_best: 35 | logging.info(f'**** NEW BEST {self.proper_name}: {self.best:.4f} ****') 36 | else: 37 | if data < self.best: 38 | self.best = data 39 | if self.alert_on_best: 40 | logging.info(f'**** NEW BEST {self.proper_name}: {self.best:.4f} ****') 41 | 42 | def reset_fig(self): 43 | self.plot = plt.figure() 44 | 45 | def clear_data(self): 46 | self.ts = [] 47 | self.data = [] 48 | 49 | def generate_plot(self): 50 | data = self.data 51 | ts = self.ts 52 | if data: 53 | self.plot.clear() 54 | ax = self.plot.add_subplot(111) 55 | 56 | if 'window' in self.addons: 57 | # get data within window 58 | data = data[-self.addons['window']:] 59 | ts = ts[-self.addons['window']:] 60 | 61 | if self.pl_type == 'hist': 62 | ax.hist(data, bins='auto') 63 | elif self.pl_type == 'bar': 64 | values, counts = np.unique(data, return_counts=True) 65 | ax.bar(values.astype(np.str_), counts) 66 | else: 67 | ax.plot(ts, data) 68 | ax.annotate('%0.3f' % data[-1], xy=(1, data[-1]), xytext=(8, 0), 69 | xycoords=('axes fraction', 'data'), textcoords='offset points', color='blue') 70 | ax.set_title(self.proper_name) 71 | ax.set_xlabel(self.xlabel) 72 | ax.set_ylabel(self.ylabel) 73 | 74 | 75 | if 'running_mean' in self.addons: 76 | running_mean = bottleneck.move_mean(data, window=min(self.addons['running_mean'], len(data)), min_count=1) 77 | ax.plot(ts, running_mean, color='red', linewidth=2) 78 | ax.annotate('%0.3f' % running_mean[-1], xy=(1, running_mean[-1]), xytext=(8, 0), 79 | xycoords=('axes fraction', 'data'), textcoords='offset points', color='red') 80 | display.display(self.plot) 81 | 82 | class TrainingMetrics: 83 | def __init__(self, train_metrics: List[Metric], episode_metrics: List[Metric], eval_metrics: List[Metric], epoch_metrics: List[Metric]) -> None: 84 | self.train_metrics: Dict[str, Metric] = { 85 | metric.name: metric for metric in train_metrics 86 | } 87 | self.episode_metrics: Dict[str, Metric] = { 88 | metric.name: metric for metric in episode_metrics 89 | } 90 | self.epoch_metrics: Dict[str, Metric] = { 91 | metric.name: metric for metric in epoch_metrics 92 | } 93 | self.eval_metrics: Dict[str, List[Metric]] = { 94 | metric.name: [metric] for metric in eval_metrics 95 | } 96 | 97 | self.cur_epoch = 0 98 | self.cur_test_step = 0 99 | self.cur_train_step = 0 100 | self.cur_train_episode = 0 101 | 102 | def add_training_data(self, data, log=True): 103 | if log: 104 | logging.info(f'Step {self.cur_train_step}') 105 | for metric_name, metric_data in data.items(): 106 | if log: 107 | logging.info(f'\t{self.train_metrics[metric_name].proper_name}: {metric_data:.4f}') 108 | self.train_metrics[metric_name].add_data(self.cur_train_step, metric_data) 109 | self.cur_train_step += 1 110 | 111 | def add_episode_data(self, data, log=True): 112 | if log: 113 | logging.info(f'Episode {self.cur_train_episode}') 114 | for metric_name, metric_data in data.items(): 115 | if log: 116 | logging.info(f'\t{self.episode_metrics[metric_name].proper_name}: {metric_data:.4f}') 117 | self.episode_metrics[metric_name].add_data(self.cur_train_episode, metric_data) 118 | self.cur_train_episode += 1 119 | 120 | def add_evaluation_data(self, data, log=True): 121 | if log: 122 | logging.info(f'Eval episode {self.cur_test_step}') 123 | 124 | for metric_name, metric_data in data.items(): 125 | if log: 126 | logging.info(f'\t{self.eval_metrics[metric_name][0].proper_name}: {metric_data:.4f}') 127 | self.eval_metrics[metric_name][self.cur_epoch].add_data(self.cur_epoch, metric_data) 128 | 129 | self.cur_test_step += 1 130 | 131 | def add_epoch_data(self, data, log=True): 132 | if log: 133 | logging.info(f'Epoch {self.cur_epoch}') 134 | for metric_name, metric_data in data.items(): 135 | if log: 136 | logging.info(f'\t{self.epoch_metrics[metric_name].proper_name}: {metric_data:.4f}') 137 | self.epoch_metrics[metric_name].add_data(self.cur_epoch, metric_data) 138 | 139 | def start_new_epoch(self): 140 | for k in self.eval_metrics.keys(): 141 | if self.eval_metrics[k]: 142 | self.eval_metrics[k].append(Metric(k, self.eval_metrics[k][-1].xlabel, self.eval_metrics[k][-1].ylabel, pl_type=self.eval_metrics[k][-1].pl_type, \ 143 | addons=self.eval_metrics[k][-1].addons, maximize=self.eval_metrics[k][-1].maximize, alert_on_best=self.eval_metrics[k][-1].alert_on_best, \ 144 | proper_name=self.eval_metrics[k][-1].proper_name, best=self.eval_metrics[k][-1].best)) 145 | self.cur_epoch += 1 146 | self.cur_test_step = 0 147 | 148 | def reset_all_figs(self): # for matplotlib compatibility 149 | for metric in self.train_metrics.values(): 150 | metric.reset_fig() 151 | for metric_list in self.eval_metrics.values(): 152 | for metric in metric_list: 153 | metric.reset_fig() 154 | for metric in self.epoch_metrics.values(): 155 | metric.reset_fig() 156 | 157 | 158 | def generate_plots(self, categories=['train', 'eval', 'episode', 'epoch']): 159 | display.clear_output(wait=False) 160 | if 'train' in categories: 161 | for metric in self.train_metrics.values(): 162 | if metric.data: 163 | metric.generate_plot() 164 | if 'eval' in categories: 165 | for metrics_list in self.eval_metrics.values(): 166 | if metrics_list: 167 | if metrics_list[-1].data: 168 | metrics_list[-1].generate_plot() 169 | if 'episode' in categories: 170 | for metric in self.episode_metrics.values(): 171 | if metric.data: 172 | metric.generate_plot() 173 | if 'epoch' in categories: 174 | for metric in self.epoch_metrics.values(): 175 | if metric.data: 176 | metric.generate_plot() 177 | -------------------------------------------------------------------------------- /core/utils/memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | from core.utils.utils import cosine_centroid_similarity 4 | 5 | import torch 6 | from collections import deque 7 | 8 | class ReplayMemory: 9 | def __init__(self, max_size=10000) -> None: 10 | self.max_size = max_size 11 | self.memory = deque([], maxlen=max_size) 12 | pass 13 | 14 | def sample(self, num_samples): 15 | return random.sample(self.memory, num_samples) 16 | 17 | def insert(self, sample): 18 | self.memory.append(sample) 19 | 20 | def size(self): 21 | return len(self.memory) 22 | 23 | def similarity(self): 24 | episodes = torch.stack([x[0] for x in self.sample(4096)]) 25 | return cosine_centroid_similarity(episodes) 26 | 27 | class GameReplayMemory(ReplayMemory): 28 | def __init__(self, max_size=10000) -> None: 29 | super().__init__(max_size) 30 | 31 | def sample(self, num_samples): 32 | games = random.choices(self.memory, k=num_samples) 33 | samples = [] 34 | for game in games: 35 | samples.append(random.sample(game, 1)[0]) 36 | return samples 37 | 38 | 39 | class EpisodeMemory: 40 | def __init__(self, parallel_envs: int, device: torch.device) -> None: 41 | self.memory = [[] for _ in range(parallel_envs)] 42 | self.parallel_envs = parallel_envs 43 | self.device = device 44 | 45 | def insert(self, inputs: torch.Tensor, action_visits: torch.Tensor, legal_actions: torch.Tensor): 46 | inputs = inputs.clone().to(device=self.device) 47 | action_visits = action_visits.clone().to(device=self.device) 48 | legal_actions = legal_actions.clone().to(device=self.device) 49 | 50 | for i in range(self.parallel_envs): 51 | self.memory[i].append((inputs[i], action_visits[i], legal_actions[i])) 52 | 53 | def pop_terminated_episodes(self, terminated: torch.Tensor): 54 | episodes = [] 55 | for i in terminated.nonzero().flatten(): 56 | episode = self.memory[i] 57 | episodes.append(episode) 58 | self.memory[i] = [] 59 | return episodes 60 | 61 | -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rand_argmax_2d(values): 5 | inds = values == values.max(dim=1, keepdim=True).values 6 | return torch.multinomial(inds.float(), 1) 7 | 8 | 9 | def count_parameters(model): 10 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 11 | 12 | def jaccard_similarity(tensor1, tensor2): 13 | # Compute the intersection and union 14 | intersection = (tensor1 & tensor2).float().sum() 15 | union = (tensor1 | tensor2).float().sum() 16 | 17 | # Compute the Jaccard similarity coefficient 18 | jaccard = intersection / union 19 | 20 | return jaccard 21 | 22 | 23 | def cosine_centroid_similarity(data): 24 | centroid = data.mean(dim=0) 25 | 26 | centroid_similarity = torch.nn.functional.cosine_similarity(data, centroid.unsqueeze(0), dim=1).mean().item() 27 | 28 | return centroid_similarity -------------------------------------------------------------------------------- /envs/_2048/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/_2048/__init__.py -------------------------------------------------------------------------------- /envs/_2048/collector.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | from typing import Optional 6 | import torch 7 | from core.train.collector import Collector 8 | from core.algorithms.evaluator import Evaluator 9 | 10 | class _2048Collector(Collector): 11 | def __init__(self, 12 | evaluator: Evaluator, 13 | episode_memory_device: torch.device 14 | ) -> None: 15 | super().__init__(evaluator, episode_memory_device) 16 | 17 | def assign_rewards(self, terminated_episodes, terminated): 18 | episodes = [] 19 | for episode in terminated_episodes: 20 | episode_with_rewards = [] 21 | moves = len(episode) 22 | for (inputs, visits, legal_actions) in episode: 23 | episode_with_rewards.append((inputs, visits, torch.tensor(moves, dtype=torch.float32, requires_grad=False, device=inputs.device), legal_actions)) 24 | moves -= 1 25 | episodes.append(episode_with_rewards) 26 | return episodes 27 | 28 | def postprocess(self, terminated_episodes): 29 | # TODO: too many lists 30 | inputs, probs, rewards, legal_actions = zip(*terminated_episodes) 31 | rotated_inputs = [] 32 | for i in inputs: 33 | for k in range(4): 34 | rotated_inputs.append(torch.rot90(i, k=k, dims=(1, 2))) 35 | rotated_probs = [] 36 | for p in probs: 37 | # left -> down 38 | # down -> right 39 | # right -> up 40 | # up -> left 41 | for k in range(4): 42 | rotated_probs.append(torch.roll(p, k)) 43 | rotated_legal_actions = [] 44 | for l in legal_actions: 45 | for k in range(4): 46 | rotated_legal_actions.append(torch.roll(l, k)) 47 | rotated_rewards = [] 48 | for r in rewards: 49 | rotated_rewards.extend([r] * 4) 50 | 51 | return list(zip(rotated_inputs, rotated_probs, rotated_rewards, rotated_legal_actions)) 52 | -------------------------------------------------------------------------------- /envs/_2048/env.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | from colorama import Fore 5 | import torch 6 | from core.env import Env, EnvConfig 7 | from .torchscripts import get_stochastic_progressions, push_actions, get_legal_actions 8 | 9 | COLOR_MAP = { 10 | 1: Fore.WHITE, 11 | 2: Fore.LIGHTWHITE_EX, 12 | 3: Fore.LIGHTYELLOW_EX, 13 | 4: Fore.LIGHTRED_EX, 14 | 5: Fore.LIGHTMAGENTA_EX, 15 | 6: Fore.LIGHTGREEN_EX, 16 | 7: Fore.LIGHTCYAN_EX, 17 | 8: Fore.LIGHTBLUE_EX, 18 | 9: Fore.YELLOW, 19 | 10: Fore.RED, 20 | 11: Fore.MAGENTA, 21 | 12: Fore.GREEN, 22 | 13: Fore.CYAN, 23 | 14: Fore.BLUE, 24 | 15: Fore.WHITE, 25 | 16: Fore.BLACK, 26 | 17: Fore.LIGHTBLACK_EX 27 | } 28 | 29 | 30 | @dataclass 31 | class _2048EnvConfig(EnvConfig): 32 | pass 33 | 34 | class _2048Env(Env): 35 | def __init__(self, 36 | parallel_envs: int, 37 | config: _2048EnvConfig, 38 | device: torch.device, 39 | debug=False 40 | ) -> None: 41 | super().__init__( 42 | parallel_envs=parallel_envs, 43 | config=config, 44 | device=device, 45 | num_players=1, 46 | state_shape=torch.Size([1, 4, 4]), 47 | policy_shape=torch.Size([4]), 48 | value_shape=torch.Size([1]), 49 | debug=debug 50 | ) 51 | 52 | if self.debug: 53 | self.get_stochastic_progressions_ts = get_stochastic_progressions 54 | self.get_legal_actions_ts = get_legal_actions 55 | self.push_actions_ts = push_actions 56 | else: 57 | self.get_stochastic_progressions_ts = torch.jit.trace(get_stochastic_progressions, ( # type: ignore 58 | self.states, 59 | )) 60 | 61 | self.get_legal_actions_ts = torch.jit.trace(get_legal_actions, ( # type: ignore 62 | self.states, 63 | )) 64 | 65 | self.push_actions_ts = torch.jit.trace(push_actions, ( # type: ignore 66 | self.states, 67 | torch.zeros((self.parallel_envs, ), dtype=torch.int64, device=device) 68 | )) 69 | 70 | self.saved_states = self.states.clone() 71 | # self.rewards is just a dummy tensor here 72 | self.rewards = torch.zeros((self.parallel_envs, ), dtype=torch.float32, device=device, requires_grad=False) 73 | 74 | def reset(self, seed=None) -> int: 75 | if seed is not None: 76 | torch.manual_seed(seed) 77 | else: 78 | seed = 0 79 | self.states.zero_() 80 | self.terminated.zero_() 81 | self.apply_stochastic_progressions() 82 | self.apply_stochastic_progressions() 83 | return seed 84 | 85 | def reset_terminated_states(self, seed: Optional[int] = None) -> int: 86 | if seed is not None: 87 | torch.manual_seed() 88 | else: 89 | seed = 0 90 | self.states *= torch.logical_not(self.terminated).view(self.parallel_envs, 1, 1, 1) 91 | self.apply_stochastic_progressions(self.terminated) 92 | self.apply_stochastic_progressions(self.terminated) 93 | self.terminated.zero_() 94 | return seed 95 | 96 | def next_turn(self): 97 | self.apply_stochastic_progressions(torch.logical_not(self.terminated)) 98 | 99 | def get_high_squares(self): 100 | return torch.amax(self.states, dim=(1, 2, 3)) 101 | 102 | def get_rewards(self, player_ids: Optional[torch.Tensor] = None) -> torch.Tensor: 103 | # TODO: handle rewards in env instead of collector postprocessing 104 | return self.rewards 105 | 106 | def update_terminated(self) -> None: 107 | self.terminated = self.is_terminal() 108 | 109 | def is_terminal(self): 110 | return (self.get_legal_actions().sum(dim=1, keepdim=True) == 0).flatten() 111 | 112 | def get_legal_actions(self) -> torch.Tensor: 113 | return self.get_legal_actions_ts(self.states) # type: ignore 114 | 115 | def get_stochastic_progressions(self) -> Tuple[torch.Tensor, torch.Tensor]: 116 | return self.get_stochastic_progressions_ts(self.states) # type: ignore 117 | 118 | def push_actions(self, actions) -> None: 119 | self.states = self.push_actions_ts(self.states, actions) # type: ignore 120 | 121 | def save_node(self) -> torch.Tensor: 122 | return self.states.clone() 123 | 124 | def load_node(self, load_envs: torch.Tensor, saved: torch.Tensor): 125 | load_envs_expnd = load_envs.view(self.parallel_envs, 1, 1, 1) 126 | self.states = saved.clone() * load_envs_expnd + self.states * (~load_envs_expnd) 127 | self.update_terminated() 128 | 129 | def print_state(self, action=None) -> None: 130 | envstr = [] 131 | assert self.parallel_envs == 1 132 | envstr.append((Fore.BLUE if action == 3 else '') + '+' + '--------+' * 4) 133 | 134 | envstr.append(Fore.RESET + '\n') 135 | for i in range(4): 136 | envstr.append((Fore.BLUE if action == 0 else '')) 137 | envstr.append('|' + Fore.RESET + ' |' * 3) 138 | envstr.append((Fore.BLUE if action == 2 else '')) 139 | envstr.append(' |') 140 | envstr.append(Fore.RESET + '\n') 141 | for j in range(4): 142 | color = Fore.RESET 143 | if j == 0 and action == 0: 144 | color = Fore.BLUE 145 | if self.states[0, 0, i, j] == 0: 146 | envstr.append(color + '| ') 147 | else: 148 | num = int(self.states[0, 0, i, j]) 149 | envstr.append(color + '|' + COLOR_MAP[num] + str(2**num).center(8)) 150 | envstr.append((Fore.BLUE if action == 2 else Fore.RESET)) 151 | envstr.append('|') 152 | envstr.append(Fore.RESET + '\n') 153 | envstr.append((Fore.BLUE if action == 0 else '')) 154 | envstr.append('|' + Fore.RESET + ' |' * 3) 155 | envstr.append((Fore.BLUE if action == 2 else '')) 156 | envstr.append(' |') 157 | envstr.append(Fore.RESET + '\n') 158 | if i < 3: 159 | envstr.append((Fore.BLUE if action == 0 else '')) 160 | envstr.append('+' + Fore.RESET + '--------+' * 3) 161 | envstr.append('--------' +(Fore.BLUE if action == 2 else '') + '+') 162 | else: 163 | envstr.append((Fore.BLUE if action == 1 else '')) 164 | envstr.append('+' + '--------+' * 4) 165 | envstr.append(Fore.RESET + '\n') 166 | print(''.join(envstr)) -------------------------------------------------------------------------------- /envs/_2048/misc/board_representation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/_2048/misc/board_representation.png -------------------------------------------------------------------------------- /envs/_2048/misc/high_tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/_2048/misc/high_tile.png -------------------------------------------------------------------------------- /envs/_2048/misc/improvement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/_2048/misc/improvement.png -------------------------------------------------------------------------------- /envs/_2048/misc/reward_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/_2048/misc/reward_distribution.png -------------------------------------------------------------------------------- /envs/_2048/misc/slide_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/_2048/misc/slide_conv.png -------------------------------------------------------------------------------- /envs/_2048/tester.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from core.test.tester import Tester 5 | 6 | 7 | class _2048Tester(Tester): 8 | def add_evaluation_metrics(self, episodes): 9 | if self.history is not None: 10 | for episode in episodes: 11 | moves = len(episode) 12 | last_state = episode[-1][0] 13 | high_square = 2 ** int(last_state.max().item()) 14 | self.history.add_evaluation_data({ 15 | 'reward': moves, 16 | 'high_square': high_square, 17 | }, log=self.log_results) -------------------------------------------------------------------------------- /envs/_2048/torchscripts.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Tuple 3 | import torch 4 | 5 | def collapse(rankt, bs_flat) -> torch.Tensor: 6 | non_zero_mask = (bs_flat != 0) 7 | 8 | # Set the rank of zero elements to zero 9 | rank = (rankt * non_zero_mask) 10 | 11 | # Create a tensor of sorted indices by sorting the rank tensor along dim=-1 12 | sorted_indices = torch.argsort(rank, dim=-1, descending=True, stable=True) 13 | 14 | return torch.gather(bs_flat, dim=-1, index=sorted_indices) 15 | 16 | 17 | def merge(states, rankt) -> torch.Tensor: 18 | shape = states.shape 19 | bs_flat = states.view(-1, shape[-1]) 20 | 21 | # Step 1: initial sort using a customized version of merge sort 22 | bs_flat = collapse(rankt, bs_flat) 23 | 24 | # Step 2: apply merge operation 25 | for i in range(3): 26 | is_same = torch.logical_and(bs_flat[:,i] == bs_flat[:,i+1], bs_flat[:,i] != 0) 27 | bs_flat[:,i].add_(is_same) 28 | bs_flat[:,i+1].masked_fill_(is_same, 0) 29 | 30 | # Step 3: reapply the customized merge sort 31 | bs_flat = collapse(rankt, bs_flat) 32 | 33 | return bs_flat.view(shape) 34 | 35 | def rotate_by_amnts(states, amnts): 36 | mask1 = amnts == 1 37 | mask2 = amnts == 2 38 | mask3 = amnts == 3 39 | 40 | states[mask1] = states.flip(2).transpose(2, 3)[mask1] 41 | states[mask2] = states.flip(3).flip(2)[mask2] 42 | states[mask3] = states.flip(3).transpose(2, 3)[mask3] 43 | 44 | return states 45 | 46 | def push_actions(states, actions) -> torch.Tensor: 47 | rankt = torch.arange(4, 0, -1, device=states.device, requires_grad=False).expand((states.shape[0] * 4, 4)) 48 | actions = actions.view(-1, 1, 1, 1).expand_as(states) 49 | states = rotate_by_amnts(states, actions) 50 | states = merge(states, rankt) 51 | states = rotate_by_amnts(states, (4-actions) % 4) 52 | return states 53 | 54 | def get_legal_actions(states): 55 | n = states.shape[0] 56 | device = states.device 57 | dtype = states.dtype 58 | 59 | mask0 = torch.tensor([[[[-1e5, 1]]]], dtype=dtype, device=device, requires_grad=False) 60 | mask1 = torch.tensor([[[[1], [-1e5]]]], dtype=dtype, device=device, requires_grad=False) 61 | mask2 = torch.tensor([[[[1, -1e5]]]], dtype=dtype, device=device, requires_grad=False) 62 | mask3 = torch.tensor([[[[-1e5], [1]]]], dtype=dtype, device=device, requires_grad=False) 63 | 64 | m0 = torch.nn.functional.conv2d(states, mask0, padding=0, bias=None).view(n, 12) 65 | m1 = torch.nn.functional.conv2d(states, mask1, padding=0, bias=None).view(n, 12) 66 | m2 = torch.nn.functional.conv2d(states, mask2, padding=0, bias=None).view(n, 12) 67 | m3 = torch.nn.functional.conv2d(states, mask3, padding=0, bias=None).view(n, 12) 68 | 69 | m0_valid = torch.any(m0 > 0.5, dim=1, keepdim=True) 70 | m1_valid = torch.any(m1 > 0.5, dim=1, keepdim=True) 71 | m2_valid = torch.any(m2 > 0.5, dim=1, keepdim=True) 72 | m3_valid = torch.any(m3 > 0.5, dim=1, keepdim=True) 73 | 74 | # Compute the differences between adjacent elements in the 2nd and 3rd dimensions 75 | vertical_diff = states[:, :, :-1, :] - states[:, :, 1:, :] 76 | horizontal_diff = states[:, :, :, :-1] - states[:, :, :, 1:] 77 | 78 | # Check where the differences are zero, excluding the zero elements in the original matrix 79 | vertical_zeros = torch.logical_and(vertical_diff == 0, states[:, :, 1:, :] != 0) 80 | horizontal_zeros = torch.logical_and(horizontal_diff == 0, states[:, :, :, 1:] != 0) 81 | 82 | # Flatten the last two dimensions and compute the logical OR along the last dimension 83 | vertical_comparison = vertical_zeros.view(n, 12).any(dim=1, keepdim=True) 84 | horizontal_comparison = horizontal_zeros.view(n, 12).any(dim=1, keepdim=True) 85 | m0_valid.logical_or_(horizontal_comparison) 86 | m2_valid.logical_or_(horizontal_comparison) 87 | m1_valid.logical_or_(vertical_comparison) 88 | m3_valid.logical_or_(vertical_comparison) 89 | 90 | return torch.concat([m0_valid, m1_valid, m2_valid, m3_valid], dim=1) 91 | 92 | def get_stochastic_progressions(states) -> Tuple[torch.Tensor, torch.Tensor]: 93 | ones = torch.eye(16, dtype=states.dtype).view(16, 4, 4) 94 | twos = torch.eye(16, dtype=states.dtype).view(16, 4, 4) * 2 95 | base_progressions = torch.concat([ones, twos], dim=0).to(states.device) 96 | base_probabilities = torch.concat([torch.full((16,), 0.9), torch.full((16,), 0.1)], dim=0).to(states.device) 97 | # check and see if each of the progressions are valid (no tile already in that spot) 98 | # base_progressions is a 32x4x4 tensor with all the possible progressions 99 | # bs is an Nx4x4 tensor with N board states 100 | # returns an 32xNx4x4 tensor with 32 possible progressions for each board state 101 | valid_progressions = torch.logical_not(torch.any((states * base_progressions).view(-1, 32, 16), dim=2)) 102 | progressions = (states + base_progressions) * valid_progressions.view(states.shape[0], 32, 1, 1) 103 | probs = base_probabilities * valid_progressions 104 | return progressions, probs -------------------------------------------------------------------------------- /envs/_2048/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from typing import Optional 5 | import torch 6 | import numpy as np 7 | from core.test.tester import Tester 8 | from core.utils.history import Metric, TrainingMetrics 9 | from core.train.trainer import Trainer, TrainerConfig 10 | from envs._2048.collector import _2048Collector 11 | 12 | class _2048Trainer(Trainer): 13 | def __init__(self, 14 | config: TrainerConfig, 15 | collector: _2048Collector, 16 | tester: Tester, 17 | model: torch.nn.Module, 18 | optimizer: torch.optim.Optimizer, 19 | device: torch.device, 20 | raw_train_config: dict, 21 | raw_env_config: dict, 22 | history: TrainingMetrics, 23 | log_results: bool = True, 24 | interactive: bool = True, 25 | run_tag: str = '2048', 26 | debug: bool = False 27 | ): 28 | super().__init__( 29 | config=config, 30 | collector=collector, 31 | tester=tester, 32 | model=model, 33 | optimizer=optimizer, 34 | device=device, 35 | raw_train_config=raw_train_config, 36 | raw_env_config=raw_env_config, 37 | history=history, 38 | log_results=log_results, 39 | interactive=interactive, 40 | run_tag=run_tag, 41 | debug=debug 42 | ) 43 | if self.history.cur_epoch == 0: 44 | self.history.episode_metrics.update({ 45 | 'reward': Metric(name='reward', xlabel='Episode', ylabel='Reward', addons={'running_mean': 100}, maximize=True, alert_on_best=self.log_results), 46 | 'log2_high_square': Metric(name='log2_high_square', xlabel='Episode', ylabel='High Tile', addons={'running_mean': 100}, maximize=True, alert_on_best=self.log_results, proper_name='High Tile (log2)'), 47 | }) 48 | 49 | self.history.eval_metrics.update({ 50 | 'reward': [Metric(name='reward', xlabel='Reward', ylabel='Frequency', pl_type='hist', maximize=True, alert_on_best=False)], 51 | 'high_square': [Metric(name='high_square', xlabel='High Tile', ylabel='Frequency Tile', pl_type='bar', maximize=True, alert_on_best=False, proper_name='High Tile')] 52 | }) 53 | 54 | self.history.epoch_metrics.update({ 55 | 'avg_reward': Metric(name='avg_reward', xlabel='Epoch', ylabel='Average Reward', maximize=True, alert_on_best=self.log_results, proper_name='Average Reward'), 56 | 'avg_log2_high_square': Metric(name='avg_log2_high_square', xlabel='Epoch', ylabel='Average High Tile', maximize=True, alert_on_best=self.log_results, proper_name='Average High Tile (log2)'), 57 | }) 58 | 59 | def add_collection_metrics(self, episodes): 60 | for episode in episodes: 61 | moves = len(episode) 62 | last_state = episode[-1][0] 63 | high_square = int(last_state.max().item()) 64 | self.history.add_episode_data({ 65 | 'reward': moves, 66 | 'log2_high_square': high_square, 67 | }, log=self.log_results) 68 | 69 | def add_epoch_metrics(self): 70 | if self.history.eval_metrics['reward'][-1].data: 71 | self.history.add_epoch_data({ 72 | 'avg_reward': np.mean(self.history.eval_metrics['reward'][-1].data) 73 | }, log=self.log_results) 74 | if self.history.eval_metrics['high_square'][-1].data: 75 | self.history.add_epoch_data({ 76 | 'avg_log2_high_square': np.log2(np.mean(self.history.eval_metrics['high_square'][-1].data)) 77 | }, log=self.log_results) 78 | 79 | def value_transform(self, value): 80 | return super().value_transform(value).log() 81 | 82 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/__init__.py -------------------------------------------------------------------------------- /envs/connect_x/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/connect_x/__init__.py -------------------------------------------------------------------------------- /envs/connect_x/collector.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import torch 7 | from core.algorithms.evaluator import TrainableEvaluator 8 | from core.train.collector import Collector 9 | from envs.connect_x.env import ConnectXConfig 10 | 11 | 12 | class ConnectXCollector(Collector): 13 | def __init__(self, 14 | evaluator: TrainableEvaluator, 15 | episode_memory_device: torch.device 16 | ) -> None: 17 | super().__init__(evaluator, episode_memory_device) 18 | assert isinstance(evaluator.env.config, ConnectXConfig) 19 | 20 | def assign_rewards(self, terminated_episodes, terminated): 21 | episodes = [] 22 | 23 | if terminated.any(): 24 | term_indices = terminated.nonzero(as_tuple=False).flatten() 25 | rewards = self.evaluator.env.get_rewards(torch.zeros_like(self.evaluator.env.env_indices)).clone().cpu().numpy() 26 | for i, episode in enumerate(terminated_episodes): 27 | episode_with_rewards = [] 28 | ti = term_indices[i] 29 | p1_reward = rewards[ti] 30 | p2_reward = 1 - p1_reward 31 | for ei, (inputs, visits, legal_actions) in enumerate(episode): 32 | if visits.sum(): # only append states where a move was possible 33 | episode_with_rewards.append((inputs, visits, torch.tensor(p2_reward if ei%2 else p1_reward, dtype=torch.float32, requires_grad=False, device=inputs.device), legal_actions)) 34 | episodes.append(episode_with_rewards) 35 | return episodes -------------------------------------------------------------------------------- /envs/connect_x/demo.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from core.demo.demo import TwoPlayerDemo 4 | 5 | 6 | class ConnectXDemo(TwoPlayerDemo): 7 | pass -------------------------------------------------------------------------------- /envs/connect_x/env.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | from core.env import Env, EnvConfig 7 | 8 | @dataclass 9 | class ConnectXConfig(EnvConfig): 10 | board_width: int 11 | board_height: int 12 | inarow: int 13 | 14 | class ConnectXEnv(Env): 15 | def __init__(self, 16 | parallel_envs: int, 17 | config: ConnectXConfig, 18 | device: torch.device, 19 | debug: bool = False 20 | ): 21 | self.config = config 22 | self.parallel_envs = parallel_envs 23 | state_shape = torch.Size([2, config.board_height, config.board_width]) 24 | policy_shape = torch.Size([config.board_width]) 25 | value_shape = torch.Size([1]) 26 | 27 | super().__init__( 28 | parallel_envs=parallel_envs, 29 | config=config, 30 | device=device, 31 | num_players=2, 32 | state_shape=state_shape, 33 | policy_shape=policy_shape, 34 | value_shape=value_shape, 35 | debug=debug 36 | ) 37 | 38 | self.next_empty = torch.full( 39 | (self.parallel_envs, self.config.board_width), 40 | self.config.board_height - 1, 41 | dtype=torch.int64, 42 | device=self.device, 43 | requires_grad=False 44 | ) 45 | self.reward_check = torch.zeros( 46 | (self.parallel_envs, 2), 47 | dtype=torch.long, 48 | device=self.device, 49 | requires_grad=False 50 | ) 51 | self.reward_check[:, 0] = self.config.inarow 52 | self.reward_check[:, 1] = -self.config.inarow 53 | 54 | self.kernels = self.create_kernels() 55 | 56 | def reset(self, seed=None): 57 | if seed is not None: 58 | torch.manual_seed(seed) 59 | else: 60 | seed = 0 61 | self.states.zero_() 62 | self.terminated.zero_() 63 | self.cur_players.zero_() 64 | self.next_empty.fill_(self.config.board_height - 1) 65 | return seed 66 | 67 | def push_actions(self, actions): 68 | self.states[self.env_indices, 0, self.next_empty[self.env_indices, actions], actions] = 1 69 | self.next_empty[self.env_indices, actions] -= 1 70 | 71 | def create_kernels(self): 72 | size = self.config.inarow 73 | 74 | horiz = torch.zeros((2, size, size), dtype=torch.float32, device=self.device, requires_grad=False) 75 | horiz[0, 0, :] = 1 76 | horiz[1, 0, :] = -1 77 | vert = torch.zeros((2, size, size), dtype=torch.float32, device=self.device, requires_grad=False) 78 | vert[1, :, 0] = 1 79 | vert[1, :, 0] = -1 80 | diag = torch.eye(size, device=self.device) 81 | inv_diag = torch.flip(torch.eye(size, device=self.device), dims=(0,)) 82 | 83 | kernels = torch.stack( 84 | [ 85 | torch.stack((diag, -diag), dim=0), 86 | torch.stack((inv_diag, -inv_diag), dim=0), 87 | horiz, 88 | vert 89 | ] 90 | ) 91 | return kernels 92 | 93 | def get_legal_actions(self) -> torch.Tensor: 94 | return self.next_empty >= 0 95 | 96 | def next_turn(self, *args, **kwargs): 97 | self.states = torch.roll(self.states, 1, dims=1) 98 | self.next_player() 99 | 100 | def save_node(self): 101 | return ( 102 | self.states.clone(), 103 | self.cur_players.clone(), 104 | self.next_empty.clone() 105 | ) 106 | 107 | def load_node(self, load_envs: torch.Tensor, saved: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 108 | load_envs_expnd = load_envs.view(-1, 1, 1, 1) 109 | self.states = saved[0].clone() * load_envs_expnd + self.states * (~load_envs_expnd) 110 | self.cur_players = saved[1].clone() * load_envs + self.cur_players * (~load_envs) 111 | self.next_empty = saved[2].clone() * load_envs.view(-1, 1) + self.next_empty * (~load_envs).view(-1, 1) 112 | self.update_terminated() 113 | 114 | def reset_terminated_states(self, seed: int | None = None): 115 | if seed is not None: 116 | torch.manual_seed(seed) 117 | else: 118 | seed = 0 119 | terminated = self.terminated.clone() 120 | self.states *= 1 * ~terminated.view(-1, 1, 1, 1) 121 | self.cur_players *= 1 * ~terminated 122 | self.next_empty *= 1 * ~terminated.view(-1, 1) 123 | self.next_empty += (self.config.board_height - 1) * terminated.view(-1, 1) 124 | self.terminated.zero_() 125 | return seed 126 | 127 | 128 | def is_terminal(self): 129 | return (self.get_rewards() != 0.5) | ~(self.get_legal_actions().any(dim=1)) 130 | 131 | def get_rewards(self, player_ids: Optional[torch.Tensor] = None) -> torch.Tensor: 132 | if player_ids is None: 133 | player_ids = self.cur_players 134 | idx = ((player_ids == self.cur_players).int() - 1) % 2 135 | other_idx = 1 - idx 136 | convolved = torch.functional.F.conv2d(self.states, self.kernels, padding=(self.config.inarow - 1, self.config.inarow - 1)).view(self.parallel_envs, -1) 137 | p1_rewards = (convolved == self.reward_check[self.env_indices, idx].view(self.parallel_envs, 1)).any(dim=1) 138 | p2_rewards = (convolved == self.reward_check[self.env_indices, other_idx].view(self.parallel_envs, 1)).any(dim=1) 139 | rewards = (1 * (p1_rewards > p2_rewards)) + (0.5 * (p1_rewards == p2_rewards)) 140 | return rewards 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /envs/connect_x/tester.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from core.test.tester import TwoPlayerTester 5 | 6 | 7 | class ConnectXTester(TwoPlayerTester): 8 | def add_evaluation_metrics(self, episodes): 9 | if self.history is not None: 10 | for _ in episodes: 11 | self.history.add_evaluation_data({}, log=self.log_results) -------------------------------------------------------------------------------- /envs/connect_x/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from core.test.tester import TwoPlayerTester 4 | from core.train.trainer import Trainer, TrainerConfig 5 | from core.utils.history import TrainingMetrics 6 | from envs.connect_x.collector import ConnectXCollector 7 | 8 | 9 | 10 | class ConnectXTrainer(Trainer): 11 | def __init__(self, 12 | config: TrainerConfig, 13 | collector: ConnectXCollector, 14 | tester: TwoPlayerTester, 15 | model: torch.nn.Module, 16 | optimizer: torch.optim.Optimizer, 17 | device: torch.device, 18 | raw_train_config: dict, 19 | raw_env_config: dict, 20 | history: TrainingMetrics, 21 | log_results: bool = True, 22 | interactive: bool = True, 23 | run_tag: str = 'connect_x', 24 | debug: bool = False 25 | ): 26 | super().__init__( 27 | config=config, 28 | collector=collector, 29 | tester=tester, 30 | model=model, 31 | optimizer=optimizer, 32 | device=device, 33 | raw_train_config=raw_train_config, 34 | raw_env_config=raw_env_config, 35 | history=history, 36 | log_results=log_results, 37 | interactive=interactive, 38 | run_tag=run_tag, 39 | debug=debug 40 | ) 41 | 42 | def add_collection_metrics(self, episodes): 43 | for _ in episodes: 44 | self.history.add_episode_data({}, log=self.log_results) 45 | 46 | def add_epoch_metrics(self): 47 | pass 48 | 49 | -------------------------------------------------------------------------------- /envs/load.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import logging 4 | from typing import Optional 5 | import torch 6 | from core.algorithms.evaluator import Evaluator 7 | 8 | from core.resnet import TurboZeroResnet 9 | from core.test.tester import TesterConfig, Tester, TwoPlayerTesterConfig, TwoPlayerTester 10 | from core.train.collector import Collector 11 | from core.train.trainer import Trainer, TrainerConfig 12 | from core.utils.history import TrainingMetrics 13 | from envs._2048.collector import _2048Collector 14 | from envs._2048.tester import _2048Tester 15 | from envs._2048.trainer import _2048Trainer 16 | from envs.connect_x.collector import ConnectXCollector 17 | from envs.connect_x.env import ConnectXConfig, ConnectXEnv 18 | from envs.connect_x.tester import ConnectXTester 19 | from envs.connect_x.trainer import ConnectXTrainer 20 | from envs.othello.collector import OthelloCollector 21 | from envs.othello.tester import OthelloTester 22 | from envs.othello.trainer import OthelloTrainer 23 | from .othello.env import OthelloEnv, OthelloEnvConfig 24 | from ._2048.env import _2048Env, _2048EnvConfig 25 | 26 | def init_env(device: torch.device, parallel_envs: int, env_config: dict, debug: bool): 27 | env_type = env_config['env_type'] 28 | if env_type == 'othello': 29 | config = OthelloEnvConfig(**env_config) 30 | return OthelloEnv(parallel_envs, config, device, debug) 31 | elif env_type == '2048': 32 | config = _2048EnvConfig(**env_config) 33 | return _2048Env(parallel_envs, config, device, debug) 34 | elif env_type == 'connect_x': 35 | config = ConnectXConfig(**env_config) 36 | return ConnectXEnv(parallel_envs, config, device, debug) 37 | else: 38 | raise NotImplementedError(f'Environment {env_type} not implemented') 39 | 40 | def init_collector(episode_memory_device: torch.device, env_type: str, evaluator: Evaluator): 41 | if env_type == 'othello': 42 | return OthelloCollector( 43 | evaluator=evaluator, 44 | episode_memory_device=episode_memory_device 45 | ) 46 | elif env_type == '2048': 47 | return _2048Collector( 48 | evaluator=evaluator, 49 | episode_memory_device=episode_memory_device 50 | ) 51 | elif env_type == 'connect_x': 52 | return ConnectXCollector( 53 | evaluator=evaluator, 54 | episode_memory_device=episode_memory_device 55 | ) 56 | else: 57 | raise NotImplementedError(f'Collector for environment {env_type} not supported') 58 | 59 | def init_tester( 60 | test_config: dict, 61 | env_type: str, 62 | collector: Collector, 63 | model: torch.nn.Module, 64 | history: TrainingMetrics, 65 | optimizer: Optional[torch.optim.Optimizer], 66 | log_results: bool, 67 | debug: bool 68 | ): 69 | if env_type == 'othello': 70 | return OthelloTester( 71 | config=TwoPlayerTesterConfig(**test_config), 72 | collector=collector, 73 | model=model, 74 | optimizer=optimizer, 75 | history=history, 76 | log_results=log_results, 77 | debug=debug 78 | ) 79 | elif env_type == '2048': 80 | return _2048Tester( 81 | config=TesterConfig(**test_config), 82 | collector=collector, 83 | model=model, 84 | optimizer=optimizer, 85 | history=history, 86 | log_results=log_results, 87 | debug=debug 88 | ) 89 | elif env_type == 'connect_x': 90 | return ConnectXTester( 91 | config=TwoPlayerTesterConfig(**test_config), 92 | collector=collector, 93 | model=model, 94 | optimizer=optimizer, 95 | history=history, 96 | log_results=log_results, 97 | debug=debug 98 | ) 99 | else: 100 | raise NotImplementedError(f'Tester for {env_type} not supported') 101 | 102 | def init_trainer( 103 | device: torch.device, 104 | env_type: str, 105 | collector: Collector, 106 | tester: Tester, 107 | model: TurboZeroResnet, 108 | optimizer: torch.optim.Optimizer, 109 | train_config: dict, 110 | raw_env_config: dict, 111 | history: TrainingMetrics, 112 | log_results: bool, 113 | interactive: bool, 114 | run_tag: str = '', 115 | debug: bool = False 116 | ): 117 | trainer_config = TrainerConfig(**train_config) 118 | if env_type == 'othello': 119 | assert isinstance(collector, OthelloCollector) 120 | assert isinstance(tester, TwoPlayerTester) 121 | return OthelloTrainer( 122 | config = trainer_config, 123 | collector = collector, 124 | tester = tester, 125 | model = model, 126 | optimizer = optimizer, 127 | device = device, 128 | raw_train_config = train_config, 129 | raw_env_config = raw_env_config, 130 | history = history, 131 | log_results=log_results, 132 | interactive=interactive, 133 | run_tag = run_tag, 134 | debug = debug 135 | ) 136 | elif env_type == '2048': 137 | assert isinstance(collector, _2048Collector) 138 | return _2048Trainer( 139 | config = trainer_config, 140 | collector = collector, 141 | tester = tester, 142 | model = model, 143 | optimizer = optimizer, 144 | device = device, 145 | raw_train_config = train_config, 146 | raw_env_config = raw_env_config, 147 | history = history, 148 | log_results=log_results, 149 | interactive=interactive, 150 | run_tag = run_tag, 151 | debug = debug 152 | ) 153 | elif env_type == 'connect_x': 154 | assert isinstance(collector, ConnectXCollector) 155 | assert isinstance(tester, TwoPlayerTester) 156 | return ConnectXTrainer( 157 | config = trainer_config, 158 | collector = collector, 159 | tester = tester, 160 | model = model, 161 | optimizer = optimizer, 162 | device = device, 163 | raw_train_config = train_config, 164 | raw_env_config = raw_env_config, 165 | history = history, 166 | log_results=log_results, 167 | interactive=interactive, 168 | run_tag = run_tag, 169 | debug = debug 170 | ) 171 | else: 172 | logging.warn(f'No trainer found for environment {env_type}') 173 | return Trainer( 174 | config = trainer_config, 175 | collector = collector, 176 | tester = tester, 177 | model = model, 178 | optimizer = optimizer, 179 | device = device, 180 | raw_train_config = train_config, 181 | raw_env_config = raw_env_config, 182 | history = history, 183 | log_results=log_results, 184 | interactive=interactive, 185 | run_tag = run_tag, 186 | debug = debug 187 | ) 188 | -------------------------------------------------------------------------------- /envs/othello/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/__init__.py -------------------------------------------------------------------------------- /envs/othello/collector.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from core.algorithms.evaluator import TrainableEvaluator 4 | from core.train.collector import Collector 5 | from envs.othello.env import OthelloEnvConfig 6 | 7 | 8 | class OthelloCollector(Collector): 9 | def __init__(self, 10 | evaluator: TrainableEvaluator, 11 | episode_memory_device: torch.device 12 | ) -> None: 13 | super().__init__(evaluator, episode_memory_device) 14 | assert isinstance(evaluator.env.config, OthelloEnvConfig) 15 | board_size = self.evaluator.env.config.board_size 16 | ids = torch.arange(self.evaluator.env.policy_shape[0]-1, device=episode_memory_device) 17 | self.rotated_action_ids = torch.zeros((board_size**2)+1, dtype=torch.long, requires_grad=False, device=ids.device) 18 | self.rotated_action_ids[:board_size**2] = torch.rot90(ids.reshape(board_size, board_size), k=1, dims=(0, 1)).flatten() 19 | self.rotated_action_ids[board_size**2] = board_size**2 20 | 21 | def assign_rewards(self, terminated_episodes, terminated): 22 | episodes = [] 23 | 24 | if terminated.any(): 25 | term_indices = terminated.nonzero(as_tuple=False).flatten() 26 | rewards = self.evaluator.env.get_rewards(torch.zeros_like(self.evaluator.env.env_indices)).clone().cpu().numpy() 27 | for i, episode in enumerate(terminated_episodes): 28 | episode_with_rewards = [] 29 | ti = term_indices[i] 30 | p1_reward = rewards[ti] 31 | p2_reward = 1 - p1_reward 32 | for ei, (inputs, visits, legal_actions) in enumerate(episode): 33 | if visits.sum(): # only append states where a move was possible 34 | episode_with_rewards.append((inputs, visits, torch.tensor(p2_reward if ei%2 else p1_reward, dtype=torch.float32, requires_grad=False, device=inputs.device), legal_actions)) 35 | episodes.append(episode_with_rewards) 36 | return episodes 37 | 38 | def postprocess(self, terminated_episodes): 39 | inputs, probs, rewards, legal_actions = zip(*terminated_episodes) 40 | rotated_inputs = [] 41 | for i in inputs: 42 | for k in range(4): 43 | rotated_inputs.append(torch.rot90(i, k=k, dims=(1, 2))) 44 | rotated_probs = [] 45 | for p in probs: 46 | # left -> down 47 | # down -> right 48 | # right -> up 49 | # up -> left 50 | new_p = p 51 | for k in range(4): 52 | rotated_probs.append(new_p) 53 | new_p = new_p[self.rotated_action_ids] 54 | 55 | rotated_legal_actions = [] 56 | for l in legal_actions: 57 | new_l = l 58 | for k in range(4): 59 | rotated_legal_actions.append(new_l) 60 | new_l = new_l[self.rotated_action_ids] 61 | 62 | rotated_rewards = [] 63 | for r in rewards: 64 | rotated_rewards.extend([r] * 4) 65 | 66 | return list(zip(rotated_inputs, rotated_probs, rotated_rewards, rotated_legal_actions)) -------------------------------------------------------------------------------- /envs/othello/demo.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from core.demo.demo import TwoPlayerDemo 4 | 5 | 6 | class OthelloDemo(TwoPlayerDemo): 7 | 8 | 9 | def print_rewards(self, p1_started: bool): 10 | super().print_rewards(p1_started) 11 | p1_idx = 0 if self.evaluator.env.cur_players.item() == 0 else 1 12 | p2_idx = 1 if p1_idx == 0 else 0 13 | p1_tiles = self.evaluator.env.states[0,p1_idx].sum().item() 14 | p2_tiles = self.evaluator.env.states[0,p2_idx].sum().item() 15 | print(f'Player 1 Tiles: {int(p1_tiles)}') 16 | print(f'Player 2 Tiles: {int(p2_tiles)}') -------------------------------------------------------------------------------- /envs/othello/env.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple 4 | import torch 5 | from colorama import Fore, Back, Style 6 | 7 | from core.env import Env, EnvConfig 8 | from .torchscripts import get_legal_actions, push_actions, build_filters, build_flips 9 | 10 | 11 | @dataclass 12 | class OthelloEnvConfig(EnvConfig): 13 | board_size: int = 8 14 | book: Optional[str] = None 15 | 16 | def coord_to_action_id(coord): 17 | coord = coord.upper() 18 | if coord == "PS": 19 | return 64 20 | letter = ord(coord[0]) - ord('A') 21 | number = int(coord[1]) - 1 22 | return letter + number * 8 23 | 24 | 25 | class OthelloEnv(Env): 26 | def __init__(self, 27 | parallel_envs: int, 28 | config: OthelloEnvConfig, 29 | device: torch.device, 30 | debug=False 31 | ) -> None: 32 | self.board_size = config.board_size 33 | self.config: OthelloEnvConfig 34 | state_shape = torch.Size((2, self.board_size, self.board_size)) 35 | policy_shape = torch.Size(((self.board_size ** 2) + 1,)) 36 | value_shape = torch.Size((2, )) 37 | 38 | super().__init__( 39 | parallel_envs = parallel_envs, 40 | config = config, 41 | device=device, 42 | num_players=2, 43 | state_shape=state_shape, 44 | policy_shape=policy_shape, 45 | value_shape=value_shape, 46 | debug=debug 47 | ) 48 | 49 | 50 | 51 | num_rays = (8 * (self.board_size - 2)) + 1 52 | self.ray_tensor = torch.zeros((self.parallel_envs, num_rays, self.board_size, self.board_size), dtype=torch.float32, device=device, requires_grad=False) 53 | 54 | self.filters_and_indices = build_filters(device, self.board_size) 55 | self.flips = build_flips(num_rays, self.board_size, device) 56 | 57 | self.consecutive_passes = torch.zeros((self.parallel_envs, ), dtype=torch.long, device=device, requires_grad=False) 58 | self.legal_actions = torch.zeros((self.parallel_envs, (self.board_size ** 2) + 1), dtype=torch.bool, device=device, requires_grad=False) 59 | 60 | self.need_to_calculate_rays = True 61 | 62 | 63 | if self.debug: 64 | self.get_legal_actions_traced = get_legal_actions 65 | self.push_actions_traced = push_actions 66 | else: 67 | self.get_legal_actions_traced = torch.jit.trace(get_legal_actions, (self.states, self.ray_tensor, self.legal_actions, *self.filters_and_indices)) # type: ignore 68 | self.push_actions_traced = torch.jit.trace(push_actions, (self.states, self.ray_tensor, torch.zeros((self.parallel_envs, ), dtype=torch.long, device=device), self.flips)) # type: ignore 69 | 70 | self.book_opening_actions = None 71 | if config.book is not None: 72 | self.book_opening_actions = self.parse_opening_book(config.book) 73 | 74 | self.reset() 75 | 76 | 77 | def parse_opening_book(self, path_to_book): 78 | with open(path_to_book, 'r') as f: 79 | lines = f.readlines() 80 | lines = [l.strip() for l in lines if l.strip() != ''] 81 | ind = 0 82 | book_actions = [] 83 | while ind < len(lines[0]): 84 | 85 | actions = torch.tensor([coord_to_action_id(line[ind:ind+2]) for line in lines], dtype=torch.long, device=self.device) 86 | book_actions.append(actions) 87 | ind += 2 88 | return torch.stack(book_actions) 89 | 90 | 91 | def get_legal_actions(self): 92 | if self.need_to_calculate_rays: 93 | self.need_to_calculate_rays = False 94 | return self.get_legal_actions_traced(self.states, self.ray_tensor, self.legal_actions, *self.filters_and_indices) # type: ignore 95 | else: 96 | return self.legal_actions 97 | 98 | def push_actions(self, actions): 99 | if self.need_to_calculate_rays: 100 | self.get_legal_actions() # updates ray tensor 101 | _, passes = self.push_actions_traced(self.states, self.ray_tensor, actions, self.flips) # type: ignore 102 | self.consecutive_passes += passes 103 | self.consecutive_passes *= passes 104 | self.need_to_calculate_rays = True 105 | 106 | def next_turn(self): 107 | self.states = torch.roll(self.states, 1, dims=1) 108 | self.next_player() 109 | 110 | def reset(self, seed=None) -> int: 111 | if seed is not None: 112 | torch.manual_seed(seed) 113 | seed = 0 114 | self.states.zero_() 115 | self.ray_tensor.zero_() 116 | self.terminated.zero_() 117 | self.cur_players.zero_() 118 | self.consecutive_passes.zero_() 119 | self.legal_actions.zero_() 120 | self.states[:, 0, 3, 4] = 1 121 | self.states[:, 1, 3, 3] = 1 122 | self.states[:, 1, 4, 4] = 1 123 | self.states[:, 0, 4, 3] = 1 124 | self.need_to_calculate_rays = True 125 | if self.book_opening_actions is not None: 126 | opening_ids = torch.randint(0, self.book_opening_actions.shape[1], (self.parallel_envs, )) 127 | for i in range(self.book_opening_actions.shape[0]): 128 | self.step(self.book_opening_actions[i, opening_ids]) 129 | return seed 130 | 131 | 132 | def is_terminal(self): 133 | return (self.states.sum(dim=(1, 2, 3)) == (self.board_size ** 2)) | (self.consecutive_passes >= 2) 134 | 135 | def update_terminated(self): 136 | super().update_terminated() 137 | 138 | def get_rewards(self, player_ids: Optional[torch.Tensor] = None): 139 | if player_ids is None: 140 | player_ids = self.cur_players 141 | idx = ((player_ids == self.cur_players).int() - 1) % 2 142 | other_idx = 1 - idx 143 | 144 | p1_sum = self.states[self.env_indices, idx].sum(dim=(1, 2)) 145 | p2_sum = self.states[self.env_indices, other_idx].sum(dim=(1, 2)) 146 | rewards = (1 * (p1_sum > p2_sum)) + (0.5 * (p1_sum == p2_sum)) 147 | return rewards 148 | 149 | def reset_terminated_states(self, seed: Optional[int] = None) -> int: 150 | if seed is not None: 151 | torch.manual_seed(seed) 152 | seed = 0 153 | terminated = self.terminated.clone() 154 | self.states *= 1 * ~terminated.view(-1, 1, 1, 1) 155 | self.cur_players *= 1 * ~terminated 156 | self.consecutive_passes *= 1 * ~terminated 157 | mask = 1 * terminated 158 | self.states[:, 0, 3, 4] += mask 159 | self.states[:, 1, 3, 3] += mask 160 | self.states[:, 1, 4, 4] += mask 161 | self.states[:, 0, 4, 3] += mask 162 | 163 | saved = self.save_node() 164 | if self.book_opening_actions is not None: 165 | opening_ids = torch.randint(0, self.book_opening_actions.shape[1], (self.parallel_envs, )) 166 | for i in range(self.book_opening_actions.shape[0]): 167 | actions = self.book_opening_actions[i, opening_ids] 168 | actions[~terminated] = 64 169 | self.step(self.book_opening_actions[i, opening_ids]) 170 | self.load_node(~terminated, saved) 171 | self.need_to_calculate_rays = True 172 | self.terminated.zero_() 173 | return seed 174 | 175 | def save_node(self): 176 | return ( 177 | self.states.clone(), 178 | self.cur_players.clone(), 179 | self.consecutive_passes.clone() 180 | ) 181 | 182 | def load_node(self, load_envs: torch.Tensor, saved: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): 183 | load_envs_expnd = load_envs.view(-1, 1, 1, 1) 184 | self.states = saved[0].clone() * load_envs_expnd + self.states * (~load_envs_expnd) 185 | self.cur_players = saved[1].clone() * load_envs + self.cur_players * (~load_envs) 186 | self.consecutive_passes = saved[2].clone() * load_envs + self.consecutive_passes * (~load_envs) 187 | self.need_to_calculate_rays = True 188 | self.update_terminated() 189 | 190 | def get_greedy_rewards(self, player_ids: Optional[torch.Tensor] = None, heuristic: str = 'num_tiles'): 191 | if heuristic == 'minmax_moves': 192 | legal_actions_sum = self.get_legal_actions().sum(dim=1) 193 | return (((legal_actions_sum) * (self.cur_players == player_ids)) + (self.policy_shape[0] - legal_actions_sum) * (self.cur_players != player_ids)) / self.policy_shape[0] 194 | elif heuristic == 'num_tiles': 195 | if player_ids is None: 196 | player_ids = self.cur_players 197 | idx = ((player_ids == self.cur_players).int() - 1) % 2 198 | other_idx = 1 - idx 199 | return 0.5 + ((self.states[self.env_indices, idx].sum(dim=(1, 2)) - self.states[self.env_indices, other_idx].sum(dim=(1, 2))) / (2 * (self.board_size ** 2))) 200 | elif heuristic == 'corners': 201 | if player_ids is None: 202 | player_ids = self.cur_players 203 | idx = ((player_ids == self.cur_players).int() - 1) % 2 204 | other_idx = 1 - idx 205 | top_left_corner = self.states[self.env_indices, idx, 0, 0] - self.states[self.env_indices, other_idx, 0, 0] 206 | top_right_corner = self.states[self.env_indices, idx, 0, self.board_size - 1] - self.states[self.env_indices, other_idx, 0, self.board_size - 1] 207 | bottom_left_corner = self.states[self.env_indices, idx, self.board_size - 1, 0] - self.states[self.env_indices, other_idx, self.board_size - 1, 0] 208 | bottom_right_corner = self.states[self.env_indices, idx, self.board_size - 1, self.board_size - 1] - self.states[self.env_indices, other_idx, self.board_size - 1, self.board_size - 1] 209 | return 0.5 + ((top_left_corner + top_right_corner + bottom_left_corner + bottom_right_corner) / 8) 210 | elif heuristic == 'corners_and_edges': 211 | if player_ids is None: 212 | player_ids = self.cur_players 213 | idx = ((player_ids == self.cur_players).int() - 1) % 2 214 | other_idx = 1 - idx 215 | edge = self.states[self.env_indices, idx, 0, :] - self.states[self.env_indices, other_idx, 0, :] 216 | edge += self.states[self.env_indices, idx, :, 0] - self.states[self.env_indices, other_idx, :, 0] 217 | edge += self.states[self.env_indices, idx, :, self.board_size - 1] - self.states[self.env_indices, other_idx, :, self.board_size - 1] 218 | edge += self.states[self.env_indices, idx, self.board_size - 1, :] - self.states[self.env_indices, other_idx, self.board_size - 1, :] 219 | # corners are counted twice, but corners are good, so it's fine! 220 | circumference = (2 * self.board_size) + (2 * (self.board_size - 2)) 221 | return 0.5 + (edge / (2 * circumference)) 222 | 223 | else: 224 | raise NotImplementedError(f'Heuristic {heuristic} not implemented for OthelloEnv') 225 | 226 | def print_state(self, last_action: Optional[int] = None) -> None: 227 | envstr = [] 228 | assert self.parallel_envs == 1 229 | cur_player_is_o = self.cur_players[0] == 0 230 | cur_player = 'O' if cur_player_is_o else 'X' 231 | other_player = 'X' if cur_player_is_o else 'O' 232 | envstr.append('+' + '---+' * (self.config.board_size)) 233 | envstr.append('\n') 234 | legal_actions = set(self.get_legal_actions()[0].nonzero().flatten().tolist()) 235 | for i in range(self.config.board_size): 236 | for j in range(self.config.board_size): 237 | action_idx = i*self.config.board_size + j 238 | color = Fore.RED if cur_player_is_o else Fore.GREEN 239 | other_color = Fore.GREEN if cur_player_is_o else Fore.RED 240 | if action_idx == last_action: 241 | color = Fore.BLUE 242 | other_color = Fore.BLUE 243 | if action_idx in legal_actions: 244 | envstr.append('|' + Fore.YELLOW + f'{action_idx}'.rjust(3)) 245 | elif self.states[0,0,i,j] == 1: 246 | envstr.append('|' + color + f' {cur_player} '.rjust(3)) 247 | elif self.states[0,1,i,j] == 1: 248 | envstr.append('|' + other_color + f' {other_player} '.rjust(3)) 249 | else: 250 | envstr.append(Fore.RESET + '| ') 251 | envstr.append(Fore.RESET) 252 | envstr.append('|') 253 | envstr.append('\n') 254 | envstr.append('+' + '---+' * (self.config.board_size)) 255 | envstr.append('\n') 256 | print(''.join(envstr)) 257 | -------------------------------------------------------------------------------- /envs/othello/evaluators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/evaluators/__init__.py -------------------------------------------------------------------------------- /envs/othello/evaluators/edax.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | import logging 4 | from subprocess import PIPE, STDOUT, Popen 5 | from typing import Optional, Set 6 | from core.algorithms.evaluator import Evaluator, EvaluatorConfig 7 | from core.env import Env 8 | from envs.othello.env import OthelloEnv, coord_to_action_id 9 | import torch 10 | 11 | # this is an adapted version of EdaxPlayer: https://github.com/2Bear/othello-zero/blob/master/api.py 12 | 13 | @dataclass 14 | class EdaxConfig(EvaluatorConfig): 15 | edax_path: str 16 | edax_weights_path: str 17 | edax_book_path: str 18 | level: int 19 | 20 | 21 | 22 | class Edax(Evaluator): 23 | def __init__(self, env: Env, config: EdaxConfig, *args, **kwargs): 24 | assert isinstance(env, OthelloEnv), "EdaxPlayer only supports OthelloEnv" 25 | super().__init__(env, config, *args, **kwargs) 26 | self.edax_exec = config.edax_path + " -q -eval-file " + config.edax_weights_path \ 27 | + " -level " + str(config.level) \ 28 | + " -n 4" 29 | self.edax_procs = [] 30 | 31 | 32 | def start_procs(self): 33 | self.edax_procs = [ 34 | Popen(self.edax_exec, shell=True, stdout=PIPE, stdin=PIPE, stderr=STDOUT) for i in range(self.env.parallel_envs) 35 | ] 36 | 37 | def reset(self, seed: Optional[int] = None) -> int: 38 | seed = super().reset(seed) 39 | for proc in self.edax_procs: 40 | proc.terminate() 41 | self.start_procs() 42 | self.read_stdout() 43 | for i in range(len(self.edax_procs)): 44 | self.write_to_proc(f"setboard {self.stringify_board(i)}", i) 45 | self.read_stdout() 46 | return seed 47 | 48 | def stringify_board(self, index): 49 | cur_player = self.env.cur_players[index] 50 | board = self.env.states[index] 51 | black_pieces = board[0] if cur_player == 0 else board[1] 52 | white_pieces = board[1] if cur_player == 0 else board[0] 53 | board_str = [] 54 | for i in range(64): 55 | r, c = i // 8, i % 8 56 | if black_pieces[r][c] == 1: 57 | board_str.append('b') 58 | elif white_pieces[r][c] == 1: 59 | board_str.append('w') 60 | else: 61 | board_str.append('-') 62 | return ''.join(board_str) 63 | 64 | 65 | def step_evaluator(self, actions, terminated): 66 | # push other evaluators actions to edax 67 | for i, action in enumerate(actions.tolist()): 68 | if action == 64: 69 | self.write_to_proc("PS", i) 70 | else: 71 | self.write_to_proc(self.action_id_to_coord(action), i) 72 | self.read_stdout() 73 | self.reset_evaluator_states(terminated) 74 | 75 | def reset_evaluator_states(self, evals_to_reset: torch.Tensor) -> None: 76 | read_from_procs = set() 77 | for i, t in enumerate(evals_to_reset): 78 | if t: 79 | self.edax_procs[i].terminate() 80 | self.edax_procs[i] = Popen(self.edax_exec, shell=True, stdout=PIPE, stdin=PIPE, stderr=STDOUT) 81 | read_from_procs.add(i) 82 | if read_from_procs: 83 | self.read_stdout(read_from_procs) 84 | for i in read_from_procs: 85 | self.write_to_proc(f"setboard {self.stringify_board(i)}", i) 86 | self.read_stdout(read_from_procs) 87 | 88 | def evaluate(self, *args): 89 | for i in range(len(self.edax_procs)): 90 | self.write_to_proc("go", i) 91 | results = self.read_stdout() 92 | actions = [] 93 | for i, result in enumerate(results): 94 | if result != '\n\n*** Game Over ***\n': 95 | move = coord_to_action_id(result.split("plays ")[-1][:2]) 96 | actions.append(move) 97 | else: 98 | # logging.info(f'PROCESS {i}: EDAX gives GAME OVER') 99 | actions.append(64) 100 | actions = torch.tensor(actions, dtype=torch.long, device=self.device) 101 | return torch.nn.functional.one_hot(actions, self.env.policy_shape[0]).float(), None 102 | 103 | def step_env(self, actions): 104 | terminated = self.env.step(actions) 105 | return terminated 106 | 107 | def write_to_proc(self, command, proc_id): 108 | self.edax_procs[proc_id].stdin.write(str.encode(command + "\n")) 109 | self.edax_procs[proc_id].stdin.flush() 110 | 111 | @staticmethod 112 | def action_id_to_coord(action_id): 113 | letter = chr(ord('a') + action_id % 8) 114 | number = str(action_id // 8 + 1) 115 | if action_id == 64: 116 | return "PS" 117 | 118 | return letter + number 119 | 120 | def write_stdin(self, command): 121 | self.edax.stdin.write(str.encode(command + "\n")) 122 | self.edax.stdin.flush() 123 | 124 | def read_stdout(self, proc_ids: Optional[Set[int]] = None): 125 | outputs = [] 126 | for i in range(len(self.edax_procs)): 127 | out = b'' 128 | if proc_ids is None or i in proc_ids: 129 | while True: 130 | next_b = self.edax_procs[i].stdout.read(1) 131 | if next_b == b'>' and ((len(out) > 0 and out[-1] == 10) or len(out) == 0): 132 | break 133 | else: 134 | out += next_b 135 | outputs.append(out) 136 | decoded = [o.decode("utf-8") for o in outputs] 137 | return decoded 138 | 139 | def close(self): 140 | for p in self.edax_procs: 141 | p.terminate() 142 | -------------------------------------------------------------------------------- /envs/othello/misc/convolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/misc/convolution.png -------------------------------------------------------------------------------- /envs/othello/misc/example_kernel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/misc/example_kernel.png -------------------------------------------------------------------------------- /envs/othello/misc/example_kernel2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/misc/example_kernel2.png -------------------------------------------------------------------------------- /envs/othello/misc/legal_starting_actions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/misc/legal_starting_actions.png -------------------------------------------------------------------------------- /envs/othello/misc/othello_sandwich_menu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/envs/othello/misc/othello_sandwich_menu.png -------------------------------------------------------------------------------- /envs/othello/tester.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from core.test.tester import TwoPlayerTester 5 | 6 | 7 | class OthelloTester(TwoPlayerTester): 8 | def add_evaluation_metrics(self, episodes): 9 | if self.history is not None: 10 | for _ in episodes: 11 | self.history.add_evaluation_data({}, log=self.log_results) -------------------------------------------------------------------------------- /envs/othello/torchscripts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_filters(device: torch.device, board_size: int): 5 | num_filters = (board_size - 2) * 8 6 | filters = torch.zeros((num_filters + 1, 2, board_size, board_size), dtype=torch.float32, device=device, requires_grad=False) 7 | 8 | index = 1 9 | top_left_indices = [] 10 | top_right_indices = [] 11 | bottom_left_indices = [] 12 | bottom_right_indices = [] 13 | close_to = (torch.arange(-1, num_filters, dtype=torch.long, device=device) // 8) + 2 14 | 15 | for i in range(2, board_size): 16 | # right 17 | filters[index, 1, 0, 1:i] = 1 18 | filters[index, 0, 0, i] = 1 19 | filters[index, :, 0, 0] = -1 20 | top_left_indices.append(index) 21 | index += 1 22 | # left 23 | filters[index, 1, -1, -i:-1] = 1 24 | filters[index, 0, -1, -i-1] = 1 25 | filters[index, :, -1, -1] = -1 26 | bottom_right_indices.append(index) 27 | index += 1 28 | # down 29 | filters[index, 1, 1:i, 0] = 1 30 | filters[index, 0, i, 0] = 1 31 | filters[index, :, 0, 0] = -1 32 | top_left_indices.append(index) 33 | index += 1 34 | # up 35 | filters[index, 1, -i:-1, -1] = 1 36 | filters[index, 0, -i-1, -1] = 1 37 | filters[index, :, -1, -1] = -1 38 | bottom_right_indices.append(index) 39 | index += 1 40 | 41 | for j in range(1, i): 42 | filters[index, 1, j, j] = 1 # down right 43 | filters[index+1, 1, -j-1, j] = 1 # up right 44 | filters[index+2, 1, -j-1, -j-1] = 1 # up left 45 | filters[index+3, 1, j, -j-1] = 1 # down left 46 | 47 | # down right 48 | filters[index, 0, i, i] = 1 49 | filters[index, :, 0, 0] = -1 50 | top_left_indices.append(index) 51 | index += 1 52 | # up right 53 | filters[index, 0, -1-i, i] = 1 54 | filters[index, :, -1, 0] = -1 55 | bottom_left_indices.append(index) 56 | index += 1 57 | # up left 58 | filters[index, 0, -1-i, -1-i] = 1 59 | filters[index, :, -1, -1] = -1 60 | bottom_right_indices.append(index) 61 | index += 1 62 | # down left 63 | filters[index, 0, i, -1-i] = 1 64 | filters[index, :, 0, -1] = -1 65 | top_right_indices.append(index) 66 | index += 1 67 | 68 | 69 | return filters, \ 70 | torch.tensor(bottom_left_indices, device=device, requires_grad=False), \ 71 | torch.tensor(bottom_right_indices, device=device, requires_grad=False), \ 72 | torch.tensor(top_left_indices, device=device, requires_grad=False), \ 73 | torch.tensor(top_right_indices, device=device, requires_grad=False), \ 74 | close_to.view(1, -1, 1, 1) 75 | 76 | def build_flips(num_rays, states_size, device): 77 | flips = torch.zeros((num_rays, states_size, states_size, states_size, states_size), device=device, requires_grad=False, dtype=torch.float32) 78 | f_index = 1 79 | for i in range(2, states_size): 80 | for x in range(states_size): 81 | for y in range(states_size): 82 | # right, left, down, up 83 | if x+1 < states_size: 84 | flips[f_index, y, x, y, x+1:min(x+i, states_size)] = 1 85 | flips[f_index+1, y, x, y, max(x-i+1, 0):x] = 1 86 | if y+1 < states_size: 87 | flips[f_index+2, y, x, y+1:min(y+i, states_size), x] = 1 88 | flips[f_index+3, y, x, max(y-i+1, 0):y, x] = 1 89 | 90 | # diag right down, diag left down, diag left up, diag right up 91 | for j in range(1, i): 92 | if y+j < states_size: 93 | if x+j < states_size: 94 | flips[f_index+4, y, x, y+j, x+j] = 1 95 | if x-j >= 0: 96 | flips[f_index+7, y, x, y+j, x-j] = 1 97 | if y-j >= 0: 98 | if x-j >= 0: 99 | flips[f_index+6, y, x, y-j, x-j] = 1 100 | if x+j < states_size: 101 | flips[f_index+5, y, x, y-j, x+j] = 1 102 | f_index += 8 103 | return flips 104 | 105 | 106 | 107 | def get_legal_actions(states, ray_tensor, legal_actions, filters, bl_idx, br_idx, tl_idx, tr_idx, ct): 108 | board_size = int(states.shape[-1]) # need to wrap in int() for tracing 109 | conv_results = torch.nn.functional.conv2d(states, filters, padding=board_size-1, bias=None) 110 | ray_tensor.zero_() 111 | ray_tensor[:, tl_idx] = conv_results[:, tl_idx, board_size-1:, board_size-1:] 112 | ray_tensor[:, tr_idx] = conv_results[:, tr_idx, board_size-1:, :-(board_size-1)] 113 | ray_tensor[:, bl_idx] = conv_results[:, bl_idx, :-(board_size-1), board_size-1:] 114 | ray_tensor[:, br_idx] = conv_results[:, br_idx, :-(board_size-1), :-(board_size-1)] 115 | ray_tensor[:] = (ray_tensor.round() == ct).float() 116 | legal_actions.zero_() 117 | legal_actions[:,:board_size**2] = ray_tensor.any(dim=1).view(-1, board_size ** 2) 118 | legal_actions[:,board_size**2] = ~(legal_actions.any(dim=1)) 119 | return legal_actions 120 | 121 | 122 | def push_actions(states, ray_tensor, actions, flips): 123 | num_rays = ray_tensor.shape[1] 124 | states_size = states.shape[-1] 125 | num_states = states.shape[0] 126 | state_indices = torch.arange(num_states, device=states.device, requires_grad=False, dtype=torch.long) 127 | 128 | is_not_null = actions != states_size ** 2 129 | action_ys, action_xs = actions // states_size, actions % states_size 130 | action_ys *= is_not_null # puts null action in-bounds 131 | action_xs *= is_not_null 132 | activated_rays = ray_tensor[state_indices, :, action_ys, action_xs] * (torch.arange(num_rays, device=states.device, requires_grad=False).unsqueeze(0)) * is_not_null.view(-1, 1) 133 | 134 | flips_to_apply = flips[activated_rays.long(), action_ys.unsqueeze(1), action_xs.unsqueeze(1)].amax(dim=1) * is_not_null.view(-1, 1, 1) 135 | 136 | states[:, 0, :, :].logical_or_(flips_to_apply) 137 | states[:, 1, :, :] *= torch.logical_not(flips_to_apply) 138 | states[state_indices, 0, action_ys, action_xs] += is_not_null.float() 139 | return states, ~is_not_null -------------------------------------------------------------------------------- /envs/othello/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | from pathlib import Path 4 | from typing import List, Optional 5 | import torch 6 | import logging 7 | from core.test.tester import TwoPlayerTester 8 | from core.train.trainer import Trainer, TrainerConfig 9 | from core.utils.history import Metric, TrainingMetrics 10 | from core.resnet import TurboZeroResnet 11 | from core.utils.memory import GameReplayMemory 12 | from envs.othello.collector import OthelloCollector 13 | from core.resnet import reset_model_weights 14 | 15 | 16 | 17 | class OthelloTrainer(Trainer): 18 | def __init__(self, 19 | config: TrainerConfig, 20 | collector: OthelloCollector, 21 | tester: TwoPlayerTester, 22 | model: torch.nn.Module, 23 | optimizer: torch.optim.Optimizer, 24 | device: torch.device, 25 | raw_train_config: dict, 26 | raw_env_config: dict, 27 | history: TrainingMetrics, 28 | log_results: bool = True, 29 | interactive: bool = True, 30 | run_tag: str = 'othello', 31 | debug: bool = False 32 | ): 33 | super().__init__( 34 | config=config, 35 | collector=collector, 36 | tester=tester, 37 | model=model, 38 | optimizer=optimizer, 39 | device=device, 40 | raw_train_config=raw_train_config, 41 | raw_env_config=raw_env_config, 42 | history=history, 43 | log_results=log_results, 44 | interactive=interactive, 45 | run_tag=run_tag, 46 | debug=debug 47 | ) 48 | 49 | def add_collection_metrics(self, episodes): 50 | for _ in episodes: 51 | self.history.add_episode_data({}, log=self.log_results) 52 | 53 | def add_epoch_metrics(self): 54 | pass 55 | 56 | -------------------------------------------------------------------------------- /example_configs/2048_gpu.yaml: -------------------------------------------------------------------------------- 1 | run_tag: '2048' 2 | env_config: { 3 | env_type: '2048' 4 | } 5 | model_config: { # overwritten if loading a checkpoint 6 | res_channels: 16, 7 | res_blocks: 6, 8 | kernel_size: 3, 9 | value_fc_size: 32, 10 | value_output_activation: "" 11 | } 12 | train_mode_config: { 13 | algo_config: { 14 | name: 'lazyzero', 15 | temperature: 1.0, 16 | num_policy_rollouts: 500, 17 | rollout_depth: 3, 18 | puct_coeff: 1.0 19 | }, 20 | learning_rate: 0.1, 21 | momentum: 0.8, 22 | c_reg: 0.0001, 23 | replay_memory_max_size: 20000, 24 | replay_memory_min_size: 20000, 25 | parallel_envs: 8192, 26 | policy_factor: 1.0, 27 | minibatch_size: 4096, 28 | episodes_per_epoch: 20000, 29 | episodes_per_minibatch: 1, 30 | test_config: { 31 | algo_config: { 32 | name: 'lazyzero', 33 | temperature: 0.0, 34 | num_policy_rollouts: 500, 35 | rollout_depth: 3, 36 | puct_coeff: 1.0 37 | }, 38 | episodes_per_epoch: 1000 39 | } 40 | } 41 | test_mode_config: { 42 | algo_config: { 43 | name: 'lazyzero', 44 | temperature: 0.0, 45 | num_policy_rollouts: 500, 46 | rollout_depth: 3, 47 | puct_coeff: 1.0 48 | }, 49 | episodes_per_epoch: 1000 50 | } 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /example_configs/2048_mini.yaml: -------------------------------------------------------------------------------- 1 | run_tag: '2048_mini' 2 | env_config: { 3 | env_type: '2048', 4 | } 5 | model_config: { # overwritten if loading a checkpoint 6 | res_channels: 16, 7 | res_blocks: 6, 8 | kernel_size: 3, 9 | value_fc_size: 32, 10 | value_output_activation: "tanh0to1" 11 | } 12 | train_mode_config: { 13 | algo_config: { 14 | name: "lazyzero", 15 | temperature: 1.0, 16 | num_policy_rollouts: 300, 17 | rollout_depth: 3, 18 | puct_coeff: 1.0, 19 | }, 20 | learning_rate: 0.1, 21 | momentum: 0.8, 22 | c_reg: 0.0001, 23 | lr_decay_gamma: 0.9, 24 | replay_memory_max_size: 5000, 25 | replay_memory_min_size: 5000, 26 | parallel_envs: 4096, 27 | policy_factor: 1.0, 28 | minibatch_size: 1024, 29 | episodes_per_epoch: 5000, 30 | episodes_per_minibatch: 1, 31 | test_config: { 32 | algo_config: { 33 | name: "lazyzero", 34 | temperature: 0.0, 35 | num_policy_rollouts: 300, 36 | rollout_depth: 3, 37 | puct_coeff: 1.0, 38 | }, 39 | episodes_per_epoch: 0 40 | } 41 | } -------------------------------------------------------------------------------- /example_configs/2048_test.yaml: -------------------------------------------------------------------------------- 1 | env_config: { 2 | env_type: '2048', 3 | } 4 | test_config: { 5 | algo_config: { 6 | name: "lazyzero", 7 | temperature: 0.0, 8 | num_policy_rollouts: 300, 9 | rollout_depth: 3, 10 | puct_coeff: 1.0, 11 | }, 12 | episodes_per_epoch: 32 13 | } -------------------------------------------------------------------------------- /example_configs/2048_tiny.yaml: -------------------------------------------------------------------------------- 1 | run_tag: '2048_tiny' 2 | env_config: { 3 | env_type: '2048', 4 | } 5 | model_config: { # overwritten if loading a checkpoint 6 | res_channels: 8, 7 | res_blocks: 2, 8 | kernel_size: 3, 9 | value_fc_size: 32, 10 | value_output_activation: "tanh0to1" 11 | } 12 | train_mode_config: { 13 | algo_config: { 14 | name: "lazyzero", 15 | temperature: 1.0, 16 | num_policy_rollouts: 100, 17 | rollout_depth: 2, 18 | puct_coeff: 1.0, 19 | }, 20 | learning_rate: 0.1, 21 | momentum: 0.8, 22 | c_reg: 0.0001, 23 | lr_decay_gamma: 0.9, 24 | replay_memory_max_size: 1000, 25 | replay_memory_min_size: 10, 26 | parallel_envs: 32, 27 | policy_factor: 1.0, 28 | minibatch_size: 256, 29 | episodes_per_epoch: 32, 30 | episodes_per_minibatch: 1, 31 | test_config: { 32 | algo_config: { 33 | name: "lazyzero", 34 | temperature: 0.0, 35 | num_policy_rollouts: 100, 36 | rollout_depth: 2, 37 | puct_coeff: 1.0, 38 | }, 39 | episodes_per_epoch: 0 40 | } 41 | } -------------------------------------------------------------------------------- /example_configs/othello_baselines.yaml: -------------------------------------------------------------------------------- 1 | env_config: { 2 | env_type: 'othello', 3 | board_size: 8 4 | } 5 | test_mode_config: { 6 | algo_config: { 7 | name: "alphazero", 8 | temperature: 0.0, 9 | num_iters: 400, 10 | max_nodes: 400, 11 | puct_coeff: 1.0, 12 | dirichlet_alpha: 0.1, 13 | dirichlet_epsilon: 0.0 14 | }, 15 | episodes_per_epoch: 64, 16 | baselines: [ 17 | { 18 | name: 'greedy', 19 | heuristic: 'corners' 20 | }, 21 | { 22 | name: 'greedy', 23 | heuristic: 'corners_and_edges' 24 | }, 25 | { 26 | name: 'greedy', 27 | heuristic: 'minmax_moves' 28 | }, 29 | { 30 | name: 'greedy', 31 | heuristic: 'num_tiles' 32 | }, 33 | ] 34 | } -------------------------------------------------------------------------------- /example_configs/othello_demo.yaml: -------------------------------------------------------------------------------- 1 | env_config: { 2 | env_type: 'othello', 3 | board_size: 8 4 | } 5 | demo_config: { 6 | evaluator1_config: { 7 | algo_config: { 8 | name: 'greedy_mcts', 9 | heuristic: 'num_tiles', 10 | num_iters: 100, 11 | max_nodes: 100, 12 | puct_coeff: 1.0, 13 | dirichlet_alpha: 0.1, 14 | dirichlet_epsilon: 0.0 15 | } 16 | }, 17 | evaluator2_config: { 18 | algo_config: { 19 | name: 'human' 20 | } 21 | }, 22 | manual_step: False 23 | } -------------------------------------------------------------------------------- /example_configs/othello_gpu.yaml: -------------------------------------------------------------------------------- 1 | run_tag: 'othello' 2 | env_config: { 3 | env_type: 'othello', 4 | board_size: 8 5 | } 6 | model_config: { # overwritten if loading a checkpoint 7 | res_channels: 128, 8 | res_blocks: 12, 9 | kernel_size: 3, 10 | value_fc_size: 256, 11 | value_output_activation: "tanh0to1" 12 | } 13 | train_mode_config: { 14 | algo_config: { 15 | name: 'alphazero', 16 | temperature: 0.9, 17 | num_iters: 2500, 18 | max_nodes: 1000, 19 | puct_coeff: 1.0, 20 | dirichlet_alpha: 0.45, 21 | dirichlet_epsilon: 0.25 22 | }, 23 | learning_rate: 0.02, 24 | lr_decay_gamma: 0.9, 25 | momentum: 0.9, 26 | c_reg: 0.0001, 27 | replay_memory_max_size: 75000, 28 | replay_memory_min_size: 75000, 29 | parallel_envs: 8192, 30 | policy_factor: 1.0, 31 | minibatch_size: 2048, 32 | episodes_per_epoch: 25000, 33 | episodes_per_minibatch: 25, 34 | test_config: { 35 | algo_config: { 36 | name: 'alphazero', 37 | temperature: 0.0, 38 | num_iters: 2500, 39 | max_nodes: 1000, 40 | puct_coeff: 1.0, 41 | dirichlet_alpha: 0.01, 42 | dirichlet_epsilon: 0.0 43 | }, 44 | episodes_per_epoch: 2048, 45 | baselines: [ 46 | { 47 | name: 'greedy', 48 | heuristic: 'minmax_moves' 49 | } 50 | ] 51 | } 52 | } -------------------------------------------------------------------------------- /example_configs/othello_mini.yaml: -------------------------------------------------------------------------------- 1 | run_tag: 'othello_mini' 2 | env_config: { 3 | env_type: 'othello', 4 | board_size: 8 5 | } 6 | model_config: { # overwritten if loading a checkpoint 7 | res_channels: 32, 8 | res_blocks: 6, 9 | kernel_size: 3, 10 | value_fc_size: 32, 11 | value_output_activation: "tanh0to1" 12 | } 13 | train_mode_config: { 14 | algo_config: { 15 | name: "alphazero", 16 | temperature: 1.0, 17 | num_iters: 1000, 18 | max_nodes: 400, 19 | puct_coeff: 1.0, 20 | dirichlet_alpha: 0.45, 21 | dirichlet_epsilon: 0.25 22 | }, 23 | learning_rate: 0.1, 24 | lr_decay_gamma: 0.8, 25 | momentum: 0.9, 26 | c_reg: 0.0001, 27 | replay_memory_max_size: 10000, 28 | replay_memory_min_size: 10000, 29 | parallel_envs: 4096, 30 | policy_factor: 1.0, 31 | minibatch_size: 1024, 32 | episodes_per_epoch: 5000, 33 | episodes_per_minibatch: 25, 34 | test_config: { 35 | algo_config: { 36 | name: "alphazero", 37 | temperature: 0.0, 38 | num_iters: 1000, 39 | max_nodes: 400, 40 | puct_coeff: 1.0, 41 | dirichlet_alpha: 0.1, 42 | dirichlet_epsilon: 0.0 43 | }, 44 | episodes_per_epoch: 256, 45 | baselines: [ 46 | { 47 | name: 'random' 48 | }, 49 | { 50 | name: 'greedy', 51 | heuristic: 'corners_and_edges' 52 | } 53 | ] 54 | } 55 | } -------------------------------------------------------------------------------- /example_configs/othello_tiny.yaml: -------------------------------------------------------------------------------- 1 | run_tag: 'othello_tiny' 2 | env_config: { 3 | env_type: 'othello', 4 | board_size: 8 5 | } 6 | model_config: { # overwritten if loading a checkpoint 7 | res_channels: 16, 8 | res_blocks: 4, 9 | kernel_size: 3, 10 | value_fc_size: 32, 11 | value_output_activation: "tanh0to1" 12 | } 13 | train_mode_config: { 14 | algo_config: { 15 | name: "alphazero", 16 | temperature: 1.0, 17 | num_iters: 100, 18 | max_nodes: 100, 19 | puct_coeff: 1.0, 20 | dirichlet_alpha: 0.45, 21 | dirichlet_epsilon: 0.25 22 | }, 23 | learning_rate: 0.1, 24 | lr_decay_gamma: 0.8, 25 | momentum: 0.9, 26 | c_reg: 0.01, 27 | replay_memory_max_size: 1000, 28 | replay_memory_min_size: 1000, 29 | parallel_envs: 256, 30 | policy_factor: 1.0, 31 | minibatch_size: 256, 32 | episodes_per_epoch: 256, 33 | episodes_per_minibatch: 16, 34 | test_config: { 35 | algo_config: { 36 | name: "alphazero", 37 | temperature: 0.0, 38 | num_iters: 100, 39 | max_nodes: 100, 40 | puct_coeff: 1.0, 41 | dirichlet_alpha: 0.1, 42 | dirichlet_epsilon: 0.0 43 | }, 44 | episodes_per_epoch: 256, 45 | baselines: [ 46 | { 47 | name: 'greedy', 48 | heuristic: 'num_tiles' 49 | }, 50 | ] 51 | } 52 | } -------------------------------------------------------------------------------- /example_configs/othello_tournament.yaml: -------------------------------------------------------------------------------- 1 | run_tag: 'othello' 2 | env_config: { 3 | env_type: 'othello', 4 | board_size: 8 5 | } 6 | tournament_mode_config: { 7 | num_games: 32, 8 | num_tournaments: 64, 9 | tournament_name: 'othello_baselines', 10 | competitors: [ 11 | # Random 12 | { 13 | name: 'random', 14 | algo_config: { 15 | name: 'random' 16 | } 17 | }, 18 | # Greedy 19 | { 20 | name: 'greedy_moves', 21 | algo_config: { 22 | name: 'greedy', 23 | heuristic: 'minmax_moves' 24 | } 25 | }, 26 | { 27 | name: 'greedy_tiles', 28 | algo_config: { 29 | name: 'greedy', 30 | heuristic: 'num_tiles' 31 | } 32 | }, 33 | { 34 | name: 'greedy_corners', 35 | algo_config: { 36 | name: 'greedy', 37 | heuristic: 'corners' 38 | } 39 | }, 40 | { 41 | name: 'greedy_corners_and_edges', 42 | algo_config: { 43 | name: 'greedy', 44 | heuristic: 'corners_and_edges' 45 | } 46 | } 47 | ] 48 | } -------------------------------------------------------------------------------- /example_configs/test_tournament.yaml: -------------------------------------------------------------------------------- 1 | run_tag: 'othello' 2 | env_config: { 3 | env_type: 'othello', 4 | board_size: 8 5 | } 6 | tournament_mode_config: { 7 | num_games: 32, 8 | num_tournaments: 128, 9 | competitors: [ 10 | # Random 11 | { 12 | name: 'random', 13 | algo_config: { 14 | name: 'random' 15 | } 16 | }, 17 | { 18 | name: 'greedy_lazymcts_100_corners_and_edges', 19 | algo_config: { 20 | name: 'lazy_greedy_mcts', 21 | num_policy_rollouts: 33, 22 | rollout_depth: 3, 23 | puct_coeff: 1.0, 24 | heuristic: 'corners_and_edges' 25 | }, 26 | }, 27 | 28 | ] 29 | } -------------------------------------------------------------------------------- /example_configs/tournament.yaml: -------------------------------------------------------------------------------- 1 | run_tag: 'othello' 2 | env_config: { 3 | env_type: 'othello', 4 | board_size: 8 5 | } 6 | tournament_mode_config: { 7 | num_games: 1024, 8 | num_tournaments: 128, 9 | competitors: [{ 10 | name: 'lazy_greedy_mcts', 11 | algo_config: { 12 | name: 'lazy_greedy_mcts', 13 | num_policy_rollouts: 64, # number of policy rollouts to run per evaluation call 14 | rollout_depth: 3, # depth of each policy rollout, once this depth is reached, return the network's evaluation (value head) of the state 15 | puct_coeff: 1.0, # C-valu 16 | heuristic: 'minmax_moves' 17 | }}, 18 | # { 19 | # name: 'greedy_tiles', 20 | # algo_config: { 21 | # name: 'greedy', 22 | # heuristic: 'num_tiles' 23 | # } 24 | # }, 25 | # { 26 | # name: 'corners', 27 | # algo_config: { 28 | # name: 'greedy', 29 | # heuristic: 'corners' 30 | # } 31 | # }, 32 | # { 33 | # name: 'corners_and_edges', 34 | # algo_config: { 35 | # name: 'greedy', 36 | # heuristic: 'corners_and_edges' 37 | # } 38 | # }, 39 | { 40 | name: 'random', 41 | algo_config: { 42 | name: 'random' 43 | } 44 | } 45 | # { 46 | # name: 'greedy_mcts_500', 47 | # algo_config: { 48 | # name: 'greedy_mcts', 49 | # num_iters: 500, 50 | # max_nodes: 1000, 51 | # puct_coeff: 1.0, 52 | # dirichlet_alpha: 0.01, 53 | # dirichlet_epsilon: 0.0 54 | # }, 55 | # }, 56 | # { 57 | # name: 'tiny_alphazero_500', 58 | # algo_config: { 59 | # name: 'alphazero', 60 | # temperature: 0.0, 61 | # num_iters: 500, 62 | # max_nodes: 1000, 63 | # puct_coeff: 1.0, 64 | # dirichlet_alpha: 0.01, 65 | # dirichlet_epsilon: 0.0 66 | # }, 67 | # checkpoint: 'tiny.pt' 68 | # }, 69 | # { 70 | # name: 'greedy_mcts_100', 71 | # algo_config: { 72 | # name: 'greedy_mcts', 73 | # num_iters: 100, 74 | # max_nodes: 200, 75 | # puct_coeff: 1.0, 76 | # dirichlet_alpha: 0.01, 77 | # dirichlet_epsilon: 0.0 78 | # }, 79 | # }, 80 | # { 81 | # name: 'tiny_alphazero_100', 82 | # algo_config: { 83 | # name: 'alphazero', 84 | # temperature: 0.0, 85 | # num_iters: 100, 86 | # max_nodes: 200, 87 | # puct_coeff: 1.0, 88 | # dirichlet_alpha: 0.01, 89 | # dirichlet_epsilon: 0.0 90 | # }, 91 | # checkpoint: 'tiny.pt' 92 | # }, 93 | # { 94 | # name: 'big_alphazero_100', 95 | # algo_config: { 96 | # name: 'alphazero', 97 | # temperature: 0.0, 98 | # num_iters: 100, 99 | # max_nodes: 200, 100 | # puct_coeff: 1.0, 101 | # dirichlet_alpha: 0.01, 102 | # dirichlet_epsilon: 0.0 103 | # }, 104 | # checkpoint: 'big.pt' 105 | # }, 106 | # { 107 | # name: 'big_alphazero_500', 108 | # algo_config: { 109 | # name: 'alphazero', 110 | # temperature: 0.0, 111 | # num_iters: 500, 112 | # max_nodes: 1000, 113 | # puct_coeff: 1.0, 114 | # dirichlet_alpha: 0.01, 115 | # dirichlet_epsilon: 0.0 116 | # }, 117 | # checkpoint: 'big.pt' 118 | # }, 119 | # { 120 | # name: 'tiny_alphazero_10', 121 | # algo_config: { 122 | # name: 'alphazero', 123 | # temperature: 0.0, 124 | # num_iters: 10, 125 | # max_nodes: 20, 126 | # puct_coeff: 1.0, 127 | # dirichlet_alpha: 0.01, 128 | # dirichlet_epsilon: 0.0 129 | # }, 130 | # checkpoint: 'big.pt' 131 | # }, 132 | # { 133 | # name: 'big_alphazero_10', 134 | # algo_config: { 135 | # name: 'alphazero', 136 | # temperature: 0.0, 137 | # num_iters: 10, 138 | # max_nodes: 20, 139 | # puct_coeff: 1.0, 140 | # dirichlet_alpha: 0.01, 141 | # dirichlet_epsilon: 0.0 142 | # }, 143 | # checkpoint: 'big.pt' 144 | # }, 145 | 146 | 147 | ] 148 | } -------------------------------------------------------------------------------- /misc/2048.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/2048.gif -------------------------------------------------------------------------------- /misc/2048env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/2048env.png -------------------------------------------------------------------------------- /misc/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/benchmark.png -------------------------------------------------------------------------------- /misc/cpu_bad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/cpu_bad.png -------------------------------------------------------------------------------- /misc/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/heatmap.png -------------------------------------------------------------------------------- /misc/high_tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/high_tile.png -------------------------------------------------------------------------------- /misc/one_lazyzero_iteration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/one_lazyzero_iteration.png -------------------------------------------------------------------------------- /misc/othello_game.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/othello_game.gif -------------------------------------------------------------------------------- /misc/slide_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/slide_conv.png -------------------------------------------------------------------------------- /misc/train_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/train_accuracy.png -------------------------------------------------------------------------------- /misc/train_high_tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/train_high_tile.png -------------------------------------------------------------------------------- /misc/train_moves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/train_moves.png -------------------------------------------------------------------------------- /misc/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowrollr/turbozero_torch/e9283ad5002fcbea95ebe977c7731aa218eb15c6/misc/workflow.png -------------------------------------------------------------------------------- /notebooks/benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 11, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import time\n", 11 | "\n", 12 | "tests = {\n", 13 | " 'devices': [torch.device('cpu')],\n", 14 | " 'parallel_envs': [1, 128, 1024, 8192]\n", 15 | "}\n", 16 | "\n", 17 | "def test_evaluator(n_times, evaluator, device, parallel_envs):\n", 18 | " timings = []\n", 19 | " for _ in range(n_times):\n", 20 | " completed = torch.zeros(parallel_envs, dtype=torch.bool, device=device, requires_grad=False)\n", 21 | " while not completed.all():\n", 22 | " start_time = time.time()\n", 23 | " _, _, _, _, terminated = evaluator.step()\n", 24 | " tot_time = time.time() - start_time\n", 25 | " timings.append(tot_time)\n", 26 | " completed |= terminated\n", 27 | " \n", 28 | " return parallel_envs / (sum(timings) / len(timings)) \n" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from core.algorithms.load import init_evaluator\n", 38 | "from envs.load import init_env\n", 39 | "\n", 40 | "\n", 41 | "if torch.cuda.is_available():\n", 42 | " tests['devices'].append(torch.device('cuda'))\n", 43 | "\n", 44 | "results = dict()\n", 45 | "for device in tests['devices']:\n", 46 | " for n_envs in tests['parallel_envs']:\n", 47 | " env = init_env(device, n_envs, {'env_type': 'othello', 'board_size': 8}, False)\n", 48 | " evaluator = init_evaluator({'name': 'random'}, env) \n", 49 | " result = test_evaluator(10, evaluator, device, n_envs)\n", 50 | " results[(str(device), n_envs)] = result" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "import matplotlib.pyplot as plt\n", 60 | "\n", 61 | "plt.bar(range(len(results)), list(results.values()), align='center')\n", 62 | "plt.xticks(range(len(results)), list(results.keys()))\n", 63 | "plt.ylabel('Steps/Second')\n", 64 | "plt.xlabel('Device, # Envs')\n", 65 | "plt.title('Othello Env Benchmark')\n", 66 | "\n", 67 | "plt.show()" 68 | ] 69 | } 70 | ], 71 | "metadata": { 72 | "kernelspec": { 73 | "display_name": "turbozero-mMa0U6zx-py3.10", 74 | "language": "python", 75 | "name": "python3" 76 | }, 77 | "language_info": { 78 | "codemirror_mode": { 79 | "name": "ipython", 80 | "version": 3 81 | }, 82 | "file_extension": ".py", 83 | "mimetype": "text/x-python", 84 | "name": "python", 85 | "nbconvert_exporter": "python", 86 | "pygments_lexer": "ipython3", 87 | "version": "3.10.9" 88 | }, 89 | "orig_nbformat": 4 90 | }, 91 | "nbformat": 4, 92 | "nbformat_minor": 2 93 | } 94 | -------------------------------------------------------------------------------- /notebooks/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from turbozero import *" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "demo = load_demo_nb(config_file='../example_configs/othello_demo.yaml')\n", 19 | "demo.run(print_state=True, interactive=True)" 20 | ] 21 | } 22 | ], 23 | "metadata": { 24 | "kernelspec": { 25 | "display_name": "python3", 26 | "language": "python", 27 | "name": "python3" 28 | }, 29 | "language_info": { 30 | "codemirror_mode": { 31 | "name": "ipython", 32 | "version": 3 33 | }, 34 | "file_extension": ".py", 35 | "mimetype": "text/x-python", 36 | "name": "python", 37 | "nbconvert_exporter": "python", 38 | "pygments_lexer": "ipython3", 39 | "version": "3.8.16" 40 | }, 41 | "orig_nbformat": 4 42 | }, 43 | "nbformat": 4, 44 | "nbformat_minor": 2 45 | } 46 | -------------------------------------------------------------------------------- /notebooks/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from turbozero import *" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "tester = load_tester_nb(\n", 21 | " config_file='../example_configs/othello_baselines.yaml', #TODO: specify a config file\n", 22 | " gpu=torch.cuda.is_available(),\n", 23 | " debug=False,\n", 24 | " logfile='../turbozero.log',\n", 25 | " verbose_logging=True,\n", 26 | " checkpoint='../checkpoints/path/to/your/checkpoint.pt'\n", 27 | ")\n", 28 | "plt.close('all');" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# sometimes JupyterLab stops reporting cell output properly and if something breaks you may lose the stack trace! this ensures it is captured in the log \n", 38 | "try:\n", 39 | " tester.collect_test_batch()\n", 40 | "except Exception as e:\n", 41 | " logging.exception(e) \n", 42 | " raise e" 43 | ] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "python3", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.8.16" 63 | }, 64 | "orig_nbformat": 4 65 | }, 66 | "nbformat": 4, 67 | "nbformat_minor": 2 68 | } 69 | -------------------------------------------------------------------------------- /notebooks/tournament.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from turbozero import *" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "tournament, competitors = load_tournament_nb(\n", 19 | " config_file='../example_configs/othello_tournament.yaml', #TODO: specify a config file\n", 20 | " gpu=torch.cuda.is_available(),\n", 21 | " debug=False,\n", 22 | " logfile='../turbozero.log',\n", 23 | " verbose_logging=True,\n", 24 | " # tournament_checkpoint='path/to/checkpoint' \n", 25 | ")" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "tournament.run(competitors, interactive=True)\n", 35 | "# tournament.simulate_elo(interactive=True)" 36 | ] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "python3", 42 | "language": "python", 43 | "name": "python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.10.9" 56 | }, 57 | "orig_nbformat": 4 58 | }, 59 | "nbformat": 4, 60 | "nbformat_minor": 2 61 | } 62 | -------------------------------------------------------------------------------- /notebooks/train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "from turbozero import *" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "trainer = load_trainer_nb(\n", 21 | " config_file='../example_configs/connectx_tiny.yaml', #TODO: specify a config file\n", 22 | " gpu=torch.cuda.is_available(),\n", 23 | " debug=False,\n", 24 | " logfile='../turbozero.log',\n", 25 | " verbose_logging=True,\n", 26 | " checkpoint='' #TODO: specify a checkpoint to load if you'd like to resume training from a checkpoint\n", 27 | ")\n", 28 | "plt.close('all'); " 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# sometimes JupyterLab stops reporting cell output properly and if something breaks you may lose the stack trace! this ensures it is captured in the log \n", 38 | "try:\n", 39 | " trainer.training_loop(epochs=100)\n", 40 | "except Exception as e:\n", 41 | " logging.exception(e)\n", 42 | " raise e" 43 | ] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "python3", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.10.9" 63 | }, 64 | "orig_nbformat": 4 65 | }, 66 | "nbformat": 4, 67 | "nbformat_minor": 2 68 | } 69 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "turbozero" 3 | version = "0.1.0" 4 | description = "vectorized implementations of alphazero + envs" 5 | authors = ["lowrollr <92640744+lowrollr@users.noreply.github.com>"] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | torch = "=2.0.0" # dependencies are not specified properly for torch==2.0.1 12 | numpy = "^1.25.2" 13 | matplotlib = "^3.7.2" 14 | ipython = "^8.14.0" 15 | colorama = "^0.4.6" 16 | Bottleneck = "^1.3.7" 17 | pyyaml = "^6.0.1" 18 | ipykernel = "^6.25.1" 19 | tqdm = "^4.66.1" 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | ipykernel = "^6.25.1" 24 | 25 | [build-system] 26 | requires = ["poetry-core"] 27 | build-backend = "poetry.core.masonry.api" 28 | -------------------------------------------------------------------------------- /turbozero.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List, Tuple, Union 3 | import torch 4 | import logging 5 | import argparse 6 | import sys 7 | from core.algorithms.load import init_evaluator 8 | from core.demo.demo import Demo 9 | from core.demo.load import init_demo 10 | from core.resnet import ResNetConfig, TurboZeroResnet 11 | from core.test.tester import Tester 12 | from core.test.tournament.tournament import Tournament, TournamentPlayer, load_tournament as load_tournament_checkpoint 13 | from core.train.trainer import Trainer, init_history 14 | from core.utils.checkpoint import load_checkpoint, load_model_and_optimizer_from_checkpoint 15 | import matplotlib.pyplot as plt 16 | import yaml 17 | 18 | from envs.load import init_collector, init_env, init_tester, init_trainer 19 | 20 | def setup_logging(logfile: str): 21 | if logfile: 22 | logging.basicConfig(filename=logfile, filemode='a', level=logging.INFO, format='%(asctime)s %(message)s', force=True) 23 | else: 24 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', handlers=[logging.StreamHandler(sys.stdout)], force=True) 25 | 26 | def load_trainer_nb( 27 | config_file: str, 28 | gpu: bool, 29 | debug: bool, 30 | logfile: str = '', 31 | verbose_logging: bool = True, 32 | checkpoint: str = '' 33 | ) -> Trainer: 34 | args = argparse.Namespace( 35 | config=config_file, 36 | gpu=gpu, 37 | debug=debug, 38 | logfile=logfile, 39 | verbose=verbose_logging, 40 | checkpoint=checkpoint 41 | ) 42 | setup_logging(args.logfile) 43 | 44 | trainer = load_trainer(args, interactive=True) 45 | plt.close('all') 46 | return trainer 47 | 48 | def load_tester_nb( 49 | config_file: str, 50 | gpu: bool, 51 | debug: bool, 52 | logfile: str = '', 53 | verbose_logging: bool = True, 54 | checkpoint: str = '' 55 | ) -> Tester: 56 | args = argparse.Namespace( 57 | config=config_file, 58 | gpu=gpu, 59 | debug=debug, 60 | logfile=logfile, 61 | verbose=verbose_logging, 62 | checkpoint=checkpoint 63 | ) 64 | setup_logging(args.logfile) 65 | tester = load_tester(args, interactive=True) 66 | plt.close('all') 67 | return tester 68 | 69 | def load_tournament_nb( 70 | config_file: str, 71 | gpu: bool, 72 | debug: bool, 73 | logfile: str = '', 74 | tournament_checkpoint: str = '', 75 | verbose_logging: bool = True 76 | ) -> Tuple[Tournament, List[dict]]: 77 | args = argparse.Namespace( 78 | config=config_file, 79 | gpu=gpu, 80 | debug=debug, 81 | logfile=logfile, 82 | verbose=verbose_logging, 83 | checkpoint=tournament_checkpoint 84 | ) 85 | setup_logging(args.logfile) 86 | return load_tournament(args, interactive=True) 87 | 88 | def load_demo_nb( 89 | config_file: str 90 | ) -> Demo: 91 | args = argparse.Namespace( 92 | config=config_file 93 | ) 94 | return load_demo(args) 95 | 96 | 97 | def load_config(config_file: str) -> dict: 98 | if config_file: 99 | with open(config_file, "r") as stream: 100 | raw_config = yaml.safe_load(stream) 101 | else: 102 | print('No config file provided, please provide a config file with --config') 103 | exit(1) 104 | return raw_config 105 | 106 | 107 | def load_trainer(args, interactive: bool) -> Trainer: 108 | raw_config = load_config(args.config) 109 | 110 | if torch.cuda.is_available() and args.gpu: 111 | device = torch.device('cuda') 112 | torch.backends.cudnn.benchmark = True 113 | else: 114 | device = torch.device('cpu') 115 | episode_memory_device = torch.device('cpu') # we do not support episdoe memory on GPU yet 116 | 117 | if args.checkpoint: 118 | checkpoint = load_checkpoint(args.checkpoint) 119 | train_config = checkpoint['raw_train_config'] 120 | env_config = checkpoint['raw_env_config'] 121 | else: 122 | env_config = raw_config['env_config'] 123 | train_config = raw_config['train_mode_config'] 124 | 125 | run_tag = raw_config.get('run_tag', '') 126 | env_type = env_config['env_type'] 127 | parallel_envs_train = train_config['parallel_envs'] 128 | parallel_envs_test = train_config['test_config']['episodes_per_epoch'] 129 | env_train = init_env(device, parallel_envs_train, env_config, args.debug) 130 | env_test = init_env(device, parallel_envs_test, env_config, args.debug) 131 | 132 | if args.checkpoint: 133 | model, optimizer = load_model_and_optimizer_from_checkpoint(checkpoint, env_train, device) 134 | history = checkpoint['history'] 135 | else: 136 | model = TurboZeroResnet(ResNetConfig(**raw_config['model_config']), env_train.state_shape, env_train.policy_shape).to(device) 137 | optimizer = torch.optim.SGD(model.parameters(), lr=raw_config['train_mode_config']['learning_rate'], momentum=raw_config['train_mode_config']['momentum'], weight_decay=raw_config['train_mode_config']['c_reg']) 138 | history = init_history() 139 | 140 | train_evaluator = init_evaluator(train_config['algo_config'], env_train, model) 141 | train_collector = init_collector(episode_memory_device, env_type, train_evaluator) 142 | test_evaluator = init_evaluator(train_config['test_config']['algo_config'], env_test, model) 143 | test_collector = init_collector(episode_memory_device, env_type, test_evaluator) 144 | tester = init_tester(train_config['test_config'], env_type, test_collector, model, history, optimizer, args.verbose, args.debug) 145 | trainer = init_trainer(device, env_type, train_collector, tester, model, optimizer, train_config, env_config, history, args.verbose, interactive, run_tag, debug=args.debug) 146 | return trainer 147 | 148 | def load_tester(args, interactive: bool) -> Tester: 149 | raw_config = load_config(args.config) 150 | test_config = raw_config['test_mode_config'] 151 | 152 | if torch.cuda.is_available() and args.gpu: 153 | device = torch.device('cuda') 154 | torch.backends.cudnn.benchmark = True 155 | else: 156 | device = torch.device('cpu') 157 | episode_memory_device = torch.device('cpu') 158 | parallel_envs = test_config['episodes_per_epoch'] 159 | if args.checkpoint: 160 | checkpoint = load_checkpoint(args.checkpoint) 161 | env_config = checkpoint['raw_env_config'] 162 | env = init_env(device, parallel_envs, env_config, args.debug) 163 | model, _ = load_model_and_optimizer_from_checkpoint(checkpoint, env, device) 164 | history = checkpoint['history'] 165 | 166 | else: 167 | print('No checkpoint provided, please provide a checkpoint with --checkpoint') 168 | exit(1) 169 | 170 | 171 | model = model.to(device) 172 | env_type = env_config['env_type'] 173 | evaluator = init_evaluator(test_config['algo_config'], env, model) 174 | collector = init_collector(episode_memory_device, env_type, evaluator) 175 | tester = init_tester(test_config, env_type, collector, model, history, None, args.verbose, args.debug) 176 | return tester 177 | 178 | def load_tournament(args, interactive: bool) -> Tuple[Tournament, List[dict]]: 179 | raw_config = load_config(args.config) 180 | if torch.cuda.is_available() and args.gpu: 181 | device = torch.device('cuda') 182 | torch.backends.cudnn.deterministic = True 183 | else: 184 | device = torch.device('cpu') 185 | 186 | tournament_config = raw_config['tournament_mode_config'] 187 | if args.checkpoint: 188 | tournament = load_tournament_checkpoint(args.checkpoint, device) 189 | else: 190 | env = init_env(device, tournament_config['num_games'], raw_config['env_config'], args.debug) 191 | tournament_name = tournament_config.get('tournament_name', 'tournament') 192 | tournament = Tournament(env, tournament_config['num_games'], tournament_config['num_tournaments'], device, tournament_name, args.debug) 193 | 194 | competitors = tournament_config['competitors'] 195 | 196 | return tournament, competitors 197 | 198 | def load_demo(args) -> Demo: 199 | raw_config = load_config(args.config) 200 | return init_demo(raw_config['env_config'], raw_config['demo_config'], torch.device('cpu')) 201 | 202 | 203 | if __name__ == '__main__': 204 | parser = argparse.ArgumentParser(prog='TurboZero') 205 | parser.add_argument('--checkpoint', type=str) 206 | parser.add_argument('--mode', type=str, default='demo', choices=['train', 'test', 'demo', 'tournament']) 207 | parser.add_argument('--config', type=str) 208 | parser.add_argument('--gpu', action='store_true') 209 | parser.add_argument('--interactive', action='store_true') 210 | parser.add_argument('--debug', action='store_true') 211 | parser.add_argument('--logfile', type=str, default='') 212 | parser.add_argument('--verbose', action='store_true') 213 | args = parser.parse_args() 214 | 215 | if args.config: 216 | with open(args.config, "r") as stream: 217 | raw_config = yaml.safe_load(stream) 218 | else: 219 | print('No config file provided, please provide a config file with --config') 220 | exit(1) 221 | 222 | setup_logging(args.logfile) 223 | 224 | if args.mode == 'train': 225 | trainer = load_trainer(args, interactive=False) 226 | trainer.training_loop() 227 | elif args.mode == 'test': 228 | tester = load_tester(args, interactive=False) 229 | tester.collect_test_batch() 230 | elif args.mode == 'tournament': 231 | tournament, competitors = load_tournament(args, interactive=False) 232 | print(tournament.run(competitors, interactive=False)) 233 | elif args.mode == 'demo': 234 | demo = load_demo(args) 235 | demo.run(print_state=True, interactive=False) 236 | --------------------------------------------------------------------------------