├── mctx
├── py.typed
├── _src
│ ├── __init__.py
│ ├── tests
│ │ ├── mctx_test.py
│ │ ├── qtransforms_test.py
│ │ ├── seq_halving_test.py
│ │ ├── tree_test.py
│ │ └── policies_test.py
│ ├── seq_halving.py
│ ├── base.py
│ ├── qtransforms.py
│ ├── action_selection.py
│ ├── tree.py
│ ├── search.py
│ └── policies.py
└── __init__.py
├── requirements
├── requirements-test.txt
├── requirements_examples.txt
└── requirements.txt
├── MANIFEST.in
├── .gitignore
├── .github
└── workflows
│ ├── ci.yml
│ └── pypi-publish.yml
├── CONTRIBUTING.md
├── test.sh
├── setup.py
├── README.md
├── examples
├── policy_improvement_demo.py
└── visualization_demo.py
├── LICENSE
├── .pylintrc
└── connect4.ipynb
/mctx/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements/requirements-test.txt:
--------------------------------------------------------------------------------
1 | absl-py>=0.9.0
2 | numpy>=1.18.0
3 |
--------------------------------------------------------------------------------
/requirements/requirements_examples.txt:
--------------------------------------------------------------------------------
1 | absl-py>=0.9.0
2 | pygraphviz>=1.7
3 |
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | chex>=0.0.8
2 | jax>=0.4.25
3 | jaxlib>=0.4.25
4 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements/*
4 | include mctx/py.typed
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Building and releasing library:
2 | *.egg-info
3 | *.pyc
4 | *.so
5 | build/
6 | dist/
7 | venv/
8 |
9 | # Mac OS
10 | .DS_Store
11 |
12 | # Python tools
13 | .mypy_cache/
14 | .pytype/
15 | .ipynb_checkpoints
16 |
17 | # Editors
18 | .idea
19 | .vscode
20 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | jobs:
10 | build-and-test:
11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
12 | runs-on: "${{ matrix.os }}"
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.10", "3.11", "3.12"]
17 | os: [ubuntu-latest]
18 |
19 | steps:
20 | - uses: "actions/checkout@v4"
21 | - uses: "actions/setup-python@v4"
22 | with:
23 | python-version: "${{ matrix.python-version }}"
24 | - name: Run CI tests
25 | run: bash test.sh
26 | shell: bash
27 |
--------------------------------------------------------------------------------
/mctx/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: pypi
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Set up Python
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: '3.x'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Check consistency between the package version and release tag
21 | run: |
22 | RELEASE_VER=${GITHUB_REF#refs/*/}
23 | PACKAGE_VER="v`python setup.py --version`"
24 | if [ $RELEASE_VER != $PACKAGE_VER ]
25 | then
26 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1
27 | fi
28 | - name: Build and publish
29 | env:
30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
32 | run: |
33 | python setup.py sdist bdist_wheel
34 | twine upload dist/*
35 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Testing
26 |
27 | Please make sure that your PR passes all tests by running `bash test.sh` on your
28 | local machine. Also, you can run only tests that are affected by your code
29 | changes, but you will need to select them manually.
30 |
31 | ## Community Guidelines
32 |
33 | This project follows [Google's Open Source Community
34 | Guidelines](https://opensource.google.com/conduct/).
35 |
--------------------------------------------------------------------------------
/mctx/_src/tests/mctx_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for Mctx."""
16 |
17 | from absl.testing import absltest
18 | import mctx
19 |
20 |
21 | class MctxTest(absltest.TestCase):
22 | """Test mctx can be imported correctly."""
23 |
24 | def test_import(self):
25 | self.assertTrue(hasattr(mctx, "gumbel_muzero_policy"))
26 | self.assertTrue(hasattr(mctx, "muzero_policy"))
27 | self.assertTrue(hasattr(mctx, "qtransform_by_min_max"))
28 | self.assertTrue(hasattr(mctx, "qtransform_by_parent_and_siblings"))
29 | self.assertTrue(hasattr(mctx, "qtransform_completed_by_mix_value"))
30 | self.assertTrue(hasattr(mctx, "PolicyOutput"))
31 | self.assertTrue(hasattr(mctx, "RootFnOutput"))
32 | self.assertTrue(hasattr(mctx, "RecurrentFnOutput"))
33 | self.assertTrue(hasattr(mctx, "get_subtree"))
34 | self.assertTrue(hasattr(mctx, "reset_search_tree"))
35 |
36 | if __name__ == "__main__":
37 | absltest.main()
38 |
--------------------------------------------------------------------------------
/mctx/_src/tests/qtransforms_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `qtransforms.py`."""
16 | from absl.testing import absltest
17 | import jax
18 | import jax.numpy as jnp
19 | from mctx._src import qtransforms
20 | import numpy as np
21 |
22 |
23 | class QtransformsTest(absltest.TestCase):
24 |
25 | def test_mix_value(self):
26 | """Tests the output of _compute_mixed_value()."""
27 | raw_value = jnp.array(-0.8)
28 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf])
29 | probs = jax.nn.softmax(prior_logits)
30 | visit_counts = jnp.array([0, 4.0, 4.0, 0])
31 | qvalues = 10.0 / 54 * jnp.array([20.0, 3.0, -1.0, 10.0])
32 | mix_value = qtransforms._compute_mixed_value(
33 | raw_value, qvalues, visit_counts, probs)
34 |
35 | num_simulations = jnp.sum(visit_counts)
36 | expected_mix_value = 1.0 / (num_simulations + 1) * (
37 | raw_value + num_simulations *
38 | (probs[1] * qvalues[1] + probs[2] * qvalues[2]))
39 | np.testing.assert_allclose(expected_mix_value, mix_value)
40 |
41 | def test_mix_value_with_zero_visits(self):
42 | """Tests that zero visit counts do not divide by zero."""
43 | raw_value = jnp.array(-0.8)
44 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf])
45 | probs = jax.nn.softmax(prior_logits)
46 | visit_counts = jnp.array([0, 0, 0, 0])
47 | qvalues = jnp.zeros_like(probs)
48 | with jax.debug_nans():
49 | mix_value = qtransforms._compute_mixed_value(
50 | raw_value, qvalues, visit_counts, probs)
51 |
52 | np.testing.assert_allclose(raw_value, mix_value)
53 |
54 |
55 | if __name__ == "__main__":
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Runs CI tests on a local machine.
17 | set -xeuo pipefail
18 |
19 | # Install deps in a virtual env.
20 | readonly VENV_DIR=/tmp/mctx-env
21 | rm -rf "${VENV_DIR}"
22 | python3 -m venv "${VENV_DIR}"
23 | source "${VENV_DIR}/bin/activate"
24 | python --version
25 |
26 | # Install dependencies.
27 | pip install --upgrade pip setuptools wheel
28 | pip install flake8 pytest-xdist pylint pylint-exit
29 | pip install -r requirements/requirements.txt
30 | pip install -r requirements/requirements-test.txt
31 |
32 | # Lint with flake8.
33 | flake8 `find mctx -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
34 |
35 | # Lint with pylint.
36 | # Fail on errors, warning, and conventions.
37 | PYLINT_ARGS="-efail -wfail -cfail"
38 | # Lint modules and tests separately.
39 | pylint --rcfile=.pylintrc `find mctx -name '*.py' | grep -v 'test.py' | xargs` || pylint-exit $PYLINT_ARGS $?
40 | # Disable `protected-access` warnings for tests.
41 | pylint --rcfile=.pylintrc `find mctx -name '*_test.py' | xargs` -d W0212 || pylint-exit $PYLINT_ARGS $?
42 |
43 | # Build the package.
44 | python setup.py sdist
45 | pip wheel --verbose --no-deps --no-clean dist/mctx*.tar.gz
46 | pip install mctx*.whl
47 |
48 | # Check types with pytype.
49 | # Note: pytype does not support 3.12 as of 23.11.23
50 | # See https://github.com/google/pytype/issues/1308
51 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 12 ];
52 | then
53 | pip install pytype
54 | pytype `find mctx/_src/ -name "*py" | xargs` -k
55 | fi;
56 |
57 | # Run tests using pytest.
58 | # Change directory to avoid importing the package from repo root.
59 | mkdir _testing && cd _testing
60 |
61 | # Run tests using pytest.
62 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs mctx
63 | cd ..
64 |
65 | set +u
66 | deactivate
67 | echo "All tests passed. Congrats!"
68 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Install script for setuptools."""
16 |
17 | import os
18 | from setuptools import find_namespace_packages
19 | from setuptools import setup
20 |
21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
22 |
23 |
24 | def _get_version():
25 | with open('mctx/__init__.py') as fp:
26 | for line in fp:
27 | if line.startswith('__version__') and '=' in line:
28 | version = line[line.find('=') + 1:].strip(' \'"\n')
29 | if version:
30 | return version
31 | raise ValueError('`__version__` not defined in `mctx/__init__.py`')
32 |
33 |
34 | def _parse_requirements(path):
35 |
36 | with open(os.path.join(_CURRENT_DIR, path)) as f:
37 | return [
38 | line.rstrip()
39 | for line in f
40 | if not (line.isspace() or line.startswith('#'))
41 | ]
42 |
43 |
44 | setup(
45 | name='mctx',
46 | version=_get_version(),
47 | url='https://github.com/google-deepmind/mctx',
48 | license='Apache 2.0',
49 | author='DeepMind',
50 | description=('Monte Carlo tree search in JAX.'),
51 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(),
52 | long_description_content_type='text/markdown',
53 | author_email='mctx-dev@google.com',
54 | keywords='jax planning reinforcement-learning python machine learning',
55 | packages=find_namespace_packages(exclude=['*_test.py']),
56 | install_requires=_parse_requirements(
57 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')),
58 | tests_require=_parse_requirements(
59 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-test.txt')),
60 | zip_safe=False, # Required for full installation.
61 | python_requires='>=3.9',
62 | classifiers=[
63 | 'Development Status :: 4 - Beta',
64 | 'Intended Audience :: Science/Research',
65 | 'License :: OSI Approved :: Apache Software License',
66 | 'Programming Language :: Python',
67 | 'Programming Language :: Python :: 3',
68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
69 | 'Topic :: Software Development :: Libraries :: Python Modules',
70 | ],
71 | )
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # mctx-az
2 | This fork of [google-deepmind/mctx](https://github.com/google-deepmind/mctx) introduces a new feature used in AlphaZero: continuing search from
3 | a subtree of a previous state's search output, or _subtree persistence_.
4 |
5 | This allows Monte Carlo Tree Search to continue from an already-initialized, partially populated search tree. This lets work done in a previous
6 | call to Monte Carlo Tree Search persist to the next call, avoiding lots of repeated work!
7 |
8 | ## Quickstart
9 | mctx-az introduces a new policy: `alphazero_policy` which allows the user to pass a pre-initialized `Tree` to continue the search with.
10 |
11 | Then, `get_subtree` can be used to extract the subtree rooted at a particular child node of the root, corresponding to a taken action.
12 |
13 | In cases where the search tree should not be saved, such as an episdoe terminating, `reset_search_tree` can be used to clear the tree.
14 |
15 | In order to initialize a new tree, pass `tree=None`, to `alphazero_policy`, along with `max_nodes` to specify the capacity of the tree, which in most cases
16 | should be >= `num_simulations`.
17 |
18 | `alphazero_policy` otherwise functions exactly the same as `muzero_policy`.
19 |
20 | #### Initializing a new Tree:
21 | ```python
22 | policy_output = mctx.alphazero_policy(params, rng_key, root, recurrent_fn,
23 | num_simulations=32, tree=None, max_nodes=48)
24 | tree = policy_output.search_tree
25 | ```
26 | #### Extracting the subtree and continuing search
27 | ```python
28 | # get chosen action from policy output
29 | action = policy_output.action
30 |
31 | # extract the subtree corresponding to the chosen action
32 | tree = mctx.get_subtree(tree, action)
33 |
34 | # go to next environment state
35 | env_state = env.step(env_state, action)
36 |
37 | # reset the search tree where the environment has terminated
38 | tree = mctx.reset_search_tree(tree, env_state.terminated)
39 |
40 | # new search with subtree
41 | # (max_nodes has no effect when a tree is passed)
42 | policy_ouput = mctx.alphazero_policy(params, rng_key, root, recurrent_fn,
43 | num_simulations=32, tree=tree)
44 | ```
45 | #### Note on out-of-bounds expansions:
46 | A call to any mctx policy will expand `num_simulations` nodes (assuming `max_depth` is not breached).
47 |
48 | Given that `alphazero_policy` accepts a pre-populated `Tree`, it is possible that there will not be enough
49 | room left for `num_simulations` new nodes.
50 |
51 | In the case where a tree is full, values and visit counts are still propagated backwards to all nodes along the visit path
52 | as they would if the expansion was in bounds. However, a new node is not created and stored in the search tree, only its
53 | in-bounds predecessors are updated.
54 |
55 | ## Examples
56 | The mctx readme links to a simple Connect4 example: https://github.com/Carbon225/mctx-classic
57 |
58 | I modified this example to demonstrate the use of `alphazero_policy` and `get_subtree`. You can see it [here](https://github.com/lowrollr/mctx-az/blob/main/connect4.ipynb)
59 |
60 | ## Issues
61 | If you run into problems or need help, please create an Issue and I will do my best to assist you promptly.
62 |
--------------------------------------------------------------------------------
/mctx/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Mctx: Monte Carlo tree search in JAX."""
16 |
17 | from mctx._src.action_selection import gumbel_muzero_interior_action_selection
18 | from mctx._src.action_selection import gumbel_muzero_root_action_selection
19 | from mctx._src.action_selection import GumbelMuZeroExtraData
20 | from mctx._src.action_selection import muzero_action_selection
21 | from mctx._src.base import ChanceRecurrentFnOutput
22 | from mctx._src.base import DecisionRecurrentFnOutput
23 | from mctx._src.base import InteriorActionSelectionFn
24 | from mctx._src.base import LoopFn
25 | from mctx._src.base import PolicyOutput
26 | from mctx._src.base import RecurrentFn
27 | from mctx._src.base import RecurrentFnOutput
28 | from mctx._src.base import RecurrentState
29 | from mctx._src.base import RootActionSelectionFn
30 | from mctx._src.base import RootFnOutput
31 | from mctx._src.policies import alphazero_policy
32 | from mctx._src.policies import gumbel_muzero_policy
33 | from mctx._src.policies import muzero_policy
34 | from mctx._src.policies import stochastic_muzero_policy
35 | from mctx._src.qtransforms import qtransform_by_min_max
36 | from mctx._src.qtransforms import qtransform_by_parent_and_siblings
37 | from mctx._src.qtransforms import qtransform_completed_by_mix_value
38 | from mctx._src.search import search
39 | from mctx._src.tree import get_subtree
40 | from mctx._src.tree import reset_search_tree
41 | from mctx._src.tree import Tree
42 |
43 | __version__ = "0.0.5"
44 |
45 | __all__ = (
46 | "ChanceRecurrentFnOutput",
47 | "DecisionRecurrentFnOutput",
48 | "GumbelMuZeroExtraData",
49 | "InteriorActionSelectionFn",
50 | "LoopFn",
51 | "PolicyOutput",
52 | "RecurrentFn",
53 | "RecurrentFnOutput",
54 | "RecurrentState",
55 | "RootActionSelectionFn",
56 | "RootFnOutput",
57 | "Tree",
58 | "alphazero_policy",
59 | "get_subtree",
60 | "gumbel_muzero_interior_action_selection",
61 | "gumbel_muzero_policy",
62 | "gumbel_muzero_root_action_selection",
63 | "muzero_action_selection",
64 | "muzero_policy",
65 | "qtransform_by_min_max",
66 | "qtransform_by_parent_and_siblings",
67 | "qtransform_completed_by_mix_value",
68 | "reset_search_tree",
69 | "search",
70 | "stochastic_muzero_policy",
71 | )
72 |
73 | # _________________________________________
74 | # / Please don't use symbols in `_src` they \
75 | # \ are not part of the Mctx public API. /
76 | # -----------------------------------------
77 | # \ ^__^
78 | # \ (oo)\_______
79 | # (__)\ )\/\
80 | # ||----w |
81 | # || ||
82 | #
83 |
--------------------------------------------------------------------------------
/mctx/_src/seq_halving.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Functions for Sequential Halving."""
16 |
17 | import math
18 |
19 | import chex
20 | import jax.numpy as jnp
21 |
22 |
23 | def score_considered(considered_visit, gumbel, logits, normalized_qvalues,
24 | visit_counts):
25 | """Returns a score usable for an argmax."""
26 | # We allow to visit a child, if it is the only considered child.
27 | low_logit = -1e9
28 | logits = logits - jnp.max(logits, keepdims=True, axis=-1)
29 | penalty = jnp.where(
30 | visit_counts == considered_visit,
31 | 0, -jnp.inf)
32 | chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty])
33 | return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty
34 |
35 |
36 | def get_sequence_of_considered_visits(max_num_considered_actions,
37 | num_simulations):
38 | """Returns a sequence of visit counts considered by Sequential Halving.
39 |
40 | Sequential Halving is a "pure exploration" algorithm for bandits, introduced
41 | in "Almost Optimal Exploration in Multi-Armed Bandits":
42 | http://proceedings.mlr.press/v28/karnin13.pdf
43 |
44 | The visit counts allows to implement Sequential Halving by selecting the best
45 | action from the actions with the currently considered visit count.
46 |
47 | Args:
48 | max_num_considered_actions: The maximum number of considered actions.
49 | The `max_num_considered_actions` can be smaller than the number of
50 | actions.
51 | num_simulations: The total simulation budget.
52 |
53 | Returns:
54 | A tuple with visit counts. Length `num_simulations`.
55 | """
56 | if max_num_considered_actions <= 1:
57 | return tuple(range(num_simulations))
58 | log2max = int(math.ceil(math.log2(max_num_considered_actions)))
59 | sequence = []
60 | visits = [0] * max_num_considered_actions
61 | num_considered = max_num_considered_actions
62 | while len(sequence) < num_simulations:
63 | num_extra_visits = max(1, int(num_simulations / (log2max * num_considered)))
64 | for _ in range(num_extra_visits):
65 | sequence.extend(visits[:num_considered])
66 | for i in range(num_considered):
67 | visits[i] += 1
68 | # Halving the number of considered actions.
69 | num_considered = max(2, num_considered // 2)
70 | return tuple(sequence[:num_simulations])
71 |
72 |
73 | def get_table_of_considered_visits(max_num_considered_actions, num_simulations):
74 | """Returns a table of sequences of visit counts.
75 |
76 | Args:
77 | max_num_considered_actions: The maximum number of considered actions.
78 | The `max_num_considered_actions` can be smaller than the number of
79 | actions.
80 | num_simulations: The total simulation budget.
81 |
82 | Returns:
83 | A tuple of sequences of visit counts.
84 | Shape [max_num_considered_actions + 1, num_simulations].
85 | """
86 | return tuple(
87 | get_sequence_of_considered_visits(m, num_simulations)
88 | for m in range(max_num_considered_actions + 1))
89 |
90 |
--------------------------------------------------------------------------------
/mctx/_src/tests/seq_halving_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `seq_halving.py`."""
16 | from absl.testing import absltest
17 | from mctx._src import seq_halving
18 |
19 |
20 | class SeqHalvingTest(absltest.TestCase):
21 |
22 | def _check_visits(self, expected_results, max_num_considered_actions,
23 | num_simulations):
24 | """Compares the expected results to the returned considered visits."""
25 | self.assertLen(expected_results, num_simulations)
26 | results = seq_halving.get_sequence_of_considered_visits(
27 | max_num_considered_actions, num_simulations)
28 | self.assertEqual(tuple(expected_results), results)
29 |
30 | def test_considered_min_sims(self):
31 | # Using exactly `num_simulations = max_num_considered_actions *
32 | # log2(max_num_considered_actions)`.
33 | num_sims = 24
34 | max_num_considered = 8
35 | expected_results = [
36 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions.
37 | 1, 1, 1, 1, # Considering 4 actions.
38 | 2, 2, 2, 2, # Considering 4 actions, round 2.
39 | 3, 3, 4, 4, 5, 5, 6, 6, # Considering 2 actions.
40 | ] # pyformat: disable
41 | self._check_visits(expected_results, max_num_considered, num_sims)
42 |
43 | def test_considered_extra_sims(self):
44 | # Using more simulations than `max_num_considered_actions *
45 | # log2(max_num_considered_actions)`.
46 | num_sims = 47
47 | max_num_considered = 8
48 | expected_results = [
49 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions.
50 | 1, 1, 1, 1, # Considering 4 actions.
51 | 2, 2, 2, 2, # Considering 4 actions, round 2.
52 | 3, 3, 3, 3, # Considering 4 actions, round 3.
53 | 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10,
54 | 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,
55 | ] # pyformat: disable
56 | self._check_visits(expected_results, max_num_considered, num_sims)
57 |
58 | def test_considered_less_sims(self):
59 | # Using a very small number of simulations.
60 | num_sims = 2
61 | max_num_considered = 8
62 | expected_results = [0, 0]
63 | self._check_visits(expected_results, max_num_considered, num_sims)
64 |
65 | def test_considered_less_sims2(self):
66 | # Using `num_simulations < max_num_considered_actions *
67 | # log2(max_num_considered_actions)`.
68 | num_sims = 13
69 | max_num_considered = 8
70 | expected_results = [
71 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions.
72 | 1, 1, 1, 1, # Considering 4 actions.
73 | 2,
74 | ] # pyformat: disable
75 | self._check_visits(expected_results, max_num_considered, num_sims)
76 |
77 | def test_considered_not_power_of_2(self):
78 | # Using max_num_considered_actions that is not a power of 2.
79 | num_sims = 24
80 | max_num_considered = 7
81 | expected_results = [
82 | 0, 0, 0, 0, 0, 0, 0, # Considering 7 actions.
83 | 1, 1, 1, 2, 2, 2, # Considering 3 actions.
84 | 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8,
85 | ] # pyformat: disable
86 | self._check_visits(expected_results, max_num_considered, num_sims)
87 |
88 | def test_considered_action0(self):
89 | num_sims = 16
90 | max_num_considered = 0
91 | expected_results = range(num_sims)
92 | self._check_visits(expected_results, max_num_considered, num_sims)
93 |
94 | def test_considered_action1(self):
95 | num_sims = 16
96 | max_num_considered = 1
97 | expected_results = range(num_sims)
98 | self._check_visits(expected_results, max_num_considered, num_sims)
99 |
100 |
101 | if __name__ == "__main__":
102 | absltest.main()
103 |
--------------------------------------------------------------------------------
/examples/policy_improvement_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A demonstration of the policy improvement by planning with Gumbel."""
16 |
17 | import functools
18 | from typing import Tuple
19 |
20 | from absl import app
21 | from absl import flags
22 | import chex
23 | import jax
24 | import jax.numpy as jnp
25 | import mctx
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_integer("seed", 42, "Random seed.")
29 | flags.DEFINE_integer("batch_size", 256, "Batch size.")
30 | flags.DEFINE_integer("num_actions", 82, "Number of actions.")
31 | flags.DEFINE_integer("num_simulations", 4, "Number of simulations.")
32 | flags.DEFINE_integer("max_num_considered_actions", 16,
33 | "The maximum number of actions expanded at the root.")
34 | flags.DEFINE_integer("num_runs", 1, "Number of runs on random data.")
35 |
36 |
37 | @chex.dataclass(frozen=True)
38 | class DemoOutput:
39 | prior_policy_value: chex.Array
40 | prior_policy_action_value: chex.Array
41 | selected_action_value: chex.Array
42 | action_weights_policy_value: chex.Array
43 |
44 |
45 | def _run_demo(rng_key: chex.PRNGKey) -> Tuple[chex.PRNGKey, DemoOutput]:
46 | """Runs a search algorithm on random data."""
47 | batch_size = FLAGS.batch_size
48 | rng_key, logits_rng, q_rng, search_rng = jax.random.split(rng_key, 4)
49 | # We will demonstrate the algorithm on random prior_logits.
50 | # Normally, the prior_logits would be produced by a policy network.
51 | prior_logits = jax.random.normal(
52 | logits_rng, shape=[batch_size, FLAGS.num_actions])
53 | # Defining a bandit with random Q-values. Only the Q-values of the visited
54 | # actions will be revealed to the search algorithm.
55 | qvalues = jax.random.uniform(q_rng, shape=prior_logits.shape)
56 | # If we know the value under the prior policy, we can use the value to
57 | # complete the missing Q-values. The completed Q-values will produce an
58 | # improved policy in `policy_output.action_weights`.
59 | raw_value = jnp.sum(jax.nn.softmax(prior_logits) * qvalues, axis=-1)
60 | use_mixed_value = False
61 |
62 | # The root output would be the output of MuZero representation network.
63 | root = mctx.RootFnOutput(
64 | prior_logits=prior_logits,
65 | value=raw_value,
66 | # The embedding is used only to implement the MuZero model.
67 | embedding=jnp.zeros([batch_size]),
68 | )
69 | # The recurrent_fn would be provided by MuZero dynamics network.
70 | recurrent_fn = _make_bandit_recurrent_fn(qvalues)
71 |
72 | # Running the search.
73 | policy_output = mctx.gumbel_muzero_policy(
74 | params=(),
75 | rng_key=search_rng,
76 | root=root,
77 | recurrent_fn=recurrent_fn,
78 | num_simulations=FLAGS.num_simulations,
79 | max_num_considered_actions=FLAGS.max_num_considered_actions,
80 | qtransform=functools.partial(
81 | mctx.qtransform_completed_by_mix_value,
82 | use_mixed_value=use_mixed_value),
83 | )
84 |
85 | # Collecting the Q-value of the selected action.
86 | selected_action_value = qvalues[jnp.arange(batch_size), policy_output.action]
87 |
88 | # We will compare the selected action to the action selected by the
89 | # prior policy, while using the same Gumbel random numbers.
90 | gumbel = policy_output.search_tree.extra_data.root_gumbel
91 | prior_policy_action = jnp.argmax(gumbel + prior_logits, axis=-1)
92 | prior_policy_action_value = qvalues[jnp.arange(batch_size),
93 | prior_policy_action]
94 |
95 | # Computing the policy value under the new action_weights.
96 | action_weights_policy_value = jnp.sum(
97 | policy_output.action_weights * qvalues, axis=-1)
98 |
99 | output = DemoOutput(
100 | prior_policy_value=raw_value,
101 | prior_policy_action_value=prior_policy_action_value,
102 | selected_action_value=selected_action_value,
103 | action_weights_policy_value=action_weights_policy_value,
104 | )
105 | return rng_key, output
106 |
107 |
108 | def _make_bandit_recurrent_fn(qvalues):
109 | """Returns a recurrent_fn for a determistic bandit."""
110 |
111 | def recurrent_fn(params, rng_key, action, embedding):
112 | del params, rng_key
113 | # For the bandit, the reward will be non-zero only at the root.
114 | reward = jnp.where(embedding == 0,
115 | qvalues[jnp.arange(action.shape[0]), action],
116 | 0.0)
117 | # On a single-player environment, use discount from [0, 1].
118 | # On a zero-sum self-play environment, use discount=-1.
119 | discount = jnp.ones_like(reward)
120 | recurrent_fn_output = mctx.RecurrentFnOutput(
121 | reward=reward,
122 | discount=discount,
123 | prior_logits=jnp.zeros_like(qvalues),
124 | value=jnp.zeros_like(reward))
125 | next_embedding = embedding + 1
126 | return recurrent_fn_output, next_embedding
127 |
128 | return recurrent_fn
129 |
130 |
131 | def main(_):
132 | rng_key = jax.random.PRNGKey(FLAGS.seed)
133 | jitted_run_demo = jax.jit(_run_demo)
134 | for _ in range(FLAGS.num_runs):
135 | rng_key, output = jitted_run_demo(rng_key)
136 | # Printing the obtained increase of the policy value.
137 | # The obtained increase should be non-negative.
138 | action_value_improvement = (
139 | output.selected_action_value - output.prior_policy_action_value)
140 | weights_value_improvement = (
141 | output.action_weights_policy_value - output.prior_policy_value)
142 | print("action value improvement: %.3f (min=%.3f)" %
143 | (action_value_improvement.mean(), action_value_improvement.min()))
144 | print("action_weights value improvement: %.3f (min=%.3f)" %
145 | (weights_value_improvement.mean(), weights_value_improvement.min()))
146 |
147 |
148 | if __name__ == "__main__":
149 | app.run(main)
150 |
--------------------------------------------------------------------------------
/mctx/_src/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Core types used in mctx."""
16 |
17 | from typing import Any, Callable, Generic, TypeVar, Tuple
18 |
19 | import chex
20 |
21 | from mctx._src import tree
22 |
23 |
24 | # Parameters are arbitrary nested structures of `chex.Array`.
25 | # A nested structure is either a single object, or a collection (list, tuple,
26 | # dictionary, etc.) of other nested structures.
27 | Params = chex.ArrayTree
28 |
29 |
30 | # The model used to search is expressed by a `RecurrentFn` function that takes
31 | # `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` and
32 | # the new state embedding.
33 | @chex.dataclass(frozen=True)
34 | class RecurrentFnOutput:
35 | """The output of a `RecurrentFn`.
36 |
37 | reward: `[B]` an approximate reward from the state-action transition.
38 | discount: `[B]` the discount between the `reward` and the `value`.
39 | prior_logits: `[B, num_actions]` the logits produced by a policy network.
40 | value: `[B]` an approximate value of the state after the state-action
41 | transition.
42 | """
43 | reward: chex.Array
44 | discount: chex.Array
45 | prior_logits: chex.Array
46 | value: chex.Array
47 |
48 |
49 | Action = chex.Array
50 | RecurrentState = Any
51 | RecurrentFn = Callable[
52 | [Params, chex.PRNGKey, Action, RecurrentState],
53 | Tuple[RecurrentFnOutput, RecurrentState]]
54 |
55 |
56 | @chex.dataclass(frozen=True)
57 | class RootFnOutput:
58 | """The output of a representation network.
59 |
60 | prior_logits: `[B, num_actions]` the logits produced by a policy network.
61 | value: `[B]` an approximate value of the current state.
62 | embedding: `[B, ...]` the inputs to the next `recurrent_fn` call.
63 | """
64 | prior_logits: chex.Array
65 | value: chex.Array
66 | embedding: RecurrentState
67 |
68 |
69 | # Action selection functions specify how to pick nodes to expand in the tree.
70 | NodeIndices = chex.Array
71 | Depth = chex.Array
72 | RootActionSelectionFn = Callable[
73 | [chex.PRNGKey, tree.Tree, NodeIndices], chex.Array]
74 | InteriorActionSelectionFn = Callable[
75 | [chex.PRNGKey, tree.Tree, NodeIndices, Depth],
76 | chex.Array]
77 | QTransform = Callable[[tree.Tree, chex.Array], chex.Array]
78 | # LoopFn has the same interface as jax.lax.fori_loop.
79 | LoopFn = Callable[
80 | [int, int, Callable[[Any, Any], Any], Tuple[chex.PRNGKey, tree.Tree]],
81 | Tuple[chex.PRNGKey, tree.Tree]]
82 |
83 | T = TypeVar("T")
84 |
85 |
86 | @chex.dataclass(frozen=True)
87 | class PolicyOutput(Generic[T]):
88 | """The output of a policy.
89 |
90 | action: `[B]` the proposed action.
91 | action_weights: `[B, num_actions]` the targets used to train a policy network.
92 | The action weights sum to one. Usually, the policy network is trained by
93 | cross-entropy:
94 | `cross_entropy(labels=stop_gradient(action_weights), logits=prior_logits)`.
95 | search_tree: `[B, ...]` the search tree of the finished search.
96 | """
97 | action: chex.Array
98 | action_weights: chex.Array
99 | search_tree: tree.Tree[T]
100 |
101 |
102 | @chex.dataclass(frozen=True)
103 | class DecisionRecurrentFnOutput:
104 | """Output of the function for expanding decision nodes.
105 |
106 | Expanding a decision node takes an action and a state embedding and produces
107 | an afterstate, which represents the state of the environment after an action
108 | is taken but before the environment has updated its state. Accordingly, there
109 | is no discount factor or reward for transitioning from state `s` to afterstate
110 | `sa`.
111 |
112 | Attributes:
113 | chance_logits: `[B, C]` logits of `C` chance outcomes at the afterstate.
114 | afterstate_value: `[B]` values of the afterstates `v(sa)`.
115 | """
116 | chance_logits: chex.Array # [B, C]
117 | afterstate_value: chex.Array # [B]
118 |
119 |
120 | @chex.dataclass(frozen=True)
121 | class ChanceRecurrentFnOutput:
122 | """Output of the function for expanding chance nodes.
123 |
124 | Expanding a chance node takes a chance outcome and an afterstate embedding
125 | and produces a state, which captures a potentially stochastic environment
126 | transition. When this transition occurs reward and discounts are produced as
127 | in a normal transition.
128 |
129 | Attributes:
130 | action_logits: `[B, A]` logits of different actions from the state.
131 | value: `[B]` values of the states `v(s)`.
132 | reward: `[B]` rewards at the states.
133 | discount: `[B]` discounts at the states.
134 | """
135 | action_logits: chex.Array # [B, A]
136 | value: chex.Array # [B]
137 | reward: chex.Array # [B]
138 | discount: chex.Array # [B]
139 |
140 |
141 | @chex.dataclass(frozen=True)
142 | class StochasticRecurrentState:
143 | """Wrapper that enables different treatment of decision and chance nodes.
144 |
145 | In Stochastic MuZero tree nodes can either be decision or chance nodes, these
146 | nodes are treated differently during expansion, search and backup, and a user
147 | could also pass differently structured embeddings for each type of node. This
148 | wrapper enables treating chance and decision nodes differently and supports
149 | potential differences between chance and decision node structures.
150 |
151 | Attributes:
152 | state_embedding: `[B ...]` an optionally meaningful state embedding.
153 | afterstate_embedding: `[B ...]` an optionally meaningful afterstate
154 | embedding.
155 | is_decision_node: `[B]` whether the node is a decision or chance node. If it
156 | is a decision node, `afterstate_embedding` is a dummy value. If it is a
157 | chance node, `state_embedding` is a dummy value.
158 | """
159 | state_embedding: chex.ArrayTree # [B, ...]
160 | afterstate_embedding: chex.ArrayTree # [B, ...]
161 | is_decision_node: chex.Array # [B]
162 |
163 |
164 | RecurrentState = chex.ArrayTree
165 |
166 | DecisionRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState],
167 | Tuple[DecisionRecurrentFnOutput, RecurrentState]]
168 |
169 | ChanceRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState],
170 | Tuple[ChanceRecurrentFnOutput, RecurrentState]]
171 |
--------------------------------------------------------------------------------
/mctx/_src/qtransforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Monotonic transformations for the Q-values."""
16 |
17 | import chex
18 | import jax
19 | import jax.numpy as jnp
20 |
21 | from mctx._src import tree as tree_lib
22 |
23 |
24 | def qtransform_by_min_max(
25 | tree: tree_lib.Tree,
26 | node_index: chex.Numeric,
27 | *,
28 | min_value: chex.Numeric,
29 | max_value: chex.Numeric,
30 | ) -> chex.Array:
31 | """Returns Q-values normalized by the given `min_value` and `max_value`.
32 |
33 | Args:
34 | tree: _unbatched_ MCTS tree state.
35 | node_index: scalar index of the parent node.
36 | min_value: given minimum value. Usually the `min_value` is minimum possible
37 | untransformed Q-value.
38 | max_value: given maximum value. Usually the `max_value` is maximum possible
39 | untransformed Q-value.
40 |
41 | Returns:
42 | Q-values normalized by `(qvalues - min_value) / (max_value - min_value)`.
43 | The unvisited actions will have zero Q-value. Shape `[num_actions]`.
44 | """
45 | chex.assert_shape(node_index, ())
46 | qvalues = tree.qvalues(node_index)
47 | visit_counts = tree.children_visits[node_index]
48 | value_score = jnp.where(visit_counts > 0, qvalues, min_value)
49 | value_score = (value_score - min_value) / ((max_value - min_value))
50 | return value_score
51 |
52 |
53 | def qtransform_by_parent_and_siblings(
54 | tree: tree_lib.Tree,
55 | node_index: chex.Numeric,
56 | *,
57 | epsilon: chex.Numeric = 1e-8,
58 | ) -> chex.Array:
59 | """Returns qvalues normalized by min, max over V(node) and qvalues.
60 |
61 | Args:
62 | tree: _unbatched_ MCTS tree state.
63 | node_index: scalar index of the parent node.
64 | epsilon: the minimum denominator for the normalization.
65 |
66 | Returns:
67 | Q-values normalized to be from the [0, 1] interval. The unvisited actions
68 | will have zero Q-value. Shape `[num_actions]`.
69 | """
70 | chex.assert_shape(node_index, ())
71 | qvalues = tree.qvalues(node_index)
72 | visit_counts = tree.children_visits[node_index]
73 | chex.assert_rank([qvalues, visit_counts, node_index], [1, 1, 0])
74 | node_value = tree.node_values[node_index]
75 | safe_qvalues = jnp.where(visit_counts > 0, qvalues, node_value)
76 | chex.assert_equal_shape([safe_qvalues, qvalues])
77 | min_value = jnp.minimum(node_value, jnp.min(safe_qvalues, axis=-1))
78 | max_value = jnp.maximum(node_value, jnp.max(safe_qvalues, axis=-1))
79 |
80 | completed_by_min = jnp.where(visit_counts > 0, qvalues, min_value)
81 | normalized = (completed_by_min - min_value) / (
82 | jnp.maximum(max_value - min_value, epsilon))
83 | chex.assert_equal_shape([normalized, qvalues])
84 | return normalized
85 |
86 |
87 | def qtransform_completed_by_mix_value(
88 | tree: tree_lib.Tree,
89 | node_index: chex.Numeric,
90 | *,
91 | value_scale: chex.Numeric = 0.1,
92 | maxvisit_init: chex.Numeric = 50.0,
93 | rescale_values: bool = True,
94 | use_mixed_value: bool = True,
95 | epsilon: chex.Numeric = 1e-8,
96 | ) -> chex.Array:
97 | """Returns completed qvalues.
98 |
99 | The missing Q-values of the unvisited actions are replaced by the
100 | mixed value, defined in Appendix D of
101 | "Policy improvement by planning with Gumbel":
102 | https://openreview.net/forum?id=bERaNdoegnO
103 |
104 | The Q-values are transformed by a linear transformation:
105 | `(maxvisit_init + max(visit_counts)) * value_scale * qvalues`.
106 |
107 | Args:
108 | tree: _unbatched_ MCTS tree state.
109 | node_index: scalar index of the parent node.
110 | value_scale: scale for the Q-values.
111 | maxvisit_init: offset to the `max(visit_counts)` in the scaling factor.
112 | rescale_values: if True, scale the qvalues by `1 / (max_q - min_q)`.
113 | use_mixed_value: if True, complete the Q-values with mixed value,
114 | otherwise complete the Q-values with the raw value.
115 | epsilon: the minimum denominator when using `rescale_values`.
116 |
117 | Returns:
118 | Completed Q-values. Shape `[num_actions]`.
119 | """
120 | chex.assert_shape(node_index, ())
121 | qvalues = tree.qvalues(node_index)
122 | visit_counts = tree.children_visits[node_index]
123 |
124 | # Computing the mixed value and producing completed_qvalues.
125 | raw_value = tree.raw_values[node_index]
126 | prior_probs = jax.nn.softmax(
127 | tree.children_prior_logits[node_index])
128 | if use_mixed_value:
129 | value = _compute_mixed_value(
130 | raw_value,
131 | qvalues=qvalues,
132 | visit_counts=visit_counts,
133 | prior_probs=prior_probs)
134 | else:
135 | value = raw_value
136 | completed_qvalues = _complete_qvalues(
137 | qvalues, visit_counts=visit_counts, value=value)
138 |
139 | # Scaling the Q-values.
140 | if rescale_values:
141 | completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon)
142 | maxvisit = jnp.max(visit_counts, axis=-1)
143 | visit_scale = maxvisit_init + maxvisit
144 | return visit_scale * value_scale * completed_qvalues
145 |
146 |
147 | def _rescale_qvalues(qvalues, epsilon):
148 | """Rescales the given completed Q-values to be from the [0, 1] interval."""
149 | min_value = jnp.min(qvalues, axis=-1, keepdims=True)
150 | max_value = jnp.max(qvalues, axis=-1, keepdims=True)
151 | return (qvalues - min_value) / jnp.maximum(max_value - min_value, epsilon)
152 |
153 |
154 | def _complete_qvalues(qvalues, *, visit_counts, value):
155 | """Returns completed Q-values, with the `value` for unvisited actions."""
156 | chex.assert_equal_shape([qvalues, visit_counts])
157 | chex.assert_shape(value, [])
158 |
159 | # The missing qvalues are replaced by the value.
160 | completed_qvalues = jnp.where(
161 | visit_counts > 0,
162 | qvalues,
163 | value)
164 | chex.assert_equal_shape([completed_qvalues, qvalues])
165 | return completed_qvalues
166 |
167 |
168 | def _compute_mixed_value(raw_value, qvalues, visit_counts, prior_probs):
169 | """Interpolates the raw_value and weighted qvalues.
170 |
171 | Args:
172 | raw_value: an approximate value of the state. Shape `[]`.
173 | qvalues: Q-values for all actions. Shape `[num_actions]`. The unvisited
174 | actions have undefined Q-value.
175 | visit_counts: the visit counts for all actions. Shape `[num_actions]`.
176 | prior_probs: the action probabilities, produced by the policy network for
177 | each action. Shape `[num_actions]`.
178 |
179 | Returns:
180 | An estimator of the state value. Shape `[]`.
181 | """
182 | sum_visit_counts = jnp.sum(visit_counts, axis=-1)
183 | # Ensuring non-nan weighted_q, even if the visited actions have zero
184 | # prior probability.
185 | prior_probs = jnp.maximum(jnp.finfo(prior_probs.dtype).tiny, prior_probs)
186 | # Summing the probabilities of the visited actions.
187 | sum_probs = jnp.sum(jnp.where(visit_counts > 0, prior_probs, 0.0),
188 | axis=-1)
189 | weighted_q = jnp.sum(jnp.where(
190 | visit_counts > 0,
191 | prior_probs * qvalues / jnp.where(visit_counts > 0, sum_probs, 1.0),
192 | 0.0), axis=-1)
193 | return (raw_value + sum_visit_counts * weighted_q) / (sum_visit_counts + 1)
194 |
--------------------------------------------------------------------------------
/examples/visualization_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A demo of Graphviz visualization of a search tree."""
16 |
17 | from typing import Optional, Sequence
18 |
19 | from absl import app
20 | from absl import flags
21 | import chex
22 | import jax
23 | import jax.numpy as jnp
24 | import mctx
25 | import pygraphviz
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_integer("seed", 42, "Random seed.")
29 | flags.DEFINE_integer("num_simulations", 32, "Number of simulations.")
30 | flags.DEFINE_integer("max_num_considered_actions", 16,
31 | "The maximum number of actions expanded at the root.")
32 | flags.DEFINE_integer("max_depth", None, "The maximum search depth.")
33 | flags.DEFINE_string("output_file", "/tmp/search_tree.png",
34 | "The output file for the visualization.")
35 |
36 |
37 | def convert_tree_to_graph(
38 | tree: mctx.Tree,
39 | action_labels: Optional[Sequence[str]] = None,
40 | batch_index: int = 0
41 | ) -> pygraphviz.AGraph:
42 | """Converts a search tree into a Graphviz graph.
43 |
44 | Args:
45 | tree: A `Tree` containing a batch of search data.
46 | action_labels: Optional labels for edges, defaults to the action index.
47 | batch_index: Index of the batch element to plot.
48 |
49 | Returns:
50 | A Graphviz graph representation of `tree`.
51 | """
52 | chex.assert_rank(tree.node_values, 2)
53 | batch_size = tree.node_values.shape[0]
54 | if action_labels is None:
55 | action_labels = range(tree.num_actions)
56 | elif len(action_labels) != tree.num_actions:
57 | raise ValueError(
58 | f"action_labels {action_labels} has the wrong number of actions "
59 | f"({len(action_labels)}). "
60 | f"Expecting {tree.num_actions}.")
61 |
62 | def node_to_str(node_i, reward=0, discount=1):
63 | return (f"{node_i}\n"
64 | f"Reward: {reward:.2f}\n"
65 | f"Discount: {discount:.2f}\n"
66 | f"Value: {tree.node_values[batch_index, node_i]:.2f}\n"
67 | f"Visits: {tree.node_visits[batch_index, node_i]}\n")
68 |
69 | def edge_to_str(node_i, a_i):
70 | node_index = jnp.full([batch_size], node_i)
71 | probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i])
72 | return (f"{action_labels[a_i]}\n"
73 | f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations
74 | f"p: {probs[a_i]:.2f}\n")
75 |
76 | graph = pygraphviz.AGraph(directed=True)
77 |
78 | # Add root
79 | graph.add_node(0, label=node_to_str(node_i=0), color="green")
80 | # Add all other nodes and connect them up.
81 | for node_i in range(tree.num_simulations):
82 | for a_i in range(tree.num_actions):
83 | # Index of children, or -1 if not expanded
84 | children_i = tree.children_index[batch_index, node_i, a_i]
85 | if children_i >= 0:
86 | graph.add_node(
87 | children_i,
88 | label=node_to_str(
89 | node_i=children_i,
90 | reward=tree.children_rewards[batch_index, node_i, a_i],
91 | discount=tree.children_discounts[batch_index, node_i, a_i]),
92 | color="red")
93 | graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i))
94 |
95 | return graph
96 |
97 |
98 | def _run_demo(rng_key: chex.PRNGKey):
99 | """Runs a search algorithm on a toy environment."""
100 | # We will define a deterministic toy environment.
101 | # The deterministic `transition_matrix` has shape `[num_states, num_actions]`.
102 | # The `transition_matrix[s, a]` holds the next state.
103 | transition_matrix = jnp.array([
104 | [1, 2, 3, 4],
105 | [0, 5, 0, 0],
106 | [0, 0, 0, 6],
107 | [0, 0, 0, 0],
108 | [0, 0, 0, 0],
109 | [0, 0, 0, 0],
110 | [0, 0, 0, 0],
111 | ], dtype=jnp.int32)
112 | # The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]`
113 | # holds the reward for that (s, a) pair.
114 | rewards = jnp.array([
115 | [1, -1, 0, 0],
116 | [0, 0, 0, 0],
117 | [0, 0, 0, 0],
118 | [0, 0, 0, 0],
119 | [0, 0, 0, 0],
120 | [0, 0, 0, 0],
121 | [10, 0, 20, 0],
122 | ], dtype=jnp.float32)
123 | num_states = rewards.shape[0]
124 | # The discount for each (s, a) pair.
125 | discounts = jnp.where(transition_matrix > 0, 1.0, 0.0)
126 | # Using optimistic initial values to encourage exploration.
127 | values = jnp.full([num_states], 15.0)
128 | # The prior policies for each state.
129 | all_prior_logits = jnp.zeros_like(rewards)
130 | root, recurrent_fn = _make_batched_env_model(
131 | # Using batch_size=2 to test the batched search.
132 | batch_size=2,
133 | transition_matrix=transition_matrix,
134 | rewards=rewards,
135 | discounts=discounts,
136 | values=values,
137 | prior_logits=all_prior_logits)
138 |
139 | # Running the search.
140 | policy_output = mctx.gumbel_muzero_policy(
141 | params=(),
142 | rng_key=rng_key,
143 | root=root,
144 | recurrent_fn=recurrent_fn,
145 | num_simulations=FLAGS.num_simulations,
146 | max_depth=FLAGS.max_depth,
147 | max_num_considered_actions=FLAGS.max_num_considered_actions,
148 | )
149 | return policy_output
150 |
151 |
152 | def _make_batched_env_model(
153 | batch_size: int,
154 | *,
155 | transition_matrix: chex.Array,
156 | rewards: chex.Array,
157 | discounts: chex.Array,
158 | values: chex.Array,
159 | prior_logits: chex.Array):
160 | """Returns a batched `(root, recurrent_fn)`."""
161 | chex.assert_equal_shape([transition_matrix, rewards, discounts,
162 | prior_logits])
163 | num_states, num_actions = transition_matrix.shape
164 | chex.assert_shape(values, [num_states])
165 | # We will start the search at state zero.
166 | root_state = 0
167 | root = mctx.RootFnOutput(
168 | prior_logits=jnp.full([batch_size, num_actions],
169 | prior_logits[root_state]),
170 | value=jnp.full([batch_size], values[root_state]),
171 | # The embedding will hold the state index.
172 | embedding=jnp.zeros([batch_size], dtype=jnp.int32),
173 | )
174 |
175 | def recurrent_fn(params, rng_key, action, embedding):
176 | del params, rng_key
177 | chex.assert_shape(action, [batch_size])
178 | chex.assert_shape(embedding, [batch_size])
179 | recurrent_fn_output = mctx.RecurrentFnOutput(
180 | reward=rewards[embedding, action],
181 | discount=discounts[embedding, action],
182 | prior_logits=prior_logits[embedding],
183 | value=values[embedding])
184 | next_embedding = transition_matrix[embedding, action]
185 | return recurrent_fn_output, next_embedding
186 |
187 | return root, recurrent_fn
188 |
189 |
190 | def main(_):
191 | rng_key = jax.random.PRNGKey(FLAGS.seed)
192 | jitted_run_demo = jax.jit(_run_demo)
193 | print("Starting search.")
194 | policy_output = jitted_run_demo(rng_key)
195 | batch_index = 0
196 | selected_action = policy_output.action[batch_index]
197 | q_value = policy_output.search_tree.summary().qvalues[
198 | batch_index, selected_action]
199 | print("Selected action:", selected_action)
200 | # To estimate the value of the root state, use the Q-value of the selected
201 | # action. The Q-value is not affected by the exploration at the root node.
202 | print("Selected action Q-value:", q_value)
203 | graph = convert_tree_to_graph(policy_output.search_tree)
204 | print("Saving tree diagram to:", FLAGS.output_file)
205 | graph.draw(FLAGS.output_file, prog="dot")
206 |
207 |
208 | if __name__ == "__main__":
209 | app.run(main)
210 |
--------------------------------------------------------------------------------
/mctx/_src/action_selection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A collection of action selection functions."""
16 | from typing import Optional, TypeVar
17 |
18 | import chex
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from mctx._src import base
23 | from mctx._src import qtransforms
24 | from mctx._src import seq_halving
25 | from mctx._src import tree as tree_lib
26 |
27 |
28 | def switching_action_selection_wrapper(
29 | root_action_selection_fn: base.RootActionSelectionFn,
30 | interior_action_selection_fn: base.InteriorActionSelectionFn
31 | ) -> base.InteriorActionSelectionFn:
32 | """Wraps root and interior action selection fns in a conditional statement."""
33 |
34 | def switching_action_selection_fn(
35 | rng_key: chex.PRNGKey,
36 | tree: tree_lib.Tree,
37 | node_index: base.NodeIndices,
38 | depth: base.Depth) -> chex.Array:
39 | return jax.lax.cond(
40 | depth == 0,
41 | lambda x: root_action_selection_fn(*x[:3]),
42 | lambda x: interior_action_selection_fn(*x),
43 | (rng_key, tree, node_index, depth))
44 |
45 | return switching_action_selection_fn
46 |
47 |
48 | def muzero_action_selection(
49 | rng_key: chex.PRNGKey,
50 | tree: tree_lib.Tree,
51 | node_index: chex.Numeric,
52 | depth: chex.Numeric,
53 | *,
54 | pb_c_init: float = 1.25,
55 | pb_c_base: float = 19652.0,
56 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
57 | ) -> chex.Array:
58 | """Returns the action selected for a node index.
59 |
60 | See Appendix B in https://arxiv.org/pdf/1911.08265.pdf for more details.
61 |
62 | Args:
63 | rng_key: random number generator state.
64 | tree: _unbatched_ MCTS tree state.
65 | node_index: scalar index of the node from which to select an action.
66 | depth: the scalar depth of the current node. The root has depth zero.
67 | pb_c_init: constant c_1 in the PUCT formula.
68 | pb_c_base: constant c_2 in the PUCT formula.
69 | qtransform: a monotonic transformation to convert the Q-values to [0, 1].
70 |
71 | Returns:
72 | action: the action selected from the given node.
73 | """
74 | visit_counts = tree.children_visits[node_index]
75 | node_visit = tree.node_visits[node_index]
76 | pb_c = pb_c_init + jnp.log((node_visit + pb_c_base + 1.) / pb_c_base)
77 | prior_logits = tree.children_prior_logits[node_index]
78 | prior_probs = jax.nn.softmax(prior_logits)
79 | policy_score = jnp.sqrt(node_visit) * pb_c * prior_probs / (visit_counts + 1)
80 | chex.assert_shape([node_index, node_visit], ())
81 | chex.assert_equal_shape([prior_probs, visit_counts, policy_score])
82 | value_score = qtransform(tree, node_index)
83 |
84 | # Add tiny bit of randomness for tie break
85 | node_noise_score = 1e-7 * jax.random.uniform(
86 | rng_key, (tree.num_actions,))
87 | to_argmax = value_score + policy_score + node_noise_score
88 |
89 | # Masking the invalid actions at the root.
90 | return masked_argmax(to_argmax, tree.root_invalid_actions * (depth == 0))
91 |
92 |
93 | @chex.dataclass(frozen=True)
94 | class GumbelMuZeroExtraData:
95 | """Extra data for Gumbel MuZero search."""
96 | root_gumbel: chex.Array
97 |
98 |
99 | GumbelMuZeroExtraDataType = TypeVar( # pylint: disable=invalid-name
100 | "GumbelMuZeroExtraDataType", bound=GumbelMuZeroExtraData)
101 |
102 |
103 | def gumbel_muzero_root_action_selection(
104 | rng_key: chex.PRNGKey,
105 | tree: tree_lib.Tree[GumbelMuZeroExtraDataType],
106 | node_index: chex.Numeric,
107 | *,
108 | num_simulations: chex.Numeric,
109 | max_num_considered_actions: chex.Numeric,
110 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
111 | ) -> chex.Array:
112 | """Returns the action selected by Sequential Halving with Gumbel.
113 |
114 | Initially, we sample `max_num_considered_actions` actions without replacement.
115 | From these, the actions with the highest `gumbel + logits + qvalues` are
116 | visited first.
117 |
118 | Args:
119 | rng_key: random number generator state.
120 | tree: _unbatched_ MCTS tree state.
121 | node_index: scalar index of the node from which to take an action.
122 | num_simulations: the simulation budget.
123 | max_num_considered_actions: the number of actions sampled without
124 | replacement.
125 | qtransform: a monotonic transformation for the Q-values.
126 |
127 | Returns:
128 | action: the action selected from the given node.
129 | """
130 | del rng_key
131 | chex.assert_shape([node_index], ())
132 | visit_counts = tree.children_visits[node_index]
133 | prior_logits = tree.children_prior_logits[node_index]
134 | chex.assert_equal_shape([visit_counts, prior_logits])
135 | completed_qvalues = qtransform(tree, node_index)
136 |
137 | table = jnp.array(seq_halving.get_table_of_considered_visits(
138 | max_num_considered_actions, num_simulations))
139 | num_valid_actions = jnp.sum(
140 | 1 - tree.root_invalid_actions, axis=-1).astype(jnp.int32)
141 | num_considered = jnp.minimum(
142 | max_num_considered_actions, num_valid_actions)
143 | chex.assert_shape(num_considered, ())
144 | # At the root, the simulation_index is equal to the sum of visit counts.
145 | simulation_index = jnp.sum(visit_counts, -1)
146 | chex.assert_shape(simulation_index, ())
147 | considered_visit = table[num_considered, simulation_index]
148 | chex.assert_shape(considered_visit, ())
149 | gumbel = tree.extra_data.root_gumbel
150 | to_argmax = seq_halving.score_considered(
151 | considered_visit, gumbel, prior_logits, completed_qvalues,
152 | visit_counts)
153 |
154 | # Masking the invalid actions at the root.
155 | return masked_argmax(to_argmax, tree.root_invalid_actions)
156 |
157 |
158 | def gumbel_muzero_interior_action_selection(
159 | rng_key: chex.PRNGKey,
160 | tree: tree_lib.Tree,
161 | node_index: chex.Numeric,
162 | depth: chex.Numeric,
163 | *,
164 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
165 | ) -> chex.Array:
166 | """Selects the action with a deterministic action selection.
167 |
168 | The action is selected based on the visit counts to produce visitation
169 | frequencies similar to softmax(prior_logits + qvalues).
170 |
171 | Args:
172 | rng_key: random number generator state.
173 | tree: _unbatched_ MCTS tree state.
174 | node_index: scalar index of the node from which to take an action.
175 | depth: the scalar depth of the current node. The root has depth zero.
176 | qtransform: function to obtain completed Q-values for a node.
177 |
178 | Returns:
179 | action: the action selected from the given node.
180 | """
181 | del rng_key, depth
182 | chex.assert_shape([node_index], ())
183 | visit_counts = tree.children_visits[node_index]
184 | prior_logits = tree.children_prior_logits[node_index]
185 | chex.assert_equal_shape([visit_counts, prior_logits])
186 | completed_qvalues = qtransform(tree, node_index)
187 |
188 | # The `prior_logits + completed_qvalues` provide an improved policy,
189 | # because the missing qvalues are replaced by v_{prior_logits}(node).
190 | to_argmax = _prepare_argmax_input(
191 | probs=jax.nn.softmax(prior_logits + completed_qvalues),
192 | visit_counts=visit_counts)
193 |
194 | chex.assert_rank(to_argmax, 1)
195 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32)
196 |
197 |
198 | def masked_argmax(
199 | to_argmax: chex.Array,
200 | invalid_actions: Optional[chex.Array]) -> chex.Array:
201 | """Returns a valid action with the highest `to_argmax`."""
202 | if invalid_actions is not None:
203 | chex.assert_equal_shape([to_argmax, invalid_actions])
204 | # The usage of the -inf inside the argmax does not lead to NaN.
205 | # Do not use -inf inside softmax, logsoftmax or cross-entropy.
206 | to_argmax = jnp.where(invalid_actions, -jnp.inf, to_argmax)
207 | # If all actions are invalid, the argmax returns action 0.
208 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32)
209 |
210 |
211 | def _prepare_argmax_input(probs, visit_counts):
212 | """Prepares the input for the deterministic selection.
213 |
214 | When calling argmax(_prepare_argmax_input(...)) multiple times
215 | with updated visit_counts, the produced visitation frequencies will
216 | approximate the probs.
217 |
218 | For the derivation, see Section 5 "Planning at non-root nodes" in
219 | "Policy improvement by planning with Gumbel":
220 | https://openreview.net/forum?id=bERaNdoegnO
221 |
222 | Args:
223 | probs: a policy or an improved policy. Shape `[num_actions]`.
224 | visit_counts: the existing visit counts. Shape `[num_actions]`.
225 |
226 | Returns:
227 | The input to an argmax. Shape `[num_actions]`.
228 | """
229 | chex.assert_equal_shape([probs, visit_counts])
230 | to_argmax = probs - visit_counts / (
231 | 1 + jnp.sum(visit_counts, keepdims=True, axis=-1))
232 | return to_argmax
233 |
--------------------------------------------------------------------------------
/mctx/_src/tests/tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A unit test comparing the search tree to an expected search tree."""
16 | # pylint: disable=use-dict-literal
17 | import functools
18 | import json
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 | import numpy as np
24 | from absl import logging
25 | from absl.testing import absltest, parameterized
26 |
27 | import mctx
28 |
29 | jax.config.update("jax_threefry_partitionable", False)
30 |
31 |
32 | def _prepare_root(batch_size, num_actions):
33 | """Returns a root consistent with the stored expected trees."""
34 | rng_key = jax.random.PRNGKey(0)
35 | # Using a different rng_key inside each batch element.
36 | rng_keys = [rng_key]
37 | for i in range(1, batch_size):
38 | rng_keys.append(jax.random.fold_in(rng_key, i))
39 | embedding = jnp.stack(rng_keys)
40 | output = jax.vmap(
41 | functools.partial(_produce_prediction_output, num_actions=num_actions))(
42 | embedding)
43 | return mctx.RootFnOutput(
44 | prior_logits=output["policy_logits"],
45 | value=output["value"],
46 | embedding=embedding,
47 | )
48 |
49 |
50 | def _produce_prediction_output(rng_key, num_actions):
51 | """Producing the model output as in the stored expected trees."""
52 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3)
53 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3)
54 | del rng_key
55 | # Producing value from [-1, +1).
56 | value = jax.random.uniform(value_rng, shape=(), minval=-1.0, maxval=1.0)
57 | # Producing reward from [-1, +1).
58 | reward = jax.random.uniform(reward_rng, shape=(), minval=-1.0, maxval=1.0)
59 | return dict(
60 | policy_logits=jax.random.normal(policy_rng, shape=[num_actions]),
61 | value=value,
62 | reward=reward,
63 | )
64 |
65 |
66 | def _prepare_recurrent_fn(num_actions, *, discount, zero_reward):
67 | """Returns a dynamics function consistent with the expected trees."""
68 |
69 | def recurrent_fn(params, rng_key, action, embedding):
70 | del params, rng_key
71 | # The embeddings serve as rng_keys.
72 | embedding = jax.vmap(
73 | functools.partial(_fold_action_in, num_actions=num_actions))(embedding,
74 | action)
75 | output = jax.vmap(
76 | functools.partial(_produce_prediction_output, num_actions=num_actions))(
77 | embedding)
78 | reward = output["reward"]
79 | if zero_reward:
80 | reward = jnp.zeros_like(reward)
81 | return mctx.RecurrentFnOutput(
82 | reward=reward,
83 | discount=jnp.full_like(reward, discount),
84 | prior_logits=output["policy_logits"],
85 | value=output["value"],
86 | ), embedding
87 |
88 | return recurrent_fn
89 |
90 |
91 | def _fold_action_in(rng_key, action, num_actions):
92 | """Returns a new rng key, selected by the given action."""
93 | chex.assert_shape(action, ())
94 | chex.assert_type(action, jnp.int32)
95 | sub_rngs = jax.random.split(rng_key, num_actions)
96 | return sub_rngs[action]
97 |
98 |
99 | def tree_to_pytree(tree: mctx.Tree, batch_i: int = 0):
100 | """Converts the MCTS tree to nested dicts."""
101 | nodes = {}
102 | if tree.node_visits[batch_i, 0] == 0:
103 | # The root node is unvisited, so there is no tree.
104 | return _create_bare_root(prior=1.0)
105 | nodes[0] = _create_pynode(
106 | tree, batch_i, 0, prior=1.0, action=None, reward=None)
107 | children_prior_probs = jax.nn.softmax(tree.children_prior_logits, axis=-1)
108 | for node_i in range(tree.num_simulations + 1):
109 | # Return early if we reach an unvisited node
110 | visits = tree.node_visits[batch_i, node_i]
111 | if visits == 0:
112 | return nodes[0]
113 | for a_i in range(tree.num_actions):
114 | prior = children_prior_probs[batch_i, node_i, a_i]
115 | # Index of children, or -1 if not expanded
116 | child_i = int(tree.children_index[batch_i, node_i, a_i])
117 | if child_i >= 0:
118 | reward = tree.children_rewards[batch_i, node_i, a_i]
119 | child = _create_pynode(
120 | tree, batch_i, child_i, prior=prior, action=a_i, reward=reward)
121 | nodes[child_i] = child
122 | else:
123 | child = _create_bare_pynode(prior=prior, action=a_i)
124 | # pylint: disable=line-too-long
125 | nodes[node_i]["child_stats"].append(child) # pytype: disable=attribute-error
126 | # pylint: enable=line-too-long
127 | return nodes[0]
128 |
129 |
130 | def _create_pynode(tree, batch_i, node_i, prior, action, reward):
131 | """Returns a dict with extracted search statistics."""
132 | node = dict(
133 | prior=_round_float(prior),
134 | visit=int(tree.node_visits[batch_i, node_i]),
135 | value_view=_round_float(tree.node_values[batch_i, node_i]),
136 | raw_value_view=_round_float(tree.raw_values[batch_i, node_i]),
137 | child_stats=[],
138 | evaluation_index=node_i,
139 | )
140 | if action is not None:
141 | node["action"] = action
142 | if reward is not None:
143 | node["reward"] = _round_float(reward)
144 | return node
145 |
146 |
147 | def _create_bare_pynode(prior, action):
148 | return dict(
149 | prior=_round_float(prior),
150 | child_stats=[],
151 | action=action,
152 | )
153 |
154 |
155 | def _create_bare_root(prior):
156 | return dict(
157 | prior=_round_float(prior),
158 | child_stats=[],
159 | )
160 |
161 |
162 | def _round_float(value, ndigits=10):
163 | return round(float(value), ndigits)
164 |
165 |
166 | class TreeTest(parameterized.TestCase):
167 | # Make sure to adjust the `shard_count` parameter in the build file to match
168 | # the number of parameter configurations passed to test_tree.
169 | # pylint: disable=line-too-long
170 | MUZERO_TREES = [("muzero_norescale",
171 | "./mctx/_src/tests/test_data/muzero_tree.json"),
172 | ("muzero_qtransform",
173 | "./mctx/_src/tests/test_data/muzero_qtransform_tree.json")]
174 | GUMBEL_MUZERO_TREES = [("gumbel_muzero_norescale",
175 | "./mctx/_src/tests/test_data/gumbel_muzero_tree.json"),
176 | ("gumbel_muzero_reward",
177 | "./mctx/_src/tests/test_data/gumbel_muzero_reward_tree.json")]
178 | TREES = MUZERO_TREES + GUMBEL_MUZERO_TREES
179 | # pylint: enable=line-too-long
180 |
181 | @parameterized.named_parameters(*TREES)
182 | def test_tree(self, tree_data_path):
183 | with open(tree_data_path, "rb") as fd:
184 | tree = json.load(fd)
185 | reproduced, _ = self._reproduce_tree(tree)
186 | chex.assert_trees_all_close(tree["tree"], reproduced, atol=1e-3)
187 |
188 |
189 | @parameterized.named_parameters(*MUZERO_TREES)
190 | def test_subtree(self, tree_data_path):
191 | with open(tree_data_path, "rb") as fd:
192 | tree = json.load(fd)
193 | reproduced, jax_tree = self._reproduce_tree(tree)
194 |
195 | def rm_evaluation_index(node):
196 | if isinstance(node, dict):
197 | node.pop("evaluation_index", None)
198 | for child in node["child_stats"]:
199 | rm_evaluation_index(child)
200 |
201 | # test populated subtree
202 | for child_tree in reproduced["child_stats"]:
203 | action = child_tree["action"]
204 | # reflect that the chosen child node is now a root node
205 | child_tree["prior"] = 1.0
206 | child_tree.pop("action")
207 | if "reward" in child_tree:
208 | child_tree.pop("reward")
209 |
210 | subtree = mctx.get_subtree(jax_tree, jnp.array([action, action, action]))
211 | reproduced_subtree = tree_to_pytree(subtree)
212 |
213 | # evaluation indices will not match since subtree indices are
214 | # collapsed down so check everything but evaluation indices
215 | rm_evaluation_index(child_tree)
216 | rm_evaluation_index(reproduced_subtree)
217 |
218 | chex.assert_trees_all_close(reproduced_subtree, child_tree, atol=1e-3)
219 |
220 |
221 | def _reproduce_tree(self, tree):
222 | """Reproduces the given JSON tree by running a search."""
223 | policy_fn = dict(
224 | gumbel_muzero=mctx.gumbel_muzero_policy,
225 | muzero=mctx.muzero_policy,
226 | )[tree["algorithm"]]
227 |
228 | env_config = tree["env_config"]
229 | root = tree["tree"]
230 | num_actions = len(root["child_stats"])
231 | num_simulations = root["visit"] - 1
232 | qtransform = functools.partial(
233 | getattr(mctx, tree["algorithm_config"].pop("qtransform")),
234 | **tree["algorithm_config"].pop("qtransform_kwargs", {}))
235 |
236 | batch_size = 3
237 | # To test the independence of the batch computation, we use different
238 | # invalid actions for the other elements of the batch. The different batch
239 | # elements will then have different search tree depths.
240 | invalid_actions = np.zeros([batch_size, num_actions])
241 | invalid_actions[1, 1:] = 1
242 | invalid_actions[2, 2:] = 1
243 |
244 | def run_policy():
245 | return policy_fn(
246 | params=(),
247 | rng_key=jax.random.PRNGKey(1),
248 | root=_prepare_root(batch_size=batch_size, num_actions=num_actions),
249 | recurrent_fn=_prepare_recurrent_fn(num_actions, **env_config),
250 | num_simulations=num_simulations,
251 | qtransform=qtransform,
252 | invalid_actions=invalid_actions,
253 | **tree["algorithm_config"])
254 |
255 | policy_output = jax.jit(run_policy)() # pylint: disable=not-callable
256 | logging.info("Done search.")
257 |
258 | return tree_to_pytree(policy_output.search_tree), policy_output.search_tree
259 |
260 |
261 | if __name__ == "__main__":
262 | jax.config.update("jax_numpy_rank_promotion", "raise")
263 | absltest.main()
264 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/mctx/_src/tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A data structure used to hold / inspect search data for a batch of inputs."""
16 |
17 | from __future__ import annotations
18 | from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 |
24 |
25 | T = TypeVar("T")
26 |
27 |
28 | @chex.dataclass(frozen=True)
29 | class Tree(Generic[T]):
30 | """State of a search tree.
31 |
32 | The `Tree` dataclass is used to hold and inspect search data for a batch of
33 | inputs. In the fields below `B` denotes the batch dimension, `N` represents
34 | the number of nodes in the tree, and `num_actions` is the number of discrete
35 | actions.
36 |
37 | node_visits: `[B, N]` the visit counts for each node.
38 | raw_values: `[B, N]` the raw value for each node.
39 | node_values: `[B, N]` the cumulative search value for each node.
40 | parents: `[B, N]` the node index for the parents for each node.
41 | action_from_parent: `[B, N]` action to take from the parent to reach each
42 | node.
43 | children_index: `[B, N, num_actions]` the node index of the children for each
44 | action.
45 | children_prior_logits: `[B, N, num_actions]` the action prior logits of each
46 | node.
47 | children_visits: `[B, N, num_actions]` the visit counts for children for
48 | each action.
49 | children_rewards: `[B, N, num_actions]` the immediate reward for each action.
50 | children_discounts: `[B, N, num_actions]` the discount between the
51 | `children_rewards` and the `children_values`.
52 | children_values: `[B, N, num_actions]` the value of the next node after the
53 | action.
54 | next_node_index: `[B]` the next free index where a new node can be inserted.
55 | embeddings: `[B, N, ...]` the state embeddings of each node.
56 | root_invalid_actions: `[B, num_actions]` a mask with invalid actions at the
57 | root. In the mask, invalid actions have ones, and valid actions have zeros.
58 | extra_data: `[B, ...]` extra data passed to the search.
59 | """
60 | node_visits: chex.Array # [B, N]
61 | raw_values: chex.Array # [B, N]
62 | node_values: chex.Array # [B, N]
63 | parents: chex.Array # [B, N]
64 | action_from_parent: chex.Array # [B, N]
65 | children_index: chex.Array # [B, N, num_actions]
66 | children_prior_logits: chex.Array # [B, N, num_actions]
67 | children_visits: chex.Array # [B, N, num_actions]
68 | children_rewards: chex.Array # [B, N, num_actions]
69 | children_discounts: chex.Array # [B, N, num_actions]
70 | children_values: chex.Array # [B, N, num_actions]
71 | next_node_index: chex.Array # [B]
72 | embeddings: Any # [B, N, ...]
73 | root_invalid_actions: chex.Array # [B, num_actions]
74 | extra_data: T # [B, ...]
75 |
76 | # The following attributes are class variables (and should not be set on
77 | # Tree instances).
78 | ROOT_INDEX: ClassVar[int] = 0
79 | NO_PARENT: ClassVar[int] = -1
80 | UNVISITED: ClassVar[int] = -1
81 |
82 | @property
83 | def num_actions(self):
84 | return self.children_index.shape[-1]
85 |
86 | @property
87 | def num_simulations(self):
88 | return self.node_visits.shape[-1] - 1
89 |
90 | def qvalues(self, indices):
91 | """Compute q-values for any node indices in the tree."""
92 | # pytype: disable=wrong-arg-types # jnp-type
93 | if jnp.asarray(indices).shape:
94 | return jax.vmap(_unbatched_qvalues)(self, indices)
95 | else:
96 | return _unbatched_qvalues(self, indices)
97 | # pytype: enable=wrong-arg-types
98 |
99 | def summary(self) -> SearchSummary:
100 | """Extract summary statistics for the root node."""
101 | # Get state and action values for the root nodes.
102 | chex.assert_rank(self.node_values, 2)
103 | value = self.node_values[:, Tree.ROOT_INDEX]
104 | batch_size, = value.shape
105 | root_indices = jnp.full((batch_size,), Tree.ROOT_INDEX)
106 | qvalues = self.qvalues(root_indices)
107 | # Extract visit counts and induced probabilities for the root nodes.
108 | visit_counts = self.children_visits[:, Tree.ROOT_INDEX].astype(value.dtype)
109 | total_counts = jnp.sum(visit_counts, axis=-1, keepdims=True)
110 | visit_probs = visit_counts / jnp.maximum(total_counts, 1)
111 | visit_probs = jnp.where(total_counts > 0, visit_probs, 1 / self.num_actions)
112 | # Return relevant stats.
113 | return SearchSummary( # pytype: disable=wrong-arg-types # numpy-scalars
114 | visit_counts=visit_counts,
115 | visit_probs=visit_probs,
116 | value=value,
117 | qvalues=qvalues)
118 |
119 |
120 | def infer_batch_size(tree: Tree) -> int:
121 | """Recovers batch size from `Tree` data structure."""
122 | if tree.node_values.ndim != 2:
123 | raise ValueError("Input tree is not batched.")
124 | chex.assert_equal_shape_prefix(jax.tree_util.tree_leaves(tree), 1)
125 | return tree.node_values.shape[0]
126 |
127 |
128 | # A number of aggregate statistics and predictions are extracted from the
129 | # search data and returned to the user for further processing.
130 | @chex.dataclass(frozen=True)
131 | class SearchSummary:
132 | """Stats from MCTS search."""
133 | visit_counts: chex.Array
134 | visit_probs: chex.Array
135 | value: chex.Array
136 | qvalues: chex.Array
137 |
138 |
139 | def _unbatched_qvalues(tree: Tree, index: int) -> int:
140 | chex.assert_rank(tree.children_discounts, 2)
141 | return ( # pytype: disable=bad-return-type # numpy-scalars
142 | tree.children_rewards[index]
143 | + tree.children_discounts[index] * tree.children_values[index]
144 | )
145 |
146 |
147 | def _get_translation(
148 | tree: Tree,
149 | child_index: chex.Array
150 | ) -> Tuple[chex.Array, chex.Array, chex.Array]:
151 | subtrees = jnp.arange(tree.num_simulations+1)
152 |
153 | def propagate_fun(_, subtrees):
154 | parents_subtrees = jnp.where(
155 | tree.parents != tree.NO_PARENT,
156 | subtrees[tree.parents],
157 | 0
158 | )
159 | return jnp.where(
160 | jnp.greater(parents_subtrees, 0),
161 | parents_subtrees,
162 | subtrees
163 | )
164 |
165 | subtrees = jax.lax.fori_loop(0, tree.num_simulations, propagate_fun, subtrees)
166 | slots_aranged = jnp.arange(tree.num_simulations+1)
167 | subtree_master_idx = tree.children_index[tree.ROOT_INDEX, child_index]
168 | nodes_to_retain = subtrees == subtree_master_idx
169 | old_subtree_idxs = nodes_to_retain * slots_aranged
170 | cumsum = jnp.cumsum(nodes_to_retain)
171 | new_next_node_index = cumsum[-1]
172 |
173 | translation = jnp.where(
174 | nodes_to_retain,
175 | nodes_to_retain * (cumsum-1),
176 | tree.UNVISITED
177 | )
178 | erase_idxs = slots_aranged >= new_next_node_index
179 |
180 | return old_subtree_idxs, translation, erase_idxs
181 |
182 |
183 | @jax.vmap
184 | def get_subtree(
185 | tree: Tree,
186 | child_index: chex.Array
187 | ) -> Tree:
188 | """Extracts subtrees rooted at child indices of the root node,
189 | across a batch of trees. Converts node index mappings and collapses
190 | node data so that populated nodes are contiguous and start at index 0.
191 |
192 | Assumes `tree` elements and `child_index` have a batch dimension.
193 |
194 | Args:
195 | tree: the tree to extract subtrees from
196 | child_index: `[B]` the index of the child (from the root) to extract each
197 | subtree from
198 | """
199 | # get mapping from old node indices to new node indices
200 | # and a mask of which nodes indices to erase
201 | old_subtree_idxs, translation, erase_idxs = _get_translation(
202 | tree, child_index)
203 | new_next_node_index = translation.max(axis=-1) + 1
204 |
205 | def translate(x, null_value=0):
206 | return jnp.where(
207 | erase_idxs.reshape((-1,) + (1,) * (x.ndim - 1)),
208 | jnp.full_like(x, null_value),
209 | # cases where translation == -1 will set last index
210 | # but since we are at least removing the root node
211 | # (and making one of its children the new root)
212 | # the last index will always be freed
213 | # and overwritten with zeros
214 | x.at[translation].set(x[old_subtree_idxs]),
215 | )
216 |
217 | def translate_idx(x, null_value=tree.UNVISITED):
218 | return jnp.where(
219 | erase_idxs.reshape((-1,) + (1,) * (x.ndim - 1)),
220 | jnp.full_like(x, null_value),
221 | # in this case we need to explicitly check for index
222 | # mappings to UNVISITED, since otherwise thsese will
223 | # map to the value of the last index of the translation
224 | x.at[translation].set(jnp.where(
225 | x == null_value,
226 | null_value,
227 | translation[x])))
228 |
229 | def translate_pytree(x, null_value=0):
230 | return jax.tree.map(
231 | lambda t: translate(t, null_value=null_value), x)
232 |
233 | return tree.replace(
234 | node_visits=translate(tree.node_visits),
235 | raw_values=translate(tree.raw_values),
236 | node_values=translate(tree.node_values),
237 | parents=translate_idx(tree.parents),
238 | action_from_parent=translate(
239 | tree.action_from_parent,
240 | null_value=tree.NO_PARENT).at[tree.ROOT_INDEX].set(tree.NO_PARENT),
241 | children_index=translate_idx(tree.children_index),
242 | children_prior_logits=translate(tree.children_prior_logits),
243 | children_visits=translate(tree.children_visits),
244 | children_rewards=translate(tree.children_rewards),
245 | children_discounts=translate(tree.children_discounts),
246 | children_values=translate(tree.children_values),
247 | next_node_index=new_next_node_index,
248 | root_invalid_actions=jnp.zeros_like(tree.root_invalid_actions),
249 | embeddings=translate_pytree(tree.embeddings)
250 | )
251 |
252 |
253 | def reset_search_tree(
254 | tree: Tree,
255 | select_batch: Optional[chex.Array] = None) -> Tree:
256 | """Fills search tree with default values for selected batches.
257 |
258 | Useful for resetting the search tree after a terminated episode.
259 |
260 | Args:
261 | tree: the tree to reset
262 | select_batch: `[B]` a boolean mask to select which batch elements to reset.
263 | If `None`, all batch elements are reset.
264 | """
265 | if select_batch is None:
266 | select_batch = jnp.ones(tree.node_visits.shape[0], dtype=bool)
267 |
268 | return tree.replace(
269 | node_visits=tree.node_visits.at[select_batch].set(0),
270 | raw_values=tree.raw_values.at[select_batch].set(0),
271 | node_values=tree.node_values.at[select_batch].set(0),
272 | parents=tree.parents.at[select_batch].set(tree.NO_PARENT),
273 | action_from_parent=tree.action_from_parent.at[select_batch].set(
274 | tree.NO_PARENT),
275 | children_index=tree.children_index.at[select_batch].set(tree.UNVISITED),
276 | children_prior_logits=tree.children_prior_logits.at[select_batch].set(0),
277 | children_values=tree.children_values.at[select_batch].set(0),
278 | children_visits=tree.children_visits.at[select_batch].set(0),
279 | children_rewards=tree.children_rewards.at[select_batch].set(0),
280 | children_discounts=tree.children_discounts.at[select_batch].set(0),
281 | next_node_index=tree.next_node_index.at[select_batch].set(1),
282 | embeddings=jax.tree_util.tree_map(
283 | lambda t: t.at[select_batch].set(0),
284 | tree.embeddings),
285 | root_invalid_actions=tree.root_invalid_actions.at[select_batch].set(0)
286 | # extra_data is always overwritten by a call to search()
287 | )
288 |
--------------------------------------------------------------------------------
/mctx/_src/tests/policies_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `policies.py`."""
16 | import functools
17 |
18 | from absl.testing import absltest
19 | import jax
20 | import jax.numpy as jnp
21 | import mctx
22 | from mctx._src import policies
23 | import numpy as np
24 |
25 | jax.config.update("jax_threefry_partitionable", False)
26 |
27 |
28 | def _make_bandit_recurrent_fn(rewards, dummy_embedding=()):
29 | """Returns a recurrent_fn with discount=0."""
30 |
31 | def recurrent_fn(params, rng_key, action, embedding):
32 | del params, rng_key, embedding
33 | reward = rewards[jnp.arange(action.shape[0]), action]
34 | return mctx.RecurrentFnOutput(
35 | reward=reward,
36 | discount=jnp.zeros_like(reward),
37 | prior_logits=jnp.zeros_like(rewards),
38 | value=jnp.zeros_like(reward),
39 | ), dummy_embedding
40 |
41 | return recurrent_fn
42 |
43 |
44 | def _make_bandit_decision_and_chance_fns(rewards, num_chance_outcomes):
45 |
46 | def decision_recurrent_fn(params, rng_key, action, embedding):
47 | del params, rng_key
48 | batch_size = action.shape[0]
49 | reward = rewards[jnp.arange(batch_size), action]
50 | dummy_chance_logits = jnp.full([batch_size, num_chance_outcomes],
51 | -jnp.inf).at[:, 0].set(1.0)
52 | afterstate_embedding = (action, embedding)
53 | return mctx.DecisionRecurrentFnOutput(
54 | chance_logits=dummy_chance_logits,
55 | afterstate_value=jnp.zeros_like(reward)), afterstate_embedding
56 |
57 | def chance_recurrent_fn(params, rng_key, chance_outcome,
58 | afterstate_embedding):
59 | del params, rng_key, chance_outcome
60 | afterstate_action, embedding = afterstate_embedding
61 | batch_size = afterstate_action.shape[0]
62 |
63 | reward = rewards[jnp.arange(batch_size), afterstate_action]
64 | return mctx.ChanceRecurrentFnOutput(
65 | action_logits=jnp.zeros_like(rewards),
66 | value=jnp.zeros_like(reward),
67 | discount=jnp.zeros_like(reward),
68 | reward=reward), embedding
69 |
70 | return decision_recurrent_fn, chance_recurrent_fn
71 |
72 |
73 | def _get_deepest_leaf(tree, node_index):
74 | """Returns `(leaf, depth)` with maximum depth and visit count.
75 |
76 | Args:
77 | tree: _unbatched_ MCTS tree state.
78 | node_index: the node of the inspected subtree.
79 |
80 | Returns:
81 | `(leaf, depth)` of a deepest leaf. If multiple leaves have the same depth,
82 | the leaf with the highest visit count is returned.
83 | """
84 | np.testing.assert_equal(len(tree.children_index.shape), 2)
85 | leaf = node_index
86 | max_found_depth = 0
87 | for action in range(tree.children_index.shape[-1]):
88 | next_node_index = tree.children_index[node_index, action]
89 | if next_node_index != tree.UNVISITED:
90 | found_leaf, found_depth = _get_deepest_leaf(tree, next_node_index)
91 | if ((1 + found_depth, tree.node_visits[found_leaf]) >
92 | (max_found_depth, tree.node_visits[leaf])):
93 | leaf = found_leaf
94 | max_found_depth = 1 + found_depth
95 | return leaf, max_found_depth
96 |
97 |
98 | class PoliciesTest(absltest.TestCase):
99 |
100 | def test_apply_temperature_one(self):
101 | """Tests temperature=1."""
102 | logits = jnp.arange(6, dtype=jnp.float32)
103 | new_logits = policies._apply_temperature(logits, temperature=1.0)
104 | np.testing.assert_allclose(logits - logits.max(), new_logits)
105 |
106 | def test_apply_temperature_two(self):
107 | """Tests temperature=2."""
108 | logits = jnp.arange(6, dtype=jnp.float32)
109 | temperature = 2.0
110 | new_logits = policies._apply_temperature(logits, temperature)
111 | np.testing.assert_allclose((logits - logits.max()) / temperature,
112 | new_logits)
113 |
114 | def test_apply_temperature_zero(self):
115 | """Tests temperature=0."""
116 | logits = jnp.arange(4, dtype=jnp.float32)
117 | new_logits = policies._apply_temperature(logits, temperature=0.0)
118 | np.testing.assert_allclose(
119 | jnp.array([-2.552118e+38, -1.701412e+38, -8.507059e+37, 0.0]),
120 | new_logits,
121 | rtol=1e-3)
122 |
123 | def test_apply_temperature_zero_on_large_logits(self):
124 | """Tests temperature=0 on large logits."""
125 | logits = jnp.array([100.0, 3.4028235e+38, -jnp.inf, -3.4028235e+38])
126 | new_logits = policies._apply_temperature(logits, temperature=0.0)
127 | np.testing.assert_allclose(
128 | jnp.array([-jnp.inf, 0.0, -jnp.inf, -jnp.inf]), new_logits)
129 |
130 | def test_mask_invalid_actions(self):
131 | """Tests action masking."""
132 | logits = jnp.array([1e6, -jnp.inf, 1e6 + 1, -100.0])
133 | invalid_actions = jnp.array([0.0, 1.0, 0.0, 1.0])
134 | masked_logits = policies._mask_invalid_actions(
135 | logits, invalid_actions)
136 | valid_probs = jax.nn.softmax(jnp.array([0.0, 1.0]))
137 | np.testing.assert_allclose(
138 | jnp.array([valid_probs[0], 0.0, valid_probs[1], 0.0]),
139 | jax.nn.softmax(masked_logits))
140 |
141 | def test_mask_all_invalid_actions(self):
142 | """Tests a state with no valid action."""
143 | logits = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, -jnp.inf])
144 | invalid_actions = jnp.array([1.0, 1.0, 1.0, 1.0])
145 | masked_logits = policies._mask_invalid_actions(
146 | logits, invalid_actions)
147 | np.testing.assert_allclose(
148 | jnp.array([0.25, 0.25, 0.25, 0.25]),
149 | jax.nn.softmax(masked_logits))
150 |
151 | def test_muzero_policy(self):
152 | root = mctx.RootFnOutput(
153 | prior_logits=jnp.array([
154 | [-1.0, 0.0, 2.0, 3.0],
155 | ]),
156 | value=jnp.array([0.0]),
157 | embedding=(),
158 | )
159 | rewards = jnp.zeros_like(root.prior_logits)
160 | invalid_actions = jnp.array([
161 | [0.0, 0.0, 0.0, 1.0],
162 | ])
163 |
164 | policy_output = mctx.muzero_policy(
165 | params=(),
166 | rng_key=jax.random.PRNGKey(0),
167 | root=root,
168 | recurrent_fn=_make_bandit_recurrent_fn(rewards),
169 | num_simulations=1,
170 | invalid_actions=invalid_actions,
171 | dirichlet_fraction=0.0)
172 | expected_action = jnp.array([2], dtype=jnp.int32)
173 | np.testing.assert_array_equal(expected_action, policy_output.action)
174 | expected_action_weights = jnp.array([
175 | [0.0, 0.0, 1.0, 0.0],
176 | ])
177 | np.testing.assert_allclose(expected_action_weights,
178 | policy_output.action_weights)
179 |
180 | def test_gumbel_muzero_policy(self):
181 | root_value = jnp.array([-5.0])
182 | root = mctx.RootFnOutput(
183 | prior_logits=jnp.array([
184 | [0.0, -1.0, 2.0, 3.0],
185 | ]),
186 | value=root_value,
187 | embedding=(),
188 | )
189 | rewards = jnp.array([
190 | [20.0, 3.0, -1.0, 10.0],
191 | ])
192 | invalid_actions = jnp.array([
193 | [1.0, 0.0, 0.0, 1.0],
194 | ])
195 |
196 | value_scale = 0.05
197 | maxvisit_init = 60
198 | num_simulations = 17
199 | max_depth = 3
200 | qtransform = functools.partial(
201 | mctx.qtransform_completed_by_mix_value,
202 | value_scale=value_scale,
203 | maxvisit_init=maxvisit_init,
204 | rescale_values=True)
205 | policy_output = mctx.gumbel_muzero_policy(
206 | params=(),
207 | rng_key=jax.random.PRNGKey(0),
208 | root=root,
209 | recurrent_fn=_make_bandit_recurrent_fn(rewards),
210 | num_simulations=num_simulations,
211 | invalid_actions=invalid_actions,
212 | max_depth=max_depth,
213 | qtransform=qtransform,
214 | gumbel_scale=1.0)
215 | # Testing the action.
216 | expected_action = jnp.array([1], dtype=jnp.int32)
217 | np.testing.assert_array_equal(expected_action, policy_output.action)
218 |
219 | # Testing the action_weights.
220 | probs = jax.nn.softmax(jnp.where(
221 | invalid_actions, -jnp.inf, root.prior_logits))
222 | mix_value = 1.0 / (num_simulations + 1) * (root_value + num_simulations * (
223 | probs[:, 1] * rewards[:, 1] + probs[:, 2] * rewards[:, 2]))
224 |
225 | completed_qvalues = jnp.array([
226 | [mix_value[0], rewards[0, 1], rewards[0, 2], mix_value[0]],
227 | ])
228 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
229 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
230 | total_value_scale = (maxvisit_init + np.ceil(num_simulations / 2)
231 | ) * value_scale
232 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
233 | max_value - min_value)
234 | expected_action_weights = jax.nn.softmax(
235 | jnp.where(invalid_actions,
236 | -jnp.inf,
237 | root.prior_logits + rescaled_qvalues))
238 | np.testing.assert_allclose(expected_action_weights,
239 | policy_output.action_weights,
240 | atol=1e-6)
241 |
242 | # Testing the visit_counts.
243 | summary = policy_output.search_tree.summary()
244 | expected_visit_counts = jnp.array(
245 | [[0.0, np.ceil(num_simulations / 2), num_simulations // 2, 0.0]])
246 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)
247 |
248 | # Testing max_depth.
249 | leaf, max_found_depth = _get_deepest_leaf(
250 | jax.tree.map(lambda x: x[0], policy_output.search_tree),
251 | policy_output.search_tree.ROOT_INDEX)
252 | self.assertEqual(max_depth, max_found_depth)
253 | self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf])
254 |
255 | def test_gumbel_muzero_policy_without_invalid_actions(self):
256 | root_value = jnp.array([-5.0])
257 | root = mctx.RootFnOutput(
258 | prior_logits=jnp.array([
259 | [0.0, -1.0, 2.0, 3.0],
260 | ]),
261 | value=root_value,
262 | embedding=(),
263 | )
264 | rewards = jnp.array([
265 | [20.0, 3.0, -1.0, 10.0],
266 | ])
267 |
268 | value_scale = 0.05
269 | maxvisit_init = 60
270 | num_simulations = 17
271 | max_depth = 3
272 | qtransform = functools.partial(
273 | mctx.qtransform_completed_by_mix_value,
274 | value_scale=value_scale,
275 | maxvisit_init=maxvisit_init,
276 | rescale_values=True)
277 | policy_output = mctx.gumbel_muzero_policy(
278 | params=(),
279 | rng_key=jax.random.PRNGKey(0),
280 | root=root,
281 | recurrent_fn=_make_bandit_recurrent_fn(rewards),
282 | num_simulations=num_simulations,
283 | invalid_actions=None,
284 | max_depth=max_depth,
285 | qtransform=qtransform,
286 | gumbel_scale=1.0)
287 | # Testing the action.
288 | expected_action = jnp.array([3], dtype=jnp.int32)
289 | np.testing.assert_array_equal(expected_action, policy_output.action)
290 |
291 | # Testing the action_weights.
292 | summary = policy_output.search_tree.summary()
293 | completed_qvalues = rewards
294 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
295 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
296 | total_value_scale = (maxvisit_init + summary.visit_counts.max()
297 | ) * value_scale
298 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
299 | max_value - min_value)
300 | expected_action_weights = jax.nn.softmax(
301 | root.prior_logits + rescaled_qvalues)
302 | np.testing.assert_allclose(expected_action_weights,
303 | policy_output.action_weights,
304 | atol=1e-6)
305 |
306 | # Testing the visit_counts.
307 | expected_visit_counts = jnp.array(
308 | [[6, 2, 2, 7]])
309 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)
310 |
311 | def test_stochastic_muzero_policy(self):
312 | """Tests that SMZ is equivalent to MZ with a dummy chance function."""
313 | root = mctx.RootFnOutput(
314 | prior_logits=jnp.array([
315 | [-1.0, 0.0, 2.0, 3.0],
316 | [0.0, 2.0, 5.0, -4.0],
317 | ]),
318 | value=jnp.array([1.0, 0.0]),
319 | embedding=jnp.zeros([2, 4])
320 | )
321 | rewards = jnp.zeros_like(root.prior_logits)
322 | invalid_actions = jnp.array([
323 | [0.0, 0.0, 0.0, 1.0],
324 | [1.0, 0.0, 1.0, 0.0],
325 | ])
326 |
327 | num_simulations = 10
328 |
329 | policy_output = mctx.muzero_policy(
330 | params=(),
331 | rng_key=jax.random.PRNGKey(0),
332 | root=root,
333 | recurrent_fn=_make_bandit_recurrent_fn(
334 | rewards,
335 | dummy_embedding=jnp.zeros_like(root.embedding)),
336 | num_simulations=num_simulations,
337 | invalid_actions=invalid_actions,
338 | dirichlet_fraction=0.0)
339 |
340 | num_chance_outcomes = 5
341 |
342 | decision_rec_fn, chance_rec_fn = _make_bandit_decision_and_chance_fns(
343 | rewards, num_chance_outcomes)
344 |
345 | stochastic_policy_output = mctx.stochastic_muzero_policy(
346 | params=(),
347 | rng_key=jax.random.PRNGKey(0),
348 | root=root,
349 | decision_recurrent_fn=decision_rec_fn,
350 | chance_recurrent_fn=chance_rec_fn,
351 | num_simulations=2 * num_simulations,
352 | invalid_actions=invalid_actions,
353 | dirichlet_fraction=0.0)
354 |
355 | np.testing.assert_array_equal(stochastic_policy_output.action,
356 | policy_output.action)
357 |
358 | np.testing.assert_allclose(stochastic_policy_output.action_weights,
359 | policy_output.action_weights)
360 |
361 |
362 | if __name__ == "__main__":
363 | absltest.main()
364 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | # This Pylint rcfile contains a best-effort configuration to uphold the
2 | # best-practices and style described in the Google Python style guide:
3 | # https://google.github.io/styleguide/pyguide.html
4 | #
5 | # Its canonical open-source location is:
6 | # https://google.github.io/styleguide/pylintrc
7 |
8 | [MASTER]
9 |
10 | # Files or directories to be skipped. They should be base names, not paths.
11 | ignore=third_party
12 |
13 | # Files or directories matching the regex patterns are skipped. The regex
14 | # matches against base names, not paths.
15 | ignore-patterns=
16 |
17 | # Pickle collected data for later comparisons.
18 | persistent=no
19 |
20 | # List of plugins (as comma separated values of python modules names) to load,
21 | # usually to register additional checkers.
22 | load-plugins=
23 |
24 | # Use multiple processes to speed up Pylint.
25 | jobs=4
26 |
27 | # Allow loading of arbitrary C extensions. Extensions are imported into the
28 | # active Python interpreter and may run arbitrary code.
29 | unsafe-load-any-extension=no
30 |
31 |
32 | [MESSAGES CONTROL]
33 |
34 | # Only show warnings with the listed confidence levels. Leave empty to show
35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
36 | confidence=
37 |
38 | # Enable the message, report, category or checker with the given id(s). You can
39 | # either give multiple identifier separated by comma (,) or put this option
40 | # multiple time (only on the command line, not in the configuration file where
41 | # it should appear only once). See also the "--disable" option for examples.
42 | #enable=
43 |
44 | # Disable the message, report, category or checker with the given id(s). You
45 | # can either give multiple identifiers separated by comma (,) or put this
46 | # option multiple times (only on the command line, not in the configuration
47 | # file where it should appear only once).You can also use "--disable=all" to
48 | # disable everything first and then reenable specific checks. For example, if
49 | # you want to run only the similarities checker, you can use "--disable=all
50 | # --enable=similarities". If you want to run only the classes checker, but have
51 | # no Warning level messages displayed, use"--disable=all --enable=classes
52 | # --disable=W"
53 | disable=abstract-method,
54 | apply-builtin,
55 | arguments-differ,
56 | attribute-defined-outside-init,
57 | backtick,
58 | bad-option-value,
59 | basestring-builtin,
60 | buffer-builtin,
61 | c-extension-no-member,
62 | consider-using-enumerate,
63 | cmp-builtin,
64 | cmp-method,
65 | coerce-builtin,
66 | coerce-method,
67 | delslice-method,
68 | div-method,
69 | duplicate-code,
70 | eq-without-hash,
71 | execfile-builtin,
72 | file-builtin,
73 | filter-builtin-not-iterating,
74 | fixme,
75 | getslice-method,
76 | global-statement,
77 | hex-method,
78 | idiv-method,
79 | implicit-str-concat,
80 | import-error,
81 | import-self,
82 | import-star-module-level,
83 | inconsistent-return-statements,
84 | input-builtin,
85 | intern-builtin,
86 | invalid-str-codec,
87 | locally-disabled,
88 | long-builtin,
89 | long-suffix,
90 | map-builtin-not-iterating,
91 | misplaced-comparison-constant,
92 | missing-function-docstring,
93 | metaclass-assignment,
94 | next-method-called,
95 | next-method-defined,
96 | no-absolute-import,
97 | no-else-break,
98 | no-else-continue,
99 | no-else-raise,
100 | no-else-return,
101 | no-init, # added
102 | no-member,
103 | no-name-in-module,
104 | no-self-use,
105 | nonzero-method,
106 | oct-method,
107 | old-division,
108 | old-ne-operator,
109 | old-octal-literal,
110 | old-raise-syntax,
111 | parameter-unpacking,
112 | print-statement,
113 | raising-string,
114 | range-builtin-not-iterating,
115 | raw_input-builtin,
116 | rdiv-method,
117 | reduce-builtin,
118 | relative-import,
119 | reload-builtin,
120 | round-builtin,
121 | setslice-method,
122 | signature-differs,
123 | standarderror-builtin,
124 | suppressed-message,
125 | sys-max-int,
126 | too-few-public-methods,
127 | too-many-ancestors,
128 | too-many-arguments,
129 | too-many-boolean-expressions,
130 | too-many-branches,
131 | too-many-instance-attributes,
132 | too-many-locals,
133 | too-many-nested-blocks,
134 | too-many-public-methods,
135 | too-many-return-statements,
136 | too-many-statements,
137 | trailing-newlines,
138 | unichr-builtin,
139 | unicode-builtin,
140 | unnecessary-pass,
141 | unpacking-in-except,
142 | useless-else-on-loop,
143 | useless-object-inheritance,
144 | useless-suppression,
145 | using-cmp-argument,
146 | wrong-import-order,
147 | xrange-builtin,
148 | zip-builtin-not-iterating,
149 |
150 |
151 | [REPORTS]
152 |
153 | # Set the output format. Available formats are text, parseable, colorized, msvs
154 | # (visual studio) and html. You can also give a reporter class, eg
155 | # mypackage.mymodule.MyReporterClass.
156 | output-format=text
157 |
158 | # Tells whether to display a full report or only the messages
159 | reports=no
160 |
161 | # Python expression which should return a note less than 10 (10 is the highest
162 | # note). You have access to the variables errors warning, statement which
163 | # respectively contain the number of errors / warnings messages and the total
164 | # number of statements analyzed. This is used by the global evaluation report
165 | # (RP0004).
166 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
167 |
168 | # Template used to display messages. This is a python new-style format string
169 | # used to format the message information. See doc for all details
170 | #msg-template=
171 |
172 |
173 | [BASIC]
174 |
175 | # Good variable names which should always be accepted, separated by a comma
176 | good-names=main,_
177 |
178 | # Bad variable names which should always be refused, separated by a comma
179 | bad-names=
180 |
181 | # Colon-delimited sets of names that determine each other's naming style when
182 | # the name regexes allow several styles.
183 | name-group=
184 |
185 | # Include a hint for the correct naming format with invalid-name
186 | include-naming-hint=no
187 |
188 | # List of decorators that produce properties, such as abc.abstractproperty. Add
189 | # to this list to register other decorators that produce valid properties.
190 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
191 |
192 | # Regular expression matching correct function names
193 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
194 |
195 | # Regular expression matching correct variable names
196 | variable-rgx=^[a-z][a-z0-9_]*$
197 |
198 | # Regular expression matching correct constant names
199 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
200 |
201 | # Regular expression matching correct attribute names
202 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
203 |
204 | # Regular expression matching correct argument names
205 | argument-rgx=^[a-z][a-z0-9_]*$
206 |
207 | # Regular expression matching correct class attribute names
208 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
209 |
210 | # Regular expression matching correct inline iteration names
211 | inlinevar-rgx=^[a-z][a-z0-9_]*$
212 |
213 | # Regular expression matching correct class names
214 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$
215 |
216 | # Regular expression matching correct module names
217 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
218 |
219 | # Regular expression matching correct method names
220 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
221 |
222 | # Regular expression which should only match function or class names that do
223 | # not require a docstring.
224 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
225 |
226 | # Minimum line length for functions/classes that require docstrings, shorter
227 | # ones are exempt.
228 | docstring-min-length=10
229 |
230 |
231 | [TYPECHECK]
232 |
233 | # List of decorators that produce context managers, such as
234 | # contextlib.contextmanager. Add to this list to register other decorators that
235 | # produce valid context managers.
236 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
237 |
238 | # Tells whether missing members accessed in mixin class should be ignored. A
239 | # mixin class is detected if its name ends with "mixin" (case insensitive).
240 | ignore-mixin-members=yes
241 |
242 | # List of module names for which member attributes should not be checked
243 | # (useful for modules/projects where namespaces are manipulated during runtime
244 | # and thus existing member attributes cannot be deduced by static analysis. It
245 | # supports qualified module names, as well as Unix pattern matching.
246 | ignored-modules=
247 |
248 | # List of class names for which member attributes should not be checked (useful
249 | # for classes with dynamically set attributes). This supports the use of
250 | # qualified names.
251 | ignored-classes=optparse.Values,thread._local,_thread._local
252 |
253 | # List of members which are set dynamically and missed by pylint inference
254 | # system, and so shouldn't trigger E1101 when accessed. Python regular
255 | # expressions are accepted.
256 | generated-members=
257 |
258 |
259 | [FORMAT]
260 |
261 | # Maximum number of characters on a single line.
262 | max-line-length=80
263 |
264 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
265 | # lines made too long by directives to pytype.
266 |
267 | # Regexp for a line that is allowed to be longer than the limit.
268 | ignore-long-lines=(?x)(
269 | ^\s*(\#\ )??$|
270 | ^\s*(from\s+\S+\s+)?import\s+.+$)
271 |
272 | # Allow the body of an if to be on the same line as the test if there is no
273 | # else.
274 | single-line-if-stmt=yes
275 |
276 | # Maximum number of lines in a module
277 | max-module-lines=99999
278 |
279 | # String used as indentation unit. The internal Google style guide mandates 2
280 | # spaces. Google's externaly-published style guide says 4, consistent with
281 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
282 | # projects (like TensorFlow).
283 | indent-string=' '
284 |
285 | # Number of spaces of indent required inside a hanging or continued line.
286 | indent-after-paren=4
287 |
288 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
289 | expected-line-ending-format=
290 |
291 |
292 | [MISCELLANEOUS]
293 |
294 | # List of note tags to take in consideration, separated by a comma.
295 | notes=TODO
296 |
297 |
298 | [STRING]
299 |
300 | # This flag controls whether inconsistent-quotes generates a warning when the
301 | # character used as a quote delimiter is used inconsistently within a module.
302 | check-quote-consistency=yes
303 |
304 |
305 | [VARIABLES]
306 |
307 | # Tells whether we should check for unused import in __init__ files.
308 | init-import=no
309 |
310 | # A regular expression matching the name of dummy variables (i.e. expectedly
311 | # not used).
312 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
313 |
314 | # List of additional names supposed to be defined in builtins. Remember that
315 | # you should avoid to define new builtins when possible.
316 | additional-builtins=
317 |
318 | # List of strings which can identify a callback function by name. A callback
319 | # name must start or end with one of those strings.
320 | callbacks=cb_,_cb
321 |
322 | # List of qualified module names which can have objects that can redefine
323 | # builtins.
324 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
325 |
326 |
327 | [LOGGING]
328 |
329 | # Logging modules to check that the string format arguments are in logging
330 | # function parameter format
331 | logging-modules=logging,absl.logging,tensorflow.io.logging
332 |
333 |
334 | [SIMILARITIES]
335 |
336 | # Minimum lines number of a similarity.
337 | min-similarity-lines=4
338 |
339 | # Ignore comments when computing similarities.
340 | ignore-comments=yes
341 |
342 | # Ignore docstrings when computing similarities.
343 | ignore-docstrings=yes
344 |
345 | # Ignore imports when computing similarities.
346 | ignore-imports=no
347 |
348 |
349 | [SPELLING]
350 |
351 | # Spelling dictionary name. Available dictionaries: none. To make it working
352 | # install python-enchant package.
353 | spelling-dict=
354 |
355 | # List of comma separated words that should not be checked.
356 | spelling-ignore-words=
357 |
358 | # A path to a file that contains private dictionary; one word per line.
359 | spelling-private-dict-file=
360 |
361 | # Tells whether to store unknown words to indicated private dictionary in
362 | # --spelling-private-dict-file option instead of raising a message.
363 | spelling-store-unknown-words=no
364 |
365 |
366 | [IMPORTS]
367 |
368 | # Deprecated modules which should not be used, separated by a comma
369 | deprecated-modules=regsub,
370 | TERMIOS,
371 | Bastion,
372 | rexec,
373 | sets
374 |
375 | # Create a graph of every (i.e. internal and external) dependencies in the
376 | # given file (report RP0402 must not be disabled)
377 | import-graph=
378 |
379 | # Create a graph of external dependencies in the given file (report RP0402 must
380 | # not be disabled)
381 | ext-import-graph=
382 |
383 | # Create a graph of internal dependencies in the given file (report RP0402 must
384 | # not be disabled)
385 | int-import-graph=
386 |
387 | # Force import order to recognize a module as part of the standard
388 | # compatibility libraries.
389 | known-standard-library=
390 |
391 | # Force import order to recognize a module as part of a third party library.
392 | known-third-party=enchant, absl
393 |
394 | # Analyse import fallback blocks. This can be used to support both Python 2 and
395 | # 3 compatible code, which means that the block might have code that exists
396 | # only in one or another interpreter, leading to false positives when analysed.
397 | analyse-fallback-blocks=no
398 |
399 |
400 | [CLASSES]
401 |
402 | # List of method names used to declare (i.e. assign) instance attributes.
403 | defining-attr-methods=__init__,
404 | __new__,
405 | setUp
406 |
407 | # List of member names, which should be excluded from the protected access
408 | # warning.
409 | exclude-protected=_asdict,
410 | _fields,
411 | _replace,
412 | _source,
413 | _make
414 |
415 | # List of valid names for the first argument in a class method.
416 | valid-classmethod-first-arg=cls,
417 | class_
418 |
419 | # List of valid names for the first argument in a metaclass class method.
420 | valid-metaclass-classmethod-first-arg=mcs
421 |
422 |
423 | [EXCEPTIONS]
424 |
425 | # Exceptions that will emit a warning when being caught. Defaults to
426 | # "Exception"
427 | overgeneral-exceptions=builtins.StandardError,
428 | builtins.Exception,
429 | builtins.BaseException
430 |
--------------------------------------------------------------------------------
/mctx/_src/search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A JAX implementation of batched MCTS."""
16 | import functools
17 | from typing import Any, NamedTuple, Optional, Tuple, TypeVar
18 |
19 | import chex
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from mctx._src import action_selection, base
24 | from mctx._src import tree as tree_lib
25 |
26 | Tree = tree_lib.Tree
27 | T = TypeVar("T")
28 |
29 |
30 | def search(
31 | params: base.Params,
32 | rng_key: chex.PRNGKey,
33 | tree: Tree,
34 | *,
35 | recurrent_fn: base.RecurrentFn,
36 | root_action_selection_fn: base.RootActionSelectionFn,
37 | interior_action_selection_fn: base.InteriorActionSelectionFn,
38 | num_simulations: int,
39 | max_depth: Optional[int] = None,
40 | loop_fn: base.LoopFn = jax.lax.fori_loop) -> Tree:
41 | """Performs a full search and returns sampled actions.
42 |
43 | In the shape descriptions, `B` denotes the batch dimension.
44 |
45 | Args:
46 | params: params to be forwarded to root and recurrent functions.
47 | rng_key: random number generator state, the key is consumed.
48 | tree: the initialized MCTS tree state to search upon.
49 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
50 | actions retrieved by the simulation step, which takes as args
51 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
52 | and the new state embedding. The `rng_key` argument is consumed.
53 | root_action_selection_fn: function used to select an action at the root.
54 | interior_action_selection_fn: function used to select an action during
55 | simulation.
56 | num_simulations: the number of simulations.
57 | max_depth: maximum search tree depth allowed during simulation, defined as
58 | the number of edges from the root to a leaf node.
59 | loop_fn: Function used to run the simulations. It may be required to pass
60 | hk.fori_loop if using this function inside a Haiku module.
61 |
62 | Returns:
63 | `SearchResults` containing outcomes of the search, e.g. `visit_counts`
64 | `[B, num_actions]`.
65 | """
66 | action_selection_fn = action_selection.switching_action_selection_wrapper(
67 | root_action_selection_fn=root_action_selection_fn,
68 | interior_action_selection_fn=interior_action_selection_fn
69 | )
70 |
71 | # Do simulation, expansion, and backward steps.
72 | batch_size = tree.node_visits.shape[0]
73 | batch_range = jnp.arange(batch_size)
74 | tree_capacity = tree.children_visits.shape[1]
75 | if max_depth is None:
76 | max_depth = tree_capacity - 1
77 |
78 | def body_fun(_, loop_state):
79 | rng_key, tree = loop_state
80 | rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)
81 | # simulate is vmapped and expects batched rng keys.
82 | simulate_keys = jax.random.split(simulate_key, batch_size)
83 | parent_index, action = simulate(
84 | simulate_keys, tree, action_selection_fn, max_depth)
85 | next_node_index = tree.children_index[batch_range, parent_index, action]
86 | unvisited = next_node_index == Tree.UNVISITED
87 |
88 | next_node_index = jnp.where(unvisited,
89 | tree.next_node_index, next_node_index)
90 | tree = expand(
91 | params, expand_key, tree, recurrent_fn, parent_index,
92 | action, next_node_index)
93 | # if next_node_index goes out of bounds (i.e. no room left for new nodes)
94 | # backward its (in-bounds) parent
95 | out_of_bounds = next_node_index >= tree_capacity
96 | next_node_index = jnp.where(out_of_bounds,
97 | parent_index,
98 | next_node_index)
99 |
100 | tree = backward(tree, next_node_index)
101 | # increment next_node_index if leaf was expanded
102 | tree = tree.replace(next_node_index=tree.next_node_index + \
103 | jnp.logical_and(unvisited, (~out_of_bounds)))
104 | loop_state = rng_key, tree
105 | return loop_state
106 |
107 | _, tree = loop_fn(
108 | 0, num_simulations, body_fun, (rng_key, tree))
109 |
110 | return tree
111 |
112 |
113 | class _SimulationState(NamedTuple):
114 | """The state for the simulation while loop."""
115 | rng_key: chex.PRNGKey
116 | node_index: int
117 | action: int
118 | next_node_index: int
119 | depth: int
120 | is_continuing: bool
121 |
122 |
123 | @functools.partial(jax.vmap, in_axes=[0, 0, None, None], out_axes=0)
124 | def simulate(
125 | rng_key: chex.PRNGKey,
126 | tree: Tree,
127 | action_selection_fn: base.InteriorActionSelectionFn,
128 | max_depth: int) -> Tuple[chex.Array, chex.Array]:
129 | """Traverses the tree until reaching an unvisited action or `max_depth`.
130 |
131 | Each simulation starts from the root and keeps selecting actions traversing
132 | the tree until a leaf or `max_depth` is reached.
133 |
134 | Args:
135 | rng_key: random number generator state, the key is consumed.
136 | tree: _unbatched_ MCTS tree state.
137 | action_selection_fn: function used to select an action during simulation.
138 | max_depth: maximum search tree depth allowed during simulation.
139 |
140 | Returns:
141 | `(parent_index, action)` tuple, where `parent_index` is the index of the
142 | node reached at the end of the simulation, and the `action` is the action to
143 | evaluate from the `parent_index`.
144 | """
145 | def cond_fun(state):
146 | return state.is_continuing
147 |
148 | def body_fun(state):
149 | # Preparing the next simulation state.
150 | node_index = state.next_node_index
151 | rng_key, action_selection_key = jax.random.split(state.rng_key)
152 | action = action_selection_fn(action_selection_key, tree, node_index,
153 | state.depth)
154 | next_node_index = tree.children_index[node_index, action]
155 | # The returned action will be visited.
156 | depth = state.depth + 1
157 | is_before_depth_cutoff = depth < max_depth
158 | is_visited = next_node_index != Tree.UNVISITED
159 | is_continuing = jnp.logical_and(is_visited, is_before_depth_cutoff)
160 | return _SimulationState( # pytype: disable=wrong-arg-types # jax-types
161 | rng_key=rng_key,
162 | node_index=node_index,
163 | action=action,
164 | next_node_index=next_node_index,
165 | depth=depth,
166 | is_continuing=is_continuing)
167 |
168 | node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32)
169 | depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype)
170 | # pytype: disable=wrong-arg-types # jnp-type
171 | initial_state = _SimulationState(
172 | rng_key=rng_key,
173 | node_index=tree.NO_PARENT,
174 | action=tree.NO_PARENT,
175 | next_node_index=node_index,
176 | depth=depth,
177 | is_continuing=jnp.array(True))
178 | # pytype: enable=wrong-arg-types
179 | end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state)
180 |
181 | # Returning a node with a selected action.
182 | # The action can be already visited, if the max_depth is reached.
183 | return end_state.node_index, end_state.action
184 |
185 |
186 | def expand(
187 | params: chex.Array,
188 | rng_key: chex.PRNGKey,
189 | tree: Tree[T],
190 | recurrent_fn: base.RecurrentFn,
191 | parent_index: chex.Array,
192 | action: chex.Array,
193 | next_node_index: chex.Array) -> Tree[T]:
194 | """Create and evaluate child nodes from given nodes and unvisited actions.
195 |
196 | Args:
197 | params: params to be forwarded to recurrent function.
198 | rng_key: random number generator state.
199 | tree: the MCTS tree state to update.
200 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
201 | actions retrieved by the simulation step, which takes as args
202 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
203 | and the new state embedding. The `rng_key` argument is consumed.
204 | parent_index: the index of the parent node, from which the action will be
205 | expanded. Shape `[B]`.
206 | action: the action to expand. Shape `[B]`.
207 | next_node_index: the index of the newly expanded node. This can be the index
208 | of an existing node, if `max_depth` is reached. Shape `[B]`.
209 |
210 | Returns:
211 | tree: updated MCTS tree state.
212 | """
213 | batch_size = tree_lib.infer_batch_size(tree)
214 | batch_range = jnp.arange(batch_size)
215 | chex.assert_shape([parent_index, action, next_node_index], (batch_size,))
216 |
217 | # Retrieve states for nodes to be evaluated.
218 | embedding = jax.tree.map(
219 | lambda x: x[batch_range, parent_index], tree.embeddings)
220 |
221 | # Evaluate and create a new node.
222 | step, embedding = recurrent_fn(params, rng_key, action, embedding)
223 | chex.assert_shape(step.prior_logits, [batch_size, tree.num_actions])
224 | chex.assert_shape(step.reward, [batch_size])
225 | chex.assert_shape(step.discount, [batch_size])
226 | chex.assert_shape(step.value, [batch_size])
227 | tree = update_tree_node(
228 | tree, next_node_index, step.prior_logits, step.value, embedding)
229 |
230 | # handle out-of-bounds next_node_index,
231 | # (parent should still point to UNVISITED)
232 | next_node_index_check_oob = jnp.where(next_node_index > tree.num_simulations,
233 | tree.UNVISITED,
234 | next_node_index)
235 | # Return updated tree topology.
236 | return tree.replace(
237 | children_index=batch_update(
238 | tree.children_index, next_node_index_check_oob, parent_index, action),
239 | children_rewards=batch_update(
240 | tree.children_rewards, step.reward, parent_index, action),
241 | children_discounts=batch_update(
242 | tree.children_discounts, step.discount, parent_index, action),
243 | parents=batch_update(tree.parents, parent_index, next_node_index),
244 | action_from_parent=batch_update(
245 | tree.action_from_parent, action, next_node_index))
246 |
247 |
248 | @jax.vmap
249 | def backward(
250 | tree: Tree[T],
251 | leaf_index: chex.Numeric) -> Tree[T]:
252 | """Goes up and updates the tree until all nodes reached the root.
253 |
254 | Args:
255 | tree: the MCTS tree state to update, without the batch size.
256 | leaf_index: the node index from which to do the backward.
257 |
258 | Returns:
259 | Updated MCTS tree state.
260 | """
261 |
262 | def cond_fun(loop_state):
263 | _, _, index = loop_state
264 | return index != Tree.ROOT_INDEX
265 |
266 | def body_fun(loop_state):
267 | # Here we update the value of our parent, so we start by reversing.
268 | tree, leaf_value, index = loop_state
269 | parent = tree.parents[index]
270 | count = tree.node_visits[parent]
271 | action = tree.action_from_parent[index]
272 | reward = tree.children_rewards[parent, action]
273 | leaf_value = reward + tree.children_discounts[parent, action] * leaf_value
274 | parent_value = (
275 | tree.node_values[parent] * count + leaf_value) / (count + 1.0)
276 | children_values = tree.node_values[index]
277 | children_counts = tree.children_visits[parent, action] + 1
278 |
279 | tree = tree.replace(
280 | node_values=update(tree.node_values, parent_value, parent),
281 | node_visits=update(tree.node_visits, count + 1, parent),
282 | children_values=update(
283 | tree.children_values, children_values, parent, action),
284 | children_visits=update(
285 | tree.children_visits, children_counts, parent, action))
286 |
287 | return tree, leaf_value, parent
288 |
289 | leaf_index = jnp.asarray(leaf_index, dtype=jnp.int32)
290 | loop_state = (tree, tree.node_values[leaf_index], leaf_index)
291 | tree, _, _ = jax.lax.while_loop(cond_fun, body_fun, loop_state)
292 | return tree
293 |
294 |
295 | # Utility function to set the values of certain indices to prescribed values.
296 | # This is vmapped to operate seamlessly on batches.
297 | def update(x, vals, *indices):
298 | return x.at[indices].set(vals)
299 |
300 |
301 | batch_update = jax.vmap(update)
302 |
303 |
304 | def update_tree_node(
305 | tree: Tree[T],
306 | node_index: chex.Array,
307 | prior_logits: chex.Array,
308 | value: chex.Array,
309 | embedding: chex.Array) -> Tree[T]:
310 | """Updates the tree at node index.
311 |
312 | Args:
313 | tree: `Tree` to whose node is to be updated.
314 | node_index: the index of the expanded node. Shape `[B]`.
315 | prior_logits: the prior logits to fill in for the new node, of shape
316 | `[B, num_actions]`.
317 | value: the value to fill in for the new node. Shape `[B]`.
318 | embedding: the state embeddings for the node. Shape `[B, ...]`.
319 |
320 | Returns:
321 | The new tree with updated nodes.
322 | """
323 | batch_size = tree_lib.infer_batch_size(tree)
324 | batch_range = jnp.arange(batch_size)
325 | chex.assert_shape(prior_logits, (batch_size, tree.num_actions))
326 |
327 | # When using max_depth, a leaf can be expanded multiple times.
328 | new_visit = tree.node_visits[batch_range, node_index] + 1
329 | updates = dict( # pylint: disable=use-dict-literal
330 | children_prior_logits=batch_update(
331 | tree.children_prior_logits, prior_logits, node_index),
332 | raw_values=batch_update(
333 | tree.raw_values, value, node_index),
334 | node_values=batch_update(
335 | tree.node_values, value, node_index),
336 | node_visits=batch_update(
337 | tree.node_visits, new_visit, node_index),
338 | embeddings=jax.tree.map(
339 | lambda t, s: batch_update(t, s, node_index),
340 | tree.embeddings, embedding))
341 |
342 | return tree.replace(**updates)
343 |
344 |
345 | def instantiate_tree_from_root(
346 | root: base.RootFnOutput,
347 | num_nodes: int,
348 | extra_data: Any,
349 | root_invalid_actions: Optional[chex.Array] = None) -> Tree:
350 | """Initializes tree state at search root."""
351 | chex.assert_rank(root.prior_logits, 2)
352 | batch_size, num_actions = root.prior_logits.shape
353 | chex.assert_shape(root.value, [batch_size])
354 |
355 | data_dtype = root.value.dtype
356 | batch_node = (batch_size, num_nodes)
357 | batch_node_action = (batch_size, num_nodes, num_actions)
358 |
359 | if root_invalid_actions is None:
360 | root_invalid_actions = jnp.zeros_like(root.prior_logits)
361 |
362 | def _zeros(x):
363 | return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype)
364 |
365 | # Create a new empty tree state and fill its root.
366 | tree = Tree(
367 | node_visits=jnp.zeros(batch_node, dtype=jnp.int32),
368 | raw_values=jnp.zeros(batch_node, dtype=data_dtype),
369 | node_values=jnp.zeros(batch_node, dtype=data_dtype),
370 | parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32),
371 | action_from_parent=jnp.full(
372 | batch_node, Tree.NO_PARENT, dtype=jnp.int32),
373 | children_index=jnp.full(
374 | batch_node_action, Tree.UNVISITED, dtype=jnp.int32),
375 | children_prior_logits=jnp.zeros(
376 | batch_node_action, dtype=root.prior_logits.dtype),
377 | children_values=jnp.zeros(batch_node_action, dtype=data_dtype),
378 | children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
379 | children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
380 | children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
381 | next_node_index=jnp.ones((batch_size,), dtype=jnp.int32),
382 | embeddings=jax.tree_util.tree_map(_zeros, root.embedding),
383 | root_invalid_actions=root_invalid_actions,
384 | extra_data=extra_data)
385 |
386 | root_index = jnp.full([batch_size], Tree.ROOT_INDEX)
387 | tree = update_tree_node(
388 | tree, root_index, root.prior_logits, root.value, root.embedding)
389 | return tree
390 |
391 |
392 | def update_tree_with_root(
393 | tree: Tree,
394 | root: base.RootFnOutput,
395 | extra_data: Any,
396 | root_invalid_actions: Optional[chex.Array] = None) -> Tree:
397 | """Given a tree, updates its root with the given root output
398 | if it's not already populated."""
399 | root_uninitialized = tree.node_visits[:, Tree.ROOT_INDEX] == 0
400 | batch_size = tree_lib.infer_batch_size(tree)
401 | root_index = jnp.full([batch_size], Tree.ROOT_INDEX)
402 | ones = jnp.ones([batch_size], dtype=jnp.int32)
403 |
404 | if root_invalid_actions is None:
405 | root_invalid_actions = jnp.zeros_like(root.prior_logits)
406 |
407 | updates = dict( # pylint: disable=use-dict-literal
408 | children_prior_logits=batch_update(
409 | tree.children_prior_logits, root.prior_logits, root_index),
410 | raw_values=batch_update(
411 | tree.raw_values, root.value, root_index),
412 | node_values=jnp.where(
413 | root_uninitialized[..., None],
414 | batch_update(tree.node_values, root.value, root_index),
415 | tree.node_values),
416 | node_visits=jnp.where(
417 | root_uninitialized[..., None],
418 | batch_update(tree.node_visits, ones, root_index),
419 | tree.node_visits),
420 | embeddings=jax.tree_util.tree_map(
421 | lambda t, s: batch_update(t, s, root_index),
422 | tree.embeddings, root.embedding),
423 | root_invalid_actions=root_invalid_actions,
424 | extra_data=extra_data)
425 |
426 | return tree.replace(**updates)
427 |
--------------------------------------------------------------------------------
/mctx/_src/policies.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Search policies."""
16 | import functools
17 | from typing import Optional, Tuple
18 |
19 | import chex
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from mctx._src import action_selection
24 | from mctx._src import base
25 | from mctx._src import qtransforms
26 | from mctx._src import search
27 | from mctx._src import seq_halving
28 |
29 |
30 | def muzero_policy(
31 | params: base.Params,
32 | rng_key: chex.PRNGKey,
33 | root: base.RootFnOutput,
34 | recurrent_fn: base.RecurrentFn,
35 | num_simulations: int,
36 | invalid_actions: Optional[chex.Array] = None,
37 | max_depth: Optional[int] = None,
38 | loop_fn: base.LoopFn = jax.lax.fori_loop,
39 | *,
40 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
41 | dirichlet_fraction: chex.Numeric = 0.25,
42 | dirichlet_alpha: chex.Numeric = 0.3,
43 | pb_c_init: chex.Numeric = 1.25,
44 | pb_c_base: chex.Numeric = 19652,
45 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]:
46 | """Runs MuZero search and returns the `PolicyOutput`.
47 |
48 | In the shape descriptions, `B` denotes the batch dimension.
49 |
50 | Args:
51 | params: params to be forwarded to root and recurrent functions.
52 | rng_key: random number generator state, the key is consumed.
53 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
54 | `prior_logits` are from a policy network. The shapes are
55 | `([B, num_actions], [B], [B, ...])`, respectively.
56 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
57 | actions retrieved by the simulation step, which takes as args
58 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
59 | and the new state embedding. The `rng_key` argument is consumed.
60 | num_simulations: the number of simulations.
61 | invalid_actions: a mask with invalid actions. Invalid actions
62 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
63 | max_depth: maximum search tree depth allowed during simulation.
64 | loop_fn: Function used to run the simulations. It may be required to pass
65 | hk.fori_loop if using this function inside a Haiku module.
66 | qtransform: function to obtain completed Q-values for a node.
67 | dirichlet_fraction: float from 0 to 1 interpolating between using only the
68 | prior policy or just the Dirichlet noise.
69 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet
70 | distribution.
71 | pb_c_init: constant c_1 in the PUCT formula.
72 | pb_c_base: constant c_2 in the PUCT formula.
73 | temperature: temperature for acting proportionally to
74 | `visit_counts**(1 / temperature)`.
75 |
76 | Returns:
77 | `PolicyOutput` containing the proposed action, action_weights and the used
78 | search tree.
79 | """
80 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)
81 |
82 | # Adding Dirichlet noise.
83 | noisy_logits = _get_logits_from_probs(
84 | _add_dirichlet_noise(
85 | dirichlet_rng_key,
86 | jax.nn.softmax(root.prior_logits),
87 | dirichlet_fraction=dirichlet_fraction,
88 | dirichlet_alpha=dirichlet_alpha))
89 | root = root.replace(
90 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions))
91 |
92 | # Running the search.
93 | interior_action_selection_fn = functools.partial(
94 | action_selection.muzero_action_selection,
95 | pb_c_base=pb_c_base,
96 | pb_c_init=pb_c_init,
97 | qtransform=qtransform)
98 | root_action_selection_fn = functools.partial(
99 | interior_action_selection_fn,
100 | depth=0)
101 | search_tree = search.search(
102 | params=params,
103 | rng_key=search_rng_key,
104 | tree=search.instantiate_tree_from_root(
105 | root, num_simulations+1,
106 | root_invalid_actions=invalid_actions,
107 | extra_data=None
108 | ),
109 | recurrent_fn=recurrent_fn,
110 | root_action_selection_fn=root_action_selection_fn,
111 | interior_action_selection_fn=interior_action_selection_fn,
112 | num_simulations=num_simulations,
113 | max_depth=max_depth,
114 | loop_fn=loop_fn)
115 |
116 | # Sampling the proposed action proportionally to the visit counts.
117 | summary = search_tree.summary()
118 | action_weights = summary.visit_probs
119 | action_logits = _apply_temperature(
120 | _get_logits_from_probs(action_weights), temperature)
121 | action = jax.random.categorical(rng_key, action_logits)
122 | return base.PolicyOutput(
123 | action=action,
124 | action_weights=action_weights,
125 | search_tree=search_tree)
126 |
127 |
128 | def gumbel_muzero_policy(
129 | params: base.Params,
130 | rng_key: chex.PRNGKey,
131 | root: base.RootFnOutput,
132 | recurrent_fn: base.RecurrentFn,
133 | num_simulations: int,
134 | invalid_actions: Optional[chex.Array] = None,
135 | max_depth: Optional[int] = None,
136 | loop_fn: base.LoopFn = jax.lax.fori_loop,
137 | *,
138 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
139 | max_num_considered_actions: int = 16,
140 | gumbel_scale: chex.Numeric = 1.,
141 | ) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]:
142 | """Runs Gumbel MuZero search and returns the `PolicyOutput`.
143 |
144 | This policy implements Full Gumbel MuZero from
145 | "Policy improvement by planning with Gumbel".
146 | https://openreview.net/forum?id=bERaNdoegnO
147 |
148 | At the root of the search tree, actions are selected by Sequential Halving
149 | with Gumbel. At non-root nodes (aka interior nodes), actions are selected by
150 | the Full Gumbel MuZero deterministic action selection.
151 |
152 | In the shape descriptions, `B` denotes the batch dimension.
153 |
154 | Args:
155 | params: params to be forwarded to root and recurrent functions.
156 | rng_key: random number generator state, the key is consumed.
157 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
158 | `prior_logits` are from a policy network. The shapes are
159 | `([B, num_actions], [B], [B, ...])`, respectively.
160 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
161 | actions retrieved by the simulation step, which takes as args
162 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
163 | and the new state embedding. The `rng_key` argument is consumed.
164 | num_simulations: the number of simulations.
165 | invalid_actions: a mask with invalid actions. Invalid actions
166 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
167 | max_depth: maximum search tree depth allowed during simulation.
168 | loop_fn: Function used to run the simulations. It may be required to pass
169 | hk.fori_loop if using this function inside a Haiku module.
170 | qtransform: function to obtain completed Q-values for a node.
171 | max_num_considered_actions: the maximum number of actions expanded at the
172 | root node. A smaller number of actions will be expanded if the number of
173 | valid actions is smaller.
174 | gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information
175 | games can use gumbel_scale=0.0.
176 |
177 | Returns:
178 | `PolicyOutput` containing the proposed action, action_weights and the used
179 | search tree.
180 | """
181 | # Masking invalid actions.
182 | root = root.replace(
183 | prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions))
184 |
185 | # Generating Gumbel.
186 | rng_key, gumbel_rng = jax.random.split(rng_key)
187 | gumbel = gumbel_scale * jax.random.gumbel(
188 | gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype)
189 |
190 |
191 | extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)
192 | # Searching.
193 | search_tree = search.instantiate_tree_from_root(
194 | root, num_simulations+1,
195 | root_invalid_actions=invalid_actions,
196 | extra_data=extra_data)
197 | search_tree = search.search(
198 | params=params,
199 | rng_key=rng_key,
200 | tree=search.instantiate_tree_from_root(
201 | root, num_simulations+1,
202 | root_invalid_actions=invalid_actions,
203 | extra_data=extra_data
204 | ),
205 | recurrent_fn=recurrent_fn,
206 | root_action_selection_fn=functools.partial(
207 | action_selection.gumbel_muzero_root_action_selection,
208 | num_simulations=num_simulations,
209 | max_num_considered_actions=max_num_considered_actions,
210 | qtransform=qtransform,
211 | ),
212 | interior_action_selection_fn=functools.partial(
213 | action_selection.gumbel_muzero_interior_action_selection,
214 | qtransform=qtransform,
215 | ),
216 | num_simulations=num_simulations,
217 | max_depth=max_depth,
218 | loop_fn=loop_fn)
219 | summary = search_tree.summary()
220 |
221 | # Acting with the best action from the most visited actions.
222 | # The "best" action has the highest `gumbel + logits + q`.
223 | # Inside the minibatch, the considered_visit can be different on states with
224 | # a smaller number of valid actions.
225 | considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
226 | # The completed_qvalues include imputed values for unvisited actions.
227 | completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])( # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long
228 | search_tree, search_tree.ROOT_INDEX)
229 | to_argmax = seq_halving.score_considered(
230 | considered_visit, gumbel, root.prior_logits, completed_qvalues,
231 | summary.visit_counts)
232 | action = action_selection.masked_argmax(to_argmax, invalid_actions)
233 |
234 | # Producing action_weights usable to train the policy network.
235 | completed_search_logits = _mask_invalid_actions(
236 | root.prior_logits + completed_qvalues, invalid_actions)
237 | action_weights = jax.nn.softmax(completed_search_logits)
238 | return base.PolicyOutput(
239 | action=action,
240 | action_weights=action_weights,
241 | search_tree=search_tree)
242 |
243 |
244 | def alphazero_policy(
245 | params: base.Params,
246 | rng_key: chex.PRNGKey,
247 | root: base.RootFnOutput,
248 | recurrent_fn: base.RecurrentFn,
249 | num_simulations: int,
250 | search_tree: Optional[search.Tree] = None,
251 | max_nodes: Optional[int] = None,
252 | invalid_actions: Optional[chex.Array] = None,
253 | max_depth: Optional[int] = None,
254 | loop_fn: base.LoopFn = jax.lax.fori_loop,
255 | *,
256 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
257 | dirichlet_fraction: chex.Numeric = 0.25,
258 | dirichlet_alpha: chex.Numeric = 0.3,
259 | pb_c_init: chex.Numeric = 1.25,
260 | pb_c_base: chex.Numeric = 19652,
261 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]:
262 | """Runs AlphaZero search and returns the `PolicyOutput`.
263 |
264 | Allows for continuing search from a provided `Tree` instance.
265 | If no tree is provided, a new tree with capacity `max_nodes` is created.
266 | This policy is otherwise identical to `muzero_policy`.
267 |
268 | In the shape descriptions, `B` denotes the batch dimension.
269 |
270 | Args:
271 | params: params to be forwarded to root and recurrent functions.
272 | rng_key: random number generator state, the key is consumed.
273 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
274 | `prior_logits` are from a policy network. The shapes are
275 | `([B, num_actions], [B], [B, ...])`, respectively.
276 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
277 | actions retrieved by the simulation step, which takes as args
278 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
279 | and the new state embedding. The `rng_key` argument is consumed.
280 | num_simulations: the number of simulations.
281 | search_tree: If provided, continue the search from this tree. If not
282 | provided, a new tree with capacity `max_nodes` is created.
283 | max_nodes: Specifies the capacity to initialize the search tree with, if
284 | `search_tree` is not provided.
285 | invalid_actions: a mask with invalid actions. Invalid actions
286 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
287 | max_depth: maximum search tree depth allowed during simulation.
288 | loop_fn: Function used to run the simulations. It may be required to pass
289 | hk.fori_loop if using this function inside a Haiku module.
290 | qtransform: function to obtain completed Q-values for a node.
291 | dirichlet_fraction: float from 0 to 1 interpolating between using only the
292 | prior policy or just the Dirichlet noise.
293 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet
294 | distribution.
295 | pb_c_init: constant c_1 in the PUCT formula.
296 | pb_c_base: constant c_2 in the PUCT formula.
297 | temperature: temperature for acting proportionally to
298 | `visit_counts**(1 / temperature)`.
299 |
300 | Returns:
301 | `PolicyOutput` containing the proposed action, action_weights and the used
302 | search tree.
303 | """
304 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)
305 |
306 | # Adding Dirichlet noise.
307 | noisy_logits = _get_logits_from_probs(
308 | _add_dirichlet_noise(
309 | dirichlet_rng_key,
310 | jax.nn.softmax(root.prior_logits),
311 | dirichlet_fraction=dirichlet_fraction,
312 | dirichlet_alpha=dirichlet_alpha))
313 | root = root.replace(
314 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions))
315 |
316 | # Running the search.
317 | interior_action_selection_fn = functools.partial(
318 | action_selection.muzero_action_selection,
319 | pb_c_base=pb_c_base,
320 | pb_c_init=pb_c_init,
321 | qtransform=qtransform)
322 | root_action_selection_fn = functools.partial(
323 | interior_action_selection_fn,
324 | depth=0)
325 |
326 | if search_tree is None:
327 | if max_nodes is None:
328 | max_nodes = num_simulations + 1
329 | search_tree = search.instantiate_tree_from_root(
330 | root, max_nodes,
331 | root_invalid_actions=invalid_actions,
332 | extra_data=None
333 | )
334 | else:
335 | search_tree = search.update_tree_with_root(
336 | search_tree, root,
337 | root_invalid_actions=invalid_actions, extra_data=None)
338 |
339 | search_tree = search.search(
340 | params=params,
341 | rng_key=search_rng_key,
342 | tree=search_tree,
343 | recurrent_fn=recurrent_fn,
344 | root_action_selection_fn=root_action_selection_fn,
345 | interior_action_selection_fn=interior_action_selection_fn,
346 | num_simulations=num_simulations,
347 | max_depth=max_depth,
348 | loop_fn=loop_fn)
349 |
350 | # Sampling the proposed action proportionally to the visit counts.
351 | summary = search_tree.summary()
352 | action_weights = summary.visit_probs
353 | action_logits = _apply_temperature(
354 | _get_logits_from_probs(action_weights), temperature)
355 | action = jax.random.categorical(rng_key, action_logits)
356 | return base.PolicyOutput(
357 | action=action,
358 | action_weights=action_weights,
359 | search_tree=search_tree)
360 |
361 |
362 | def stochastic_muzero_policy(
363 | params: chex.ArrayTree,
364 | rng_key: chex.PRNGKey,
365 | root: base.RootFnOutput,
366 | decision_recurrent_fn: base.DecisionRecurrentFn,
367 | chance_recurrent_fn: base.ChanceRecurrentFn,
368 | num_simulations: int,
369 | invalid_actions: Optional[chex.Array] = None,
370 | max_depth: Optional[int] = None,
371 | loop_fn: base.LoopFn = jax.lax.fori_loop,
372 | *,
373 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
374 | dirichlet_fraction: chex.Numeric = 0.25,
375 | dirichlet_alpha: chex.Numeric = 0.3,
376 | pb_c_init: chex.Numeric = 1.25,
377 | pb_c_base: chex.Numeric = 19652,
378 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]:
379 | """Runs Stochastic MuZero search.
380 |
381 | Implements search as described in the Stochastic MuZero paper:
382 | (https://openreview.net/forum?id=X6D9bAHhBQ1).
383 |
384 | In the shape descriptions, `B` denotes the batch dimension.
385 | Args:
386 | params: params to be forwarded to root and recurrent functions.
387 | rng_key: random number generator state, the key is consumed.
388 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
389 | `prior_logits` are from a policy network. The shapes are `([B,
390 | num_actions], [B], [B, ...])`, respectively.
391 | decision_recurrent_fn: a callable to be called on the leaf decision nodes
392 | and unvisited actions retrieved by the simulation step, which takes as
393 | args `(params, rng_key, action, state_embedding)` and returns a
394 | `(DecisionRecurrentFnOutput, afterstate_embedding)`.
395 | chance_recurrent_fn: a callable to be called on the leaf chance nodes and
396 | unvisited actions retrieved by the simulation step, which takes as args
397 | `(params, rng_key, chance_outcome, afterstate_embedding)` and returns a
398 | `(ChanceRecurrentFnOutput, state_embedding)`.
399 | num_simulations: the number of simulations.
400 | invalid_actions: a mask with invalid actions. Invalid actions have ones,
401 | valid actions have zeros in the mask. Shape `[B, num_actions]`.
402 | max_depth: maximum search tree depth allowed during simulation.
403 | loop_fn: Function used to run the simulations. It may be required to pass
404 | hk.fori_loop if using this function inside a Haiku module.
405 | qtransform: function to obtain completed Q-values for a node.
406 | dirichlet_fraction: float from 0 to 1 interpolating between using only the
407 | prior policy or just the Dirichlet noise.
408 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet
409 | distribution.
410 | pb_c_init: constant c_1 in the PUCT formula.
411 | pb_c_base: constant c_2 in the PUCT formula.
412 | temperature: temperature for acting proportionally to `visit_counts**(1 /
413 | temperature)`.
414 |
415 | Returns:
416 | `PolicyOutput` containing the proposed action, action_weights and the used
417 | search tree.
418 | """
419 |
420 | num_actions = root.prior_logits.shape[-1]
421 |
422 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)
423 |
424 | # Adding Dirichlet noise.
425 | noisy_logits = _get_logits_from_probs(
426 | _add_dirichlet_noise(
427 | dirichlet_rng_key,
428 | jax.nn.softmax(root.prior_logits),
429 | dirichlet_fraction=dirichlet_fraction,
430 | dirichlet_alpha=dirichlet_alpha))
431 |
432 | root = root.replace(
433 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions))
434 |
435 | # construct a dummy afterstate embedding
436 | batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0]
437 | dummy_action = jnp.zeros([batch_size], dtype=jnp.int32)
438 | dummy_output, dummy_afterstate_embedding = decision_recurrent_fn(
439 | params, rng_key, dummy_action, root.embedding)
440 | num_chance_outcomes = dummy_output.chance_logits.shape[-1]
441 |
442 | root = root.replace(
443 | # pad action logits with num_chance_outcomes so dim is A + C
444 | prior_logits=jnp.concatenate([
445 | root.prior_logits,
446 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf)
447 | ], axis=-1),
448 | # replace embedding with wrapper.
449 | embedding=base.StochasticRecurrentState(
450 | state_embedding=root.embedding,
451 | afterstate_embedding=dummy_afterstate_embedding,
452 | is_decision_node=jnp.ones([batch_size], dtype=bool)))
453 |
454 | # Stochastic MuZero Change: We need to be able to tell if different nodes are
455 | # decision or chance. This is accomplished by imposing a special structure
456 | # on the embeddings stored in each node. Each embedding is an instance of
457 | # StochasticRecurrentState which maintains this information.
458 | recurrent_fn = _make_stochastic_recurrent_fn(
459 | decision_node_fn=decision_recurrent_fn,
460 | chance_node_fn=chance_recurrent_fn,
461 | num_actions=num_actions,
462 | num_chance_outcomes=num_chance_outcomes,
463 | )
464 |
465 | # Running the search.
466 |
467 | interior_decision_node_selection_fn = functools.partial(
468 | action_selection.muzero_action_selection,
469 | pb_c_base=pb_c_base,
470 | pb_c_init=pb_c_init,
471 | qtransform=qtransform)
472 |
473 | interior_action_selection_fn = _make_stochastic_action_selection_fn(
474 | interior_decision_node_selection_fn, num_actions)
475 |
476 | root_action_selection_fn = functools.partial(
477 | interior_action_selection_fn, depth=0)
478 |
479 | search_tree = search.instantiate_tree_from_root(
480 | root, num_simulations+1,
481 | root_invalid_actions=invalid_actions,
482 | extra_data=None)
483 | search_tree = search.search(
484 | params=params,
485 | rng_key=search_rng_key,
486 | tree=search_tree,
487 | recurrent_fn=recurrent_fn,
488 | root_action_selection_fn=root_action_selection_fn,
489 | interior_action_selection_fn=interior_action_selection_fn,
490 | num_simulations=num_simulations,
491 | max_depth=max_depth,
492 | loop_fn=loop_fn)
493 |
494 | # Sampling the proposed action proportionally to the visit counts.
495 | search_tree = _mask_tree(search_tree, num_actions, 'decision')
496 | summary = search_tree.summary()
497 | action_weights = summary.visit_probs
498 | action_logits = _apply_temperature(
499 | _get_logits_from_probs(action_weights), temperature)
500 | action = jax.random.categorical(rng_key, action_logits)
501 | return base.PolicyOutput(
502 | action=action, action_weights=action_weights, search_tree=search_tree)
503 |
504 |
505 | def _mask_invalid_actions(logits, invalid_actions):
506 | """Returns logits with zero mass to invalid actions."""
507 | if invalid_actions is None:
508 | return logits
509 | chex.assert_equal_shape([logits, invalid_actions])
510 | logits = logits - jnp.max(logits, axis=-1, keepdims=True)
511 | # At the end of an episode, all actions can be invalid. A softmax would then
512 | # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
513 | # a finite `min_logit` for the invalid actions.
514 | min_logit = jnp.finfo(logits.dtype).min
515 | return jnp.where(invalid_actions, min_logit, logits)
516 |
517 |
518 | def _get_logits_from_probs(probs):
519 | tiny = jnp.finfo(probs.dtype).tiny
520 | return jnp.log(jnp.maximum(probs, tiny))
521 |
522 |
523 | def _add_dirichlet_noise(rng_key, probs, *, dirichlet_alpha,
524 | dirichlet_fraction):
525 | """Mixes the probs with Dirichlet noise."""
526 | chex.assert_rank(probs, 2)
527 | chex.assert_type([dirichlet_alpha, dirichlet_fraction], float)
528 |
529 | batch_size, num_actions = probs.shape
530 | noise = jax.random.dirichlet(
531 | rng_key,
532 | alpha=jnp.full([num_actions], fill_value=dirichlet_alpha),
533 | shape=(batch_size,))
534 | noisy_probs = (1 - dirichlet_fraction) * probs + dirichlet_fraction * noise
535 | return noisy_probs
536 |
537 |
538 | def _apply_temperature(logits, temperature):
539 | """Returns `logits / temperature`, supporting also temperature=0."""
540 | # The max subtraction prevents +inf after dividing by a small temperature.
541 | logits = logits - jnp.max(logits, keepdims=True, axis=-1)
542 | tiny = jnp.finfo(logits.dtype).tiny
543 | return logits / jnp.maximum(tiny, temperature)
544 |
545 |
546 | def _make_stochastic_recurrent_fn(
547 | decision_node_fn: base.DecisionRecurrentFn,
548 | chance_node_fn: base.ChanceRecurrentFn,
549 | num_actions: int,
550 | num_chance_outcomes: int,
551 | ) -> base.RecurrentFn:
552 | """Make Stochastic Recurrent Fn."""
553 |
554 | def stochastic_recurrent_fn(
555 | params: base.Params,
556 | rng: chex.PRNGKey,
557 | action_or_chance: base.Action, # [B]
558 | state: base.StochasticRecurrentState
559 | ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentState]:
560 | batch_size = jax.tree_util.tree_leaves(state.state_embedding)[0].shape[0]
561 | # Internally we assume that there are `A' = A + C` "actions";
562 | # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,.
563 | # To interpret it as an action we can leave it as is:
564 | action = action_or_chance - 0
565 | # To interpret it as a chance outcome we subtract num_actions:
566 | chance_outcome = action_or_chance - num_actions
567 |
568 | decision_output, afterstate_embedding = decision_node_fn(
569 | params, rng, action, state.state_embedding)
570 | # Outputs from DecisionRecurrentFunction produce chance logits with
571 | # dim `C`, to respect our internal convention that there are `A' = A + C`
572 | # "actions" we pad with `A` dummy logits which are ultimately ignored:
573 | # see `_mask_tree`.
574 | output_if_decision_node = base.RecurrentFnOutput(
575 | prior_logits=jnp.concatenate([
576 | jnp.full([batch_size, num_actions], fill_value=-jnp.inf),
577 | decision_output.chance_logits], axis=-1),
578 | value=decision_output.afterstate_value,
579 | reward=jnp.zeros_like(decision_output.afterstate_value),
580 | discount=jnp.ones_like(decision_output.afterstate_value))
581 |
582 | chance_output, state_embedding = chance_node_fn(params, rng, chance_outcome,
583 | state.afterstate_embedding)
584 | # Outputs from ChanceRecurrentFunction produce action logits with dim `A`,
585 | # to respect our internal convention that there are `A' = A + C` "actions"
586 | # we pad with `C` dummy logits which are ultimately ignored: see
587 | # `_mask_tree`.
588 | output_if_chance_node = base.RecurrentFnOutput(
589 | prior_logits=jnp.concatenate([
590 | chance_output.action_logits,
591 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf)
592 | ], axis=-1),
593 | value=chance_output.value,
594 | reward=chance_output.reward,
595 | discount=chance_output.discount)
596 |
597 | new_state = base.StochasticRecurrentState(
598 | state_embedding=state_embedding,
599 | afterstate_embedding=afterstate_embedding,
600 | is_decision_node=jnp.logical_not(state.is_decision_node))
601 |
602 | def _broadcast_where(decision_leaf, chance_leaf):
603 | extra_dims = [1] * (len(decision_leaf.shape) - 1)
604 | expanded_is_decision = jnp.reshape(state.is_decision_node,
605 | [-1] + extra_dims)
606 | return jnp.where(
607 | # ensure state.is_decision node has appropriate shape.
608 | expanded_is_decision,
609 | decision_leaf, chance_leaf)
610 |
611 | output = jax.tree.map(_broadcast_where,
612 | output_if_decision_node,
613 | output_if_chance_node)
614 | return output, new_state
615 |
616 | return stochastic_recurrent_fn
617 |
618 |
619 | def _mask_tree(tree: search.Tree, num_actions: int, mode: str) -> search.Tree:
620 | """Masks out parts of the tree based upon node type.
621 |
622 | "Actions" in our tree can either be action or chance values: A' = A + C. This
623 | utility function masks the parts of the tree containing dimensions of shape
624 | A' to be either A or C depending upon `mode`.
625 |
626 | Args:
627 | tree: The tree to be masked.
628 | num_actions: The number of environment actions A.
629 | mode: Either "decision" or "chance".
630 |
631 | Returns:
632 | An appropriately masked tree.
633 | """
634 |
635 | def _take_slice(x):
636 | if mode == 'decision':
637 | return x[..., :num_actions]
638 | elif mode == 'chance':
639 | return x[..., num_actions:]
640 | else:
641 | raise ValueError(f'Unknown mode: {mode}.')
642 |
643 | return tree.replace(
644 | children_index=_take_slice(tree.children_index),
645 | children_prior_logits=_take_slice(tree.children_prior_logits),
646 | children_visits=_take_slice(tree.children_visits),
647 | children_rewards=_take_slice(tree.children_rewards),
648 | children_discounts=_take_slice(tree.children_discounts),
649 | children_values=_take_slice(tree.children_values),
650 | root_invalid_actions=_take_slice(tree.root_invalid_actions))
651 |
652 |
653 | def _make_stochastic_action_selection_fn(
654 | decision_node_selection_fn: base.InteriorActionSelectionFn,
655 | num_actions: int,
656 | ) -> base.InteriorActionSelectionFn:
657 | """Make Stochastic Action Selection Fn."""
658 |
659 | # NOTE: trees are unbatched here.
660 |
661 | def _chance_node_selection_fn(
662 | tree: search.Tree,
663 | node_index: chex.Array,
664 | ) -> chex.Array:
665 | num_chance = tree.children_visits[node_index]
666 | chance_logits = tree.children_prior_logits[node_index]
667 | prob_chance = jax.nn.softmax(chance_logits)
668 | argmax_chance = jnp.argmax(prob_chance / (num_chance + 1), axis=-1).astype(
669 | jnp.int32
670 | )
671 | return argmax_chance
672 |
673 | def _action_selection_fn(key: chex.PRNGKey, tree: search.Tree,
674 | node_index: chex.Array,
675 | depth: chex.Array) -> chex.Array:
676 | is_decision = tree.embeddings.is_decision_node[node_index]
677 | chance_selection = _chance_node_selection_fn(
678 | tree=_mask_tree(tree, num_actions, 'chance'),
679 | node_index=node_index) + num_actions
680 | decision_selection = decision_node_selection_fn(
681 | key, _mask_tree(tree, num_actions, 'decision'), node_index, depth)
682 | return jax.lax.cond(is_decision, lambda: decision_selection,
683 | lambda: chance_selection)
684 |
685 | return _action_selection_fn
686 |
--------------------------------------------------------------------------------
/connect4.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# MCTS in MCTX\n",
9 | "\n",
10 | "In this example, we will use the `mctx` library to play the game of Connect 4.\n",
11 | "We will implement the Monte Carlo Tree Search (MCTS) algorithm with random rollouts."
12 | ]
13 | },
14 | {
15 | "attachments": {},
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "## Game mechanics"
20 | ]
21 | },
22 | {
23 | "attachments": {},
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "Let's start by defining some type aliases to make the code more readable."
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 1,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "import chex\n",
37 | "\n",
38 | "# 6x7 board\n",
39 | "# We use the following coordinate system:\n",
40 | "# ^\n",
41 | "# |\n",
42 | "# |\n",
43 | "# |\n",
44 | "# 0 +-------->\n",
45 | "# 0\n",
46 | "Board = chex.Array\n",
47 | "\n",
48 | "# Index of the column to play\n",
49 | "Action = chex.Array\n",
50 | "\n",
51 | "# Let's assume the game is played by players X and O.\n",
52 | "# 1 if player X, -1 if player O\n",
53 | "Player = chex.Array\n",
54 | "\n",
55 | "# Reward for the player who played the action.\n",
56 | "# 1 for winning, 0 for draw, -1 for losing\n",
57 | "Reward = chex.Array\n",
58 | "\n",
59 | "# True/False if the game is over\n",
60 | "Done = chex.Array"
61 | ]
62 | },
63 | {
64 | "attachments": {},
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "Now, we create a class defining the game state at a given time."
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 2,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "@chex.dataclass\n",
78 | "class Env:\n",
79 | " board: Board\n",
80 | " player: Player\n",
81 | " done: Done\n",
82 | " reward: Reward"
83 | ]
84 | },
85 | {
86 | "attachments": {},
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "To visualize the game state, we define a function that prints the board."
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": 3,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "from rich import print\n",
100 | "\n",
101 | "BOARD_STRING = \"\"\"\n",
102 | " ? | ? | ? | ? | ? | ? | ?\n",
103 | "---|---|---|---|---|---|---\n",
104 | " ? | ? | ? | ? | ? | ? | ?\n",
105 | "---|---|---|---|---|---|---\n",
106 | " ? | ? | ? | ? | ? | ? | ?\n",
107 | "---|---|---|---|---|---|---\n",
108 | " ? | ? | ? | ? | ? | ? | ?\n",
109 | "---|---|---|---|---|---|---\n",
110 | " ? | ? | ? | ? | ? | ? | ?\n",
111 | "---|---|---|---|---|---|---\n",
112 | " ? | ? | ? | ? | ? | ? | ?\n",
113 | "---|---|---|---|---|---|---\n",
114 | " 1 2 3 4 5 6 7\n",
115 | "\"\"\"\n",
116 | "\n",
117 | "def print_board(board: Board):\n",
118 | " board_str = BOARD_STRING\n",
119 | " for i in reversed(range(board.shape[0])):\n",
120 | " for j in range(board.shape[1]):\n",
121 | " board_str = board_str.replace('?', '[green]X[/green]' if board[i, j] == 1 else '[red]O[/red]' if board[i, j] == -1 else ' ', 1)\n",
122 | " print(board_str)"
123 | ]
124 | },
125 | {
126 | "attachments": {},
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": [
130 | "Let's write a function that checks if the game is over.\n",
131 | "We iterate over all horizontal/vertical/diagonal/anti-diagonal lines and check if there are 4 consecutive pieces of the same color."
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 4,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "import jax.numpy as jnp\n",
141 | "\n",
142 | "def horizontals(board: Board) -> chex.Array:\n",
143 | " return jnp.stack([\n",
144 | " board[i, j:j+4]\n",
145 | " for i in range(board.shape[0])\n",
146 | " for j in range(board.shape[1] - 3)\n",
147 | " ])\n",
148 | "\n",
149 | "def verticals(board: Board) -> chex.Array:\n",
150 | " return jnp.stack([\n",
151 | " board[i:i+4, j]\n",
152 | " for i in range(board.shape[0] - 3)\n",
153 | " for j in range(board.shape[1])\n",
154 | " ])\n",
155 | "\n",
156 | "def diagonals(board: Board) -> chex.Array:\n",
157 | " return jnp.stack([\n",
158 | " jnp.diag(board[i:i+4, j:j+4])\n",
159 | " for i in range(board.shape[0] - 3)\n",
160 | " for j in range(board.shape[1] - 3)\n",
161 | " ])\n",
162 | "\n",
163 | "def antidiagonals(board: Board) -> chex.Array:\n",
164 | " return jnp.stack([\n",
165 | " jnp.diag(board[i:i+4, j:j+4][::-1])\n",
166 | " for i in range(board.shape[0] - 3)\n",
167 | " for j in range(board.shape[1] - 3)\n",
168 | " ])\n",
169 | "\n",
170 | "def get_winner(board: Board) -> Player:\n",
171 | " all_lines = jnp.concatenate((\n",
172 | " horizontals(board),\n",
173 | " verticals(board),\n",
174 | " diagonals(board),\n",
175 | " antidiagonals(board),\n",
176 | " ))\n",
177 | " # x_won and o_won are 1 if the player won, 0 otherwise\n",
178 | " x_won = jnp.any(jnp.all(all_lines == 1, axis=1)).astype(jnp.int8)\n",
179 | " o_won = jnp.any(jnp.all(all_lines == -1, axis=1)).astype(jnp.int8)\n",
180 | " # We consider the following cases:\n",
181 | " # - !x_won and !o_won -> 0 - 0 = 0 -> draw OR not finished\n",
182 | " # - x_won and !o_won -> 1 - 0 = 1 -> Player 1 (X) won\n",
183 | " # - !x_won and o_won -> 0 - 1 = -1 -> Player -1 (O) won\n",
184 | " # - x_won and o_won -> impossible, the game would have ended earlier\n",
185 | " return x_won - o_won"
186 | ]
187 | },
188 | {
189 | "attachments": {},
190 | "cell_type": "markdown",
191 | "metadata": {},
192 | "source": [
193 | "Finally, we can implement the environment dynamics:\n",
194 | "\n",
195 | "* `env_reset` creates a brand new game\n",
196 | "* `env_step` plays a move and returns the new state, the reward **for the player who played the move** and whether the game has ended"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 5,
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "import jax.numpy as jnp\n",
206 | "\n",
207 | "def env_reset(_):\n",
208 | " return Env(\n",
209 | " board=jnp.zeros((6, 7), dtype=jnp.int8),\n",
210 | " player=jnp.int8(1),\n",
211 | " done=jnp.bool_(False),\n",
212 | " reward=jnp.int8(0),\n",
213 | " )"
214 | ]
215 | },
216 | {
217 | "attachments": {},
218 | "cell_type": "markdown",
219 | "metadata": {},
220 | "source": [
221 | "You might be wondering why `env_reset` takes an unused argument.\n",
222 | "This allows us to parallelize the environment with the `jax.vmap` function:"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": 6,
228 | "metadata": {},
229 | "outputs": [
230 | {
231 | "data": {
232 | "text/plain": [
233 | "Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int8)"
234 | ]
235 | },
236 | "execution_count": 6,
237 | "metadata": {},
238 | "output_type": "execute_result"
239 | }
240 | ],
241 | "source": [
242 | "import jax\n",
243 | "\n",
244 | "jax.vmap(env_reset)(jnp.arange(10)).player\n",
245 | "# 10 games have been created:"
246 | ]
247 | },
248 | {
249 | "attachments": {},
250 | "cell_type": "markdown",
251 | "metadata": {},
252 | "source": [
253 | "You need to perfectly understand the `env_step` function to implement the MCTS algorithm.\n",
254 | "\n",
255 | "In particular, you need to understand that the reward returned by `env_step` is the reward **for the player who played the move**.\n",
256 | "This means that the reward should always be either 0 for a move that does not end the game or causes a draw, or 1 for a move that ends the game with a win for the player who played the move.\n",
257 | "\n",
258 | "A reward of -1 is reserved for punishing illegal moves.\n",
259 | "Since `mctx` does not support masking illegal moves below the root node, we need another way to stop the MCTS algorithm from exploring illegal moves.\n",
260 | "Later, we will also write a custom policy function that will never select illegal moves."
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 7,
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "def env_step(env: Env, action: Action) -> tuple[Env, Reward, Done]:\n",
270 | " col = action\n",
271 | "\n",
272 | " # Find the first empty row in the column.\n",
273 | " # If the column is full, this will be the top row.\n",
274 | " row = jnp.argmax(env.board[:, col] == 0)\n",
275 | "\n",
276 | " # If the column is full, the move is invalid.\n",
277 | " invalid_move = env.board[row, col] != 0\n",
278 | "\n",
279 | " # Place the player's piece in the board only if the move is valid and the game is not over.\n",
280 | " board = env.board.at[row, col].set(jnp.where(env.done | invalid_move, env.board[row, col], env.player))\n",
281 | "\n",
282 | " # The reward is computed as follows:\n",
283 | " # * 0 if the game is **already** over. This is to ignore nodes below terminal nodes.\n",
284 | " # * -1 if the move is invalid\n",
285 | " # * 1 if the move won the game for the current player\n",
286 | " # * 0 if the move caused a draw\n",
287 | " # * (impossible for Connect 4) -1 if the move lost the game for the current player\n",
288 | " reward = jnp.where(env.done, 0, jnp.where(invalid_move, -1, get_winner(board) * env.player)).astype(jnp.int8)\n",
289 | "\n",
290 | " # We end the game if:\n",
291 | " # * the game was already over\n",
292 | " # * the move won or lost the game\n",
293 | " # * the move was invalid\n",
294 | " # * the board is full (draw)\n",
295 | " done = env.done | reward != 0 | invalid_move | jnp.all(board[-1] != 0)\n",
296 | "\n",
297 | " env = Env(\n",
298 | " board=board,\n",
299 | " # switch player\n",
300 | " player=jnp.where(done, env.player, -env.player),\n",
301 | " done=done,\n",
302 | " reward=reward,\n",
303 | " )\n",
304 | "\n",
305 | " return env, reward, done"
306 | ]
307 | },
308 | {
309 | "attachments": {},
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "## The MCTS algorithm\n",
314 | "\n",
315 | "We can now implement the MCTS algorithm."
316 | ]
317 | },
318 | {
319 | "attachments": {},
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "Let's start with some helper functions."
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 8,
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "def valid_action_mask(env: Env) -> chex.Array:\n",
333 | " '''\n",
334 | " Computes which actions are valid in the current state.\n",
335 | " Returns an array of booleans, indicating which columns are not full.\n",
336 | " In case the game is over, all columns are considered invalid.\n",
337 | " '''\n",
338 | " return jnp.where(env.done, jnp.array([False] * env.board.shape[1]), env.board[-1] == 0)"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 9,
344 | "metadata": {},
345 | "outputs": [],
346 | "source": [
347 | "def winning_action_mask(env: Env, player: Player) -> chex.Array:\n",
348 | " '''\n",
349 | " Finds all actions that would immediately win the game for the given player.\n",
350 | " '''\n",
351 | " # Override the next player in the environment with the given player.\n",
352 | " env = Env(board=env.board, player=player, done=env.done, reward=env.reward)\n",
353 | "\n",
354 | " # Play all actions and check the reward.\n",
355 | " # Remember that the reward is for the current player, so we expect it to be 1.\n",
356 | " env, reward, done = jax.vmap(env_step, (None, 0))(env, jnp.arange(7, dtype=jnp.int8))\n",
357 | " return reward == 1"
358 | ]
359 | },
360 | {
361 | "attachments": {},
362 | "cell_type": "markdown",
363 | "metadata": {},
364 | "source": [
365 | "Next, we define a policy function. The policy function is used in two places:\n",
366 | "\n",
367 | "* during random rollouts to select the next action to play\n",
368 | "* during node expansion as the prior distribution\n",
369 | "\n",
370 | "We could have used a uniform distribution for both, but we can do better.\n",
371 | "Let's use three simple heuristics:\n",
372 | "\n",
373 | "* never select illegal moves\n",
374 | "* always select winning moves if available\n",
375 | "* always block opponent's winning moves if we have no winning move\n",
376 | "\n",
377 | "Our policy function implementation returns unnormalized logits,\n",
378 | "so we cannot return pure 0% and 100% probabilities (using `inf` and `-inf` causes downstream numerical issues).\n",
379 | "Instead, we add `100` to the logits of a given action which we want to prioritize:\n",
380 | "\n",
381 | "* the lowest priority of `0` is assigned to illegal moves\n",
382 | "* a medium priority of `100` is assigned to legal moves\n",
383 | "* a higher priority of `200` is assigned to the opponent's winning moves\n",
384 | "* the highest priority of `300` is assigned to our winning moves"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": 10,
390 | "metadata": {},
391 | "outputs": [],
392 | "source": [
393 | "def policy_function(env: Env) -> chex.Array:\n",
394 | " return sum((\n",
395 | " valid_action_mask(env).astype(jnp.float32) * 100,\n",
396 | " winning_action_mask(env, -env.player).astype(jnp.float32) * 200,\n",
397 | " winning_action_mask(env, env.player).astype(jnp.float32) * 300,\n",
398 | " ))"
399 | ]
400 | },
401 | {
402 | "attachments": {},
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "This is how these logits translate to probabilities:"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 11,
412 | "metadata": {},
413 | "outputs": [
414 | {
415 | "data": {
416 | "text/plain": [
417 | "Array([0.14285715, 0.14285715, 0.14285715, 0.14285715, 0.14285715,\n",
418 | " 0.14285715, 0.14285715], dtype=float32)"
419 | ]
420 | },
421 | "execution_count": 11,
422 | "metadata": {},
423 | "output_type": "execute_result"
424 | }
425 | ],
426 | "source": [
427 | "jax.nn.softmax(jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": 12,
433 | "metadata": {},
434 | "outputs": [
435 | {
436 | "data": {
437 | "text/plain": [
438 | "Array([1., 0., 0., 0., 0., 0., 0.], dtype=float32)"
439 | ]
440 | },
441 | "execution_count": 12,
442 | "metadata": {},
443 | "output_type": "execute_result"
444 | }
445 | ],
446 | "source": [
447 | "jax.nn.softmax(jnp.array([100.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))"
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": 13,
453 | "metadata": {},
454 | "outputs": [
455 | {
456 | "data": {
457 | "text/plain": [
458 | "Array([0., 1., 0., 0., 0., 0., 0.], dtype=float32)"
459 | ]
460 | },
461 | "execution_count": 13,
462 | "metadata": {},
463 | "output_type": "execute_result"
464 | }
465 | ],
466 | "source": [
467 | "jax.nn.softmax(jnp.array([100.0, 200.0, 0.0, 0.0, 0.0, 0.0, 0.0]))"
468 | ]
469 | },
470 | {
471 | "cell_type": "code",
472 | "execution_count": 14,
473 | "metadata": {},
474 | "outputs": [
475 | {
476 | "data": {
477 | "text/plain": [
478 | "Array([0., 0., 1., 0., 0., 0., 0.], dtype=float32)"
479 | ]
480 | },
481 | "execution_count": 14,
482 | "metadata": {},
483 | "output_type": "execute_result"
484 | }
485 | ],
486 | "source": [
487 | "jax.nn.softmax(jnp.array([100.0, 200.0, 300.0, 0.0, 0.0, 0.0, 0.0]))"
488 | ]
489 | },
490 | {
491 | "cell_type": "code",
492 | "execution_count": 15,
493 | "metadata": {},
494 | "outputs": [
495 | {
496 | "data": {
497 | "text/plain": [
498 | "Array([0. , 0. , 0.5, 0.5, 0. , 0. , 0. ], dtype=float32)"
499 | ]
500 | },
501 | "execution_count": 15,
502 | "metadata": {},
503 | "output_type": "execute_result"
504 | }
505 | ],
506 | "source": [
507 | "jax.nn.softmax(jnp.array([100.0, 200.0, 300.0, 300.0, 0.0, 0.0, 0.0]))"
508 | ]
509 | },
510 | {
511 | "attachments": {},
512 | "cell_type": "markdown",
513 | "metadata": {},
514 | "source": [
515 | "You can see that a difference of 100 is enough to make a move almost certain to be selected."
516 | ]
517 | },
518 | {
519 | "attachments": {},
520 | "cell_type": "markdown",
521 | "metadata": {},
522 | "source": [
523 | "The MCTS algorithm uses random rollouts to estimate the value of each state.\n",
524 | "We can implement this with a simple function that plays moves according to our policy function until the game ends."
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": 16,
530 | "metadata": {},
531 | "outputs": [],
532 | "source": [
533 | "def rollout(env: Env, rng_key: chex.PRNGKey) -> Reward:\n",
534 | " '''\n",
535 | " Plays a game until the end and returns the reward from the perspective of the initial player.\n",
536 | " '''\n",
537 | " def cond(a):\n",
538 | " env, key = a\n",
539 | " return ~env.done\n",
540 | " def step(a):\n",
541 | " env, key = a\n",
542 | " key, subkey = jax.random.split(key)\n",
543 | " action = jax.random.categorical(subkey, policy_function(env))\n",
544 | " env, reward, done = env_step(env, action)\n",
545 | " return env, key\n",
546 | " leaf, key = jax.lax.while_loop(cond, step, (env, rng_key))\n",
547 | " # The leaf reward is from the perspective of the last player.\n",
548 | " # We negate it if the last player is not the initial player.\n",
549 | " return leaf.reward * leaf.player * env.player"
550 | ]
551 | },
552 | {
553 | "attachments": {},
554 | "cell_type": "markdown",
555 | "metadata": {},
556 | "source": [
557 | "The value function is simply the result of the rollout."
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": 17,
563 | "metadata": {},
564 | "outputs": [],
565 | "source": [
566 | "def value_function(env: Env, rng_key: chex.PRNGKey) -> chex.Array:\n",
567 | " return rollout(env, rng_key).astype(jnp.float32)"
568 | ]
569 | },
570 | {
571 | "attachments": {},
572 | "cell_type": "markdown",
573 | "metadata": {},
574 | "source": [
575 | "## Running mctx\n",
576 | "\n",
577 | "Finally, we get to use the `mctx` library.\n",
578 | "We need two function:\n",
579 | "\n",
580 | "* `root_fn` returns the root node of the MCTS tree\n",
581 | "* `recurrent_fn` expands a new node given a parent node and an action"
582 | ]
583 | },
584 | {
585 | "cell_type": "code",
586 | "execution_count": 18,
587 | "metadata": {},
588 | "outputs": [],
589 | "source": [
590 | "import mctx\n",
591 | "\n",
592 | "def root_fn(env: Env, rng_key: chex.PRNGKey) -> mctx.RootFnOutput:\n",
593 | " return mctx.RootFnOutput(\n",
594 | " prior_logits=policy_function(env),\n",
595 | " value=value_function(env, rng_key),\n",
596 | " # We will use the `embedding` field to store the environment.\n",
597 | " embedding=env,\n",
598 | " )"
599 | ]
600 | },
601 | {
602 | "cell_type": "code",
603 | "execution_count": 19,
604 | "metadata": {},
605 | "outputs": [],
606 | "source": [
607 | "def recurrent_fn(params, rng_key, action, embedding):\n",
608 | " # Extract the environment from the embedding.\n",
609 | " env = embedding\n",
610 | "\n",
611 | " # Play the action.\n",
612 | " env, reward, done = env_step(env, action)\n",
613 | "\n",
614 | " # Create the new MCTS node.\n",
615 | " recurrent_fn_output = mctx.RecurrentFnOutput(\n",
616 | " # reward for playing `action`\n",
617 | " reward=reward,\n",
618 | " # discount explained in the next section\n",
619 | " discount=jnp.where(done, 0, -1).astype(jnp.float32),\n",
620 | " # prior for the new state\n",
621 | " prior_logits=policy_function(env),\n",
622 | " # value for the new state\n",
623 | " value=jnp.where(done, 0, value_function(env, rng_key)).astype(jnp.float32),\n",
624 | " )\n",
625 | "\n",
626 | " # Return the new node and the new environment.\n",
627 | " return recurrent_fn_output, env"
628 | ]
629 | },
630 | {
631 | "attachments": {},
632 | "cell_type": "markdown",
633 | "metadata": {},
634 | "source": [
635 | "The `discount` field is used to flip the sign of the reward and value at each tree level.\n",
636 | "It would be possible to implement this in the environment dynamics, but it is more convenient to do it here.\n",
637 | "By flipping the sign, MCTS selects the best action for the current player of each turn.\n",
638 | "\n",
639 | "Note that we set `discount` to `0` when the game is over.\n",
640 | "This discards all the rewards and values after the end of the game."
641 | ]
642 | },
643 | {
644 | "attachments": {},
645 | "cell_type": "markdown",
646 | "metadata": {},
647 | "source": [
648 | "And now we can run MCTS!"
649 | ]
650 | },
651 | {
652 | "cell_type": "code",
653 | "execution_count": 20,
654 | "metadata": {},
655 | "outputs": [],
656 | "source": [
657 | "import functools\n",
658 | "from typing import Optional\n",
659 | "\n",
660 | "@functools.partial(jax.jit, static_argnums=(2,))\n",
661 | "def run_mcts(rng_key: chex.PRNGKey, env: Env, num_simulations: int, tree: Optional[mctx.Tree] = None) -> chex.Array:\n",
662 | " batch_size = 1\n",
663 | " key1, key2 = jax.random.split(rng_key)\n",
664 | " policy_output = mctx.alphazero_policy(\n",
665 | " # params can be used to pass additional data to the recurrent_fn like neural network weights\n",
666 | " params=None,\n",
667 | "\n",
668 | " rng_key=key1,\n",
669 | "\n",
670 | " # create a batch of environments (in this case, a batch of size 1)\n",
671 | " root=jax.vmap(root_fn, (None, 0))(env, jax.random.split(key2, batch_size)),\n",
672 | "\n",
673 | " # automatically vectorize the recurrent_fn\n",
674 | " recurrent_fn=jax.vmap(recurrent_fn, (None, None, 0, 0)),\n",
675 | "\n",
676 | " num_simulations=num_simulations,\n",
677 | "\n",
678 | " # we limit the depth of the search tree to 42, since we know that Connect Four can't last longer\n",
679 | " max_depth=42,\n",
680 | " max_nodes=int(num_simulations * 1.5),\n",
681 | " search_tree=tree,\n",
682 | "\n",
683 | " # our value is in the range [-1, 1], so we can use the min_max qtransform to map it to [0, 1]\n",
684 | " qtransform=functools.partial(mctx.qtransform_by_min_max, min_value=-1, max_value=1),\n",
685 | "\n",
686 | " # Dirichlet noise is used for exploration which we don't need in this example (we aren't training)\n",
687 | " dirichlet_fraction=0.0,\n",
688 | " temperature=0.0\n",
689 | " )\n",
690 | " return policy_output"
691 | ]
692 | },
693 | {
694 | "attachments": {},
695 | "cell_type": "markdown",
696 | "metadata": {},
697 | "source": [
698 | "Let's play the middle column two times and see what MCTS does:"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 21,
704 | "metadata": {},
705 | "outputs": [
706 | {
707 | "data": {
708 | "text/plain": [
709 | "Array([[0.016, 0.097, 0.189, 0.436, 0.177, 0.048, 0.037]], dtype=float32)"
710 | ]
711 | },
712 | "execution_count": 21,
713 | "metadata": {},
714 | "output_type": "execute_result"
715 | }
716 | ],
717 | "source": [
718 | "env = env_reset(0)\n",
719 | "env, reward, done = env_step(env, 3)\n",
720 | "env, reward, done = env_step(env, 3)\n",
721 | "policy_output = run_mcts(jax.random.PRNGKey(0), env, 1000)\n",
722 | "policy_output.action_weights"
723 | ]
724 | },
725 | {
726 | "cell_type": "code",
727 | "execution_count": 22,
728 | "metadata": {},
729 | "outputs": [
730 | {
731 | "data": {
732 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZgklEQVR4nO3df6zVdf3A8de90L1XBa4oeRG8cvNHISkX5codmll5kxyz2PpBzuLu5twqNO3OJlTjai4vlTpMGIRlNZuDamk/VMxuQnNdh0Isf6SlySDtXmC1e/G6Xdy95/sH6zK+gnLwcl9ceDy2s8mHz+ec13nPeZ9+zud8bkmhUCgEAECS0uwBAICjmxgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKNzB7gQPT398err74ao0ePjpKSkuxxAIADUCgUYufOnTFhwoQoLd3/+Y9hESOvvvpqVFdXZ48BAByErVu3ximnnLLfvx8WMTJ69OiI2P1mxowZkzwNAHAguru7o7q6euDn+P4Mixj530czY8aMESMAMMy83SUWLmAFAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAglRgBAFKJEQAg1cjsAYDDU82CB7NHGHSbF8/OHgHYB2dGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASHVQMbJs2bKoqamJioqKqK+vj/Xr1x/QcatWrYqSkpKYM2fOwbwsAHAEKjpGVq9eHc3NzdHS0hIbN26M2tramDVrVmzbtu0tj9u8eXPccMMNcdFFFx30sADAkafoGLnjjjvi6quvjqamppgyZUqsWLEijj322Ljnnnv2e0xfX19ceeWVcfPNN8dpp532jgYGAI4sRcXIrl27YsOGDdHQ0LDnCUpLo6GhIdrb2/d73Le+9a046aST4qqrrjqg1+nt7Y3u7u69HgDAkamoGNmxY0f09fVFVVXVXturqqqio6Njn8c8/vjj8aMf/SjuvvvuA36d1tbWqKysHHhUV1cXMyYAMIwc0m/T7Ny5Mz7/+c/H3XffHePGjTvg4xYuXBhdXV0Dj61btx7CKQGATCOL2XncuHExYsSI6Ozs3Gt7Z2dnjB8//k37v/TSS7F58+a4/PLLB7b19/fvfuGRI+OFF16I008//U3HlZeXR3l5eTGjAQDDVFFnRsrKymL69OnR1tY2sK2/vz/a2tpi5syZb9p/8uTJ8fTTT8emTZsGHh//+Mfjwx/+cGzatMnHLwBAcWdGIiKam5ujsbEx6urqYsaMGbFkyZLo6emJpqamiIiYN29eTJw4MVpbW6OioiLOPvvsvY4//vjjIyLetB0AODoVHSNz586N7du3x6JFi6KjoyOmTZsWa9asGbiodcuWLVFa6sauAMCBKSkUCoXsId5Od3d3VFZWRldXV4wZMyZ7HDgq1Cx4MHuEQbd58ezsEeCocqA/v53CAABSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAINVBxciyZcuipqYmKioqor6+PtavX7/ffX/1q19FXV1dHH/88XHcccfFtGnT4t577z3ogQGAI0vRMbJ69epobm6OlpaW2LhxY9TW1sasWbNi27Zt+9z/hBNOiG984xvR3t4ef/3rX6OpqSmamprikUceecfDAwDDX0mhUCgUc0B9fX2cf/75sXTp0oiI6O/vj+rq6rj22mtjwYIFB/Qc5513XsyePTtuueWWA9q/u7s7Kisro6urK8aMGVPMuMBBqlnwYPYIg27z4tnZI8BR5UB/fhd1ZmTXrl2xYcOGaGho2PMEpaXR0NAQ7e3tb3t8oVCItra2eOGFF+KDH/zgfvfr7e2N7u7uvR4AwJGpqBjZsWNH9PX1RVVV1V7bq6qqoqOjY7/HdXV1xahRo6KsrCxmz54dd911V3z0ox/d7/6tra1RWVk58Kiuri5mTABgGBmSb9OMHj06Nm3aFE8++WR8+9vfjubm5li7du1+91+4cGF0dXUNPLZu3ToUYwIACUYWs/O4ceNixIgR0dnZudf2zs7OGD9+/H6PKy0tjTPOOCMiIqZNmxZ/+9vforW1NT70oQ/tc//y8vIoLy8vZjQAYJgq6sxIWVlZTJ8+Pdra2ga29ff3R1tbW8ycOfOAn6e/vz96e3uLeWkA4AhV1JmRiIjm5uZobGyMurq6mDFjRixZsiR6enqiqakpIiLmzZsXEydOjNbW1ojYff1HXV1dnH766dHb2xsPPfRQ3HvvvbF8+fLBfScAwLBUdIzMnTs3tm/fHosWLYqOjo6YNm1arFmzZuCi1i1btkRp6Z4TLj09PfHlL385/vWvf8UxxxwTkydPjp/97Gcxd+7cwXsXAMCwVfR9RjK4zwgMPfcZAd6pQ3KfEQCAwSZGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUBxUjy5Yti5qamqioqIj6+vpYv379fve9++6746KLLoqxY8fG2LFjo6Gh4S33BwCOLkXHyOrVq6O5uTlaWlpi48aNUVtbG7NmzYpt27btc/+1a9fGFVdcEY899li0t7dHdXV1XHrppfHKK6+84+EBgOGvpFAoFIo5oL6+Ps4///xYunRpRET09/dHdXV1XHvttbFgwYK3Pb6vry/Gjh0bS5cujXnz5h3Qa3Z3d0dlZWV0dXXFmDFjihkXOEg1Cx7MHmHQbV48O3sEOKoc6M/vos6M7Nq1KzZs2BANDQ17nqC0NBoaGqK9vf2AnuP111+PN954I0444YT97tPb2xvd3d17PQCAI1NRMbJjx47o6+uLqqqqvbZXVVVFR0fHAT3HjTfeGBMmTNgraP6/1tbWqKysHHhUV1cXMyYAMIwM6bdpFi9eHKtWrYr7778/Kioq9rvfwoULo6ura+CxdevWIZwSABhKI4vZedy4cTFixIjo7Ozca3tnZ2eMHz/+LY+97bbbYvHixfGHP/whpk6d+pb7lpeXR3l5eTGjAQDDVFFnRsrKymL69OnR1tY2sK2/vz/a2tpi5syZ+z3uu9/9btxyyy2xZs2aqKurO/hpAYAjTlFnRiIimpubo7GxMerq6mLGjBmxZMmS6OnpiaampoiImDdvXkycODFaW1sjIuI73/lOLFq0KO67776oqakZuLZk1KhRMWrUqEF8KwDAcFR0jMydOze2b98eixYtio6Ojpg2bVqsWbNm4KLWLVu2RGnpnhMuy5cvj127dsWnPvWpvZ6npaUlbrrppnc2PQAw7BV9n5EM7jMCQ899RoB36pDcZwQAYLCJEQAgVdHXjMCRzscTAEPLmREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSuekZwNtwIzw4tJwZAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSHVSMLFu2LGpqaqKioiLq6+tj/fr1+9332WefjU9+8pNRU1MTJSUlsWTJkoOdFQA4AhUdI6tXr47m5uZoaWmJjRs3Rm1tbcyaNSu2bdu2z/1ff/31OO2002Lx4sUxfvz4dzwwAHBkKTpG7rjjjrj66qujqakppkyZEitWrIhjjz027rnnnn3uf/7558f3vve9+OxnPxvl5eXveGAA4MhSVIzs2rUrNmzYEA0NDXueoLQ0Ghoaor29fdCHAwCOfCOL2XnHjh3R19cXVVVVe22vqqqK559/ftCG6u3tjd7e3oE/d3d3D9pzAwCHl8Py2zStra1RWVk58Kiurs4eCQA4RIqKkXHjxsWIESOis7Nzr+2dnZ2DenHqwoULo6ura+CxdevWQXtuAODwUtTHNGVlZTF9+vRoa2uLOXPmREREf39/tLW1xTXXXDNoQ5WXl7vYNUHNggezRxh0mxfPzh4BgLdRVIxERDQ3N0djY2PU1dXFjBkzYsmSJdHT0xNNTU0RETFv3ryYOHFitLa2RsTui16fe+65gX9+5ZVXYtOmTTFq1Kg444wzBvGtAADDUdExMnfu3Ni+fXssWrQoOjo6Ytq0abFmzZqBi1q3bNkSpaV7Pv159dVX49xzzx3482233Ra33XZbXHzxxbF27dp3/g4AgGGt6BiJiLjmmmv2+7HM/w+MmpqaKBQKB/MyAMBR4LD8Ng0AcPQQIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAKjECAKQSIwBAqpHZAwAwPNQseDB7hEG3efHs7BEIZ0YAgGRiBABIJUYAgFRiBABIJUYAgFRiBABIJUYAgFTuMwIARXC/lcHnzAgAkEqMAACpxAgAkEqMAACpxAgAkEqMAACpxAgAkEqMAACpxAgAkEqMAACpxAgAkEqMAACpxAgAkEqMAACpxAgAkEqMAACpRmYPkK1mwYPZIwy6zYtnZ48AAAfMmREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSiREAIJUYAQBSHVSMLFu2LGpqaqKioiLq6+tj/fr1b7n/L37xi5g8eXJUVFTEOeecEw899NBBDQsAHHmKjpHVq1dHc3NztLS0xMaNG6O2tjZmzZoV27Zt2+f+f/7zn+OKK66Iq666Kv7yl7/EnDlzYs6cOfHMM8+84+EBgOGv6Bi544474uqrr46mpqaYMmVKrFixIo499ti455579rn/nXfeGR/72Mfia1/7Wpx11llxyy23xHnnnRdLly59x8MDAMNfUb+bZteuXbFhw4ZYuHDhwLbS0tJoaGiI9vb2fR7T3t4ezc3Ne22bNWtWPPDAA/t9nd7e3ujt7R34c1dXV0REdHd3FzPuAenvfX3QnzPbwa6TtdjNOuxmHfawFrtZh92sQ/HPWygU3nK/omJkx44d0dfXF1VVVXttr6qqiueff36fx3R0dOxz/46Ojv2+Tmtra9x8881v2l5dXV3MuEetyiXZExw+rMVu1mE367CHtdjNOux2qNdh586dUVlZud+/Pyx/a+/ChQv3OpvS398f//nPf+LEE0+MkpKSxMkOXnd3d1RXV8fWrVtjzJgx2eOksQ67WYc9rMVu1mE367DHkbAWhUIhdu7cGRMmTHjL/YqKkXHjxsWIESOis7Nzr+2dnZ0xfvz4fR4zfvz4ovaPiCgvL4/y8vK9th1//PHFjHrYGjNmzLD9l2owWYfdrMMe1mI367CbddhjuK/FW50R+Z+iLmAtKyuL6dOnR1tb28C2/v7+aGtri5kzZ+7zmJkzZ+61f0TEo48+ut/9AYCjS9Ef0zQ3N0djY2PU1dXFjBkzYsmSJdHT0xNNTU0RETFv3ryYOHFitLa2RkTEddddFxdffHHcfvvtMXv27Fi1alU89dRTsXLlysF9JwDAsFR0jMydOze2b98eixYtio6Ojpg2bVqsWbNm4CLVLVu2RGnpnhMuF1xwQdx3333xzW9+M77+9a/HmWeeGQ888ECcffbZg/cuhoHy8vJoaWl508dPRxvrsJt12MNa7GYddrMOexxNa1FSeLvv2wAAHEJ+Nw0AkEqMAACpxAgAkEqMAACpxMgQWLZsWdTU1ERFRUXU19fH+vXrs0cacn/605/i8ssvjwkTJkRJSclb/m6iI1lra2ucf/75MXr06DjppJNizpw58cILL2SPNeSWL18eU6dOHbiZ08yZM+Phhx/OHivd4sWLo6SkJK6//vrsUYbcTTfdFCUlJXs9Jk+enD1WildeeSU+97nPxYknnhjHHHNMnHPOOfHUU09lj3VIiZFDbPXq1dHc3BwtLS2xcePGqK2tjVmzZsW2bduyRxtSPT09UVtbG8uWLcseJdW6deti/vz58cQTT8Sjjz4ab7zxRlx66aXR09OTPdqQOuWUU2Lx4sWxYcOGeOqpp+IjH/lIfOITn4hnn302e7Q0Tz75ZPzgBz+IqVOnZo+S5v3vf3/8+9//Hng8/vjj2SMNuf/+979x4YUXxrve9a54+OGH47nnnovbb789xo4dmz3aoVXgkJoxY0Zh/vz5A3/u6+srTJgwodDa2po4Va6IKNx///3ZYxwWtm3bVoiIwrp167JHSTd27NjCD3/4w+wxUuzcubNw5plnFh599NHCxRdfXLjuuuuyRxpyLS0thdra2uwx0t14442FD3zgA9ljDDlnRg6hXbt2xYYNG6KhoWFgW2lpaTQ0NER7e3viZBwuurq6IiLihBNOSJ4kT19fX6xatSp6enqO2l8TMX/+/Jg9e/Ze/604Gv3jH/+ICRMmxGmnnRZXXnllbNmyJXukIfeb3/wm6urq4tOf/nScdNJJce6558bdd9+dPdYhJ0YOoR07dkRfX9/A3Wn/p6qqKjo6OpKm4nDR398f119/fVx44YVH3R2JIyKefvrpGDVqVJSXl8cXv/jFuP/++2PKlCnZYw25VatWxcaNGwd+hcbRqr6+Pn7yk5/EmjVrYvny5fHyyy/HRRddFDt37swebUj985//jOXLl8eZZ54ZjzzySHzpS1+Kr3zlK/HTn/40e7RDqujbwQODY/78+fHMM88clZ+LR0S8733vi02bNkVXV1f88pe/jMbGxli3bt1RFSRbt26N6667Lh599NGoqKjIHifVZZddNvDPU6dOjfr6+pg0aVL8/Oc/j6uuuipxsqHV398fdXV1ceutt0ZExLnnnhvPPPNMrFixIhobG5OnO3ScGTmExo0bFyNGjIjOzs69tnd2dsb48eOTpuJwcM0118Tvfve7eOyxx+KUU07JHidFWVlZnHHGGTF9+vRobW2N2trauPPOO7PHGlIbNmyIbdu2xXnnnRcjR46MkSNHxrp16+L73/9+jBw5Mvr6+rJHTHP88cfHe9/73njxxRezRxlSJ5988puC/KyzzjriP7ISI4dQWVlZTJ8+Pdra2ga29ff3R1tb21H72fjRrlAoxDXXXBP3339//PGPf4z3vOc92SMdNvr7+6O3tzd7jCF1ySWXxNNPPx2bNm0aeNTV1cWVV14ZmzZtihEjRmSPmOa1116Ll156KU4++eTsUYbUhRde+Kav+//973+PSZMmJU00NHxMc4g1NzdHY2Nj1NXVxYwZM2LJkiXR09MTTU1N2aMNqddee22v/8N5+eWXY9OmTXHCCSfEqaeemjjZ0Jo/f37cd9998etf/zpGjx49cO1QZWVlHHPMMcnTDZ2FCxfGZZddFqeeemrs3Lkz7rvvvli7dm088sgj2aMNqdGjR7/peqHjjjsuTjzxxKPuOqIbbrghLr/88pg0aVK8+uqr0dLSEiNGjIgrrrgie7Qh9dWvfjUuuOCCuPXWW+Mzn/lMrF+/PlauXBkrV67MHu3Qyv46z9HgrrvuKpx66qmFsrKywowZMwpPPPFE9khD7rHHHitExJsejY2N2aMNqX2tQUQUfvzjH2ePNqS+8IUvFCZNmlQoKysrvPvd7y5ccsklhd///vfZYx0Wjtav9s6dO7dw8sknF8rKygoTJ04szJ07t/Diiy9mj5Xit7/9beHss88ulJeXFyZPnlxYuXJl9kiHXEmhUCgkdRAAgGtGAIBcYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASCVGAIBUYgQASPV/yIqCzOtC/tEAAAAASUVORK5CYII=",
733 | "text/plain": [
734 | ""
735 | ]
736 | },
737 | "metadata": {},
738 | "output_type": "display_data"
739 | }
740 | ],
741 | "source": [
742 | "import matplotlib.pyplot as plt\n",
743 | "plt.bar(jnp.arange(7), policy_output.action_weights.mean(axis=0))\n",
744 | "plt.show()"
745 | ]
746 | },
747 | {
748 | "attachments": {},
749 | "cell_type": "markdown",
750 | "metadata": {},
751 | "source": [
752 | "It chose the middle column as the best move, as expected."
753 | ]
754 | },
755 | {
756 | "attachments": {},
757 | "cell_type": "markdown",
758 | "metadata": {},
759 | "source": [
760 | "Let's write a simple script to play against MCTS:"
761 | ]
762 | },
763 | {
764 | "cell_type": "code",
765 | "execution_count": null,
766 | "metadata": {},
767 | "outputs": [],
768 | "source": [
769 | "# set to False to enable human input\n",
770 | "player_1_ai = True\n",
771 | "player_2_ai = True\n",
772 | "player_1_use_subtree = True\n",
773 | "player_2_use_subtree = False\n",
774 | "simulations_p1 = 2000\n",
775 | "simulations_p2 = 2000\n",
776 | "key = jax.random.PRNGKey(0)\n",
777 | "env = env_reset(0)\n",
778 | "print_board(env.board)\n",
779 | "tree=None\n",
780 | "tree2=None\n",
781 | "while True:\n",
782 | " if player_1_ai:\n",
783 | " output = run_mcts(key, env, simulations_p1, tree)\n",
784 | " output.action\n",
785 | " \n",
786 | " if player_1_use_subtree:\n",
787 | " tree = mctx.get_subtree(output.search_tree, output.action)\n",
788 | " action = output.action.item()\n",
789 | " else:\n",
790 | " action = int(input()) - 1\n",
791 | "\n",
792 | " if player_2_ai and player_2_use_subtree:\n",
793 | " if tree2:\n",
794 | " tree2 = mctx.get_subtree(tree2, jnp.array([action], dtype=jnp.int32).reshape(1, -1))\n",
795 | "\n",
796 | " env, reward, done = env_step(env, action)\n",
797 | " print_board(env.board)\n",
798 | " if done: break\n",
799 | "\n",
800 | " if player_2_ai:\n",
801 | " # you can give it more simulations to make it stronger\n",
802 | " output = run_mcts(key, env, simulations_p2, tree2)\n",
803 | " \n",
804 | " if player_2_use_subtree:\n",
805 | " tree2 = mctx.get_subtree(output.search_tree, output.action)\n",
806 | " action = output.action.item()\n",
807 | " else:\n",
808 | " action = int(input()) - 1\n",
809 | " \n",
810 | " if player_1_ai and player_1_use_subtree:\n",
811 | " if tree:\n",
812 | " tree = mctx.get_subtree(tree, jnp.array([action], dtype=jnp.int32).reshape(1, -1))\n",
813 | "\n",
814 | " env, reward, done = env_step(env, action)\n",
815 | " print_board(env.board)\n",
816 | " if done: break\n",
817 | "\n",
818 | "players = {\n",
819 | " 1: \"[green]X[/green]\",\n",
820 | " -1: \"[red]O[/red]\",\n",
821 | "}\n",
822 | "print(reward)\n",
823 | "print(f\"Winner: {players[env.player.item()]}\")"
824 | ]
825 | },
826 | {
827 | "attachments": {},
828 | "cell_type": "markdown",
829 | "metadata": {},
830 | "source": [
831 | "At `10_000` simulations, the agent was able to beat all online Connect 4 bots I could find."
832 | ]
833 | }
834 | ],
835 | "metadata": {
836 | "kernelspec": {
837 | "display_name": "mctx-classic",
838 | "language": "python",
839 | "name": "python3"
840 | },
841 | "language_info": {
842 | "codemirror_mode": {
843 | "name": "ipython",
844 | "version": 3
845 | },
846 | "file_extension": ".py",
847 | "mimetype": "text/x-python",
848 | "name": "python",
849 | "nbconvert_exporter": "python",
850 | "pygments_lexer": "ipython3",
851 | "version": "3.10.9"
852 | },
853 | "orig_nbformat": 4
854 | },
855 | "nbformat": 4,
856 | "nbformat_minor": 2
857 | }
858 |
--------------------------------------------------------------------------------