├── mctx ├── py.typed ├── _src │ ├── __init__.py │ ├── tests │ │ ├── mctx_test.py │ │ ├── qtransforms_test.py │ │ ├── seq_halving_test.py │ │ ├── tree_test.py │ │ └── policies_test.py │ ├── seq_halving.py │ ├── base.py │ ├── qtransforms.py │ ├── action_selection.py │ ├── tree.py │ ├── search.py │ └── policies.py └── __init__.py ├── requirements ├── requirements-test.txt ├── requirements_examples.txt └── requirements.txt ├── MANIFEST.in ├── .gitignore ├── .github └── workflows │ ├── ci.yml │ └── pypi-publish.yml ├── CONTRIBUTING.md ├── test.sh ├── setup.py ├── README.md ├── examples ├── policy_improvement_demo.py └── visualization_demo.py ├── LICENSE ├── .pylintrc └── connect4.ipynb /mctx/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/requirements-test.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.9.0 2 | numpy>=1.18.0 3 | -------------------------------------------------------------------------------- /requirements/requirements_examples.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.9.0 2 | pygraphviz>=1.7 3 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | chex>=0.0.8 2 | jax>=0.4.25 3 | jaxlib>=0.4.25 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements/* 4 | include mctx/py.typed 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | 9 | # Mac OS 10 | .DS_Store 11 | 12 | # Python tools 13 | .mypy_cache/ 14 | .pytype/ 15 | .ipynb_checkpoints 16 | 17 | # Editors 18 | .idea 19 | .vscode 20 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build-and-test: 11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 12 | runs-on: "${{ matrix.os }}" 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.10", "3.11", "3.12"] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: "actions/checkout@v4" 21 | - uses: "actions/setup-python@v4" 22 | with: 23 | python-version: "${{ matrix.python-version }}" 24 | - name: Run CI tests 25 | run: bash test.sh 26 | shell: bash 27 | -------------------------------------------------------------------------------- /mctx/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Check consistency between the package version and release tag 21 | run: | 22 | RELEASE_VER=${GITHUB_REF#refs/*/} 23 | PACKAGE_VER="v`python setup.py --version`" 24 | if [ $RELEASE_VER != $PACKAGE_VER ] 25 | then 26 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 27 | fi 28 | - name: Build and publish 29 | env: 30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 32 | run: | 33 | python setup.py sdist bdist_wheel 34 | twine upload dist/* 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Testing 26 | 27 | Please make sure that your PR passes all tests by running `bash test.sh` on your 28 | local machine. Also, you can run only tests that are affected by your code 29 | changes, but you will need to select them manually. 30 | 31 | ## Community Guidelines 32 | 33 | This project follows [Google's Open Source Community 34 | Guidelines](https://opensource.google.com/conduct/). 35 | -------------------------------------------------------------------------------- /mctx/_src/tests/mctx_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Mctx.""" 16 | 17 | from absl.testing import absltest 18 | import mctx 19 | 20 | 21 | class MctxTest(absltest.TestCase): 22 | """Test mctx can be imported correctly.""" 23 | 24 | def test_import(self): 25 | self.assertTrue(hasattr(mctx, "gumbel_muzero_policy")) 26 | self.assertTrue(hasattr(mctx, "muzero_policy")) 27 | self.assertTrue(hasattr(mctx, "qtransform_by_min_max")) 28 | self.assertTrue(hasattr(mctx, "qtransform_by_parent_and_siblings")) 29 | self.assertTrue(hasattr(mctx, "qtransform_completed_by_mix_value")) 30 | self.assertTrue(hasattr(mctx, "PolicyOutput")) 31 | self.assertTrue(hasattr(mctx, "RootFnOutput")) 32 | self.assertTrue(hasattr(mctx, "RecurrentFnOutput")) 33 | self.assertTrue(hasattr(mctx, "get_subtree")) 34 | self.assertTrue(hasattr(mctx, "reset_search_tree")) 35 | 36 | if __name__ == "__main__": 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /mctx/_src/tests/qtransforms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `qtransforms.py`.""" 16 | from absl.testing import absltest 17 | import jax 18 | import jax.numpy as jnp 19 | from mctx._src import qtransforms 20 | import numpy as np 21 | 22 | 23 | class QtransformsTest(absltest.TestCase): 24 | 25 | def test_mix_value(self): 26 | """Tests the output of _compute_mixed_value().""" 27 | raw_value = jnp.array(-0.8) 28 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf]) 29 | probs = jax.nn.softmax(prior_logits) 30 | visit_counts = jnp.array([0, 4.0, 4.0, 0]) 31 | qvalues = 10.0 / 54 * jnp.array([20.0, 3.0, -1.0, 10.0]) 32 | mix_value = qtransforms._compute_mixed_value( 33 | raw_value, qvalues, visit_counts, probs) 34 | 35 | num_simulations = jnp.sum(visit_counts) 36 | expected_mix_value = 1.0 / (num_simulations + 1) * ( 37 | raw_value + num_simulations * 38 | (probs[1] * qvalues[1] + probs[2] * qvalues[2])) 39 | np.testing.assert_allclose(expected_mix_value, mix_value) 40 | 41 | def test_mix_value_with_zero_visits(self): 42 | """Tests that zero visit counts do not divide by zero.""" 43 | raw_value = jnp.array(-0.8) 44 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf]) 45 | probs = jax.nn.softmax(prior_logits) 46 | visit_counts = jnp.array([0, 0, 0, 0]) 47 | qvalues = jnp.zeros_like(probs) 48 | with jax.debug_nans(): 49 | mix_value = qtransforms._compute_mixed_value( 50 | raw_value, qvalues, visit_counts, probs) 51 | 52 | np.testing.assert_allclose(raw_value, mix_value) 53 | 54 | 55 | if __name__ == "__main__": 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Runs CI tests on a local machine. 17 | set -xeuo pipefail 18 | 19 | # Install deps in a virtual env. 20 | readonly VENV_DIR=/tmp/mctx-env 21 | rm -rf "${VENV_DIR}" 22 | python3 -m venv "${VENV_DIR}" 23 | source "${VENV_DIR}/bin/activate" 24 | python --version 25 | 26 | # Install dependencies. 27 | pip install --upgrade pip setuptools wheel 28 | pip install flake8 pytest-xdist pylint pylint-exit 29 | pip install -r requirements/requirements.txt 30 | pip install -r requirements/requirements-test.txt 31 | 32 | # Lint with flake8. 33 | flake8 `find mctx -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics 34 | 35 | # Lint with pylint. 36 | # Fail on errors, warning, and conventions. 37 | PYLINT_ARGS="-efail -wfail -cfail" 38 | # Lint modules and tests separately. 39 | pylint --rcfile=.pylintrc `find mctx -name '*.py' | grep -v 'test.py' | xargs` || pylint-exit $PYLINT_ARGS $? 40 | # Disable `protected-access` warnings for tests. 41 | pylint --rcfile=.pylintrc `find mctx -name '*_test.py' | xargs` -d W0212 || pylint-exit $PYLINT_ARGS $? 42 | 43 | # Build the package. 44 | python setup.py sdist 45 | pip wheel --verbose --no-deps --no-clean dist/mctx*.tar.gz 46 | pip install mctx*.whl 47 | 48 | # Check types with pytype. 49 | # Note: pytype does not support 3.12 as of 23.11.23 50 | # See https://github.com/google/pytype/issues/1308 51 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 12 ]; 52 | then 53 | pip install pytype 54 | pytype `find mctx/_src/ -name "*py" | xargs` -k 55 | fi; 56 | 57 | # Run tests using pytest. 58 | # Change directory to avoid importing the package from repo root. 59 | mkdir _testing && cd _testing 60 | 61 | # Run tests using pytest. 62 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs mctx 63 | cd .. 64 | 65 | set +u 66 | deactivate 67 | echo "All tests passed. Congrats!" 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Install script for setuptools.""" 16 | 17 | import os 18 | from setuptools import find_namespace_packages 19 | from setuptools import setup 20 | 21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | 23 | 24 | def _get_version(): 25 | with open('mctx/__init__.py') as fp: 26 | for line in fp: 27 | if line.startswith('__version__') and '=' in line: 28 | version = line[line.find('=') + 1:].strip(' \'"\n') 29 | if version: 30 | return version 31 | raise ValueError('`__version__` not defined in `mctx/__init__.py`') 32 | 33 | 34 | def _parse_requirements(path): 35 | 36 | with open(os.path.join(_CURRENT_DIR, path)) as f: 37 | return [ 38 | line.rstrip() 39 | for line in f 40 | if not (line.isspace() or line.startswith('#')) 41 | ] 42 | 43 | 44 | setup( 45 | name='mctx', 46 | version=_get_version(), 47 | url='https://github.com/google-deepmind/mctx', 48 | license='Apache 2.0', 49 | author='DeepMind', 50 | description=('Monte Carlo tree search in JAX.'), 51 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(), 52 | long_description_content_type='text/markdown', 53 | author_email='mctx-dev@google.com', 54 | keywords='jax planning reinforcement-learning python machine learning', 55 | packages=find_namespace_packages(exclude=['*_test.py']), 56 | install_requires=_parse_requirements( 57 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')), 58 | tests_require=_parse_requirements( 59 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-test.txt')), 60 | zip_safe=False, # Required for full installation. 61 | python_requires='>=3.9', 62 | classifiers=[ 63 | 'Development Status :: 4 - Beta', 64 | 'Intended Audience :: Science/Research', 65 | 'License :: OSI Approved :: Apache Software License', 66 | 'Programming Language :: Python', 67 | 'Programming Language :: Python :: 3', 68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 69 | 'Topic :: Software Development :: Libraries :: Python Modules', 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mctx-az 2 | This fork of [google-deepmind/mctx](https://github.com/google-deepmind/mctx) introduces a new feature used in AlphaZero: continuing search from 3 | a subtree of a previous state's search output, or _subtree persistence_. 4 | 5 | This allows Monte Carlo Tree Search to continue from an already-initialized, partially populated search tree. This lets work done in a previous 6 | call to Monte Carlo Tree Search persist to the next call, avoiding lots of repeated work! 7 | 8 | ## Quickstart 9 | mctx-az introduces a new policy: `alphazero_policy` which allows the user to pass a pre-initialized `Tree` to continue the search with. 10 | 11 | Then, `get_subtree` can be used to extract the subtree rooted at a particular child node of the root, corresponding to a taken action. 12 | 13 | In cases where the search tree should not be saved, such as an episdoe terminating, `reset_search_tree` can be used to clear the tree. 14 | 15 | In order to initialize a new tree, pass `tree=None`, to `alphazero_policy`, along with `max_nodes` to specify the capacity of the tree, which in most cases 16 | should be >= `num_simulations`. 17 | 18 | `alphazero_policy` otherwise functions exactly the same as `muzero_policy`. 19 | 20 | #### Initializing a new Tree: 21 | ```python 22 | policy_output = mctx.alphazero_policy(params, rng_key, root, recurrent_fn, 23 | num_simulations=32, tree=None, max_nodes=48) 24 | tree = policy_output.search_tree 25 | ``` 26 | #### Extracting the subtree and continuing search 27 | ```python 28 | # get chosen action from policy output 29 | action = policy_output.action 30 | 31 | # extract the subtree corresponding to the chosen action 32 | tree = mctx.get_subtree(tree, action) 33 | 34 | # go to next environment state 35 | env_state = env.step(env_state, action) 36 | 37 | # reset the search tree where the environment has terminated 38 | tree = mctx.reset_search_tree(tree, env_state.terminated) 39 | 40 | # new search with subtree 41 | # (max_nodes has no effect when a tree is passed) 42 | policy_ouput = mctx.alphazero_policy(params, rng_key, root, recurrent_fn, 43 | num_simulations=32, tree=tree) 44 | ``` 45 | #### Note on out-of-bounds expansions: 46 | A call to any mctx policy will expand `num_simulations` nodes (assuming `max_depth` is not breached). 47 | 48 | Given that `alphazero_policy` accepts a pre-populated `Tree`, it is possible that there will not be enough 49 | room left for `num_simulations` new nodes. 50 | 51 | In the case where a tree is full, values and visit counts are still propagated backwards to all nodes along the visit path 52 | as they would if the expansion was in bounds. However, a new node is not created and stored in the search tree, only its 53 | in-bounds predecessors are updated. 54 | 55 | ## Examples 56 | The mctx readme links to a simple Connect4 example: https://github.com/Carbon225/mctx-classic 57 | 58 | I modified this example to demonstrate the use of `alphazero_policy` and `get_subtree`. You can see it [here](https://github.com/lowrollr/mctx-az/blob/main/connect4.ipynb) 59 | 60 | ## Issues 61 | If you run into problems or need help, please create an Issue and I will do my best to assist you promptly. 62 | -------------------------------------------------------------------------------- /mctx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Mctx: Monte Carlo tree search in JAX.""" 16 | 17 | from mctx._src.action_selection import gumbel_muzero_interior_action_selection 18 | from mctx._src.action_selection import gumbel_muzero_root_action_selection 19 | from mctx._src.action_selection import GumbelMuZeroExtraData 20 | from mctx._src.action_selection import muzero_action_selection 21 | from mctx._src.base import ChanceRecurrentFnOutput 22 | from mctx._src.base import DecisionRecurrentFnOutput 23 | from mctx._src.base import InteriorActionSelectionFn 24 | from mctx._src.base import LoopFn 25 | from mctx._src.base import PolicyOutput 26 | from mctx._src.base import RecurrentFn 27 | from mctx._src.base import RecurrentFnOutput 28 | from mctx._src.base import RecurrentState 29 | from mctx._src.base import RootActionSelectionFn 30 | from mctx._src.base import RootFnOutput 31 | from mctx._src.policies import alphazero_policy 32 | from mctx._src.policies import gumbel_muzero_policy 33 | from mctx._src.policies import muzero_policy 34 | from mctx._src.policies import stochastic_muzero_policy 35 | from mctx._src.qtransforms import qtransform_by_min_max 36 | from mctx._src.qtransforms import qtransform_by_parent_and_siblings 37 | from mctx._src.qtransforms import qtransform_completed_by_mix_value 38 | from mctx._src.search import search 39 | from mctx._src.tree import get_subtree 40 | from mctx._src.tree import reset_search_tree 41 | from mctx._src.tree import Tree 42 | 43 | __version__ = "0.0.5" 44 | 45 | __all__ = ( 46 | "ChanceRecurrentFnOutput", 47 | "DecisionRecurrentFnOutput", 48 | "GumbelMuZeroExtraData", 49 | "InteriorActionSelectionFn", 50 | "LoopFn", 51 | "PolicyOutput", 52 | "RecurrentFn", 53 | "RecurrentFnOutput", 54 | "RecurrentState", 55 | "RootActionSelectionFn", 56 | "RootFnOutput", 57 | "Tree", 58 | "alphazero_policy", 59 | "get_subtree", 60 | "gumbel_muzero_interior_action_selection", 61 | "gumbel_muzero_policy", 62 | "gumbel_muzero_root_action_selection", 63 | "muzero_action_selection", 64 | "muzero_policy", 65 | "qtransform_by_min_max", 66 | "qtransform_by_parent_and_siblings", 67 | "qtransform_completed_by_mix_value", 68 | "reset_search_tree", 69 | "search", 70 | "stochastic_muzero_policy", 71 | ) 72 | 73 | # _________________________________________ 74 | # / Please don't use symbols in `_src` they \ 75 | # \ are not part of the Mctx public API. / 76 | # ----------------------------------------- 77 | # \ ^__^ 78 | # \ (oo)\_______ 79 | # (__)\ )\/\ 80 | # ||----w | 81 | # || || 82 | # 83 | -------------------------------------------------------------------------------- /mctx/_src/seq_halving.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions for Sequential Halving.""" 16 | 17 | import math 18 | 19 | import chex 20 | import jax.numpy as jnp 21 | 22 | 23 | def score_considered(considered_visit, gumbel, logits, normalized_qvalues, 24 | visit_counts): 25 | """Returns a score usable for an argmax.""" 26 | # We allow to visit a child, if it is the only considered child. 27 | low_logit = -1e9 28 | logits = logits - jnp.max(logits, keepdims=True, axis=-1) 29 | penalty = jnp.where( 30 | visit_counts == considered_visit, 31 | 0, -jnp.inf) 32 | chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty]) 33 | return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty 34 | 35 | 36 | def get_sequence_of_considered_visits(max_num_considered_actions, 37 | num_simulations): 38 | """Returns a sequence of visit counts considered by Sequential Halving. 39 | 40 | Sequential Halving is a "pure exploration" algorithm for bandits, introduced 41 | in "Almost Optimal Exploration in Multi-Armed Bandits": 42 | http://proceedings.mlr.press/v28/karnin13.pdf 43 | 44 | The visit counts allows to implement Sequential Halving by selecting the best 45 | action from the actions with the currently considered visit count. 46 | 47 | Args: 48 | max_num_considered_actions: The maximum number of considered actions. 49 | The `max_num_considered_actions` can be smaller than the number of 50 | actions. 51 | num_simulations: The total simulation budget. 52 | 53 | Returns: 54 | A tuple with visit counts. Length `num_simulations`. 55 | """ 56 | if max_num_considered_actions <= 1: 57 | return tuple(range(num_simulations)) 58 | log2max = int(math.ceil(math.log2(max_num_considered_actions))) 59 | sequence = [] 60 | visits = [0] * max_num_considered_actions 61 | num_considered = max_num_considered_actions 62 | while len(sequence) < num_simulations: 63 | num_extra_visits = max(1, int(num_simulations / (log2max * num_considered))) 64 | for _ in range(num_extra_visits): 65 | sequence.extend(visits[:num_considered]) 66 | for i in range(num_considered): 67 | visits[i] += 1 68 | # Halving the number of considered actions. 69 | num_considered = max(2, num_considered // 2) 70 | return tuple(sequence[:num_simulations]) 71 | 72 | 73 | def get_table_of_considered_visits(max_num_considered_actions, num_simulations): 74 | """Returns a table of sequences of visit counts. 75 | 76 | Args: 77 | max_num_considered_actions: The maximum number of considered actions. 78 | The `max_num_considered_actions` can be smaller than the number of 79 | actions. 80 | num_simulations: The total simulation budget. 81 | 82 | Returns: 83 | A tuple of sequences of visit counts. 84 | Shape [max_num_considered_actions + 1, num_simulations]. 85 | """ 86 | return tuple( 87 | get_sequence_of_considered_visits(m, num_simulations) 88 | for m in range(max_num_considered_actions + 1)) 89 | 90 | -------------------------------------------------------------------------------- /mctx/_src/tests/seq_halving_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `seq_halving.py`.""" 16 | from absl.testing import absltest 17 | from mctx._src import seq_halving 18 | 19 | 20 | class SeqHalvingTest(absltest.TestCase): 21 | 22 | def _check_visits(self, expected_results, max_num_considered_actions, 23 | num_simulations): 24 | """Compares the expected results to the returned considered visits.""" 25 | self.assertLen(expected_results, num_simulations) 26 | results = seq_halving.get_sequence_of_considered_visits( 27 | max_num_considered_actions, num_simulations) 28 | self.assertEqual(tuple(expected_results), results) 29 | 30 | def test_considered_min_sims(self): 31 | # Using exactly `num_simulations = max_num_considered_actions * 32 | # log2(max_num_considered_actions)`. 33 | num_sims = 24 34 | max_num_considered = 8 35 | expected_results = [ 36 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions. 37 | 1, 1, 1, 1, # Considering 4 actions. 38 | 2, 2, 2, 2, # Considering 4 actions, round 2. 39 | 3, 3, 4, 4, 5, 5, 6, 6, # Considering 2 actions. 40 | ] # pyformat: disable 41 | self._check_visits(expected_results, max_num_considered, num_sims) 42 | 43 | def test_considered_extra_sims(self): 44 | # Using more simulations than `max_num_considered_actions * 45 | # log2(max_num_considered_actions)`. 46 | num_sims = 47 47 | max_num_considered = 8 48 | expected_results = [ 49 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions. 50 | 1, 1, 1, 1, # Considering 4 actions. 51 | 2, 2, 2, 2, # Considering 4 actions, round 2. 52 | 3, 3, 3, 3, # Considering 4 actions, round 3. 53 | 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 54 | 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 55 | ] # pyformat: disable 56 | self._check_visits(expected_results, max_num_considered, num_sims) 57 | 58 | def test_considered_less_sims(self): 59 | # Using a very small number of simulations. 60 | num_sims = 2 61 | max_num_considered = 8 62 | expected_results = [0, 0] 63 | self._check_visits(expected_results, max_num_considered, num_sims) 64 | 65 | def test_considered_less_sims2(self): 66 | # Using `num_simulations < max_num_considered_actions * 67 | # log2(max_num_considered_actions)`. 68 | num_sims = 13 69 | max_num_considered = 8 70 | expected_results = [ 71 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions. 72 | 1, 1, 1, 1, # Considering 4 actions. 73 | 2, 74 | ] # pyformat: disable 75 | self._check_visits(expected_results, max_num_considered, num_sims) 76 | 77 | def test_considered_not_power_of_2(self): 78 | # Using max_num_considered_actions that is not a power of 2. 79 | num_sims = 24 80 | max_num_considered = 7 81 | expected_results = [ 82 | 0, 0, 0, 0, 0, 0, 0, # Considering 7 actions. 83 | 1, 1, 1, 2, 2, 2, # Considering 3 actions. 84 | 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 85 | ] # pyformat: disable 86 | self._check_visits(expected_results, max_num_considered, num_sims) 87 | 88 | def test_considered_action0(self): 89 | num_sims = 16 90 | max_num_considered = 0 91 | expected_results = range(num_sims) 92 | self._check_visits(expected_results, max_num_considered, num_sims) 93 | 94 | def test_considered_action1(self): 95 | num_sims = 16 96 | max_num_considered = 1 97 | expected_results = range(num_sims) 98 | self._check_visits(expected_results, max_num_considered, num_sims) 99 | 100 | 101 | if __name__ == "__main__": 102 | absltest.main() 103 | -------------------------------------------------------------------------------- /examples/policy_improvement_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A demonstration of the policy improvement by planning with Gumbel.""" 16 | 17 | import functools 18 | from typing import Tuple 19 | 20 | from absl import app 21 | from absl import flags 22 | import chex 23 | import jax 24 | import jax.numpy as jnp 25 | import mctx 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_integer("seed", 42, "Random seed.") 29 | flags.DEFINE_integer("batch_size", 256, "Batch size.") 30 | flags.DEFINE_integer("num_actions", 82, "Number of actions.") 31 | flags.DEFINE_integer("num_simulations", 4, "Number of simulations.") 32 | flags.DEFINE_integer("max_num_considered_actions", 16, 33 | "The maximum number of actions expanded at the root.") 34 | flags.DEFINE_integer("num_runs", 1, "Number of runs on random data.") 35 | 36 | 37 | @chex.dataclass(frozen=True) 38 | class DemoOutput: 39 | prior_policy_value: chex.Array 40 | prior_policy_action_value: chex.Array 41 | selected_action_value: chex.Array 42 | action_weights_policy_value: chex.Array 43 | 44 | 45 | def _run_demo(rng_key: chex.PRNGKey) -> Tuple[chex.PRNGKey, DemoOutput]: 46 | """Runs a search algorithm on random data.""" 47 | batch_size = FLAGS.batch_size 48 | rng_key, logits_rng, q_rng, search_rng = jax.random.split(rng_key, 4) 49 | # We will demonstrate the algorithm on random prior_logits. 50 | # Normally, the prior_logits would be produced by a policy network. 51 | prior_logits = jax.random.normal( 52 | logits_rng, shape=[batch_size, FLAGS.num_actions]) 53 | # Defining a bandit with random Q-values. Only the Q-values of the visited 54 | # actions will be revealed to the search algorithm. 55 | qvalues = jax.random.uniform(q_rng, shape=prior_logits.shape) 56 | # If we know the value under the prior policy, we can use the value to 57 | # complete the missing Q-values. The completed Q-values will produce an 58 | # improved policy in `policy_output.action_weights`. 59 | raw_value = jnp.sum(jax.nn.softmax(prior_logits) * qvalues, axis=-1) 60 | use_mixed_value = False 61 | 62 | # The root output would be the output of MuZero representation network. 63 | root = mctx.RootFnOutput( 64 | prior_logits=prior_logits, 65 | value=raw_value, 66 | # The embedding is used only to implement the MuZero model. 67 | embedding=jnp.zeros([batch_size]), 68 | ) 69 | # The recurrent_fn would be provided by MuZero dynamics network. 70 | recurrent_fn = _make_bandit_recurrent_fn(qvalues) 71 | 72 | # Running the search. 73 | policy_output = mctx.gumbel_muzero_policy( 74 | params=(), 75 | rng_key=search_rng, 76 | root=root, 77 | recurrent_fn=recurrent_fn, 78 | num_simulations=FLAGS.num_simulations, 79 | max_num_considered_actions=FLAGS.max_num_considered_actions, 80 | qtransform=functools.partial( 81 | mctx.qtransform_completed_by_mix_value, 82 | use_mixed_value=use_mixed_value), 83 | ) 84 | 85 | # Collecting the Q-value of the selected action. 86 | selected_action_value = qvalues[jnp.arange(batch_size), policy_output.action] 87 | 88 | # We will compare the selected action to the action selected by the 89 | # prior policy, while using the same Gumbel random numbers. 90 | gumbel = policy_output.search_tree.extra_data.root_gumbel 91 | prior_policy_action = jnp.argmax(gumbel + prior_logits, axis=-1) 92 | prior_policy_action_value = qvalues[jnp.arange(batch_size), 93 | prior_policy_action] 94 | 95 | # Computing the policy value under the new action_weights. 96 | action_weights_policy_value = jnp.sum( 97 | policy_output.action_weights * qvalues, axis=-1) 98 | 99 | output = DemoOutput( 100 | prior_policy_value=raw_value, 101 | prior_policy_action_value=prior_policy_action_value, 102 | selected_action_value=selected_action_value, 103 | action_weights_policy_value=action_weights_policy_value, 104 | ) 105 | return rng_key, output 106 | 107 | 108 | def _make_bandit_recurrent_fn(qvalues): 109 | """Returns a recurrent_fn for a determistic bandit.""" 110 | 111 | def recurrent_fn(params, rng_key, action, embedding): 112 | del params, rng_key 113 | # For the bandit, the reward will be non-zero only at the root. 114 | reward = jnp.where(embedding == 0, 115 | qvalues[jnp.arange(action.shape[0]), action], 116 | 0.0) 117 | # On a single-player environment, use discount from [0, 1]. 118 | # On a zero-sum self-play environment, use discount=-1. 119 | discount = jnp.ones_like(reward) 120 | recurrent_fn_output = mctx.RecurrentFnOutput( 121 | reward=reward, 122 | discount=discount, 123 | prior_logits=jnp.zeros_like(qvalues), 124 | value=jnp.zeros_like(reward)) 125 | next_embedding = embedding + 1 126 | return recurrent_fn_output, next_embedding 127 | 128 | return recurrent_fn 129 | 130 | 131 | def main(_): 132 | rng_key = jax.random.PRNGKey(FLAGS.seed) 133 | jitted_run_demo = jax.jit(_run_demo) 134 | for _ in range(FLAGS.num_runs): 135 | rng_key, output = jitted_run_demo(rng_key) 136 | # Printing the obtained increase of the policy value. 137 | # The obtained increase should be non-negative. 138 | action_value_improvement = ( 139 | output.selected_action_value - output.prior_policy_action_value) 140 | weights_value_improvement = ( 141 | output.action_weights_policy_value - output.prior_policy_value) 142 | print("action value improvement: %.3f (min=%.3f)" % 143 | (action_value_improvement.mean(), action_value_improvement.min())) 144 | print("action_weights value improvement: %.3f (min=%.3f)" % 145 | (weights_value_improvement.mean(), weights_value_improvement.min())) 146 | 147 | 148 | if __name__ == "__main__": 149 | app.run(main) 150 | -------------------------------------------------------------------------------- /mctx/_src/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Core types used in mctx.""" 16 | 17 | from typing import Any, Callable, Generic, TypeVar, Tuple 18 | 19 | import chex 20 | 21 | from mctx._src import tree 22 | 23 | 24 | # Parameters are arbitrary nested structures of `chex.Array`. 25 | # A nested structure is either a single object, or a collection (list, tuple, 26 | # dictionary, etc.) of other nested structures. 27 | Params = chex.ArrayTree 28 | 29 | 30 | # The model used to search is expressed by a `RecurrentFn` function that takes 31 | # `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` and 32 | # the new state embedding. 33 | @chex.dataclass(frozen=True) 34 | class RecurrentFnOutput: 35 | """The output of a `RecurrentFn`. 36 | 37 | reward: `[B]` an approximate reward from the state-action transition. 38 | discount: `[B]` the discount between the `reward` and the `value`. 39 | prior_logits: `[B, num_actions]` the logits produced by a policy network. 40 | value: `[B]` an approximate value of the state after the state-action 41 | transition. 42 | """ 43 | reward: chex.Array 44 | discount: chex.Array 45 | prior_logits: chex.Array 46 | value: chex.Array 47 | 48 | 49 | Action = chex.Array 50 | RecurrentState = Any 51 | RecurrentFn = Callable[ 52 | [Params, chex.PRNGKey, Action, RecurrentState], 53 | Tuple[RecurrentFnOutput, RecurrentState]] 54 | 55 | 56 | @chex.dataclass(frozen=True) 57 | class RootFnOutput: 58 | """The output of a representation network. 59 | 60 | prior_logits: `[B, num_actions]` the logits produced by a policy network. 61 | value: `[B]` an approximate value of the current state. 62 | embedding: `[B, ...]` the inputs to the next `recurrent_fn` call. 63 | """ 64 | prior_logits: chex.Array 65 | value: chex.Array 66 | embedding: RecurrentState 67 | 68 | 69 | # Action selection functions specify how to pick nodes to expand in the tree. 70 | NodeIndices = chex.Array 71 | Depth = chex.Array 72 | RootActionSelectionFn = Callable[ 73 | [chex.PRNGKey, tree.Tree, NodeIndices], chex.Array] 74 | InteriorActionSelectionFn = Callable[ 75 | [chex.PRNGKey, tree.Tree, NodeIndices, Depth], 76 | chex.Array] 77 | QTransform = Callable[[tree.Tree, chex.Array], chex.Array] 78 | # LoopFn has the same interface as jax.lax.fori_loop. 79 | LoopFn = Callable[ 80 | [int, int, Callable[[Any, Any], Any], Tuple[chex.PRNGKey, tree.Tree]], 81 | Tuple[chex.PRNGKey, tree.Tree]] 82 | 83 | T = TypeVar("T") 84 | 85 | 86 | @chex.dataclass(frozen=True) 87 | class PolicyOutput(Generic[T]): 88 | """The output of a policy. 89 | 90 | action: `[B]` the proposed action. 91 | action_weights: `[B, num_actions]` the targets used to train a policy network. 92 | The action weights sum to one. Usually, the policy network is trained by 93 | cross-entropy: 94 | `cross_entropy(labels=stop_gradient(action_weights), logits=prior_logits)`. 95 | search_tree: `[B, ...]` the search tree of the finished search. 96 | """ 97 | action: chex.Array 98 | action_weights: chex.Array 99 | search_tree: tree.Tree[T] 100 | 101 | 102 | @chex.dataclass(frozen=True) 103 | class DecisionRecurrentFnOutput: 104 | """Output of the function for expanding decision nodes. 105 | 106 | Expanding a decision node takes an action and a state embedding and produces 107 | an afterstate, which represents the state of the environment after an action 108 | is taken but before the environment has updated its state. Accordingly, there 109 | is no discount factor or reward for transitioning from state `s` to afterstate 110 | `sa`. 111 | 112 | Attributes: 113 | chance_logits: `[B, C]` logits of `C` chance outcomes at the afterstate. 114 | afterstate_value: `[B]` values of the afterstates `v(sa)`. 115 | """ 116 | chance_logits: chex.Array # [B, C] 117 | afterstate_value: chex.Array # [B] 118 | 119 | 120 | @chex.dataclass(frozen=True) 121 | class ChanceRecurrentFnOutput: 122 | """Output of the function for expanding chance nodes. 123 | 124 | Expanding a chance node takes a chance outcome and an afterstate embedding 125 | and produces a state, which captures a potentially stochastic environment 126 | transition. When this transition occurs reward and discounts are produced as 127 | in a normal transition. 128 | 129 | Attributes: 130 | action_logits: `[B, A]` logits of different actions from the state. 131 | value: `[B]` values of the states `v(s)`. 132 | reward: `[B]` rewards at the states. 133 | discount: `[B]` discounts at the states. 134 | """ 135 | action_logits: chex.Array # [B, A] 136 | value: chex.Array # [B] 137 | reward: chex.Array # [B] 138 | discount: chex.Array # [B] 139 | 140 | 141 | @chex.dataclass(frozen=True) 142 | class StochasticRecurrentState: 143 | """Wrapper that enables different treatment of decision and chance nodes. 144 | 145 | In Stochastic MuZero tree nodes can either be decision or chance nodes, these 146 | nodes are treated differently during expansion, search and backup, and a user 147 | could also pass differently structured embeddings for each type of node. This 148 | wrapper enables treating chance and decision nodes differently and supports 149 | potential differences between chance and decision node structures. 150 | 151 | Attributes: 152 | state_embedding: `[B ...]` an optionally meaningful state embedding. 153 | afterstate_embedding: `[B ...]` an optionally meaningful afterstate 154 | embedding. 155 | is_decision_node: `[B]` whether the node is a decision or chance node. If it 156 | is a decision node, `afterstate_embedding` is a dummy value. If it is a 157 | chance node, `state_embedding` is a dummy value. 158 | """ 159 | state_embedding: chex.ArrayTree # [B, ...] 160 | afterstate_embedding: chex.ArrayTree # [B, ...] 161 | is_decision_node: chex.Array # [B] 162 | 163 | 164 | RecurrentState = chex.ArrayTree 165 | 166 | DecisionRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState], 167 | Tuple[DecisionRecurrentFnOutput, RecurrentState]] 168 | 169 | ChanceRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState], 170 | Tuple[ChanceRecurrentFnOutput, RecurrentState]] 171 | -------------------------------------------------------------------------------- /mctx/_src/qtransforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Monotonic transformations for the Q-values.""" 16 | 17 | import chex 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from mctx._src import tree as tree_lib 22 | 23 | 24 | def qtransform_by_min_max( 25 | tree: tree_lib.Tree, 26 | node_index: chex.Numeric, 27 | *, 28 | min_value: chex.Numeric, 29 | max_value: chex.Numeric, 30 | ) -> chex.Array: 31 | """Returns Q-values normalized by the given `min_value` and `max_value`. 32 | 33 | Args: 34 | tree: _unbatched_ MCTS tree state. 35 | node_index: scalar index of the parent node. 36 | min_value: given minimum value. Usually the `min_value` is minimum possible 37 | untransformed Q-value. 38 | max_value: given maximum value. Usually the `max_value` is maximum possible 39 | untransformed Q-value. 40 | 41 | Returns: 42 | Q-values normalized by `(qvalues - min_value) / (max_value - min_value)`. 43 | The unvisited actions will have zero Q-value. Shape `[num_actions]`. 44 | """ 45 | chex.assert_shape(node_index, ()) 46 | qvalues = tree.qvalues(node_index) 47 | visit_counts = tree.children_visits[node_index] 48 | value_score = jnp.where(visit_counts > 0, qvalues, min_value) 49 | value_score = (value_score - min_value) / ((max_value - min_value)) 50 | return value_score 51 | 52 | 53 | def qtransform_by_parent_and_siblings( 54 | tree: tree_lib.Tree, 55 | node_index: chex.Numeric, 56 | *, 57 | epsilon: chex.Numeric = 1e-8, 58 | ) -> chex.Array: 59 | """Returns qvalues normalized by min, max over V(node) and qvalues. 60 | 61 | Args: 62 | tree: _unbatched_ MCTS tree state. 63 | node_index: scalar index of the parent node. 64 | epsilon: the minimum denominator for the normalization. 65 | 66 | Returns: 67 | Q-values normalized to be from the [0, 1] interval. The unvisited actions 68 | will have zero Q-value. Shape `[num_actions]`. 69 | """ 70 | chex.assert_shape(node_index, ()) 71 | qvalues = tree.qvalues(node_index) 72 | visit_counts = tree.children_visits[node_index] 73 | chex.assert_rank([qvalues, visit_counts, node_index], [1, 1, 0]) 74 | node_value = tree.node_values[node_index] 75 | safe_qvalues = jnp.where(visit_counts > 0, qvalues, node_value) 76 | chex.assert_equal_shape([safe_qvalues, qvalues]) 77 | min_value = jnp.minimum(node_value, jnp.min(safe_qvalues, axis=-1)) 78 | max_value = jnp.maximum(node_value, jnp.max(safe_qvalues, axis=-1)) 79 | 80 | completed_by_min = jnp.where(visit_counts > 0, qvalues, min_value) 81 | normalized = (completed_by_min - min_value) / ( 82 | jnp.maximum(max_value - min_value, epsilon)) 83 | chex.assert_equal_shape([normalized, qvalues]) 84 | return normalized 85 | 86 | 87 | def qtransform_completed_by_mix_value( 88 | tree: tree_lib.Tree, 89 | node_index: chex.Numeric, 90 | *, 91 | value_scale: chex.Numeric = 0.1, 92 | maxvisit_init: chex.Numeric = 50.0, 93 | rescale_values: bool = True, 94 | use_mixed_value: bool = True, 95 | epsilon: chex.Numeric = 1e-8, 96 | ) -> chex.Array: 97 | """Returns completed qvalues. 98 | 99 | The missing Q-values of the unvisited actions are replaced by the 100 | mixed value, defined in Appendix D of 101 | "Policy improvement by planning with Gumbel": 102 | https://openreview.net/forum?id=bERaNdoegnO 103 | 104 | The Q-values are transformed by a linear transformation: 105 | `(maxvisit_init + max(visit_counts)) * value_scale * qvalues`. 106 | 107 | Args: 108 | tree: _unbatched_ MCTS tree state. 109 | node_index: scalar index of the parent node. 110 | value_scale: scale for the Q-values. 111 | maxvisit_init: offset to the `max(visit_counts)` in the scaling factor. 112 | rescale_values: if True, scale the qvalues by `1 / (max_q - min_q)`. 113 | use_mixed_value: if True, complete the Q-values with mixed value, 114 | otherwise complete the Q-values with the raw value. 115 | epsilon: the minimum denominator when using `rescale_values`. 116 | 117 | Returns: 118 | Completed Q-values. Shape `[num_actions]`. 119 | """ 120 | chex.assert_shape(node_index, ()) 121 | qvalues = tree.qvalues(node_index) 122 | visit_counts = tree.children_visits[node_index] 123 | 124 | # Computing the mixed value and producing completed_qvalues. 125 | raw_value = tree.raw_values[node_index] 126 | prior_probs = jax.nn.softmax( 127 | tree.children_prior_logits[node_index]) 128 | if use_mixed_value: 129 | value = _compute_mixed_value( 130 | raw_value, 131 | qvalues=qvalues, 132 | visit_counts=visit_counts, 133 | prior_probs=prior_probs) 134 | else: 135 | value = raw_value 136 | completed_qvalues = _complete_qvalues( 137 | qvalues, visit_counts=visit_counts, value=value) 138 | 139 | # Scaling the Q-values. 140 | if rescale_values: 141 | completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon) 142 | maxvisit = jnp.max(visit_counts, axis=-1) 143 | visit_scale = maxvisit_init + maxvisit 144 | return visit_scale * value_scale * completed_qvalues 145 | 146 | 147 | def _rescale_qvalues(qvalues, epsilon): 148 | """Rescales the given completed Q-values to be from the [0, 1] interval.""" 149 | min_value = jnp.min(qvalues, axis=-1, keepdims=True) 150 | max_value = jnp.max(qvalues, axis=-1, keepdims=True) 151 | return (qvalues - min_value) / jnp.maximum(max_value - min_value, epsilon) 152 | 153 | 154 | def _complete_qvalues(qvalues, *, visit_counts, value): 155 | """Returns completed Q-values, with the `value` for unvisited actions.""" 156 | chex.assert_equal_shape([qvalues, visit_counts]) 157 | chex.assert_shape(value, []) 158 | 159 | # The missing qvalues are replaced by the value. 160 | completed_qvalues = jnp.where( 161 | visit_counts > 0, 162 | qvalues, 163 | value) 164 | chex.assert_equal_shape([completed_qvalues, qvalues]) 165 | return completed_qvalues 166 | 167 | 168 | def _compute_mixed_value(raw_value, qvalues, visit_counts, prior_probs): 169 | """Interpolates the raw_value and weighted qvalues. 170 | 171 | Args: 172 | raw_value: an approximate value of the state. Shape `[]`. 173 | qvalues: Q-values for all actions. Shape `[num_actions]`. The unvisited 174 | actions have undefined Q-value. 175 | visit_counts: the visit counts for all actions. Shape `[num_actions]`. 176 | prior_probs: the action probabilities, produced by the policy network for 177 | each action. Shape `[num_actions]`. 178 | 179 | Returns: 180 | An estimator of the state value. Shape `[]`. 181 | """ 182 | sum_visit_counts = jnp.sum(visit_counts, axis=-1) 183 | # Ensuring non-nan weighted_q, even if the visited actions have zero 184 | # prior probability. 185 | prior_probs = jnp.maximum(jnp.finfo(prior_probs.dtype).tiny, prior_probs) 186 | # Summing the probabilities of the visited actions. 187 | sum_probs = jnp.sum(jnp.where(visit_counts > 0, prior_probs, 0.0), 188 | axis=-1) 189 | weighted_q = jnp.sum(jnp.where( 190 | visit_counts > 0, 191 | prior_probs * qvalues / jnp.where(visit_counts > 0, sum_probs, 1.0), 192 | 0.0), axis=-1) 193 | return (raw_value + sum_visit_counts * weighted_q) / (sum_visit_counts + 1) 194 | -------------------------------------------------------------------------------- /examples/visualization_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A demo of Graphviz visualization of a search tree.""" 16 | 17 | from typing import Optional, Sequence 18 | 19 | from absl import app 20 | from absl import flags 21 | import chex 22 | import jax 23 | import jax.numpy as jnp 24 | import mctx 25 | import pygraphviz 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_integer("seed", 42, "Random seed.") 29 | flags.DEFINE_integer("num_simulations", 32, "Number of simulations.") 30 | flags.DEFINE_integer("max_num_considered_actions", 16, 31 | "The maximum number of actions expanded at the root.") 32 | flags.DEFINE_integer("max_depth", None, "The maximum search depth.") 33 | flags.DEFINE_string("output_file", "/tmp/search_tree.png", 34 | "The output file for the visualization.") 35 | 36 | 37 | def convert_tree_to_graph( 38 | tree: mctx.Tree, 39 | action_labels: Optional[Sequence[str]] = None, 40 | batch_index: int = 0 41 | ) -> pygraphviz.AGraph: 42 | """Converts a search tree into a Graphviz graph. 43 | 44 | Args: 45 | tree: A `Tree` containing a batch of search data. 46 | action_labels: Optional labels for edges, defaults to the action index. 47 | batch_index: Index of the batch element to plot. 48 | 49 | Returns: 50 | A Graphviz graph representation of `tree`. 51 | """ 52 | chex.assert_rank(tree.node_values, 2) 53 | batch_size = tree.node_values.shape[0] 54 | if action_labels is None: 55 | action_labels = range(tree.num_actions) 56 | elif len(action_labels) != tree.num_actions: 57 | raise ValueError( 58 | f"action_labels {action_labels} has the wrong number of actions " 59 | f"({len(action_labels)}). " 60 | f"Expecting {tree.num_actions}.") 61 | 62 | def node_to_str(node_i, reward=0, discount=1): 63 | return (f"{node_i}\n" 64 | f"Reward: {reward:.2f}\n" 65 | f"Discount: {discount:.2f}\n" 66 | f"Value: {tree.node_values[batch_index, node_i]:.2f}\n" 67 | f"Visits: {tree.node_visits[batch_index, node_i]}\n") 68 | 69 | def edge_to_str(node_i, a_i): 70 | node_index = jnp.full([batch_size], node_i) 71 | probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i]) 72 | return (f"{action_labels[a_i]}\n" 73 | f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations 74 | f"p: {probs[a_i]:.2f}\n") 75 | 76 | graph = pygraphviz.AGraph(directed=True) 77 | 78 | # Add root 79 | graph.add_node(0, label=node_to_str(node_i=0), color="green") 80 | # Add all other nodes and connect them up. 81 | for node_i in range(tree.num_simulations): 82 | for a_i in range(tree.num_actions): 83 | # Index of children, or -1 if not expanded 84 | children_i = tree.children_index[batch_index, node_i, a_i] 85 | if children_i >= 0: 86 | graph.add_node( 87 | children_i, 88 | label=node_to_str( 89 | node_i=children_i, 90 | reward=tree.children_rewards[batch_index, node_i, a_i], 91 | discount=tree.children_discounts[batch_index, node_i, a_i]), 92 | color="red") 93 | graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i)) 94 | 95 | return graph 96 | 97 | 98 | def _run_demo(rng_key: chex.PRNGKey): 99 | """Runs a search algorithm on a toy environment.""" 100 | # We will define a deterministic toy environment. 101 | # The deterministic `transition_matrix` has shape `[num_states, num_actions]`. 102 | # The `transition_matrix[s, a]` holds the next state. 103 | transition_matrix = jnp.array([ 104 | [1, 2, 3, 4], 105 | [0, 5, 0, 0], 106 | [0, 0, 0, 6], 107 | [0, 0, 0, 0], 108 | [0, 0, 0, 0], 109 | [0, 0, 0, 0], 110 | [0, 0, 0, 0], 111 | ], dtype=jnp.int32) 112 | # The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]` 113 | # holds the reward for that (s, a) pair. 114 | rewards = jnp.array([ 115 | [1, -1, 0, 0], 116 | [0, 0, 0, 0], 117 | [0, 0, 0, 0], 118 | [0, 0, 0, 0], 119 | [0, 0, 0, 0], 120 | [0, 0, 0, 0], 121 | [10, 0, 20, 0], 122 | ], dtype=jnp.float32) 123 | num_states = rewards.shape[0] 124 | # The discount for each (s, a) pair. 125 | discounts = jnp.where(transition_matrix > 0, 1.0, 0.0) 126 | # Using optimistic initial values to encourage exploration. 127 | values = jnp.full([num_states], 15.0) 128 | # The prior policies for each state. 129 | all_prior_logits = jnp.zeros_like(rewards) 130 | root, recurrent_fn = _make_batched_env_model( 131 | # Using batch_size=2 to test the batched search. 132 | batch_size=2, 133 | transition_matrix=transition_matrix, 134 | rewards=rewards, 135 | discounts=discounts, 136 | values=values, 137 | prior_logits=all_prior_logits) 138 | 139 | # Running the search. 140 | policy_output = mctx.gumbel_muzero_policy( 141 | params=(), 142 | rng_key=rng_key, 143 | root=root, 144 | recurrent_fn=recurrent_fn, 145 | num_simulations=FLAGS.num_simulations, 146 | max_depth=FLAGS.max_depth, 147 | max_num_considered_actions=FLAGS.max_num_considered_actions, 148 | ) 149 | return policy_output 150 | 151 | 152 | def _make_batched_env_model( 153 | batch_size: int, 154 | *, 155 | transition_matrix: chex.Array, 156 | rewards: chex.Array, 157 | discounts: chex.Array, 158 | values: chex.Array, 159 | prior_logits: chex.Array): 160 | """Returns a batched `(root, recurrent_fn)`.""" 161 | chex.assert_equal_shape([transition_matrix, rewards, discounts, 162 | prior_logits]) 163 | num_states, num_actions = transition_matrix.shape 164 | chex.assert_shape(values, [num_states]) 165 | # We will start the search at state zero. 166 | root_state = 0 167 | root = mctx.RootFnOutput( 168 | prior_logits=jnp.full([batch_size, num_actions], 169 | prior_logits[root_state]), 170 | value=jnp.full([batch_size], values[root_state]), 171 | # The embedding will hold the state index. 172 | embedding=jnp.zeros([batch_size], dtype=jnp.int32), 173 | ) 174 | 175 | def recurrent_fn(params, rng_key, action, embedding): 176 | del params, rng_key 177 | chex.assert_shape(action, [batch_size]) 178 | chex.assert_shape(embedding, [batch_size]) 179 | recurrent_fn_output = mctx.RecurrentFnOutput( 180 | reward=rewards[embedding, action], 181 | discount=discounts[embedding, action], 182 | prior_logits=prior_logits[embedding], 183 | value=values[embedding]) 184 | next_embedding = transition_matrix[embedding, action] 185 | return recurrent_fn_output, next_embedding 186 | 187 | return root, recurrent_fn 188 | 189 | 190 | def main(_): 191 | rng_key = jax.random.PRNGKey(FLAGS.seed) 192 | jitted_run_demo = jax.jit(_run_demo) 193 | print("Starting search.") 194 | policy_output = jitted_run_demo(rng_key) 195 | batch_index = 0 196 | selected_action = policy_output.action[batch_index] 197 | q_value = policy_output.search_tree.summary().qvalues[ 198 | batch_index, selected_action] 199 | print("Selected action:", selected_action) 200 | # To estimate the value of the root state, use the Q-value of the selected 201 | # action. The Q-value is not affected by the exploration at the root node. 202 | print("Selected action Q-value:", q_value) 203 | graph = convert_tree_to_graph(policy_output.search_tree) 204 | print("Saving tree diagram to:", FLAGS.output_file) 205 | graph.draw(FLAGS.output_file, prog="dot") 206 | 207 | 208 | if __name__ == "__main__": 209 | app.run(main) 210 | -------------------------------------------------------------------------------- /mctx/_src/action_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A collection of action selection functions.""" 16 | from typing import Optional, TypeVar 17 | 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from mctx._src import base 23 | from mctx._src import qtransforms 24 | from mctx._src import seq_halving 25 | from mctx._src import tree as tree_lib 26 | 27 | 28 | def switching_action_selection_wrapper( 29 | root_action_selection_fn: base.RootActionSelectionFn, 30 | interior_action_selection_fn: base.InteriorActionSelectionFn 31 | ) -> base.InteriorActionSelectionFn: 32 | """Wraps root and interior action selection fns in a conditional statement.""" 33 | 34 | def switching_action_selection_fn( 35 | rng_key: chex.PRNGKey, 36 | tree: tree_lib.Tree, 37 | node_index: base.NodeIndices, 38 | depth: base.Depth) -> chex.Array: 39 | return jax.lax.cond( 40 | depth == 0, 41 | lambda x: root_action_selection_fn(*x[:3]), 42 | lambda x: interior_action_selection_fn(*x), 43 | (rng_key, tree, node_index, depth)) 44 | 45 | return switching_action_selection_fn 46 | 47 | 48 | def muzero_action_selection( 49 | rng_key: chex.PRNGKey, 50 | tree: tree_lib.Tree, 51 | node_index: chex.Numeric, 52 | depth: chex.Numeric, 53 | *, 54 | pb_c_init: float = 1.25, 55 | pb_c_base: float = 19652.0, 56 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 57 | ) -> chex.Array: 58 | """Returns the action selected for a node index. 59 | 60 | See Appendix B in https://arxiv.org/pdf/1911.08265.pdf for more details. 61 | 62 | Args: 63 | rng_key: random number generator state. 64 | tree: _unbatched_ MCTS tree state. 65 | node_index: scalar index of the node from which to select an action. 66 | depth: the scalar depth of the current node. The root has depth zero. 67 | pb_c_init: constant c_1 in the PUCT formula. 68 | pb_c_base: constant c_2 in the PUCT formula. 69 | qtransform: a monotonic transformation to convert the Q-values to [0, 1]. 70 | 71 | Returns: 72 | action: the action selected from the given node. 73 | """ 74 | visit_counts = tree.children_visits[node_index] 75 | node_visit = tree.node_visits[node_index] 76 | pb_c = pb_c_init + jnp.log((node_visit + pb_c_base + 1.) / pb_c_base) 77 | prior_logits = tree.children_prior_logits[node_index] 78 | prior_probs = jax.nn.softmax(prior_logits) 79 | policy_score = jnp.sqrt(node_visit) * pb_c * prior_probs / (visit_counts + 1) 80 | chex.assert_shape([node_index, node_visit], ()) 81 | chex.assert_equal_shape([prior_probs, visit_counts, policy_score]) 82 | value_score = qtransform(tree, node_index) 83 | 84 | # Add tiny bit of randomness for tie break 85 | node_noise_score = 1e-7 * jax.random.uniform( 86 | rng_key, (tree.num_actions,)) 87 | to_argmax = value_score + policy_score + node_noise_score 88 | 89 | # Masking the invalid actions at the root. 90 | return masked_argmax(to_argmax, tree.root_invalid_actions * (depth == 0)) 91 | 92 | 93 | @chex.dataclass(frozen=True) 94 | class GumbelMuZeroExtraData: 95 | """Extra data for Gumbel MuZero search.""" 96 | root_gumbel: chex.Array 97 | 98 | 99 | GumbelMuZeroExtraDataType = TypeVar( # pylint: disable=invalid-name 100 | "GumbelMuZeroExtraDataType", bound=GumbelMuZeroExtraData) 101 | 102 | 103 | def gumbel_muzero_root_action_selection( 104 | rng_key: chex.PRNGKey, 105 | tree: tree_lib.Tree[GumbelMuZeroExtraDataType], 106 | node_index: chex.Numeric, 107 | *, 108 | num_simulations: chex.Numeric, 109 | max_num_considered_actions: chex.Numeric, 110 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value, 111 | ) -> chex.Array: 112 | """Returns the action selected by Sequential Halving with Gumbel. 113 | 114 | Initially, we sample `max_num_considered_actions` actions without replacement. 115 | From these, the actions with the highest `gumbel + logits + qvalues` are 116 | visited first. 117 | 118 | Args: 119 | rng_key: random number generator state. 120 | tree: _unbatched_ MCTS tree state. 121 | node_index: scalar index of the node from which to take an action. 122 | num_simulations: the simulation budget. 123 | max_num_considered_actions: the number of actions sampled without 124 | replacement. 125 | qtransform: a monotonic transformation for the Q-values. 126 | 127 | Returns: 128 | action: the action selected from the given node. 129 | """ 130 | del rng_key 131 | chex.assert_shape([node_index], ()) 132 | visit_counts = tree.children_visits[node_index] 133 | prior_logits = tree.children_prior_logits[node_index] 134 | chex.assert_equal_shape([visit_counts, prior_logits]) 135 | completed_qvalues = qtransform(tree, node_index) 136 | 137 | table = jnp.array(seq_halving.get_table_of_considered_visits( 138 | max_num_considered_actions, num_simulations)) 139 | num_valid_actions = jnp.sum( 140 | 1 - tree.root_invalid_actions, axis=-1).astype(jnp.int32) 141 | num_considered = jnp.minimum( 142 | max_num_considered_actions, num_valid_actions) 143 | chex.assert_shape(num_considered, ()) 144 | # At the root, the simulation_index is equal to the sum of visit counts. 145 | simulation_index = jnp.sum(visit_counts, -1) 146 | chex.assert_shape(simulation_index, ()) 147 | considered_visit = table[num_considered, simulation_index] 148 | chex.assert_shape(considered_visit, ()) 149 | gumbel = tree.extra_data.root_gumbel 150 | to_argmax = seq_halving.score_considered( 151 | considered_visit, gumbel, prior_logits, completed_qvalues, 152 | visit_counts) 153 | 154 | # Masking the invalid actions at the root. 155 | return masked_argmax(to_argmax, tree.root_invalid_actions) 156 | 157 | 158 | def gumbel_muzero_interior_action_selection( 159 | rng_key: chex.PRNGKey, 160 | tree: tree_lib.Tree, 161 | node_index: chex.Numeric, 162 | depth: chex.Numeric, 163 | *, 164 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value, 165 | ) -> chex.Array: 166 | """Selects the action with a deterministic action selection. 167 | 168 | The action is selected based on the visit counts to produce visitation 169 | frequencies similar to softmax(prior_logits + qvalues). 170 | 171 | Args: 172 | rng_key: random number generator state. 173 | tree: _unbatched_ MCTS tree state. 174 | node_index: scalar index of the node from which to take an action. 175 | depth: the scalar depth of the current node. The root has depth zero. 176 | qtransform: function to obtain completed Q-values for a node. 177 | 178 | Returns: 179 | action: the action selected from the given node. 180 | """ 181 | del rng_key, depth 182 | chex.assert_shape([node_index], ()) 183 | visit_counts = tree.children_visits[node_index] 184 | prior_logits = tree.children_prior_logits[node_index] 185 | chex.assert_equal_shape([visit_counts, prior_logits]) 186 | completed_qvalues = qtransform(tree, node_index) 187 | 188 | # The `prior_logits + completed_qvalues` provide an improved policy, 189 | # because the missing qvalues are replaced by v_{prior_logits}(node). 190 | to_argmax = _prepare_argmax_input( 191 | probs=jax.nn.softmax(prior_logits + completed_qvalues), 192 | visit_counts=visit_counts) 193 | 194 | chex.assert_rank(to_argmax, 1) 195 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32) 196 | 197 | 198 | def masked_argmax( 199 | to_argmax: chex.Array, 200 | invalid_actions: Optional[chex.Array]) -> chex.Array: 201 | """Returns a valid action with the highest `to_argmax`.""" 202 | if invalid_actions is not None: 203 | chex.assert_equal_shape([to_argmax, invalid_actions]) 204 | # The usage of the -inf inside the argmax does not lead to NaN. 205 | # Do not use -inf inside softmax, logsoftmax or cross-entropy. 206 | to_argmax = jnp.where(invalid_actions, -jnp.inf, to_argmax) 207 | # If all actions are invalid, the argmax returns action 0. 208 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32) 209 | 210 | 211 | def _prepare_argmax_input(probs, visit_counts): 212 | """Prepares the input for the deterministic selection. 213 | 214 | When calling argmax(_prepare_argmax_input(...)) multiple times 215 | with updated visit_counts, the produced visitation frequencies will 216 | approximate the probs. 217 | 218 | For the derivation, see Section 5 "Planning at non-root nodes" in 219 | "Policy improvement by planning with Gumbel": 220 | https://openreview.net/forum?id=bERaNdoegnO 221 | 222 | Args: 223 | probs: a policy or an improved policy. Shape `[num_actions]`. 224 | visit_counts: the existing visit counts. Shape `[num_actions]`. 225 | 226 | Returns: 227 | The input to an argmax. Shape `[num_actions]`. 228 | """ 229 | chex.assert_equal_shape([probs, visit_counts]) 230 | to_argmax = probs - visit_counts / ( 231 | 1 + jnp.sum(visit_counts, keepdims=True, axis=-1)) 232 | return to_argmax 233 | -------------------------------------------------------------------------------- /mctx/_src/tests/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A unit test comparing the search tree to an expected search tree.""" 16 | # pylint: disable=use-dict-literal 17 | import functools 18 | import json 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from absl import logging 25 | from absl.testing import absltest, parameterized 26 | 27 | import mctx 28 | 29 | jax.config.update("jax_threefry_partitionable", False) 30 | 31 | 32 | def _prepare_root(batch_size, num_actions): 33 | """Returns a root consistent with the stored expected trees.""" 34 | rng_key = jax.random.PRNGKey(0) 35 | # Using a different rng_key inside each batch element. 36 | rng_keys = [rng_key] 37 | for i in range(1, batch_size): 38 | rng_keys.append(jax.random.fold_in(rng_key, i)) 39 | embedding = jnp.stack(rng_keys) 40 | output = jax.vmap( 41 | functools.partial(_produce_prediction_output, num_actions=num_actions))( 42 | embedding) 43 | return mctx.RootFnOutput( 44 | prior_logits=output["policy_logits"], 45 | value=output["value"], 46 | embedding=embedding, 47 | ) 48 | 49 | 50 | def _produce_prediction_output(rng_key, num_actions): 51 | """Producing the model output as in the stored expected trees.""" 52 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3) 53 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3) 54 | del rng_key 55 | # Producing value from [-1, +1). 56 | value = jax.random.uniform(value_rng, shape=(), minval=-1.0, maxval=1.0) 57 | # Producing reward from [-1, +1). 58 | reward = jax.random.uniform(reward_rng, shape=(), minval=-1.0, maxval=1.0) 59 | return dict( 60 | policy_logits=jax.random.normal(policy_rng, shape=[num_actions]), 61 | value=value, 62 | reward=reward, 63 | ) 64 | 65 | 66 | def _prepare_recurrent_fn(num_actions, *, discount, zero_reward): 67 | """Returns a dynamics function consistent with the expected trees.""" 68 | 69 | def recurrent_fn(params, rng_key, action, embedding): 70 | del params, rng_key 71 | # The embeddings serve as rng_keys. 72 | embedding = jax.vmap( 73 | functools.partial(_fold_action_in, num_actions=num_actions))(embedding, 74 | action) 75 | output = jax.vmap( 76 | functools.partial(_produce_prediction_output, num_actions=num_actions))( 77 | embedding) 78 | reward = output["reward"] 79 | if zero_reward: 80 | reward = jnp.zeros_like(reward) 81 | return mctx.RecurrentFnOutput( 82 | reward=reward, 83 | discount=jnp.full_like(reward, discount), 84 | prior_logits=output["policy_logits"], 85 | value=output["value"], 86 | ), embedding 87 | 88 | return recurrent_fn 89 | 90 | 91 | def _fold_action_in(rng_key, action, num_actions): 92 | """Returns a new rng key, selected by the given action.""" 93 | chex.assert_shape(action, ()) 94 | chex.assert_type(action, jnp.int32) 95 | sub_rngs = jax.random.split(rng_key, num_actions) 96 | return sub_rngs[action] 97 | 98 | 99 | def tree_to_pytree(tree: mctx.Tree, batch_i: int = 0): 100 | """Converts the MCTS tree to nested dicts.""" 101 | nodes = {} 102 | if tree.node_visits[batch_i, 0] == 0: 103 | # The root node is unvisited, so there is no tree. 104 | return _create_bare_root(prior=1.0) 105 | nodes[0] = _create_pynode( 106 | tree, batch_i, 0, prior=1.0, action=None, reward=None) 107 | children_prior_probs = jax.nn.softmax(tree.children_prior_logits, axis=-1) 108 | for node_i in range(tree.num_simulations + 1): 109 | # Return early if we reach an unvisited node 110 | visits = tree.node_visits[batch_i, node_i] 111 | if visits == 0: 112 | return nodes[0] 113 | for a_i in range(tree.num_actions): 114 | prior = children_prior_probs[batch_i, node_i, a_i] 115 | # Index of children, or -1 if not expanded 116 | child_i = int(tree.children_index[batch_i, node_i, a_i]) 117 | if child_i >= 0: 118 | reward = tree.children_rewards[batch_i, node_i, a_i] 119 | child = _create_pynode( 120 | tree, batch_i, child_i, prior=prior, action=a_i, reward=reward) 121 | nodes[child_i] = child 122 | else: 123 | child = _create_bare_pynode(prior=prior, action=a_i) 124 | # pylint: disable=line-too-long 125 | nodes[node_i]["child_stats"].append(child) # pytype: disable=attribute-error 126 | # pylint: enable=line-too-long 127 | return nodes[0] 128 | 129 | 130 | def _create_pynode(tree, batch_i, node_i, prior, action, reward): 131 | """Returns a dict with extracted search statistics.""" 132 | node = dict( 133 | prior=_round_float(prior), 134 | visit=int(tree.node_visits[batch_i, node_i]), 135 | value_view=_round_float(tree.node_values[batch_i, node_i]), 136 | raw_value_view=_round_float(tree.raw_values[batch_i, node_i]), 137 | child_stats=[], 138 | evaluation_index=node_i, 139 | ) 140 | if action is not None: 141 | node["action"] = action 142 | if reward is not None: 143 | node["reward"] = _round_float(reward) 144 | return node 145 | 146 | 147 | def _create_bare_pynode(prior, action): 148 | return dict( 149 | prior=_round_float(prior), 150 | child_stats=[], 151 | action=action, 152 | ) 153 | 154 | 155 | def _create_bare_root(prior): 156 | return dict( 157 | prior=_round_float(prior), 158 | child_stats=[], 159 | ) 160 | 161 | 162 | def _round_float(value, ndigits=10): 163 | return round(float(value), ndigits) 164 | 165 | 166 | class TreeTest(parameterized.TestCase): 167 | # Make sure to adjust the `shard_count` parameter in the build file to match 168 | # the number of parameter configurations passed to test_tree. 169 | # pylint: disable=line-too-long 170 | MUZERO_TREES = [("muzero_norescale", 171 | "./mctx/_src/tests/test_data/muzero_tree.json"), 172 | ("muzero_qtransform", 173 | "./mctx/_src/tests/test_data/muzero_qtransform_tree.json")] 174 | GUMBEL_MUZERO_TREES = [("gumbel_muzero_norescale", 175 | "./mctx/_src/tests/test_data/gumbel_muzero_tree.json"), 176 | ("gumbel_muzero_reward", 177 | "./mctx/_src/tests/test_data/gumbel_muzero_reward_tree.json")] 178 | TREES = MUZERO_TREES + GUMBEL_MUZERO_TREES 179 | # pylint: enable=line-too-long 180 | 181 | @parameterized.named_parameters(*TREES) 182 | def test_tree(self, tree_data_path): 183 | with open(tree_data_path, "rb") as fd: 184 | tree = json.load(fd) 185 | reproduced, _ = self._reproduce_tree(tree) 186 | chex.assert_trees_all_close(tree["tree"], reproduced, atol=1e-3) 187 | 188 | 189 | @parameterized.named_parameters(*MUZERO_TREES) 190 | def test_subtree(self, tree_data_path): 191 | with open(tree_data_path, "rb") as fd: 192 | tree = json.load(fd) 193 | reproduced, jax_tree = self._reproduce_tree(tree) 194 | 195 | def rm_evaluation_index(node): 196 | if isinstance(node, dict): 197 | node.pop("evaluation_index", None) 198 | for child in node["child_stats"]: 199 | rm_evaluation_index(child) 200 | 201 | # test populated subtree 202 | for child_tree in reproduced["child_stats"]: 203 | action = child_tree["action"] 204 | # reflect that the chosen child node is now a root node 205 | child_tree["prior"] = 1.0 206 | child_tree.pop("action") 207 | if "reward" in child_tree: 208 | child_tree.pop("reward") 209 | 210 | subtree = mctx.get_subtree(jax_tree, jnp.array([action, action, action])) 211 | reproduced_subtree = tree_to_pytree(subtree) 212 | 213 | # evaluation indices will not match since subtree indices are 214 | # collapsed down so check everything but evaluation indices 215 | rm_evaluation_index(child_tree) 216 | rm_evaluation_index(reproduced_subtree) 217 | 218 | chex.assert_trees_all_close(reproduced_subtree, child_tree, atol=1e-3) 219 | 220 | 221 | def _reproduce_tree(self, tree): 222 | """Reproduces the given JSON tree by running a search.""" 223 | policy_fn = dict( 224 | gumbel_muzero=mctx.gumbel_muzero_policy, 225 | muzero=mctx.muzero_policy, 226 | )[tree["algorithm"]] 227 | 228 | env_config = tree["env_config"] 229 | root = tree["tree"] 230 | num_actions = len(root["child_stats"]) 231 | num_simulations = root["visit"] - 1 232 | qtransform = functools.partial( 233 | getattr(mctx, tree["algorithm_config"].pop("qtransform")), 234 | **tree["algorithm_config"].pop("qtransform_kwargs", {})) 235 | 236 | batch_size = 3 237 | # To test the independence of the batch computation, we use different 238 | # invalid actions for the other elements of the batch. The different batch 239 | # elements will then have different search tree depths. 240 | invalid_actions = np.zeros([batch_size, num_actions]) 241 | invalid_actions[1, 1:] = 1 242 | invalid_actions[2, 2:] = 1 243 | 244 | def run_policy(): 245 | return policy_fn( 246 | params=(), 247 | rng_key=jax.random.PRNGKey(1), 248 | root=_prepare_root(batch_size=batch_size, num_actions=num_actions), 249 | recurrent_fn=_prepare_recurrent_fn(num_actions, **env_config), 250 | num_simulations=num_simulations, 251 | qtransform=qtransform, 252 | invalid_actions=invalid_actions, 253 | **tree["algorithm_config"]) 254 | 255 | policy_output = jax.jit(run_policy)() # pylint: disable=not-callable 256 | logging.info("Done search.") 257 | 258 | return tree_to_pytree(policy_output.search_tree), policy_output.search_tree 259 | 260 | 261 | if __name__ == "__main__": 262 | jax.config.update("jax_numpy_rank_promotion", "raise") 263 | absltest.main() 264 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /mctx/_src/tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A data structure used to hold / inspect search data for a batch of inputs.""" 16 | 17 | from __future__ import annotations 18 | from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | @chex.dataclass(frozen=True) 29 | class Tree(Generic[T]): 30 | """State of a search tree. 31 | 32 | The `Tree` dataclass is used to hold and inspect search data for a batch of 33 | inputs. In the fields below `B` denotes the batch dimension, `N` represents 34 | the number of nodes in the tree, and `num_actions` is the number of discrete 35 | actions. 36 | 37 | node_visits: `[B, N]` the visit counts for each node. 38 | raw_values: `[B, N]` the raw value for each node. 39 | node_values: `[B, N]` the cumulative search value for each node. 40 | parents: `[B, N]` the node index for the parents for each node. 41 | action_from_parent: `[B, N]` action to take from the parent to reach each 42 | node. 43 | children_index: `[B, N, num_actions]` the node index of the children for each 44 | action. 45 | children_prior_logits: `[B, N, num_actions]` the action prior logits of each 46 | node. 47 | children_visits: `[B, N, num_actions]` the visit counts for children for 48 | each action. 49 | children_rewards: `[B, N, num_actions]` the immediate reward for each action. 50 | children_discounts: `[B, N, num_actions]` the discount between the 51 | `children_rewards` and the `children_values`. 52 | children_values: `[B, N, num_actions]` the value of the next node after the 53 | action. 54 | next_node_index: `[B]` the next free index where a new node can be inserted. 55 | embeddings: `[B, N, ...]` the state embeddings of each node. 56 | root_invalid_actions: `[B, num_actions]` a mask with invalid actions at the 57 | root. In the mask, invalid actions have ones, and valid actions have zeros. 58 | extra_data: `[B, ...]` extra data passed to the search. 59 | """ 60 | node_visits: chex.Array # [B, N] 61 | raw_values: chex.Array # [B, N] 62 | node_values: chex.Array # [B, N] 63 | parents: chex.Array # [B, N] 64 | action_from_parent: chex.Array # [B, N] 65 | children_index: chex.Array # [B, N, num_actions] 66 | children_prior_logits: chex.Array # [B, N, num_actions] 67 | children_visits: chex.Array # [B, N, num_actions] 68 | children_rewards: chex.Array # [B, N, num_actions] 69 | children_discounts: chex.Array # [B, N, num_actions] 70 | children_values: chex.Array # [B, N, num_actions] 71 | next_node_index: chex.Array # [B] 72 | embeddings: Any # [B, N, ...] 73 | root_invalid_actions: chex.Array # [B, num_actions] 74 | extra_data: T # [B, ...] 75 | 76 | # The following attributes are class variables (and should not be set on 77 | # Tree instances). 78 | ROOT_INDEX: ClassVar[int] = 0 79 | NO_PARENT: ClassVar[int] = -1 80 | UNVISITED: ClassVar[int] = -1 81 | 82 | @property 83 | def num_actions(self): 84 | return self.children_index.shape[-1] 85 | 86 | @property 87 | def num_simulations(self): 88 | return self.node_visits.shape[-1] - 1 89 | 90 | def qvalues(self, indices): 91 | """Compute q-values for any node indices in the tree.""" 92 | # pytype: disable=wrong-arg-types # jnp-type 93 | if jnp.asarray(indices).shape: 94 | return jax.vmap(_unbatched_qvalues)(self, indices) 95 | else: 96 | return _unbatched_qvalues(self, indices) 97 | # pytype: enable=wrong-arg-types 98 | 99 | def summary(self) -> SearchSummary: 100 | """Extract summary statistics for the root node.""" 101 | # Get state and action values for the root nodes. 102 | chex.assert_rank(self.node_values, 2) 103 | value = self.node_values[:, Tree.ROOT_INDEX] 104 | batch_size, = value.shape 105 | root_indices = jnp.full((batch_size,), Tree.ROOT_INDEX) 106 | qvalues = self.qvalues(root_indices) 107 | # Extract visit counts and induced probabilities for the root nodes. 108 | visit_counts = self.children_visits[:, Tree.ROOT_INDEX].astype(value.dtype) 109 | total_counts = jnp.sum(visit_counts, axis=-1, keepdims=True) 110 | visit_probs = visit_counts / jnp.maximum(total_counts, 1) 111 | visit_probs = jnp.where(total_counts > 0, visit_probs, 1 / self.num_actions) 112 | # Return relevant stats. 113 | return SearchSummary( # pytype: disable=wrong-arg-types # numpy-scalars 114 | visit_counts=visit_counts, 115 | visit_probs=visit_probs, 116 | value=value, 117 | qvalues=qvalues) 118 | 119 | 120 | def infer_batch_size(tree: Tree) -> int: 121 | """Recovers batch size from `Tree` data structure.""" 122 | if tree.node_values.ndim != 2: 123 | raise ValueError("Input tree is not batched.") 124 | chex.assert_equal_shape_prefix(jax.tree_util.tree_leaves(tree), 1) 125 | return tree.node_values.shape[0] 126 | 127 | 128 | # A number of aggregate statistics and predictions are extracted from the 129 | # search data and returned to the user for further processing. 130 | @chex.dataclass(frozen=True) 131 | class SearchSummary: 132 | """Stats from MCTS search.""" 133 | visit_counts: chex.Array 134 | visit_probs: chex.Array 135 | value: chex.Array 136 | qvalues: chex.Array 137 | 138 | 139 | def _unbatched_qvalues(tree: Tree, index: int) -> int: 140 | chex.assert_rank(tree.children_discounts, 2) 141 | return ( # pytype: disable=bad-return-type # numpy-scalars 142 | tree.children_rewards[index] 143 | + tree.children_discounts[index] * tree.children_values[index] 144 | ) 145 | 146 | 147 | def _get_translation( 148 | tree: Tree, 149 | child_index: chex.Array 150 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 151 | subtrees = jnp.arange(tree.num_simulations+1) 152 | 153 | def propagate_fun(_, subtrees): 154 | parents_subtrees = jnp.where( 155 | tree.parents != tree.NO_PARENT, 156 | subtrees[tree.parents], 157 | 0 158 | ) 159 | return jnp.where( 160 | jnp.greater(parents_subtrees, 0), 161 | parents_subtrees, 162 | subtrees 163 | ) 164 | 165 | subtrees = jax.lax.fori_loop(0, tree.num_simulations, propagate_fun, subtrees) 166 | slots_aranged = jnp.arange(tree.num_simulations+1) 167 | subtree_master_idx = tree.children_index[tree.ROOT_INDEX, child_index] 168 | nodes_to_retain = subtrees == subtree_master_idx 169 | old_subtree_idxs = nodes_to_retain * slots_aranged 170 | cumsum = jnp.cumsum(nodes_to_retain) 171 | new_next_node_index = cumsum[-1] 172 | 173 | translation = jnp.where( 174 | nodes_to_retain, 175 | nodes_to_retain * (cumsum-1), 176 | tree.UNVISITED 177 | ) 178 | erase_idxs = slots_aranged >= new_next_node_index 179 | 180 | return old_subtree_idxs, translation, erase_idxs 181 | 182 | 183 | @jax.vmap 184 | def get_subtree( 185 | tree: Tree, 186 | child_index: chex.Array 187 | ) -> Tree: 188 | """Extracts subtrees rooted at child indices of the root node, 189 | across a batch of trees. Converts node index mappings and collapses 190 | node data so that populated nodes are contiguous and start at index 0. 191 | 192 | Assumes `tree` elements and `child_index` have a batch dimension. 193 | 194 | Args: 195 | tree: the tree to extract subtrees from 196 | child_index: `[B]` the index of the child (from the root) to extract each 197 | subtree from 198 | """ 199 | # get mapping from old node indices to new node indices 200 | # and a mask of which nodes indices to erase 201 | old_subtree_idxs, translation, erase_idxs = _get_translation( 202 | tree, child_index) 203 | new_next_node_index = translation.max(axis=-1) + 1 204 | 205 | def translate(x, null_value=0): 206 | return jnp.where( 207 | erase_idxs.reshape((-1,) + (1,) * (x.ndim - 1)), 208 | jnp.full_like(x, null_value), 209 | # cases where translation == -1 will set last index 210 | # but since we are at least removing the root node 211 | # (and making one of its children the new root) 212 | # the last index will always be freed 213 | # and overwritten with zeros 214 | x.at[translation].set(x[old_subtree_idxs]), 215 | ) 216 | 217 | def translate_idx(x, null_value=tree.UNVISITED): 218 | return jnp.where( 219 | erase_idxs.reshape((-1,) + (1,) * (x.ndim - 1)), 220 | jnp.full_like(x, null_value), 221 | # in this case we need to explicitly check for index 222 | # mappings to UNVISITED, since otherwise thsese will 223 | # map to the value of the last index of the translation 224 | x.at[translation].set(jnp.where( 225 | x == null_value, 226 | null_value, 227 | translation[x]))) 228 | 229 | def translate_pytree(x, null_value=0): 230 | return jax.tree.map( 231 | lambda t: translate(t, null_value=null_value), x) 232 | 233 | return tree.replace( 234 | node_visits=translate(tree.node_visits), 235 | raw_values=translate(tree.raw_values), 236 | node_values=translate(tree.node_values), 237 | parents=translate_idx(tree.parents), 238 | action_from_parent=translate( 239 | tree.action_from_parent, 240 | null_value=tree.NO_PARENT).at[tree.ROOT_INDEX].set(tree.NO_PARENT), 241 | children_index=translate_idx(tree.children_index), 242 | children_prior_logits=translate(tree.children_prior_logits), 243 | children_visits=translate(tree.children_visits), 244 | children_rewards=translate(tree.children_rewards), 245 | children_discounts=translate(tree.children_discounts), 246 | children_values=translate(tree.children_values), 247 | next_node_index=new_next_node_index, 248 | root_invalid_actions=jnp.zeros_like(tree.root_invalid_actions), 249 | embeddings=translate_pytree(tree.embeddings) 250 | ) 251 | 252 | 253 | def reset_search_tree( 254 | tree: Tree, 255 | select_batch: Optional[chex.Array] = None) -> Tree: 256 | """Fills search tree with default values for selected batches. 257 | 258 | Useful for resetting the search tree after a terminated episode. 259 | 260 | Args: 261 | tree: the tree to reset 262 | select_batch: `[B]` a boolean mask to select which batch elements to reset. 263 | If `None`, all batch elements are reset. 264 | """ 265 | if select_batch is None: 266 | select_batch = jnp.ones(tree.node_visits.shape[0], dtype=bool) 267 | 268 | return tree.replace( 269 | node_visits=tree.node_visits.at[select_batch].set(0), 270 | raw_values=tree.raw_values.at[select_batch].set(0), 271 | node_values=tree.node_values.at[select_batch].set(0), 272 | parents=tree.parents.at[select_batch].set(tree.NO_PARENT), 273 | action_from_parent=tree.action_from_parent.at[select_batch].set( 274 | tree.NO_PARENT), 275 | children_index=tree.children_index.at[select_batch].set(tree.UNVISITED), 276 | children_prior_logits=tree.children_prior_logits.at[select_batch].set(0), 277 | children_values=tree.children_values.at[select_batch].set(0), 278 | children_visits=tree.children_visits.at[select_batch].set(0), 279 | children_rewards=tree.children_rewards.at[select_batch].set(0), 280 | children_discounts=tree.children_discounts.at[select_batch].set(0), 281 | next_node_index=tree.next_node_index.at[select_batch].set(1), 282 | embeddings=jax.tree_util.tree_map( 283 | lambda t: t.at[select_batch].set(0), 284 | tree.embeddings), 285 | root_invalid_actions=tree.root_invalid_actions.at[select_batch].set(0) 286 | # extra_data is always overwritten by a call to search() 287 | ) 288 | -------------------------------------------------------------------------------- /mctx/_src/tests/policies_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `policies.py`.""" 16 | import functools 17 | 18 | from absl.testing import absltest 19 | import jax 20 | import jax.numpy as jnp 21 | import mctx 22 | from mctx._src import policies 23 | import numpy as np 24 | 25 | jax.config.update("jax_threefry_partitionable", False) 26 | 27 | 28 | def _make_bandit_recurrent_fn(rewards, dummy_embedding=()): 29 | """Returns a recurrent_fn with discount=0.""" 30 | 31 | def recurrent_fn(params, rng_key, action, embedding): 32 | del params, rng_key, embedding 33 | reward = rewards[jnp.arange(action.shape[0]), action] 34 | return mctx.RecurrentFnOutput( 35 | reward=reward, 36 | discount=jnp.zeros_like(reward), 37 | prior_logits=jnp.zeros_like(rewards), 38 | value=jnp.zeros_like(reward), 39 | ), dummy_embedding 40 | 41 | return recurrent_fn 42 | 43 | 44 | def _make_bandit_decision_and_chance_fns(rewards, num_chance_outcomes): 45 | 46 | def decision_recurrent_fn(params, rng_key, action, embedding): 47 | del params, rng_key 48 | batch_size = action.shape[0] 49 | reward = rewards[jnp.arange(batch_size), action] 50 | dummy_chance_logits = jnp.full([batch_size, num_chance_outcomes], 51 | -jnp.inf).at[:, 0].set(1.0) 52 | afterstate_embedding = (action, embedding) 53 | return mctx.DecisionRecurrentFnOutput( 54 | chance_logits=dummy_chance_logits, 55 | afterstate_value=jnp.zeros_like(reward)), afterstate_embedding 56 | 57 | def chance_recurrent_fn(params, rng_key, chance_outcome, 58 | afterstate_embedding): 59 | del params, rng_key, chance_outcome 60 | afterstate_action, embedding = afterstate_embedding 61 | batch_size = afterstate_action.shape[0] 62 | 63 | reward = rewards[jnp.arange(batch_size), afterstate_action] 64 | return mctx.ChanceRecurrentFnOutput( 65 | action_logits=jnp.zeros_like(rewards), 66 | value=jnp.zeros_like(reward), 67 | discount=jnp.zeros_like(reward), 68 | reward=reward), embedding 69 | 70 | return decision_recurrent_fn, chance_recurrent_fn 71 | 72 | 73 | def _get_deepest_leaf(tree, node_index): 74 | """Returns `(leaf, depth)` with maximum depth and visit count. 75 | 76 | Args: 77 | tree: _unbatched_ MCTS tree state. 78 | node_index: the node of the inspected subtree. 79 | 80 | Returns: 81 | `(leaf, depth)` of a deepest leaf. If multiple leaves have the same depth, 82 | the leaf with the highest visit count is returned. 83 | """ 84 | np.testing.assert_equal(len(tree.children_index.shape), 2) 85 | leaf = node_index 86 | max_found_depth = 0 87 | for action in range(tree.children_index.shape[-1]): 88 | next_node_index = tree.children_index[node_index, action] 89 | if next_node_index != tree.UNVISITED: 90 | found_leaf, found_depth = _get_deepest_leaf(tree, next_node_index) 91 | if ((1 + found_depth, tree.node_visits[found_leaf]) > 92 | (max_found_depth, tree.node_visits[leaf])): 93 | leaf = found_leaf 94 | max_found_depth = 1 + found_depth 95 | return leaf, max_found_depth 96 | 97 | 98 | class PoliciesTest(absltest.TestCase): 99 | 100 | def test_apply_temperature_one(self): 101 | """Tests temperature=1.""" 102 | logits = jnp.arange(6, dtype=jnp.float32) 103 | new_logits = policies._apply_temperature(logits, temperature=1.0) 104 | np.testing.assert_allclose(logits - logits.max(), new_logits) 105 | 106 | def test_apply_temperature_two(self): 107 | """Tests temperature=2.""" 108 | logits = jnp.arange(6, dtype=jnp.float32) 109 | temperature = 2.0 110 | new_logits = policies._apply_temperature(logits, temperature) 111 | np.testing.assert_allclose((logits - logits.max()) / temperature, 112 | new_logits) 113 | 114 | def test_apply_temperature_zero(self): 115 | """Tests temperature=0.""" 116 | logits = jnp.arange(4, dtype=jnp.float32) 117 | new_logits = policies._apply_temperature(logits, temperature=0.0) 118 | np.testing.assert_allclose( 119 | jnp.array([-2.552118e+38, -1.701412e+38, -8.507059e+37, 0.0]), 120 | new_logits, 121 | rtol=1e-3) 122 | 123 | def test_apply_temperature_zero_on_large_logits(self): 124 | """Tests temperature=0 on large logits.""" 125 | logits = jnp.array([100.0, 3.4028235e+38, -jnp.inf, -3.4028235e+38]) 126 | new_logits = policies._apply_temperature(logits, temperature=0.0) 127 | np.testing.assert_allclose( 128 | jnp.array([-jnp.inf, 0.0, -jnp.inf, -jnp.inf]), new_logits) 129 | 130 | def test_mask_invalid_actions(self): 131 | """Tests action masking.""" 132 | logits = jnp.array([1e6, -jnp.inf, 1e6 + 1, -100.0]) 133 | invalid_actions = jnp.array([0.0, 1.0, 0.0, 1.0]) 134 | masked_logits = policies._mask_invalid_actions( 135 | logits, invalid_actions) 136 | valid_probs = jax.nn.softmax(jnp.array([0.0, 1.0])) 137 | np.testing.assert_allclose( 138 | jnp.array([valid_probs[0], 0.0, valid_probs[1], 0.0]), 139 | jax.nn.softmax(masked_logits)) 140 | 141 | def test_mask_all_invalid_actions(self): 142 | """Tests a state with no valid action.""" 143 | logits = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, -jnp.inf]) 144 | invalid_actions = jnp.array([1.0, 1.0, 1.0, 1.0]) 145 | masked_logits = policies._mask_invalid_actions( 146 | logits, invalid_actions) 147 | np.testing.assert_allclose( 148 | jnp.array([0.25, 0.25, 0.25, 0.25]), 149 | jax.nn.softmax(masked_logits)) 150 | 151 | def test_muzero_policy(self): 152 | root = mctx.RootFnOutput( 153 | prior_logits=jnp.array([ 154 | [-1.0, 0.0, 2.0, 3.0], 155 | ]), 156 | value=jnp.array([0.0]), 157 | embedding=(), 158 | ) 159 | rewards = jnp.zeros_like(root.prior_logits) 160 | invalid_actions = jnp.array([ 161 | [0.0, 0.0, 0.0, 1.0], 162 | ]) 163 | 164 | policy_output = mctx.muzero_policy( 165 | params=(), 166 | rng_key=jax.random.PRNGKey(0), 167 | root=root, 168 | recurrent_fn=_make_bandit_recurrent_fn(rewards), 169 | num_simulations=1, 170 | invalid_actions=invalid_actions, 171 | dirichlet_fraction=0.0) 172 | expected_action = jnp.array([2], dtype=jnp.int32) 173 | np.testing.assert_array_equal(expected_action, policy_output.action) 174 | expected_action_weights = jnp.array([ 175 | [0.0, 0.0, 1.0, 0.0], 176 | ]) 177 | np.testing.assert_allclose(expected_action_weights, 178 | policy_output.action_weights) 179 | 180 | def test_gumbel_muzero_policy(self): 181 | root_value = jnp.array([-5.0]) 182 | root = mctx.RootFnOutput( 183 | prior_logits=jnp.array([ 184 | [0.0, -1.0, 2.0, 3.0], 185 | ]), 186 | value=root_value, 187 | embedding=(), 188 | ) 189 | rewards = jnp.array([ 190 | [20.0, 3.0, -1.0, 10.0], 191 | ]) 192 | invalid_actions = jnp.array([ 193 | [1.0, 0.0, 0.0, 1.0], 194 | ]) 195 | 196 | value_scale = 0.05 197 | maxvisit_init = 60 198 | num_simulations = 17 199 | max_depth = 3 200 | qtransform = functools.partial( 201 | mctx.qtransform_completed_by_mix_value, 202 | value_scale=value_scale, 203 | maxvisit_init=maxvisit_init, 204 | rescale_values=True) 205 | policy_output = mctx.gumbel_muzero_policy( 206 | params=(), 207 | rng_key=jax.random.PRNGKey(0), 208 | root=root, 209 | recurrent_fn=_make_bandit_recurrent_fn(rewards), 210 | num_simulations=num_simulations, 211 | invalid_actions=invalid_actions, 212 | max_depth=max_depth, 213 | qtransform=qtransform, 214 | gumbel_scale=1.0) 215 | # Testing the action. 216 | expected_action = jnp.array([1], dtype=jnp.int32) 217 | np.testing.assert_array_equal(expected_action, policy_output.action) 218 | 219 | # Testing the action_weights. 220 | probs = jax.nn.softmax(jnp.where( 221 | invalid_actions, -jnp.inf, root.prior_logits)) 222 | mix_value = 1.0 / (num_simulations + 1) * (root_value + num_simulations * ( 223 | probs[:, 1] * rewards[:, 1] + probs[:, 2] * rewards[:, 2])) 224 | 225 | completed_qvalues = jnp.array([ 226 | [mix_value[0], rewards[0, 1], rewards[0, 2], mix_value[0]], 227 | ]) 228 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True) 229 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True) 230 | total_value_scale = (maxvisit_init + np.ceil(num_simulations / 2) 231 | ) * value_scale 232 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / ( 233 | max_value - min_value) 234 | expected_action_weights = jax.nn.softmax( 235 | jnp.where(invalid_actions, 236 | -jnp.inf, 237 | root.prior_logits + rescaled_qvalues)) 238 | np.testing.assert_allclose(expected_action_weights, 239 | policy_output.action_weights, 240 | atol=1e-6) 241 | 242 | # Testing the visit_counts. 243 | summary = policy_output.search_tree.summary() 244 | expected_visit_counts = jnp.array( 245 | [[0.0, np.ceil(num_simulations / 2), num_simulations // 2, 0.0]]) 246 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts) 247 | 248 | # Testing max_depth. 249 | leaf, max_found_depth = _get_deepest_leaf( 250 | jax.tree.map(lambda x: x[0], policy_output.search_tree), 251 | policy_output.search_tree.ROOT_INDEX) 252 | self.assertEqual(max_depth, max_found_depth) 253 | self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf]) 254 | 255 | def test_gumbel_muzero_policy_without_invalid_actions(self): 256 | root_value = jnp.array([-5.0]) 257 | root = mctx.RootFnOutput( 258 | prior_logits=jnp.array([ 259 | [0.0, -1.0, 2.0, 3.0], 260 | ]), 261 | value=root_value, 262 | embedding=(), 263 | ) 264 | rewards = jnp.array([ 265 | [20.0, 3.0, -1.0, 10.0], 266 | ]) 267 | 268 | value_scale = 0.05 269 | maxvisit_init = 60 270 | num_simulations = 17 271 | max_depth = 3 272 | qtransform = functools.partial( 273 | mctx.qtransform_completed_by_mix_value, 274 | value_scale=value_scale, 275 | maxvisit_init=maxvisit_init, 276 | rescale_values=True) 277 | policy_output = mctx.gumbel_muzero_policy( 278 | params=(), 279 | rng_key=jax.random.PRNGKey(0), 280 | root=root, 281 | recurrent_fn=_make_bandit_recurrent_fn(rewards), 282 | num_simulations=num_simulations, 283 | invalid_actions=None, 284 | max_depth=max_depth, 285 | qtransform=qtransform, 286 | gumbel_scale=1.0) 287 | # Testing the action. 288 | expected_action = jnp.array([3], dtype=jnp.int32) 289 | np.testing.assert_array_equal(expected_action, policy_output.action) 290 | 291 | # Testing the action_weights. 292 | summary = policy_output.search_tree.summary() 293 | completed_qvalues = rewards 294 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True) 295 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True) 296 | total_value_scale = (maxvisit_init + summary.visit_counts.max() 297 | ) * value_scale 298 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / ( 299 | max_value - min_value) 300 | expected_action_weights = jax.nn.softmax( 301 | root.prior_logits + rescaled_qvalues) 302 | np.testing.assert_allclose(expected_action_weights, 303 | policy_output.action_weights, 304 | atol=1e-6) 305 | 306 | # Testing the visit_counts. 307 | expected_visit_counts = jnp.array( 308 | [[6, 2, 2, 7]]) 309 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts) 310 | 311 | def test_stochastic_muzero_policy(self): 312 | """Tests that SMZ is equivalent to MZ with a dummy chance function.""" 313 | root = mctx.RootFnOutput( 314 | prior_logits=jnp.array([ 315 | [-1.0, 0.0, 2.0, 3.0], 316 | [0.0, 2.0, 5.0, -4.0], 317 | ]), 318 | value=jnp.array([1.0, 0.0]), 319 | embedding=jnp.zeros([2, 4]) 320 | ) 321 | rewards = jnp.zeros_like(root.prior_logits) 322 | invalid_actions = jnp.array([ 323 | [0.0, 0.0, 0.0, 1.0], 324 | [1.0, 0.0, 1.0, 0.0], 325 | ]) 326 | 327 | num_simulations = 10 328 | 329 | policy_output = mctx.muzero_policy( 330 | params=(), 331 | rng_key=jax.random.PRNGKey(0), 332 | root=root, 333 | recurrent_fn=_make_bandit_recurrent_fn( 334 | rewards, 335 | dummy_embedding=jnp.zeros_like(root.embedding)), 336 | num_simulations=num_simulations, 337 | invalid_actions=invalid_actions, 338 | dirichlet_fraction=0.0) 339 | 340 | num_chance_outcomes = 5 341 | 342 | decision_rec_fn, chance_rec_fn = _make_bandit_decision_and_chance_fns( 343 | rewards, num_chance_outcomes) 344 | 345 | stochastic_policy_output = mctx.stochastic_muzero_policy( 346 | params=(), 347 | rng_key=jax.random.PRNGKey(0), 348 | root=root, 349 | decision_recurrent_fn=decision_rec_fn, 350 | chance_recurrent_fn=chance_rec_fn, 351 | num_simulations=2 * num_simulations, 352 | invalid_actions=invalid_actions, 353 | dirichlet_fraction=0.0) 354 | 355 | np.testing.assert_array_equal(stochastic_policy_output.action, 356 | policy_output.action) 357 | 358 | np.testing.assert_allclose(stochastic_policy_output.action_weights, 359 | policy_output.action_weights) 360 | 361 | 362 | if __name__ == "__main__": 363 | absltest.main() 364 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=abstract-method, 54 | apply-builtin, 55 | arguments-differ, 56 | attribute-defined-outside-init, 57 | backtick, 58 | bad-option-value, 59 | basestring-builtin, 60 | buffer-builtin, 61 | c-extension-no-member, 62 | consider-using-enumerate, 63 | cmp-builtin, 64 | cmp-method, 65 | coerce-builtin, 66 | coerce-method, 67 | delslice-method, 68 | div-method, 69 | duplicate-code, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | inconsistent-return-statements, 84 | input-builtin, 85 | intern-builtin, 86 | invalid-str-codec, 87 | locally-disabled, 88 | long-builtin, 89 | long-suffix, 90 | map-builtin-not-iterating, 91 | misplaced-comparison-constant, 92 | missing-function-docstring, 93 | metaclass-assignment, 94 | next-method-called, 95 | next-method-defined, 96 | no-absolute-import, 97 | no-else-break, 98 | no-else-continue, 99 | no-else-raise, 100 | no-else-return, 101 | no-init, # added 102 | no-member, 103 | no-name-in-module, 104 | no-self-use, 105 | nonzero-method, 106 | oct-method, 107 | old-division, 108 | old-ne-operator, 109 | old-octal-literal, 110 | old-raise-syntax, 111 | parameter-unpacking, 112 | print-statement, 113 | raising-string, 114 | range-builtin-not-iterating, 115 | raw_input-builtin, 116 | rdiv-method, 117 | reduce-builtin, 118 | relative-import, 119 | reload-builtin, 120 | round-builtin, 121 | setslice-method, 122 | signature-differs, 123 | standarderror-builtin, 124 | suppressed-message, 125 | sys-max-int, 126 | too-few-public-methods, 127 | too-many-ancestors, 128 | too-many-arguments, 129 | too-many-boolean-expressions, 130 | too-many-branches, 131 | too-many-instance-attributes, 132 | too-many-locals, 133 | too-many-nested-blocks, 134 | too-many-public-methods, 135 | too-many-return-statements, 136 | too-many-statements, 137 | trailing-newlines, 138 | unichr-builtin, 139 | unicode-builtin, 140 | unnecessary-pass, 141 | unpacking-in-except, 142 | useless-else-on-loop, 143 | useless-object-inheritance, 144 | useless-suppression, 145 | using-cmp-argument, 146 | wrong-import-order, 147 | xrange-builtin, 148 | zip-builtin-not-iterating, 149 | 150 | 151 | [REPORTS] 152 | 153 | # Set the output format. Available formats are text, parseable, colorized, msvs 154 | # (visual studio) and html. You can also give a reporter class, eg 155 | # mypackage.mymodule.MyReporterClass. 156 | output-format=text 157 | 158 | # Tells whether to display a full report or only the messages 159 | reports=no 160 | 161 | # Python expression which should return a note less than 10 (10 is the highest 162 | # note). You have access to the variables errors warning, statement which 163 | # respectively contain the number of errors / warnings messages and the total 164 | # number of statements analyzed. This is used by the global evaluation report 165 | # (RP0004). 166 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 167 | 168 | # Template used to display messages. This is a python new-style format string 169 | # used to format the message information. See doc for all details 170 | #msg-template= 171 | 172 | 173 | [BASIC] 174 | 175 | # Good variable names which should always be accepted, separated by a comma 176 | good-names=main,_ 177 | 178 | # Bad variable names which should always be refused, separated by a comma 179 | bad-names= 180 | 181 | # Colon-delimited sets of names that determine each other's naming style when 182 | # the name regexes allow several styles. 183 | name-group= 184 | 185 | # Include a hint for the correct naming format with invalid-name 186 | include-naming-hint=no 187 | 188 | # List of decorators that produce properties, such as abc.abstractproperty. Add 189 | # to this list to register other decorators that produce valid properties. 190 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 191 | 192 | # Regular expression matching correct function names 193 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 194 | 195 | # Regular expression matching correct variable names 196 | variable-rgx=^[a-z][a-z0-9_]*$ 197 | 198 | # Regular expression matching correct constant names 199 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 200 | 201 | # Regular expression matching correct attribute names 202 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 203 | 204 | # Regular expression matching correct argument names 205 | argument-rgx=^[a-z][a-z0-9_]*$ 206 | 207 | # Regular expression matching correct class attribute names 208 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 209 | 210 | # Regular expression matching correct inline iteration names 211 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 212 | 213 | # Regular expression matching correct class names 214 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 215 | 216 | # Regular expression matching correct module names 217 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 218 | 219 | # Regular expression matching correct method names 220 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 221 | 222 | # Regular expression which should only match function or class names that do 223 | # not require a docstring. 224 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 225 | 226 | # Minimum line length for functions/classes that require docstrings, shorter 227 | # ones are exempt. 228 | docstring-min-length=10 229 | 230 | 231 | [TYPECHECK] 232 | 233 | # List of decorators that produce context managers, such as 234 | # contextlib.contextmanager. Add to this list to register other decorators that 235 | # produce valid context managers. 236 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 237 | 238 | # Tells whether missing members accessed in mixin class should be ignored. A 239 | # mixin class is detected if its name ends with "mixin" (case insensitive). 240 | ignore-mixin-members=yes 241 | 242 | # List of module names for which member attributes should not be checked 243 | # (useful for modules/projects where namespaces are manipulated during runtime 244 | # and thus existing member attributes cannot be deduced by static analysis. It 245 | # supports qualified module names, as well as Unix pattern matching. 246 | ignored-modules= 247 | 248 | # List of class names for which member attributes should not be checked (useful 249 | # for classes with dynamically set attributes). This supports the use of 250 | # qualified names. 251 | ignored-classes=optparse.Values,thread._local,_thread._local 252 | 253 | # List of members which are set dynamically and missed by pylint inference 254 | # system, and so shouldn't trigger E1101 when accessed. Python regular 255 | # expressions are accepted. 256 | generated-members= 257 | 258 | 259 | [FORMAT] 260 | 261 | # Maximum number of characters on a single line. 262 | max-line-length=80 263 | 264 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 265 | # lines made too long by directives to pytype. 266 | 267 | # Regexp for a line that is allowed to be longer than the limit. 268 | ignore-long-lines=(?x)( 269 | ^\s*(\#\ )??$| 270 | ^\s*(from\s+\S+\s+)?import\s+.+$) 271 | 272 | # Allow the body of an if to be on the same line as the test if there is no 273 | # else. 274 | single-line-if-stmt=yes 275 | 276 | # Maximum number of lines in a module 277 | max-module-lines=99999 278 | 279 | # String used as indentation unit. The internal Google style guide mandates 2 280 | # spaces. Google's externaly-published style guide says 4, consistent with 281 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 282 | # projects (like TensorFlow). 283 | indent-string=' ' 284 | 285 | # Number of spaces of indent required inside a hanging or continued line. 286 | indent-after-paren=4 287 | 288 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 289 | expected-line-ending-format= 290 | 291 | 292 | [MISCELLANEOUS] 293 | 294 | # List of note tags to take in consideration, separated by a comma. 295 | notes=TODO 296 | 297 | 298 | [STRING] 299 | 300 | # This flag controls whether inconsistent-quotes generates a warning when the 301 | # character used as a quote delimiter is used inconsistently within a module. 302 | check-quote-consistency=yes 303 | 304 | 305 | [VARIABLES] 306 | 307 | # Tells whether we should check for unused import in __init__ files. 308 | init-import=no 309 | 310 | # A regular expression matching the name of dummy variables (i.e. expectedly 311 | # not used). 312 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 313 | 314 | # List of additional names supposed to be defined in builtins. Remember that 315 | # you should avoid to define new builtins when possible. 316 | additional-builtins= 317 | 318 | # List of strings which can identify a callback function by name. A callback 319 | # name must start or end with one of those strings. 320 | callbacks=cb_,_cb 321 | 322 | # List of qualified module names which can have objects that can redefine 323 | # builtins. 324 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 325 | 326 | 327 | [LOGGING] 328 | 329 | # Logging modules to check that the string format arguments are in logging 330 | # function parameter format 331 | logging-modules=logging,absl.logging,tensorflow.io.logging 332 | 333 | 334 | [SIMILARITIES] 335 | 336 | # Minimum lines number of a similarity. 337 | min-similarity-lines=4 338 | 339 | # Ignore comments when computing similarities. 340 | ignore-comments=yes 341 | 342 | # Ignore docstrings when computing similarities. 343 | ignore-docstrings=yes 344 | 345 | # Ignore imports when computing similarities. 346 | ignore-imports=no 347 | 348 | 349 | [SPELLING] 350 | 351 | # Spelling dictionary name. Available dictionaries: none. To make it working 352 | # install python-enchant package. 353 | spelling-dict= 354 | 355 | # List of comma separated words that should not be checked. 356 | spelling-ignore-words= 357 | 358 | # A path to a file that contains private dictionary; one word per line. 359 | spelling-private-dict-file= 360 | 361 | # Tells whether to store unknown words to indicated private dictionary in 362 | # --spelling-private-dict-file option instead of raising a message. 363 | spelling-store-unknown-words=no 364 | 365 | 366 | [IMPORTS] 367 | 368 | # Deprecated modules which should not be used, separated by a comma 369 | deprecated-modules=regsub, 370 | TERMIOS, 371 | Bastion, 372 | rexec, 373 | sets 374 | 375 | # Create a graph of every (i.e. internal and external) dependencies in the 376 | # given file (report RP0402 must not be disabled) 377 | import-graph= 378 | 379 | # Create a graph of external dependencies in the given file (report RP0402 must 380 | # not be disabled) 381 | ext-import-graph= 382 | 383 | # Create a graph of internal dependencies in the given file (report RP0402 must 384 | # not be disabled) 385 | int-import-graph= 386 | 387 | # Force import order to recognize a module as part of the standard 388 | # compatibility libraries. 389 | known-standard-library= 390 | 391 | # Force import order to recognize a module as part of a third party library. 392 | known-third-party=enchant, absl 393 | 394 | # Analyse import fallback blocks. This can be used to support both Python 2 and 395 | # 3 compatible code, which means that the block might have code that exists 396 | # only in one or another interpreter, leading to false positives when analysed. 397 | analyse-fallback-blocks=no 398 | 399 | 400 | [CLASSES] 401 | 402 | # List of method names used to declare (i.e. assign) instance attributes. 403 | defining-attr-methods=__init__, 404 | __new__, 405 | setUp 406 | 407 | # List of member names, which should be excluded from the protected access 408 | # warning. 409 | exclude-protected=_asdict, 410 | _fields, 411 | _replace, 412 | _source, 413 | _make 414 | 415 | # List of valid names for the first argument in a class method. 416 | valid-classmethod-first-arg=cls, 417 | class_ 418 | 419 | # List of valid names for the first argument in a metaclass class method. 420 | valid-metaclass-classmethod-first-arg=mcs 421 | 422 | 423 | [EXCEPTIONS] 424 | 425 | # Exceptions that will emit a warning when being caught. Defaults to 426 | # "Exception" 427 | overgeneral-exceptions=builtins.StandardError, 428 | builtins.Exception, 429 | builtins.BaseException 430 | -------------------------------------------------------------------------------- /mctx/_src/search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A JAX implementation of batched MCTS.""" 16 | import functools 17 | from typing import Any, NamedTuple, Optional, Tuple, TypeVar 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from mctx._src import action_selection, base 24 | from mctx._src import tree as tree_lib 25 | 26 | Tree = tree_lib.Tree 27 | T = TypeVar("T") 28 | 29 | 30 | def search( 31 | params: base.Params, 32 | rng_key: chex.PRNGKey, 33 | tree: Tree, 34 | *, 35 | recurrent_fn: base.RecurrentFn, 36 | root_action_selection_fn: base.RootActionSelectionFn, 37 | interior_action_selection_fn: base.InteriorActionSelectionFn, 38 | num_simulations: int, 39 | max_depth: Optional[int] = None, 40 | loop_fn: base.LoopFn = jax.lax.fori_loop) -> Tree: 41 | """Performs a full search and returns sampled actions. 42 | 43 | In the shape descriptions, `B` denotes the batch dimension. 44 | 45 | Args: 46 | params: params to be forwarded to root and recurrent functions. 47 | rng_key: random number generator state, the key is consumed. 48 | tree: the initialized MCTS tree state to search upon. 49 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 50 | actions retrieved by the simulation step, which takes as args 51 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 52 | and the new state embedding. The `rng_key` argument is consumed. 53 | root_action_selection_fn: function used to select an action at the root. 54 | interior_action_selection_fn: function used to select an action during 55 | simulation. 56 | num_simulations: the number of simulations. 57 | max_depth: maximum search tree depth allowed during simulation, defined as 58 | the number of edges from the root to a leaf node. 59 | loop_fn: Function used to run the simulations. It may be required to pass 60 | hk.fori_loop if using this function inside a Haiku module. 61 | 62 | Returns: 63 | `SearchResults` containing outcomes of the search, e.g. `visit_counts` 64 | `[B, num_actions]`. 65 | """ 66 | action_selection_fn = action_selection.switching_action_selection_wrapper( 67 | root_action_selection_fn=root_action_selection_fn, 68 | interior_action_selection_fn=interior_action_selection_fn 69 | ) 70 | 71 | # Do simulation, expansion, and backward steps. 72 | batch_size = tree.node_visits.shape[0] 73 | batch_range = jnp.arange(batch_size) 74 | tree_capacity = tree.children_visits.shape[1] 75 | if max_depth is None: 76 | max_depth = tree_capacity - 1 77 | 78 | def body_fun(_, loop_state): 79 | rng_key, tree = loop_state 80 | rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3) 81 | # simulate is vmapped and expects batched rng keys. 82 | simulate_keys = jax.random.split(simulate_key, batch_size) 83 | parent_index, action = simulate( 84 | simulate_keys, tree, action_selection_fn, max_depth) 85 | next_node_index = tree.children_index[batch_range, parent_index, action] 86 | unvisited = next_node_index == Tree.UNVISITED 87 | 88 | next_node_index = jnp.where(unvisited, 89 | tree.next_node_index, next_node_index) 90 | tree = expand( 91 | params, expand_key, tree, recurrent_fn, parent_index, 92 | action, next_node_index) 93 | # if next_node_index goes out of bounds (i.e. no room left for new nodes) 94 | # backward its (in-bounds) parent 95 | out_of_bounds = next_node_index >= tree_capacity 96 | next_node_index = jnp.where(out_of_bounds, 97 | parent_index, 98 | next_node_index) 99 | 100 | tree = backward(tree, next_node_index) 101 | # increment next_node_index if leaf was expanded 102 | tree = tree.replace(next_node_index=tree.next_node_index + \ 103 | jnp.logical_and(unvisited, (~out_of_bounds))) 104 | loop_state = rng_key, tree 105 | return loop_state 106 | 107 | _, tree = loop_fn( 108 | 0, num_simulations, body_fun, (rng_key, tree)) 109 | 110 | return tree 111 | 112 | 113 | class _SimulationState(NamedTuple): 114 | """The state for the simulation while loop.""" 115 | rng_key: chex.PRNGKey 116 | node_index: int 117 | action: int 118 | next_node_index: int 119 | depth: int 120 | is_continuing: bool 121 | 122 | 123 | @functools.partial(jax.vmap, in_axes=[0, 0, None, None], out_axes=0) 124 | def simulate( 125 | rng_key: chex.PRNGKey, 126 | tree: Tree, 127 | action_selection_fn: base.InteriorActionSelectionFn, 128 | max_depth: int) -> Tuple[chex.Array, chex.Array]: 129 | """Traverses the tree until reaching an unvisited action or `max_depth`. 130 | 131 | Each simulation starts from the root and keeps selecting actions traversing 132 | the tree until a leaf or `max_depth` is reached. 133 | 134 | Args: 135 | rng_key: random number generator state, the key is consumed. 136 | tree: _unbatched_ MCTS tree state. 137 | action_selection_fn: function used to select an action during simulation. 138 | max_depth: maximum search tree depth allowed during simulation. 139 | 140 | Returns: 141 | `(parent_index, action)` tuple, where `parent_index` is the index of the 142 | node reached at the end of the simulation, and the `action` is the action to 143 | evaluate from the `parent_index`. 144 | """ 145 | def cond_fun(state): 146 | return state.is_continuing 147 | 148 | def body_fun(state): 149 | # Preparing the next simulation state. 150 | node_index = state.next_node_index 151 | rng_key, action_selection_key = jax.random.split(state.rng_key) 152 | action = action_selection_fn(action_selection_key, tree, node_index, 153 | state.depth) 154 | next_node_index = tree.children_index[node_index, action] 155 | # The returned action will be visited. 156 | depth = state.depth + 1 157 | is_before_depth_cutoff = depth < max_depth 158 | is_visited = next_node_index != Tree.UNVISITED 159 | is_continuing = jnp.logical_and(is_visited, is_before_depth_cutoff) 160 | return _SimulationState( # pytype: disable=wrong-arg-types # jax-types 161 | rng_key=rng_key, 162 | node_index=node_index, 163 | action=action, 164 | next_node_index=next_node_index, 165 | depth=depth, 166 | is_continuing=is_continuing) 167 | 168 | node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32) 169 | depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype) 170 | # pytype: disable=wrong-arg-types # jnp-type 171 | initial_state = _SimulationState( 172 | rng_key=rng_key, 173 | node_index=tree.NO_PARENT, 174 | action=tree.NO_PARENT, 175 | next_node_index=node_index, 176 | depth=depth, 177 | is_continuing=jnp.array(True)) 178 | # pytype: enable=wrong-arg-types 179 | end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state) 180 | 181 | # Returning a node with a selected action. 182 | # The action can be already visited, if the max_depth is reached. 183 | return end_state.node_index, end_state.action 184 | 185 | 186 | def expand( 187 | params: chex.Array, 188 | rng_key: chex.PRNGKey, 189 | tree: Tree[T], 190 | recurrent_fn: base.RecurrentFn, 191 | parent_index: chex.Array, 192 | action: chex.Array, 193 | next_node_index: chex.Array) -> Tree[T]: 194 | """Create and evaluate child nodes from given nodes and unvisited actions. 195 | 196 | Args: 197 | params: params to be forwarded to recurrent function. 198 | rng_key: random number generator state. 199 | tree: the MCTS tree state to update. 200 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 201 | actions retrieved by the simulation step, which takes as args 202 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 203 | and the new state embedding. The `rng_key` argument is consumed. 204 | parent_index: the index of the parent node, from which the action will be 205 | expanded. Shape `[B]`. 206 | action: the action to expand. Shape `[B]`. 207 | next_node_index: the index of the newly expanded node. This can be the index 208 | of an existing node, if `max_depth` is reached. Shape `[B]`. 209 | 210 | Returns: 211 | tree: updated MCTS tree state. 212 | """ 213 | batch_size = tree_lib.infer_batch_size(tree) 214 | batch_range = jnp.arange(batch_size) 215 | chex.assert_shape([parent_index, action, next_node_index], (batch_size,)) 216 | 217 | # Retrieve states for nodes to be evaluated. 218 | embedding = jax.tree.map( 219 | lambda x: x[batch_range, parent_index], tree.embeddings) 220 | 221 | # Evaluate and create a new node. 222 | step, embedding = recurrent_fn(params, rng_key, action, embedding) 223 | chex.assert_shape(step.prior_logits, [batch_size, tree.num_actions]) 224 | chex.assert_shape(step.reward, [batch_size]) 225 | chex.assert_shape(step.discount, [batch_size]) 226 | chex.assert_shape(step.value, [batch_size]) 227 | tree = update_tree_node( 228 | tree, next_node_index, step.prior_logits, step.value, embedding) 229 | 230 | # handle out-of-bounds next_node_index, 231 | # (parent should still point to UNVISITED) 232 | next_node_index_check_oob = jnp.where(next_node_index > tree.num_simulations, 233 | tree.UNVISITED, 234 | next_node_index) 235 | # Return updated tree topology. 236 | return tree.replace( 237 | children_index=batch_update( 238 | tree.children_index, next_node_index_check_oob, parent_index, action), 239 | children_rewards=batch_update( 240 | tree.children_rewards, step.reward, parent_index, action), 241 | children_discounts=batch_update( 242 | tree.children_discounts, step.discount, parent_index, action), 243 | parents=batch_update(tree.parents, parent_index, next_node_index), 244 | action_from_parent=batch_update( 245 | tree.action_from_parent, action, next_node_index)) 246 | 247 | 248 | @jax.vmap 249 | def backward( 250 | tree: Tree[T], 251 | leaf_index: chex.Numeric) -> Tree[T]: 252 | """Goes up and updates the tree until all nodes reached the root. 253 | 254 | Args: 255 | tree: the MCTS tree state to update, without the batch size. 256 | leaf_index: the node index from which to do the backward. 257 | 258 | Returns: 259 | Updated MCTS tree state. 260 | """ 261 | 262 | def cond_fun(loop_state): 263 | _, _, index = loop_state 264 | return index != Tree.ROOT_INDEX 265 | 266 | def body_fun(loop_state): 267 | # Here we update the value of our parent, so we start by reversing. 268 | tree, leaf_value, index = loop_state 269 | parent = tree.parents[index] 270 | count = tree.node_visits[parent] 271 | action = tree.action_from_parent[index] 272 | reward = tree.children_rewards[parent, action] 273 | leaf_value = reward + tree.children_discounts[parent, action] * leaf_value 274 | parent_value = ( 275 | tree.node_values[parent] * count + leaf_value) / (count + 1.0) 276 | children_values = tree.node_values[index] 277 | children_counts = tree.children_visits[parent, action] + 1 278 | 279 | tree = tree.replace( 280 | node_values=update(tree.node_values, parent_value, parent), 281 | node_visits=update(tree.node_visits, count + 1, parent), 282 | children_values=update( 283 | tree.children_values, children_values, parent, action), 284 | children_visits=update( 285 | tree.children_visits, children_counts, parent, action)) 286 | 287 | return tree, leaf_value, parent 288 | 289 | leaf_index = jnp.asarray(leaf_index, dtype=jnp.int32) 290 | loop_state = (tree, tree.node_values[leaf_index], leaf_index) 291 | tree, _, _ = jax.lax.while_loop(cond_fun, body_fun, loop_state) 292 | return tree 293 | 294 | 295 | # Utility function to set the values of certain indices to prescribed values. 296 | # This is vmapped to operate seamlessly on batches. 297 | def update(x, vals, *indices): 298 | return x.at[indices].set(vals) 299 | 300 | 301 | batch_update = jax.vmap(update) 302 | 303 | 304 | def update_tree_node( 305 | tree: Tree[T], 306 | node_index: chex.Array, 307 | prior_logits: chex.Array, 308 | value: chex.Array, 309 | embedding: chex.Array) -> Tree[T]: 310 | """Updates the tree at node index. 311 | 312 | Args: 313 | tree: `Tree` to whose node is to be updated. 314 | node_index: the index of the expanded node. Shape `[B]`. 315 | prior_logits: the prior logits to fill in for the new node, of shape 316 | `[B, num_actions]`. 317 | value: the value to fill in for the new node. Shape `[B]`. 318 | embedding: the state embeddings for the node. Shape `[B, ...]`. 319 | 320 | Returns: 321 | The new tree with updated nodes. 322 | """ 323 | batch_size = tree_lib.infer_batch_size(tree) 324 | batch_range = jnp.arange(batch_size) 325 | chex.assert_shape(prior_logits, (batch_size, tree.num_actions)) 326 | 327 | # When using max_depth, a leaf can be expanded multiple times. 328 | new_visit = tree.node_visits[batch_range, node_index] + 1 329 | updates = dict( # pylint: disable=use-dict-literal 330 | children_prior_logits=batch_update( 331 | tree.children_prior_logits, prior_logits, node_index), 332 | raw_values=batch_update( 333 | tree.raw_values, value, node_index), 334 | node_values=batch_update( 335 | tree.node_values, value, node_index), 336 | node_visits=batch_update( 337 | tree.node_visits, new_visit, node_index), 338 | embeddings=jax.tree.map( 339 | lambda t, s: batch_update(t, s, node_index), 340 | tree.embeddings, embedding)) 341 | 342 | return tree.replace(**updates) 343 | 344 | 345 | def instantiate_tree_from_root( 346 | root: base.RootFnOutput, 347 | num_nodes: int, 348 | extra_data: Any, 349 | root_invalid_actions: Optional[chex.Array] = None) -> Tree: 350 | """Initializes tree state at search root.""" 351 | chex.assert_rank(root.prior_logits, 2) 352 | batch_size, num_actions = root.prior_logits.shape 353 | chex.assert_shape(root.value, [batch_size]) 354 | 355 | data_dtype = root.value.dtype 356 | batch_node = (batch_size, num_nodes) 357 | batch_node_action = (batch_size, num_nodes, num_actions) 358 | 359 | if root_invalid_actions is None: 360 | root_invalid_actions = jnp.zeros_like(root.prior_logits) 361 | 362 | def _zeros(x): 363 | return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype) 364 | 365 | # Create a new empty tree state and fill its root. 366 | tree = Tree( 367 | node_visits=jnp.zeros(batch_node, dtype=jnp.int32), 368 | raw_values=jnp.zeros(batch_node, dtype=data_dtype), 369 | node_values=jnp.zeros(batch_node, dtype=data_dtype), 370 | parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32), 371 | action_from_parent=jnp.full( 372 | batch_node, Tree.NO_PARENT, dtype=jnp.int32), 373 | children_index=jnp.full( 374 | batch_node_action, Tree.UNVISITED, dtype=jnp.int32), 375 | children_prior_logits=jnp.zeros( 376 | batch_node_action, dtype=root.prior_logits.dtype), 377 | children_values=jnp.zeros(batch_node_action, dtype=data_dtype), 378 | children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32), 379 | children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype), 380 | children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype), 381 | next_node_index=jnp.ones((batch_size,), dtype=jnp.int32), 382 | embeddings=jax.tree_util.tree_map(_zeros, root.embedding), 383 | root_invalid_actions=root_invalid_actions, 384 | extra_data=extra_data) 385 | 386 | root_index = jnp.full([batch_size], Tree.ROOT_INDEX) 387 | tree = update_tree_node( 388 | tree, root_index, root.prior_logits, root.value, root.embedding) 389 | return tree 390 | 391 | 392 | def update_tree_with_root( 393 | tree: Tree, 394 | root: base.RootFnOutput, 395 | extra_data: Any, 396 | root_invalid_actions: Optional[chex.Array] = None) -> Tree: 397 | """Given a tree, updates its root with the given root output 398 | if it's not already populated.""" 399 | root_uninitialized = tree.node_visits[:, Tree.ROOT_INDEX] == 0 400 | batch_size = tree_lib.infer_batch_size(tree) 401 | root_index = jnp.full([batch_size], Tree.ROOT_INDEX) 402 | ones = jnp.ones([batch_size], dtype=jnp.int32) 403 | 404 | if root_invalid_actions is None: 405 | root_invalid_actions = jnp.zeros_like(root.prior_logits) 406 | 407 | updates = dict( # pylint: disable=use-dict-literal 408 | children_prior_logits=batch_update( 409 | tree.children_prior_logits, root.prior_logits, root_index), 410 | raw_values=batch_update( 411 | tree.raw_values, root.value, root_index), 412 | node_values=jnp.where( 413 | root_uninitialized[..., None], 414 | batch_update(tree.node_values, root.value, root_index), 415 | tree.node_values), 416 | node_visits=jnp.where( 417 | root_uninitialized[..., None], 418 | batch_update(tree.node_visits, ones, root_index), 419 | tree.node_visits), 420 | embeddings=jax.tree_util.tree_map( 421 | lambda t, s: batch_update(t, s, root_index), 422 | tree.embeddings, root.embedding), 423 | root_invalid_actions=root_invalid_actions, 424 | extra_data=extra_data) 425 | 426 | return tree.replace(**updates) 427 | -------------------------------------------------------------------------------- /mctx/_src/policies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Search policies.""" 16 | import functools 17 | from typing import Optional, Tuple 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from mctx._src import action_selection 24 | from mctx._src import base 25 | from mctx._src import qtransforms 26 | from mctx._src import search 27 | from mctx._src import seq_halving 28 | 29 | 30 | def muzero_policy( 31 | params: base.Params, 32 | rng_key: chex.PRNGKey, 33 | root: base.RootFnOutput, 34 | recurrent_fn: base.RecurrentFn, 35 | num_simulations: int, 36 | invalid_actions: Optional[chex.Array] = None, 37 | max_depth: Optional[int] = None, 38 | loop_fn: base.LoopFn = jax.lax.fori_loop, 39 | *, 40 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 41 | dirichlet_fraction: chex.Numeric = 0.25, 42 | dirichlet_alpha: chex.Numeric = 0.3, 43 | pb_c_init: chex.Numeric = 1.25, 44 | pb_c_base: chex.Numeric = 19652, 45 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]: 46 | """Runs MuZero search and returns the `PolicyOutput`. 47 | 48 | In the shape descriptions, `B` denotes the batch dimension. 49 | 50 | Args: 51 | params: params to be forwarded to root and recurrent functions. 52 | rng_key: random number generator state, the key is consumed. 53 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 54 | `prior_logits` are from a policy network. The shapes are 55 | `([B, num_actions], [B], [B, ...])`, respectively. 56 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 57 | actions retrieved by the simulation step, which takes as args 58 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 59 | and the new state embedding. The `rng_key` argument is consumed. 60 | num_simulations: the number of simulations. 61 | invalid_actions: a mask with invalid actions. Invalid actions 62 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. 63 | max_depth: maximum search tree depth allowed during simulation. 64 | loop_fn: Function used to run the simulations. It may be required to pass 65 | hk.fori_loop if using this function inside a Haiku module. 66 | qtransform: function to obtain completed Q-values for a node. 67 | dirichlet_fraction: float from 0 to 1 interpolating between using only the 68 | prior policy or just the Dirichlet noise. 69 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet 70 | distribution. 71 | pb_c_init: constant c_1 in the PUCT formula. 72 | pb_c_base: constant c_2 in the PUCT formula. 73 | temperature: temperature for acting proportionally to 74 | `visit_counts**(1 / temperature)`. 75 | 76 | Returns: 77 | `PolicyOutput` containing the proposed action, action_weights and the used 78 | search tree. 79 | """ 80 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3) 81 | 82 | # Adding Dirichlet noise. 83 | noisy_logits = _get_logits_from_probs( 84 | _add_dirichlet_noise( 85 | dirichlet_rng_key, 86 | jax.nn.softmax(root.prior_logits), 87 | dirichlet_fraction=dirichlet_fraction, 88 | dirichlet_alpha=dirichlet_alpha)) 89 | root = root.replace( 90 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)) 91 | 92 | # Running the search. 93 | interior_action_selection_fn = functools.partial( 94 | action_selection.muzero_action_selection, 95 | pb_c_base=pb_c_base, 96 | pb_c_init=pb_c_init, 97 | qtransform=qtransform) 98 | root_action_selection_fn = functools.partial( 99 | interior_action_selection_fn, 100 | depth=0) 101 | search_tree = search.search( 102 | params=params, 103 | rng_key=search_rng_key, 104 | tree=search.instantiate_tree_from_root( 105 | root, num_simulations+1, 106 | root_invalid_actions=invalid_actions, 107 | extra_data=None 108 | ), 109 | recurrent_fn=recurrent_fn, 110 | root_action_selection_fn=root_action_selection_fn, 111 | interior_action_selection_fn=interior_action_selection_fn, 112 | num_simulations=num_simulations, 113 | max_depth=max_depth, 114 | loop_fn=loop_fn) 115 | 116 | # Sampling the proposed action proportionally to the visit counts. 117 | summary = search_tree.summary() 118 | action_weights = summary.visit_probs 119 | action_logits = _apply_temperature( 120 | _get_logits_from_probs(action_weights), temperature) 121 | action = jax.random.categorical(rng_key, action_logits) 122 | return base.PolicyOutput( 123 | action=action, 124 | action_weights=action_weights, 125 | search_tree=search_tree) 126 | 127 | 128 | def gumbel_muzero_policy( 129 | params: base.Params, 130 | rng_key: chex.PRNGKey, 131 | root: base.RootFnOutput, 132 | recurrent_fn: base.RecurrentFn, 133 | num_simulations: int, 134 | invalid_actions: Optional[chex.Array] = None, 135 | max_depth: Optional[int] = None, 136 | loop_fn: base.LoopFn = jax.lax.fori_loop, 137 | *, 138 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value, 139 | max_num_considered_actions: int = 16, 140 | gumbel_scale: chex.Numeric = 1., 141 | ) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]: 142 | """Runs Gumbel MuZero search and returns the `PolicyOutput`. 143 | 144 | This policy implements Full Gumbel MuZero from 145 | "Policy improvement by planning with Gumbel". 146 | https://openreview.net/forum?id=bERaNdoegnO 147 | 148 | At the root of the search tree, actions are selected by Sequential Halving 149 | with Gumbel. At non-root nodes (aka interior nodes), actions are selected by 150 | the Full Gumbel MuZero deterministic action selection. 151 | 152 | In the shape descriptions, `B` denotes the batch dimension. 153 | 154 | Args: 155 | params: params to be forwarded to root and recurrent functions. 156 | rng_key: random number generator state, the key is consumed. 157 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 158 | `prior_logits` are from a policy network. The shapes are 159 | `([B, num_actions], [B], [B, ...])`, respectively. 160 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 161 | actions retrieved by the simulation step, which takes as args 162 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 163 | and the new state embedding. The `rng_key` argument is consumed. 164 | num_simulations: the number of simulations. 165 | invalid_actions: a mask with invalid actions. Invalid actions 166 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. 167 | max_depth: maximum search tree depth allowed during simulation. 168 | loop_fn: Function used to run the simulations. It may be required to pass 169 | hk.fori_loop if using this function inside a Haiku module. 170 | qtransform: function to obtain completed Q-values for a node. 171 | max_num_considered_actions: the maximum number of actions expanded at the 172 | root node. A smaller number of actions will be expanded if the number of 173 | valid actions is smaller. 174 | gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information 175 | games can use gumbel_scale=0.0. 176 | 177 | Returns: 178 | `PolicyOutput` containing the proposed action, action_weights and the used 179 | search tree. 180 | """ 181 | # Masking invalid actions. 182 | root = root.replace( 183 | prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions)) 184 | 185 | # Generating Gumbel. 186 | rng_key, gumbel_rng = jax.random.split(rng_key) 187 | gumbel = gumbel_scale * jax.random.gumbel( 188 | gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype) 189 | 190 | 191 | extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel) 192 | # Searching. 193 | search_tree = search.instantiate_tree_from_root( 194 | root, num_simulations+1, 195 | root_invalid_actions=invalid_actions, 196 | extra_data=extra_data) 197 | search_tree = search.search( 198 | params=params, 199 | rng_key=rng_key, 200 | tree=search.instantiate_tree_from_root( 201 | root, num_simulations+1, 202 | root_invalid_actions=invalid_actions, 203 | extra_data=extra_data 204 | ), 205 | recurrent_fn=recurrent_fn, 206 | root_action_selection_fn=functools.partial( 207 | action_selection.gumbel_muzero_root_action_selection, 208 | num_simulations=num_simulations, 209 | max_num_considered_actions=max_num_considered_actions, 210 | qtransform=qtransform, 211 | ), 212 | interior_action_selection_fn=functools.partial( 213 | action_selection.gumbel_muzero_interior_action_selection, 214 | qtransform=qtransform, 215 | ), 216 | num_simulations=num_simulations, 217 | max_depth=max_depth, 218 | loop_fn=loop_fn) 219 | summary = search_tree.summary() 220 | 221 | # Acting with the best action from the most visited actions. 222 | # The "best" action has the highest `gumbel + logits + q`. 223 | # Inside the minibatch, the considered_visit can be different on states with 224 | # a smaller number of valid actions. 225 | considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True) 226 | # The completed_qvalues include imputed values for unvisited actions. 227 | completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])( # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long 228 | search_tree, search_tree.ROOT_INDEX) 229 | to_argmax = seq_halving.score_considered( 230 | considered_visit, gumbel, root.prior_logits, completed_qvalues, 231 | summary.visit_counts) 232 | action = action_selection.masked_argmax(to_argmax, invalid_actions) 233 | 234 | # Producing action_weights usable to train the policy network. 235 | completed_search_logits = _mask_invalid_actions( 236 | root.prior_logits + completed_qvalues, invalid_actions) 237 | action_weights = jax.nn.softmax(completed_search_logits) 238 | return base.PolicyOutput( 239 | action=action, 240 | action_weights=action_weights, 241 | search_tree=search_tree) 242 | 243 | 244 | def alphazero_policy( 245 | params: base.Params, 246 | rng_key: chex.PRNGKey, 247 | root: base.RootFnOutput, 248 | recurrent_fn: base.RecurrentFn, 249 | num_simulations: int, 250 | search_tree: Optional[search.Tree] = None, 251 | max_nodes: Optional[int] = None, 252 | invalid_actions: Optional[chex.Array] = None, 253 | max_depth: Optional[int] = None, 254 | loop_fn: base.LoopFn = jax.lax.fori_loop, 255 | *, 256 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 257 | dirichlet_fraction: chex.Numeric = 0.25, 258 | dirichlet_alpha: chex.Numeric = 0.3, 259 | pb_c_init: chex.Numeric = 1.25, 260 | pb_c_base: chex.Numeric = 19652, 261 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]: 262 | """Runs AlphaZero search and returns the `PolicyOutput`. 263 | 264 | Allows for continuing search from a provided `Tree` instance. 265 | If no tree is provided, a new tree with capacity `max_nodes` is created. 266 | This policy is otherwise identical to `muzero_policy`. 267 | 268 | In the shape descriptions, `B` denotes the batch dimension. 269 | 270 | Args: 271 | params: params to be forwarded to root and recurrent functions. 272 | rng_key: random number generator state, the key is consumed. 273 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 274 | `prior_logits` are from a policy network. The shapes are 275 | `([B, num_actions], [B], [B, ...])`, respectively. 276 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 277 | actions retrieved by the simulation step, which takes as args 278 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 279 | and the new state embedding. The `rng_key` argument is consumed. 280 | num_simulations: the number of simulations. 281 | search_tree: If provided, continue the search from this tree. If not 282 | provided, a new tree with capacity `max_nodes` is created. 283 | max_nodes: Specifies the capacity to initialize the search tree with, if 284 | `search_tree` is not provided. 285 | invalid_actions: a mask with invalid actions. Invalid actions 286 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. 287 | max_depth: maximum search tree depth allowed during simulation. 288 | loop_fn: Function used to run the simulations. It may be required to pass 289 | hk.fori_loop if using this function inside a Haiku module. 290 | qtransform: function to obtain completed Q-values for a node. 291 | dirichlet_fraction: float from 0 to 1 interpolating between using only the 292 | prior policy or just the Dirichlet noise. 293 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet 294 | distribution. 295 | pb_c_init: constant c_1 in the PUCT formula. 296 | pb_c_base: constant c_2 in the PUCT formula. 297 | temperature: temperature for acting proportionally to 298 | `visit_counts**(1 / temperature)`. 299 | 300 | Returns: 301 | `PolicyOutput` containing the proposed action, action_weights and the used 302 | search tree. 303 | """ 304 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3) 305 | 306 | # Adding Dirichlet noise. 307 | noisy_logits = _get_logits_from_probs( 308 | _add_dirichlet_noise( 309 | dirichlet_rng_key, 310 | jax.nn.softmax(root.prior_logits), 311 | dirichlet_fraction=dirichlet_fraction, 312 | dirichlet_alpha=dirichlet_alpha)) 313 | root = root.replace( 314 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)) 315 | 316 | # Running the search. 317 | interior_action_selection_fn = functools.partial( 318 | action_selection.muzero_action_selection, 319 | pb_c_base=pb_c_base, 320 | pb_c_init=pb_c_init, 321 | qtransform=qtransform) 322 | root_action_selection_fn = functools.partial( 323 | interior_action_selection_fn, 324 | depth=0) 325 | 326 | if search_tree is None: 327 | if max_nodes is None: 328 | max_nodes = num_simulations + 1 329 | search_tree = search.instantiate_tree_from_root( 330 | root, max_nodes, 331 | root_invalid_actions=invalid_actions, 332 | extra_data=None 333 | ) 334 | else: 335 | search_tree = search.update_tree_with_root( 336 | search_tree, root, 337 | root_invalid_actions=invalid_actions, extra_data=None) 338 | 339 | search_tree = search.search( 340 | params=params, 341 | rng_key=search_rng_key, 342 | tree=search_tree, 343 | recurrent_fn=recurrent_fn, 344 | root_action_selection_fn=root_action_selection_fn, 345 | interior_action_selection_fn=interior_action_selection_fn, 346 | num_simulations=num_simulations, 347 | max_depth=max_depth, 348 | loop_fn=loop_fn) 349 | 350 | # Sampling the proposed action proportionally to the visit counts. 351 | summary = search_tree.summary() 352 | action_weights = summary.visit_probs 353 | action_logits = _apply_temperature( 354 | _get_logits_from_probs(action_weights), temperature) 355 | action = jax.random.categorical(rng_key, action_logits) 356 | return base.PolicyOutput( 357 | action=action, 358 | action_weights=action_weights, 359 | search_tree=search_tree) 360 | 361 | 362 | def stochastic_muzero_policy( 363 | params: chex.ArrayTree, 364 | rng_key: chex.PRNGKey, 365 | root: base.RootFnOutput, 366 | decision_recurrent_fn: base.DecisionRecurrentFn, 367 | chance_recurrent_fn: base.ChanceRecurrentFn, 368 | num_simulations: int, 369 | invalid_actions: Optional[chex.Array] = None, 370 | max_depth: Optional[int] = None, 371 | loop_fn: base.LoopFn = jax.lax.fori_loop, 372 | *, 373 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 374 | dirichlet_fraction: chex.Numeric = 0.25, 375 | dirichlet_alpha: chex.Numeric = 0.3, 376 | pb_c_init: chex.Numeric = 1.25, 377 | pb_c_base: chex.Numeric = 19652, 378 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]: 379 | """Runs Stochastic MuZero search. 380 | 381 | Implements search as described in the Stochastic MuZero paper: 382 | (https://openreview.net/forum?id=X6D9bAHhBQ1). 383 | 384 | In the shape descriptions, `B` denotes the batch dimension. 385 | Args: 386 | params: params to be forwarded to root and recurrent functions. 387 | rng_key: random number generator state, the key is consumed. 388 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 389 | `prior_logits` are from a policy network. The shapes are `([B, 390 | num_actions], [B], [B, ...])`, respectively. 391 | decision_recurrent_fn: a callable to be called on the leaf decision nodes 392 | and unvisited actions retrieved by the simulation step, which takes as 393 | args `(params, rng_key, action, state_embedding)` and returns a 394 | `(DecisionRecurrentFnOutput, afterstate_embedding)`. 395 | chance_recurrent_fn: a callable to be called on the leaf chance nodes and 396 | unvisited actions retrieved by the simulation step, which takes as args 397 | `(params, rng_key, chance_outcome, afterstate_embedding)` and returns a 398 | `(ChanceRecurrentFnOutput, state_embedding)`. 399 | num_simulations: the number of simulations. 400 | invalid_actions: a mask with invalid actions. Invalid actions have ones, 401 | valid actions have zeros in the mask. Shape `[B, num_actions]`. 402 | max_depth: maximum search tree depth allowed during simulation. 403 | loop_fn: Function used to run the simulations. It may be required to pass 404 | hk.fori_loop if using this function inside a Haiku module. 405 | qtransform: function to obtain completed Q-values for a node. 406 | dirichlet_fraction: float from 0 to 1 interpolating between using only the 407 | prior policy or just the Dirichlet noise. 408 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet 409 | distribution. 410 | pb_c_init: constant c_1 in the PUCT formula. 411 | pb_c_base: constant c_2 in the PUCT formula. 412 | temperature: temperature for acting proportionally to `visit_counts**(1 / 413 | temperature)`. 414 | 415 | Returns: 416 | `PolicyOutput` containing the proposed action, action_weights and the used 417 | search tree. 418 | """ 419 | 420 | num_actions = root.prior_logits.shape[-1] 421 | 422 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3) 423 | 424 | # Adding Dirichlet noise. 425 | noisy_logits = _get_logits_from_probs( 426 | _add_dirichlet_noise( 427 | dirichlet_rng_key, 428 | jax.nn.softmax(root.prior_logits), 429 | dirichlet_fraction=dirichlet_fraction, 430 | dirichlet_alpha=dirichlet_alpha)) 431 | 432 | root = root.replace( 433 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)) 434 | 435 | # construct a dummy afterstate embedding 436 | batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0] 437 | dummy_action = jnp.zeros([batch_size], dtype=jnp.int32) 438 | dummy_output, dummy_afterstate_embedding = decision_recurrent_fn( 439 | params, rng_key, dummy_action, root.embedding) 440 | num_chance_outcomes = dummy_output.chance_logits.shape[-1] 441 | 442 | root = root.replace( 443 | # pad action logits with num_chance_outcomes so dim is A + C 444 | prior_logits=jnp.concatenate([ 445 | root.prior_logits, 446 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf) 447 | ], axis=-1), 448 | # replace embedding with wrapper. 449 | embedding=base.StochasticRecurrentState( 450 | state_embedding=root.embedding, 451 | afterstate_embedding=dummy_afterstate_embedding, 452 | is_decision_node=jnp.ones([batch_size], dtype=bool))) 453 | 454 | # Stochastic MuZero Change: We need to be able to tell if different nodes are 455 | # decision or chance. This is accomplished by imposing a special structure 456 | # on the embeddings stored in each node. Each embedding is an instance of 457 | # StochasticRecurrentState which maintains this information. 458 | recurrent_fn = _make_stochastic_recurrent_fn( 459 | decision_node_fn=decision_recurrent_fn, 460 | chance_node_fn=chance_recurrent_fn, 461 | num_actions=num_actions, 462 | num_chance_outcomes=num_chance_outcomes, 463 | ) 464 | 465 | # Running the search. 466 | 467 | interior_decision_node_selection_fn = functools.partial( 468 | action_selection.muzero_action_selection, 469 | pb_c_base=pb_c_base, 470 | pb_c_init=pb_c_init, 471 | qtransform=qtransform) 472 | 473 | interior_action_selection_fn = _make_stochastic_action_selection_fn( 474 | interior_decision_node_selection_fn, num_actions) 475 | 476 | root_action_selection_fn = functools.partial( 477 | interior_action_selection_fn, depth=0) 478 | 479 | search_tree = search.instantiate_tree_from_root( 480 | root, num_simulations+1, 481 | root_invalid_actions=invalid_actions, 482 | extra_data=None) 483 | search_tree = search.search( 484 | params=params, 485 | rng_key=search_rng_key, 486 | tree=search_tree, 487 | recurrent_fn=recurrent_fn, 488 | root_action_selection_fn=root_action_selection_fn, 489 | interior_action_selection_fn=interior_action_selection_fn, 490 | num_simulations=num_simulations, 491 | max_depth=max_depth, 492 | loop_fn=loop_fn) 493 | 494 | # Sampling the proposed action proportionally to the visit counts. 495 | search_tree = _mask_tree(search_tree, num_actions, 'decision') 496 | summary = search_tree.summary() 497 | action_weights = summary.visit_probs 498 | action_logits = _apply_temperature( 499 | _get_logits_from_probs(action_weights), temperature) 500 | action = jax.random.categorical(rng_key, action_logits) 501 | return base.PolicyOutput( 502 | action=action, action_weights=action_weights, search_tree=search_tree) 503 | 504 | 505 | def _mask_invalid_actions(logits, invalid_actions): 506 | """Returns logits with zero mass to invalid actions.""" 507 | if invalid_actions is None: 508 | return logits 509 | chex.assert_equal_shape([logits, invalid_actions]) 510 | logits = logits - jnp.max(logits, axis=-1, keepdims=True) 511 | # At the end of an episode, all actions can be invalid. A softmax would then 512 | # produce NaNs, if using -inf for the logits. We avoid the NaNs by using 513 | # a finite `min_logit` for the invalid actions. 514 | min_logit = jnp.finfo(logits.dtype).min 515 | return jnp.where(invalid_actions, min_logit, logits) 516 | 517 | 518 | def _get_logits_from_probs(probs): 519 | tiny = jnp.finfo(probs.dtype).tiny 520 | return jnp.log(jnp.maximum(probs, tiny)) 521 | 522 | 523 | def _add_dirichlet_noise(rng_key, probs, *, dirichlet_alpha, 524 | dirichlet_fraction): 525 | """Mixes the probs with Dirichlet noise.""" 526 | chex.assert_rank(probs, 2) 527 | chex.assert_type([dirichlet_alpha, dirichlet_fraction], float) 528 | 529 | batch_size, num_actions = probs.shape 530 | noise = jax.random.dirichlet( 531 | rng_key, 532 | alpha=jnp.full([num_actions], fill_value=dirichlet_alpha), 533 | shape=(batch_size,)) 534 | noisy_probs = (1 - dirichlet_fraction) * probs + dirichlet_fraction * noise 535 | return noisy_probs 536 | 537 | 538 | def _apply_temperature(logits, temperature): 539 | """Returns `logits / temperature`, supporting also temperature=0.""" 540 | # The max subtraction prevents +inf after dividing by a small temperature. 541 | logits = logits - jnp.max(logits, keepdims=True, axis=-1) 542 | tiny = jnp.finfo(logits.dtype).tiny 543 | return logits / jnp.maximum(tiny, temperature) 544 | 545 | 546 | def _make_stochastic_recurrent_fn( 547 | decision_node_fn: base.DecisionRecurrentFn, 548 | chance_node_fn: base.ChanceRecurrentFn, 549 | num_actions: int, 550 | num_chance_outcomes: int, 551 | ) -> base.RecurrentFn: 552 | """Make Stochastic Recurrent Fn.""" 553 | 554 | def stochastic_recurrent_fn( 555 | params: base.Params, 556 | rng: chex.PRNGKey, 557 | action_or_chance: base.Action, # [B] 558 | state: base.StochasticRecurrentState 559 | ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentState]: 560 | batch_size = jax.tree_util.tree_leaves(state.state_embedding)[0].shape[0] 561 | # Internally we assume that there are `A' = A + C` "actions"; 562 | # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,. 563 | # To interpret it as an action we can leave it as is: 564 | action = action_or_chance - 0 565 | # To interpret it as a chance outcome we subtract num_actions: 566 | chance_outcome = action_or_chance - num_actions 567 | 568 | decision_output, afterstate_embedding = decision_node_fn( 569 | params, rng, action, state.state_embedding) 570 | # Outputs from DecisionRecurrentFunction produce chance logits with 571 | # dim `C`, to respect our internal convention that there are `A' = A + C` 572 | # "actions" we pad with `A` dummy logits which are ultimately ignored: 573 | # see `_mask_tree`. 574 | output_if_decision_node = base.RecurrentFnOutput( 575 | prior_logits=jnp.concatenate([ 576 | jnp.full([batch_size, num_actions], fill_value=-jnp.inf), 577 | decision_output.chance_logits], axis=-1), 578 | value=decision_output.afterstate_value, 579 | reward=jnp.zeros_like(decision_output.afterstate_value), 580 | discount=jnp.ones_like(decision_output.afterstate_value)) 581 | 582 | chance_output, state_embedding = chance_node_fn(params, rng, chance_outcome, 583 | state.afterstate_embedding) 584 | # Outputs from ChanceRecurrentFunction produce action logits with dim `A`, 585 | # to respect our internal convention that there are `A' = A + C` "actions" 586 | # we pad with `C` dummy logits which are ultimately ignored: see 587 | # `_mask_tree`. 588 | output_if_chance_node = base.RecurrentFnOutput( 589 | prior_logits=jnp.concatenate([ 590 | chance_output.action_logits, 591 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf) 592 | ], axis=-1), 593 | value=chance_output.value, 594 | reward=chance_output.reward, 595 | discount=chance_output.discount) 596 | 597 | new_state = base.StochasticRecurrentState( 598 | state_embedding=state_embedding, 599 | afterstate_embedding=afterstate_embedding, 600 | is_decision_node=jnp.logical_not(state.is_decision_node)) 601 | 602 | def _broadcast_where(decision_leaf, chance_leaf): 603 | extra_dims = [1] * (len(decision_leaf.shape) - 1) 604 | expanded_is_decision = jnp.reshape(state.is_decision_node, 605 | [-1] + extra_dims) 606 | return jnp.where( 607 | # ensure state.is_decision node has appropriate shape. 608 | expanded_is_decision, 609 | decision_leaf, chance_leaf) 610 | 611 | output = jax.tree.map(_broadcast_where, 612 | output_if_decision_node, 613 | output_if_chance_node) 614 | return output, new_state 615 | 616 | return stochastic_recurrent_fn 617 | 618 | 619 | def _mask_tree(tree: search.Tree, num_actions: int, mode: str) -> search.Tree: 620 | """Masks out parts of the tree based upon node type. 621 | 622 | "Actions" in our tree can either be action or chance values: A' = A + C. This 623 | utility function masks the parts of the tree containing dimensions of shape 624 | A' to be either A or C depending upon `mode`. 625 | 626 | Args: 627 | tree: The tree to be masked. 628 | num_actions: The number of environment actions A. 629 | mode: Either "decision" or "chance". 630 | 631 | Returns: 632 | An appropriately masked tree. 633 | """ 634 | 635 | def _take_slice(x): 636 | if mode == 'decision': 637 | return x[..., :num_actions] 638 | elif mode == 'chance': 639 | return x[..., num_actions:] 640 | else: 641 | raise ValueError(f'Unknown mode: {mode}.') 642 | 643 | return tree.replace( 644 | children_index=_take_slice(tree.children_index), 645 | children_prior_logits=_take_slice(tree.children_prior_logits), 646 | children_visits=_take_slice(tree.children_visits), 647 | children_rewards=_take_slice(tree.children_rewards), 648 | children_discounts=_take_slice(tree.children_discounts), 649 | children_values=_take_slice(tree.children_values), 650 | root_invalid_actions=_take_slice(tree.root_invalid_actions)) 651 | 652 | 653 | def _make_stochastic_action_selection_fn( 654 | decision_node_selection_fn: base.InteriorActionSelectionFn, 655 | num_actions: int, 656 | ) -> base.InteriorActionSelectionFn: 657 | """Make Stochastic Action Selection Fn.""" 658 | 659 | # NOTE: trees are unbatched here. 660 | 661 | def _chance_node_selection_fn( 662 | tree: search.Tree, 663 | node_index: chex.Array, 664 | ) -> chex.Array: 665 | num_chance = tree.children_visits[node_index] 666 | chance_logits = tree.children_prior_logits[node_index] 667 | prob_chance = jax.nn.softmax(chance_logits) 668 | argmax_chance = jnp.argmax(prob_chance / (num_chance + 1), axis=-1).astype( 669 | jnp.int32 670 | ) 671 | return argmax_chance 672 | 673 | def _action_selection_fn(key: chex.PRNGKey, tree: search.Tree, 674 | node_index: chex.Array, 675 | depth: chex.Array) -> chex.Array: 676 | is_decision = tree.embeddings.is_decision_node[node_index] 677 | chance_selection = _chance_node_selection_fn( 678 | tree=_mask_tree(tree, num_actions, 'chance'), 679 | node_index=node_index) + num_actions 680 | decision_selection = decision_node_selection_fn( 681 | key, _mask_tree(tree, num_actions, 'decision'), node_index, depth) 682 | return jax.lax.cond(is_decision, lambda: decision_selection, 683 | lambda: chance_selection) 684 | 685 | return _action_selection_fn 686 | -------------------------------------------------------------------------------- /connect4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# MCTS in MCTX\n", 9 | "\n", 10 | "In this example, we will use the `mctx` library to play the game of Connect 4.\n", 11 | "We will implement the Monte Carlo Tree Search (MCTS) algorithm with random rollouts." 12 | ] 13 | }, 14 | { 15 | "attachments": {}, 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## Game mechanics" 20 | ] 21 | }, 22 | { 23 | "attachments": {}, 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "Let's start by defining some type aliases to make the code more readable." 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import chex\n", 37 | "\n", 38 | "# 6x7 board\n", 39 | "# We use the following coordinate system:\n", 40 | "# ^\n", 41 | "# |\n", 42 | "# |\n", 43 | "# |\n", 44 | "# 0 +-------->\n", 45 | "# 0\n", 46 | "Board = chex.Array\n", 47 | "\n", 48 | "# Index of the column to play\n", 49 | "Action = chex.Array\n", 50 | "\n", 51 | "# Let's assume the game is played by players X and O.\n", 52 | "# 1 if player X, -1 if player O\n", 53 | "Player = chex.Array\n", 54 | "\n", 55 | "# Reward for the player who played the action.\n", 56 | "# 1 for winning, 0 for draw, -1 for losing\n", 57 | "Reward = chex.Array\n", 58 | "\n", 59 | "# True/False if the game is over\n", 60 | "Done = chex.Array" 61 | ] 62 | }, 63 | { 64 | "attachments": {}, 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "Now, we create a class defining the game state at a given time." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "@chex.dataclass\n", 78 | "class Env:\n", 79 | " board: Board\n", 80 | " player: Player\n", 81 | " done: Done\n", 82 | " reward: Reward" 83 | ] 84 | }, 85 | { 86 | "attachments": {}, 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "To visualize the game state, we define a function that prints the board." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "from rich import print\n", 100 | "\n", 101 | "BOARD_STRING = \"\"\"\n", 102 | " ? | ? | ? | ? | ? | ? | ?\n", 103 | "---|---|---|---|---|---|---\n", 104 | " ? | ? | ? | ? | ? | ? | ?\n", 105 | "---|---|---|---|---|---|---\n", 106 | " ? | ? | ? | ? | ? | ? | ?\n", 107 | "---|---|---|---|---|---|---\n", 108 | " ? | ? | ? | ? | ? | ? | ?\n", 109 | "---|---|---|---|---|---|---\n", 110 | " ? | ? | ? | ? | ? | ? | ?\n", 111 | "---|---|---|---|---|---|---\n", 112 | " ? | ? | ? | ? | ? | ? | ?\n", 113 | "---|---|---|---|---|---|---\n", 114 | " 1 2 3 4 5 6 7\n", 115 | "\"\"\"\n", 116 | "\n", 117 | "def print_board(board: Board):\n", 118 | " board_str = BOARD_STRING\n", 119 | " for i in reversed(range(board.shape[0])):\n", 120 | " for j in range(board.shape[1]):\n", 121 | " board_str = board_str.replace('?', '[green]X[/green]' if board[i, j] == 1 else '[red]O[/red]' if board[i, j] == -1 else ' ', 1)\n", 122 | " print(board_str)" 123 | ] 124 | }, 125 | { 126 | "attachments": {}, 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "Let's write a function that checks if the game is over.\n", 131 | "We iterate over all horizontal/vertical/diagonal/anti-diagonal lines and check if there are 4 consecutive pieces of the same color." 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 4, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "import jax.numpy as jnp\n", 141 | "\n", 142 | "def horizontals(board: Board) -> chex.Array:\n", 143 | " return jnp.stack([\n", 144 | " board[i, j:j+4]\n", 145 | " for i in range(board.shape[0])\n", 146 | " for j in range(board.shape[1] - 3)\n", 147 | " ])\n", 148 | "\n", 149 | "def verticals(board: Board) -> chex.Array:\n", 150 | " return jnp.stack([\n", 151 | " board[i:i+4, j]\n", 152 | " for i in range(board.shape[0] - 3)\n", 153 | " for j in range(board.shape[1])\n", 154 | " ])\n", 155 | "\n", 156 | "def diagonals(board: Board) -> chex.Array:\n", 157 | " return jnp.stack([\n", 158 | " jnp.diag(board[i:i+4, j:j+4])\n", 159 | " for i in range(board.shape[0] - 3)\n", 160 | " for j in range(board.shape[1] - 3)\n", 161 | " ])\n", 162 | "\n", 163 | "def antidiagonals(board: Board) -> chex.Array:\n", 164 | " return jnp.stack([\n", 165 | " jnp.diag(board[i:i+4, j:j+4][::-1])\n", 166 | " for i in range(board.shape[0] - 3)\n", 167 | " for j in range(board.shape[1] - 3)\n", 168 | " ])\n", 169 | "\n", 170 | "def get_winner(board: Board) -> Player:\n", 171 | " all_lines = jnp.concatenate((\n", 172 | " horizontals(board),\n", 173 | " verticals(board),\n", 174 | " diagonals(board),\n", 175 | " antidiagonals(board),\n", 176 | " ))\n", 177 | " # x_won and o_won are 1 if the player won, 0 otherwise\n", 178 | " x_won = jnp.any(jnp.all(all_lines == 1, axis=1)).astype(jnp.int8)\n", 179 | " o_won = jnp.any(jnp.all(all_lines == -1, axis=1)).astype(jnp.int8)\n", 180 | " # We consider the following cases:\n", 181 | " # - !x_won and !o_won -> 0 - 0 = 0 -> draw OR not finished\n", 182 | " # - x_won and !o_won -> 1 - 0 = 1 -> Player 1 (X) won\n", 183 | " # - !x_won and o_won -> 0 - 1 = -1 -> Player -1 (O) won\n", 184 | " # - x_won and o_won -> impossible, the game would have ended earlier\n", 185 | " return x_won - o_won" 186 | ] 187 | }, 188 | { 189 | "attachments": {}, 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "Finally, we can implement the environment dynamics:\n", 194 | "\n", 195 | "* `env_reset` creates a brand new game\n", 196 | "* `env_step` plays a move and returns the new state, the reward **for the player who played the move** and whether the game has ended" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 5, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "import jax.numpy as jnp\n", 206 | "\n", 207 | "def env_reset(_):\n", 208 | " return Env(\n", 209 | " board=jnp.zeros((6, 7), dtype=jnp.int8),\n", 210 | " player=jnp.int8(1),\n", 211 | " done=jnp.bool_(False),\n", 212 | " reward=jnp.int8(0),\n", 213 | " )" 214 | ] 215 | }, 216 | { 217 | "attachments": {}, 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "You might be wondering why `env_reset` takes an unused argument.\n", 222 | "This allows us to parallelize the environment with the `jax.vmap` function:" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 6, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "data": { 232 | "text/plain": [ 233 | "Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)" 234 | ] 235 | }, 236 | "execution_count": 6, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | } 240 | ], 241 | "source": [ 242 | "import jax\n", 243 | "\n", 244 | "jax.vmap(env_reset)(jnp.arange(10)).player\n", 245 | "# 10 games have been created:" 246 | ] 247 | }, 248 | { 249 | "attachments": {}, 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "You need to perfectly understand the `env_step` function to implement the MCTS algorithm.\n", 254 | "\n", 255 | "In particular, you need to understand that the reward returned by `env_step` is the reward **for the player who played the move**.\n", 256 | "This means that the reward should always be either 0 for a move that does not end the game or causes a draw, or 1 for a move that ends the game with a win for the player who played the move.\n", 257 | "\n", 258 | "A reward of -1 is reserved for punishing illegal moves.\n", 259 | "Since `mctx` does not support masking illegal moves below the root node, we need another way to stop the MCTS algorithm from exploring illegal moves.\n", 260 | "Later, we will also write a custom policy function that will never select illegal moves." 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 7, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "def env_step(env: Env, action: Action) -> tuple[Env, Reward, Done]:\n", 270 | " col = action\n", 271 | "\n", 272 | " # Find the first empty row in the column.\n", 273 | " # If the column is full, this will be the top row.\n", 274 | " row = jnp.argmax(env.board[:, col] == 0)\n", 275 | "\n", 276 | " # If the column is full, the move is invalid.\n", 277 | " invalid_move = env.board[row, col] != 0\n", 278 | "\n", 279 | " # Place the player's piece in the board only if the move is valid and the game is not over.\n", 280 | " board = env.board.at[row, col].set(jnp.where(env.done | invalid_move, env.board[row, col], env.player))\n", 281 | "\n", 282 | " # The reward is computed as follows:\n", 283 | " # * 0 if the game is **already** over. This is to ignore nodes below terminal nodes.\n", 284 | " # * -1 if the move is invalid\n", 285 | " # * 1 if the move won the game for the current player\n", 286 | " # * 0 if the move caused a draw\n", 287 | " # * (impossible for Connect 4) -1 if the move lost the game for the current player\n", 288 | " reward = jnp.where(env.done, 0, jnp.where(invalid_move, -1, get_winner(board) * env.player)).astype(jnp.int8)\n", 289 | "\n", 290 | " # We end the game if:\n", 291 | " # * the game was already over\n", 292 | " # * the move won or lost the game\n", 293 | " # * the move was invalid\n", 294 | " # * the board is full (draw)\n", 295 | " done = env.done | reward != 0 | invalid_move | jnp.all(board[-1] != 0)\n", 296 | "\n", 297 | " env = Env(\n", 298 | " board=board,\n", 299 | " # switch player\n", 300 | " player=jnp.where(done, env.player, -env.player),\n", 301 | " done=done,\n", 302 | " reward=reward,\n", 303 | " )\n", 304 | "\n", 305 | " return env, reward, done" 306 | ] 307 | }, 308 | { 309 | "attachments": {}, 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "## The MCTS algorithm\n", 314 | "\n", 315 | "We can now implement the MCTS algorithm." 316 | ] 317 | }, 318 | { 319 | "attachments": {}, 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "Let's start with some helper functions." 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 8, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "def valid_action_mask(env: Env) -> chex.Array:\n", 333 | " '''\n", 334 | " Computes which actions are valid in the current state.\n", 335 | " Returns an array of booleans, indicating which columns are not full.\n", 336 | " In case the game is over, all columns are considered invalid.\n", 337 | " '''\n", 338 | " return jnp.where(env.done, jnp.array([False] * env.board.shape[1]), env.board[-1] == 0)" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 9, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "def winning_action_mask(env: Env, player: Player) -> chex.Array:\n", 348 | " '''\n", 349 | " Finds all actions that would immediately win the game for the given player.\n", 350 | " '''\n", 351 | " # Override the next player in the environment with the given player.\n", 352 | " env = Env(board=env.board, player=player, done=env.done, reward=env.reward)\n", 353 | "\n", 354 | " # Play all actions and check the reward.\n", 355 | " # Remember that the reward is for the current player, so we expect it to be 1.\n", 356 | " env, reward, done = jax.vmap(env_step, (None, 0))(env, jnp.arange(7, dtype=jnp.int8))\n", 357 | " return reward == 1" 358 | ] 359 | }, 360 | { 361 | "attachments": {}, 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "Next, we define a policy function. The policy function is used in two places:\n", 366 | "\n", 367 | "* during random rollouts to select the next action to play\n", 368 | "* during node expansion as the prior distribution\n", 369 | "\n", 370 | "We could have used a uniform distribution for both, but we can do better.\n", 371 | "Let's use three simple heuristics:\n", 372 | "\n", 373 | "* never select illegal moves\n", 374 | "* always select winning moves if available\n", 375 | "* always block opponent's winning moves if we have no winning move\n", 376 | "\n", 377 | "Our policy function implementation returns unnormalized logits,\n", 378 | "so we cannot return pure 0% and 100% probabilities (using `inf` and `-inf` causes downstream numerical issues).\n", 379 | "Instead, we add `100` to the logits of a given action which we want to prioritize:\n", 380 | "\n", 381 | "* the lowest priority of `0` is assigned to illegal moves\n", 382 | "* a medium priority of `100` is assigned to legal moves\n", 383 | "* a higher priority of `200` is assigned to the opponent's winning moves\n", 384 | "* the highest priority of `300` is assigned to our winning moves" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 10, 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "def policy_function(env: Env) -> chex.Array:\n", 394 | " return sum((\n", 395 | " valid_action_mask(env).astype(jnp.float32) * 100,\n", 396 | " winning_action_mask(env, -env.player).astype(jnp.float32) * 200,\n", 397 | " winning_action_mask(env, env.player).astype(jnp.float32) * 300,\n", 398 | " ))" 399 | ] 400 | }, 401 | { 402 | "attachments": {}, 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "This is how these logits translate to probabilities:" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 11, 412 | "metadata": {}, 413 | "outputs": [ 414 | { 415 | "data": { 416 | "text/plain": [ 417 | "Array([0.14285715, 0.14285715, 0.14285715, 0.14285715, 0.14285715,\n", 418 | " 0.14285715, 0.14285715], dtype=float32)" 419 | ] 420 | }, 421 | "execution_count": 11, 422 | "metadata": {}, 423 | "output_type": "execute_result" 424 | } 425 | ], 426 | "source": [ 427 | "jax.nn.softmax(jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 12, 433 | "metadata": {}, 434 | "outputs": [ 435 | { 436 | "data": { 437 | "text/plain": [ 438 | "Array([1., 0., 0., 0., 0., 0., 0.], dtype=float32)" 439 | ] 440 | }, 441 | "execution_count": 12, 442 | "metadata": {}, 443 | "output_type": "execute_result" 444 | } 445 | ], 446 | "source": [ 447 | "jax.nn.softmax(jnp.array([100.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 13, 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "data": { 457 | "text/plain": [ 458 | "Array([0., 1., 0., 0., 0., 0., 0.], dtype=float32)" 459 | ] 460 | }, 461 | "execution_count": 13, 462 | "metadata": {}, 463 | "output_type": "execute_result" 464 | } 465 | ], 466 | "source": [ 467 | "jax.nn.softmax(jnp.array([100.0, 200.0, 0.0, 0.0, 0.0, 0.0, 0.0]))" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 14, 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "data": { 477 | "text/plain": [ 478 | "Array([0., 0., 1., 0., 0., 0., 0.], dtype=float32)" 479 | ] 480 | }, 481 | "execution_count": 14, 482 | "metadata": {}, 483 | "output_type": "execute_result" 484 | } 485 | ], 486 | "source": [ 487 | "jax.nn.softmax(jnp.array([100.0, 200.0, 300.0, 0.0, 0.0, 0.0, 0.0]))" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 15, 493 | "metadata": {}, 494 | "outputs": [ 495 | { 496 | "data": { 497 | "text/plain": [ 498 | "Array([0. , 0. , 0.5, 0.5, 0. , 0. , 0. ], dtype=float32)" 499 | ] 500 | }, 501 | "execution_count": 15, 502 | "metadata": {}, 503 | "output_type": "execute_result" 504 | } 505 | ], 506 | "source": [ 507 | "jax.nn.softmax(jnp.array([100.0, 200.0, 300.0, 300.0, 0.0, 0.0, 0.0]))" 508 | ] 509 | }, 510 | { 511 | "attachments": {}, 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "You can see that a difference of 100 is enough to make a move almost certain to be selected." 516 | ] 517 | }, 518 | { 519 | "attachments": {}, 520 | "cell_type": "markdown", 521 | "metadata": {}, 522 | "source": [ 523 | "The MCTS algorithm uses random rollouts to estimate the value of each state.\n", 524 | "We can implement this with a simple function that plays moves according to our policy function until the game ends." 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 16, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "def rollout(env: Env, rng_key: chex.PRNGKey) -> Reward:\n", 534 | " '''\n", 535 | " Plays a game until the end and returns the reward from the perspective of the initial player.\n", 536 | " '''\n", 537 | " def cond(a):\n", 538 | " env, key = a\n", 539 | " return ~env.done\n", 540 | " def step(a):\n", 541 | " env, key = a\n", 542 | " key, subkey = jax.random.split(key)\n", 543 | " action = jax.random.categorical(subkey, policy_function(env))\n", 544 | " env, reward, done = env_step(env, action)\n", 545 | " return env, key\n", 546 | " leaf, key = jax.lax.while_loop(cond, step, (env, rng_key))\n", 547 | " # The leaf reward is from the perspective of the last player.\n", 548 | " # We negate it if the last player is not the initial player.\n", 549 | " return leaf.reward * leaf.player * env.player" 550 | ] 551 | }, 552 | { 553 | "attachments": {}, 554 | "cell_type": "markdown", 555 | "metadata": {}, 556 | "source": [ 557 | "The value function is simply the result of the rollout." 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 17, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "def value_function(env: Env, rng_key: chex.PRNGKey) -> chex.Array:\n", 567 | " return rollout(env, rng_key).astype(jnp.float32)" 568 | ] 569 | }, 570 | { 571 | "attachments": {}, 572 | "cell_type": "markdown", 573 | "metadata": {}, 574 | "source": [ 575 | "## Running mctx\n", 576 | "\n", 577 | "Finally, we get to use the `mctx` library.\n", 578 | "We need two function:\n", 579 | "\n", 580 | "* `root_fn` returns the root node of the MCTS tree\n", 581 | "* `recurrent_fn` expands a new node given a parent node and an action" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 18, 587 | "metadata": {}, 588 | "outputs": [], 589 | "source": [ 590 | "import mctx\n", 591 | "\n", 592 | "def root_fn(env: Env, rng_key: chex.PRNGKey) -> mctx.RootFnOutput:\n", 593 | " return mctx.RootFnOutput(\n", 594 | " prior_logits=policy_function(env),\n", 595 | " value=value_function(env, rng_key),\n", 596 | " # We will use the `embedding` field to store the environment.\n", 597 | " embedding=env,\n", 598 | " )" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 19, 604 | "metadata": {}, 605 | "outputs": [], 606 | "source": [ 607 | "def recurrent_fn(params, rng_key, action, embedding):\n", 608 | " # Extract the environment from the embedding.\n", 609 | " env = embedding\n", 610 | "\n", 611 | " # Play the action.\n", 612 | " env, reward, done = env_step(env, action)\n", 613 | "\n", 614 | " # Create the new MCTS node.\n", 615 | " recurrent_fn_output = mctx.RecurrentFnOutput(\n", 616 | " # reward for playing `action`\n", 617 | " reward=reward,\n", 618 | " # discount explained in the next section\n", 619 | " discount=jnp.where(done, 0, -1).astype(jnp.float32),\n", 620 | " # prior for the new state\n", 621 | " prior_logits=policy_function(env),\n", 622 | " # value for the new state\n", 623 | " value=jnp.where(done, 0, value_function(env, rng_key)).astype(jnp.float32),\n", 624 | " )\n", 625 | "\n", 626 | " # Return the new node and the new environment.\n", 627 | " return recurrent_fn_output, env" 628 | ] 629 | }, 630 | { 631 | "attachments": {}, 632 | "cell_type": "markdown", 633 | "metadata": {}, 634 | "source": [ 635 | "The `discount` field is used to flip the sign of the reward and value at each tree level.\n", 636 | "It would be possible to implement this in the environment dynamics, but it is more convenient to do it here.\n", 637 | "By flipping the sign, MCTS selects the best action for the current player of each turn.\n", 638 | "\n", 639 | "Note that we set `discount` to `0` when the game is over.\n", 640 | "This discards all the rewards and values after the end of the game." 641 | ] 642 | }, 643 | { 644 | "attachments": {}, 645 | "cell_type": "markdown", 646 | "metadata": {}, 647 | "source": [ 648 | "And now we can run MCTS!" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 20, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [ 657 | "import functools\n", 658 | "from typing import Optional\n", 659 | "\n", 660 | "@functools.partial(jax.jit, static_argnums=(2,))\n", 661 | "def run_mcts(rng_key: chex.PRNGKey, env: Env, num_simulations: int, tree: Optional[mctx.Tree] = None) -> chex.Array:\n", 662 | " batch_size = 1\n", 663 | " key1, key2 = jax.random.split(rng_key)\n", 664 | " policy_output = mctx.alphazero_policy(\n", 665 | " # params can be used to pass additional data to the recurrent_fn like neural network weights\n", 666 | " params=None,\n", 667 | "\n", 668 | " rng_key=key1,\n", 669 | "\n", 670 | " # create a batch of environments (in this case, a batch of size 1)\n", 671 | " root=jax.vmap(root_fn, (None, 0))(env, jax.random.split(key2, batch_size)),\n", 672 | "\n", 673 | " # automatically vectorize the recurrent_fn\n", 674 | " recurrent_fn=jax.vmap(recurrent_fn, (None, None, 0, 0)),\n", 675 | "\n", 676 | " num_simulations=num_simulations,\n", 677 | "\n", 678 | " # we limit the depth of the search tree to 42, since we know that Connect Four can't last longer\n", 679 | " max_depth=42,\n", 680 | " max_nodes=int(num_simulations * 1.5),\n", 681 | " search_tree=tree,\n", 682 | "\n", 683 | " # our value is in the range [-1, 1], so we can use the min_max qtransform to map it to [0, 1]\n", 684 | " qtransform=functools.partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),\n", 685 | "\n", 686 | " # Dirichlet noise is used for exploration which we don't need in this example (we aren't training)\n", 687 | " dirichlet_fraction=0.0,\n", 688 | " temperature=0.0\n", 689 | " )\n", 690 | " return policy_output" 691 | ] 692 | }, 693 | { 694 | "attachments": {}, 695 | "cell_type": "markdown", 696 | "metadata": {}, 697 | "source": [ 698 | "Let's play the middle column two times and see what MCTS does:" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 21, 704 | "metadata": {}, 705 | "outputs": [ 706 | { 707 | "data": { 708 | "text/plain": [ 709 | "Array([[0.016, 0.097, 0.189, 0.436, 0.177, 0.048, 0.037]], dtype=float32)" 710 | ] 711 | }, 712 | "execution_count": 21, 713 | "metadata": {}, 714 | "output_type": "execute_result" 715 | } 716 | ], 717 | "source": [ 718 | "env = env_reset(0)\n", 719 | "env, reward, done = env_step(env, 3)\n", 720 | "env, reward, done = env_step(env, 3)\n", 721 | "policy_output = run_mcts(jax.random.PRNGKey(0), env, 1000)\n", 722 | "policy_output.action_weights" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 22, 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "data": { 732 | "image/png": "", 733 | "text/plain": [ 734 | "
" 735 | ] 736 | }, 737 | "metadata": {}, 738 | "output_type": "display_data" 739 | } 740 | ], 741 | "source": [ 742 | "import matplotlib.pyplot as plt\n", 743 | "plt.bar(jnp.arange(7), policy_output.action_weights.mean(axis=0))\n", 744 | "plt.show()" 745 | ] 746 | }, 747 | { 748 | "attachments": {}, 749 | "cell_type": "markdown", 750 | "metadata": {}, 751 | "source": [ 752 | "It chose the middle column as the best move, as expected." 753 | ] 754 | }, 755 | { 756 | "attachments": {}, 757 | "cell_type": "markdown", 758 | "metadata": {}, 759 | "source": [ 760 | "Let's write a simple script to play against MCTS:" 761 | ] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "execution_count": null, 766 | "metadata": {}, 767 | "outputs": [], 768 | "source": [ 769 | "# set to False to enable human input\n", 770 | "player_1_ai = True\n", 771 | "player_2_ai = True\n", 772 | "player_1_use_subtree = True\n", 773 | "player_2_use_subtree = False\n", 774 | "simulations_p1 = 2000\n", 775 | "simulations_p2 = 2000\n", 776 | "key = jax.random.PRNGKey(0)\n", 777 | "env = env_reset(0)\n", 778 | "print_board(env.board)\n", 779 | "tree=None\n", 780 | "tree2=None\n", 781 | "while True:\n", 782 | " if player_1_ai:\n", 783 | " output = run_mcts(key, env, simulations_p1, tree)\n", 784 | " output.action\n", 785 | " \n", 786 | " if player_1_use_subtree:\n", 787 | " tree = mctx.get_subtree(output.search_tree, output.action)\n", 788 | " action = output.action.item()\n", 789 | " else:\n", 790 | " action = int(input()) - 1\n", 791 | "\n", 792 | " if player_2_ai and player_2_use_subtree:\n", 793 | " if tree2:\n", 794 | " tree2 = mctx.get_subtree(tree2, jnp.array([action], dtype=jnp.int32).reshape(1, -1))\n", 795 | "\n", 796 | " env, reward, done = env_step(env, action)\n", 797 | " print_board(env.board)\n", 798 | " if done: break\n", 799 | "\n", 800 | " if player_2_ai:\n", 801 | " # you can give it more simulations to make it stronger\n", 802 | " output = run_mcts(key, env, simulations_p2, tree2)\n", 803 | " \n", 804 | " if player_2_use_subtree:\n", 805 | " tree2 = mctx.get_subtree(output.search_tree, output.action)\n", 806 | " action = output.action.item()\n", 807 | " else:\n", 808 | " action = int(input()) - 1\n", 809 | " \n", 810 | " if player_1_ai and player_1_use_subtree:\n", 811 | " if tree:\n", 812 | " tree = mctx.get_subtree(tree, jnp.array([action], dtype=jnp.int32).reshape(1, -1))\n", 813 | "\n", 814 | " env, reward, done = env_step(env, action)\n", 815 | " print_board(env.board)\n", 816 | " if done: break\n", 817 | "\n", 818 | "players = {\n", 819 | " 1: \"[green]X[/green]\",\n", 820 | " -1: \"[red]O[/red]\",\n", 821 | "}\n", 822 | "print(reward)\n", 823 | "print(f\"Winner: {players[env.player.item()]}\")" 824 | ] 825 | }, 826 | { 827 | "attachments": {}, 828 | "cell_type": "markdown", 829 | "metadata": {}, 830 | "source": [ 831 | "At `10_000` simulations, the agent was able to beat all online Connect 4 bots I could find." 832 | ] 833 | } 834 | ], 835 | "metadata": { 836 | "kernelspec": { 837 | "display_name": "mctx-classic", 838 | "language": "python", 839 | "name": "python3" 840 | }, 841 | "language_info": { 842 | "codemirror_mode": { 843 | "name": "ipython", 844 | "version": 3 845 | }, 846 | "file_extension": ".py", 847 | "mimetype": "text/x-python", 848 | "name": "python", 849 | "nbconvert_exporter": "python", 850 | "pygments_lexer": "ipython3", 851 | "version": "3.10.9" 852 | }, 853 | "orig_nbformat": 4 854 | }, 855 | "nbformat": 4, 856 | "nbformat_minor": 2 857 | } 858 | --------------------------------------------------------------------------------