├── mctx
├── py.typed
├── _src
│ ├── __init__.py
│ ├── tests
│ │ ├── mctx_test.py
│ │ ├── qtransforms_test.py
│ │ ├── seq_halving_test.py
│ │ ├── tree_test.py
│ │ └── policies_test.py
│ ├── seq_halving.py
│ ├── tree.py
│ ├── base.py
│ ├── qtransforms.py
│ ├── action_selection.py
│ ├── search.py
│ └── policies.py
└── __init__.py
├── requirements
├── requirements-test.txt
├── requirements_examples.txt
└── requirements.txt
├── MANIFEST.in
├── .gitignore
├── .github
└── workflows
│ ├── ci.yml
│ └── pypi-publish.yml
├── CONTRIBUTING.md
├── test.sh
├── setup.py
├── examples
├── policy_improvement_demo.py
└── visualization_demo.py
├── README.md
├── LICENSE
└── .pylintrc
/mctx/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements/requirements-test.txt:
--------------------------------------------------------------------------------
1 | absl-py>=2.3.1
2 | numpy>=1.24.1
3 |
--------------------------------------------------------------------------------
/requirements/requirements_examples.txt:
--------------------------------------------------------------------------------
1 | absl-py>=2.3.1
2 | pygraphviz>=1.7
3 |
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | chex>=0.1.91
2 | jax>=0.7.0
3 | jaxlib>=0.7.0
4 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements/*
4 | include mctx/py.typed
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Building and releasing library:
2 | *.egg-info
3 | *.pyc
4 | *.so
5 | build/
6 | dist/
7 | venv/
8 |
9 | # Mac OS
10 | .DS_Store
11 |
12 | # Python tools
13 | .mypy_cache/
14 | .pytype/
15 | .ipynb_checkpoints
16 |
17 | # Editors
18 | .idea
19 | .vscode
20 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | jobs:
10 | build-and-test:
11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
12 | runs-on: "${{ matrix.os }}"
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.11", "3.12", "3.13"]
17 | os: [ubuntu-latest]
18 |
19 | steps:
20 | - uses: "actions/checkout@v5"
21 | - uses: "actions/setup-python@v5.3.0"
22 | with:
23 | python-version: "${{ matrix.python-version }}"
24 | - name: Run CI tests
25 | run: bash test.sh
26 | shell: bash
27 |
--------------------------------------------------------------------------------
/mctx/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
1 | name: pypi
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | permissions:
11 | id-token: write
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Set up Python
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: '3.x'
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install setuptools wheel twine
22 | - name: Check consistency between the package version and release tag
23 | run: |
24 | RELEASE_VER=${GITHUB_REF#refs/*/}
25 | PACKAGE_VER="v`python setup.py --version`"
26 | if [ $RELEASE_VER != $PACKAGE_VER ]
27 | then
28 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1
29 | fi
30 | - name: Build
31 | run: |
32 | python setup.py sdist bdist_wheel
33 | - name: Publish package distributions to PyPI
34 | uses: pypa/gh-action-pypi-publish@release/v1
35 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Testing
26 |
27 | Please make sure that your PR passes all tests by running `bash test.sh` on your
28 | local machine. Also, you can run only tests that are affected by your code
29 | changes, but you will need to select them manually.
30 |
31 | ## Community Guidelines
32 |
33 | This project follows [Google's Open Source Community
34 | Guidelines](https://opensource.google.com/conduct/).
35 |
--------------------------------------------------------------------------------
/mctx/_src/tests/mctx_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for Mctx."""
16 |
17 | from absl.testing import absltest
18 | import mctx
19 |
20 |
21 | class MctxTest(absltest.TestCase):
22 | """Test mctx can be imported correctly."""
23 |
24 | def test_import(self):
25 | self.assertTrue(hasattr(mctx, "gumbel_muzero_policy"))
26 | self.assertTrue(hasattr(mctx, "muzero_policy"))
27 | self.assertTrue(hasattr(mctx, "qtransform_by_min_max"))
28 | self.assertTrue(hasattr(mctx, "qtransform_by_parent_and_siblings"))
29 | self.assertTrue(hasattr(mctx, "qtransform_completed_by_mix_value"))
30 | self.assertTrue(hasattr(mctx, "PolicyOutput"))
31 | self.assertTrue(hasattr(mctx, "RootFnOutput"))
32 | self.assertTrue(hasattr(mctx, "RecurrentFnOutput"))
33 |
34 |
35 | if __name__ == "__main__":
36 | absltest.main()
37 |
--------------------------------------------------------------------------------
/mctx/_src/tests/qtransforms_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `qtransforms.py`."""
16 | from absl.testing import absltest
17 | import jax
18 | import jax.numpy as jnp
19 | from mctx._src import qtransforms
20 | import numpy as np
21 |
22 |
23 | class QtransformsTest(absltest.TestCase):
24 |
25 | def test_mix_value(self):
26 | """Tests the output of _compute_mixed_value()."""
27 | raw_value = jnp.array(-0.8)
28 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf])
29 | probs = jax.nn.softmax(prior_logits)
30 | visit_counts = jnp.array([0, 4.0, 4.0, 0])
31 | qvalues = 10.0 / 54 * jnp.array([20.0, 3.0, -1.0, 10.0])
32 | mix_value = qtransforms._compute_mixed_value(
33 | raw_value, qvalues, visit_counts, probs)
34 |
35 | num_simulations = jnp.sum(visit_counts)
36 | expected_mix_value = 1.0 / (num_simulations + 1) * (
37 | raw_value + num_simulations *
38 | (probs[1] * qvalues[1] + probs[2] * qvalues[2]))
39 | np.testing.assert_allclose(expected_mix_value, mix_value)
40 |
41 | def test_mix_value_with_zero_visits(self):
42 | """Tests that zero visit counts do not divide by zero."""
43 | raw_value = jnp.array(-0.8)
44 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf])
45 | probs = jax.nn.softmax(prior_logits)
46 | visit_counts = jnp.array([0, 0, 0, 0])
47 | qvalues = jnp.zeros_like(probs)
48 | with jax.debug_nans():
49 | mix_value = qtransforms._compute_mixed_value(
50 | raw_value, qvalues, visit_counts, probs)
51 |
52 | np.testing.assert_allclose(raw_value, mix_value)
53 |
54 |
55 | if __name__ == "__main__":
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # Runs CI tests on a local machine.
17 | set -xeuo pipefail
18 |
19 | # Install deps in a virtual env.
20 | readonly VENV_DIR=/tmp/mctx-env
21 | rm -rf "${VENV_DIR}"
22 | python3 -m venv "${VENV_DIR}"
23 | source "${VENV_DIR}/bin/activate"
24 | python --version
25 |
26 | # Install dependencies.
27 | pip install --upgrade pip setuptools wheel
28 | pip install flake8 pytest-xdist pylint pylint-exit
29 | pip install -r requirements/requirements.txt
30 | pip install -r requirements/requirements-test.txt
31 |
32 | # Lint with flake8.
33 | flake8 `find mctx -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
34 |
35 | # Lint with pylint.
36 | # Fail on errors, warning, and conventions.
37 | PYLINT_ARGS="-efail -wfail -cfail"
38 | # Lint modules and tests separately.
39 | pylint --rcfile=.pylintrc `find mctx -name '*.py' | grep -v 'test.py' | xargs` || pylint-exit $PYLINT_ARGS $?
40 | # Disable `protected-access` warnings for tests.
41 | pylint --rcfile=.pylintrc `find mctx -name '*_test.py' | xargs` -d W0212 || pylint-exit $PYLINT_ARGS $?
42 |
43 | # Build the package.
44 | python setup.py sdist
45 | pip wheel --verbose --no-deps --no-clean dist/mctx*.tar.gz
46 | pip install mctx*.whl
47 |
48 | # Check types with pytype.
49 | # Note: pytype does not support 3.12 as of 23.11.23
50 | # See https://github.com/google/pytype/issues/1308
51 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 12 ];
52 | then
53 | pip install pytype
54 | pytype `find mctx/_src/ -name "*py" | xargs` -k
55 | fi;
56 |
57 | # Run tests using pytest.
58 | # Change directory to avoid importing the package from repo root.
59 | mkdir _testing && cd _testing
60 |
61 | # Run tests using pytest.
62 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs mctx
63 | cd ..
64 |
65 | set +u
66 | deactivate
67 | echo "All tests passed. Congrats!"
68 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Install script for setuptools."""
16 |
17 | import os
18 | from setuptools import find_namespace_packages
19 | from setuptools import setup
20 |
21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
22 |
23 |
24 | def _get_version():
25 | with open('mctx/__init__.py') as fp:
26 | for line in fp:
27 | if line.startswith('__version__') and '=' in line:
28 | version = line[line.find('=') + 1:].strip(' \'"\n')
29 | if version:
30 | return version
31 | raise ValueError('`__version__` not defined in `mctx/__init__.py`')
32 |
33 |
34 | def _parse_requirements(path):
35 |
36 | with open(os.path.join(_CURRENT_DIR, path)) as f:
37 | return [
38 | line.rstrip()
39 | for line in f
40 | if not (line.isspace() or line.startswith('#'))
41 | ]
42 |
43 |
44 | setup(
45 | name='mctx',
46 | version=_get_version(),
47 | url='https://github.com/google-deepmind/mctx',
48 | license='Apache 2.0',
49 | author='DeepMind',
50 | description=('Monte Carlo tree search in JAX.'),
51 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(),
52 | long_description_content_type='text/markdown',
53 | author_email='mctx-dev@google.com',
54 | keywords='jax planning reinforcement-learning python machine learning',
55 | packages=find_namespace_packages(exclude=['*_test.py']),
56 | install_requires=_parse_requirements(
57 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')),
58 | tests_require=_parse_requirements(
59 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-test.txt')),
60 | zip_safe=False, # Required for full installation.
61 | python_requires='>=3.11',
62 | classifiers=[
63 | 'Development Status :: 4 - Beta',
64 | 'Intended Audience :: Science/Research',
65 | 'License :: OSI Approved :: Apache Software License',
66 | 'Programming Language :: Python',
67 | 'Programming Language :: Python :: 3',
68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
69 | 'Topic :: Software Development :: Libraries :: Python Modules',
70 | ],
71 | )
72 |
--------------------------------------------------------------------------------
/mctx/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Mctx: Monte Carlo tree search in JAX."""
16 |
17 | from mctx._src.action_selection import gumbel_muzero_interior_action_selection
18 | from mctx._src.action_selection import gumbel_muzero_root_action_selection
19 | from mctx._src.action_selection import GumbelMuZeroExtraData
20 | from mctx._src.action_selection import muzero_action_selection
21 | from mctx._src.base import ChanceRecurrentFnOutput
22 | from mctx._src.base import DecisionRecurrentFnOutput
23 | from mctx._src.base import InteriorActionSelectionFn
24 | from mctx._src.base import LoopFn
25 | from mctx._src.base import PolicyOutput
26 | from mctx._src.base import RecurrentFn
27 | from mctx._src.base import RecurrentFnOutput
28 | from mctx._src.base import RecurrentState
29 | from mctx._src.base import RootActionSelectionFn
30 | from mctx._src.base import RootFnOutput
31 | from mctx._src.policies import gumbel_muzero_policy
32 | from mctx._src.policies import muzero_policy
33 | from mctx._src.policies import stochastic_muzero_policy
34 | from mctx._src.qtransforms import qtransform_by_min_max
35 | from mctx._src.qtransforms import qtransform_by_parent_and_siblings
36 | from mctx._src.qtransforms import qtransform_completed_by_mix_value
37 | from mctx._src.search import search
38 | from mctx._src.tree import Tree
39 |
40 | __version__ = "0.0.6"
41 |
42 | __all__ = (
43 | "ChanceRecurrentFnOutput",
44 | "DecisionRecurrentFnOutput",
45 | "GumbelMuZeroExtraData",
46 | "InteriorActionSelectionFn",
47 | "LoopFn",
48 | "PolicyOutput",
49 | "RecurrentFn",
50 | "RecurrentFnOutput",
51 | "RecurrentState",
52 | "RootActionSelectionFn",
53 | "RootFnOutput",
54 | "Tree",
55 | "gumbel_muzero_interior_action_selection",
56 | "gumbel_muzero_policy",
57 | "gumbel_muzero_root_action_selection",
58 | "muzero_action_selection",
59 | "muzero_policy",
60 | "qtransform_by_min_max",
61 | "qtransform_by_parent_and_siblings",
62 | "qtransform_completed_by_mix_value",
63 | "search",
64 | "stochastic_muzero_policy",
65 | )
66 |
67 | # _________________________________________
68 | # / Please don't use symbols in `_src` they \
69 | # \ are not part of the Mctx public API. /
70 | # -----------------------------------------
71 | # \ ^__^
72 | # \ (oo)\_______
73 | # (__)\ )\/\
74 | # ||----w |
75 | # || ||
76 | #
77 |
--------------------------------------------------------------------------------
/mctx/_src/seq_halving.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Functions for Sequential Halving."""
16 |
17 | import math
18 |
19 | import chex
20 | import jax.numpy as jnp
21 |
22 |
23 | def score_considered(considered_visit, gumbel, logits, normalized_qvalues,
24 | visit_counts):
25 | """Returns a score usable for an argmax."""
26 | # We allow to visit a child, if it is the only considered child.
27 | low_logit = -1e9
28 | logits = logits - jnp.max(logits, keepdims=True, axis=-1)
29 | penalty = jnp.where(
30 | visit_counts == considered_visit,
31 | 0, -jnp.inf)
32 | chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty])
33 | return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty
34 |
35 |
36 | def get_sequence_of_considered_visits(max_num_considered_actions,
37 | num_simulations):
38 | """Returns a sequence of visit counts considered by Sequential Halving.
39 |
40 | Sequential Halving is a "pure exploration" algorithm for bandits, introduced
41 | in "Almost Optimal Exploration in Multi-Armed Bandits":
42 | http://proceedings.mlr.press/v28/karnin13.pdf
43 |
44 | The visit counts allows to implement Sequential Halving by selecting the best
45 | action from the actions with the currently considered visit count.
46 |
47 | Args:
48 | max_num_considered_actions: The maximum number of considered actions.
49 | The `max_num_considered_actions` can be smaller than the number of
50 | actions.
51 | num_simulations: The total simulation budget.
52 |
53 | Returns:
54 | A tuple with visit counts. Length `num_simulations`.
55 | """
56 | if max_num_considered_actions <= 1:
57 | return tuple(range(num_simulations))
58 | log2max = int(math.ceil(math.log2(max_num_considered_actions)))
59 | sequence = []
60 | visits = [0] * max_num_considered_actions
61 | num_considered = max_num_considered_actions
62 | while len(sequence) < num_simulations:
63 | num_extra_visits = max(1, int(num_simulations / (log2max * num_considered)))
64 | for _ in range(num_extra_visits):
65 | sequence.extend(visits[:num_considered])
66 | for i in range(num_considered):
67 | visits[i] += 1
68 | # Halving the number of considered actions.
69 | num_considered = max(2, num_considered // 2)
70 | return tuple(sequence[:num_simulations])
71 |
72 |
73 | def get_table_of_considered_visits(max_num_considered_actions, num_simulations):
74 | """Returns a table of sequences of visit counts.
75 |
76 | Args:
77 | max_num_considered_actions: The maximum number of considered actions.
78 | The `max_num_considered_actions` can be smaller than the number of
79 | actions.
80 | num_simulations: The total simulation budget.
81 |
82 | Returns:
83 | A tuple of sequences of visit counts.
84 | Shape [max_num_considered_actions + 1, num_simulations].
85 | """
86 | return tuple(
87 | get_sequence_of_considered_visits(m, num_simulations)
88 | for m in range(max_num_considered_actions + 1))
89 |
90 |
--------------------------------------------------------------------------------
/mctx/_src/tests/seq_halving_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `seq_halving.py`."""
16 | from absl.testing import absltest
17 | from mctx._src import seq_halving
18 |
19 |
20 | class SeqHalvingTest(absltest.TestCase):
21 |
22 | def _check_visits(self, expected_results, max_num_considered_actions,
23 | num_simulations):
24 | """Compares the expected results to the returned considered visits."""
25 | self.assertLen(expected_results, num_simulations)
26 | results = seq_halving.get_sequence_of_considered_visits(
27 | max_num_considered_actions, num_simulations)
28 | self.assertEqual(tuple(expected_results), results)
29 |
30 | def test_considered_min_sims(self):
31 | # Using exactly `num_simulations = max_num_considered_actions *
32 | # log2(max_num_considered_actions)`.
33 | num_sims = 24
34 | max_num_considered = 8
35 | expected_results = [
36 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions.
37 | 1, 1, 1, 1, # Considering 4 actions.
38 | 2, 2, 2, 2, # Considering 4 actions, round 2.
39 | 3, 3, 4, 4, 5, 5, 6, 6, # Considering 2 actions.
40 | ] # pyformat: disable
41 | self._check_visits(expected_results, max_num_considered, num_sims)
42 |
43 | def test_considered_extra_sims(self):
44 | # Using more simulations than `max_num_considered_actions *
45 | # log2(max_num_considered_actions)`.
46 | num_sims = 47
47 | max_num_considered = 8
48 | expected_results = [
49 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions.
50 | 1, 1, 1, 1, # Considering 4 actions.
51 | 2, 2, 2, 2, # Considering 4 actions, round 2.
52 | 3, 3, 3, 3, # Considering 4 actions, round 3.
53 | 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10,
54 | 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17,
55 | ] # pyformat: disable
56 | self._check_visits(expected_results, max_num_considered, num_sims)
57 |
58 | def test_considered_less_sims(self):
59 | # Using a very small number of simulations.
60 | num_sims = 2
61 | max_num_considered = 8
62 | expected_results = [0, 0]
63 | self._check_visits(expected_results, max_num_considered, num_sims)
64 |
65 | def test_considered_less_sims2(self):
66 | # Using `num_simulations < max_num_considered_actions *
67 | # log2(max_num_considered_actions)`.
68 | num_sims = 13
69 | max_num_considered = 8
70 | expected_results = [
71 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions.
72 | 1, 1, 1, 1, # Considering 4 actions.
73 | 2,
74 | ] # pyformat: disable
75 | self._check_visits(expected_results, max_num_considered, num_sims)
76 |
77 | def test_considered_not_power_of_2(self):
78 | # Using max_num_considered_actions that is not a power of 2.
79 | num_sims = 24
80 | max_num_considered = 7
81 | expected_results = [
82 | 0, 0, 0, 0, 0, 0, 0, # Considering 7 actions.
83 | 1, 1, 1, 2, 2, 2, # Considering 3 actions.
84 | 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8,
85 | ] # pyformat: disable
86 | self._check_visits(expected_results, max_num_considered, num_sims)
87 |
88 | def test_considered_action0(self):
89 | num_sims = 16
90 | max_num_considered = 0
91 | expected_results = range(num_sims)
92 | self._check_visits(expected_results, max_num_considered, num_sims)
93 |
94 | def test_considered_action1(self):
95 | num_sims = 16
96 | max_num_considered = 1
97 | expected_results = range(num_sims)
98 | self._check_visits(expected_results, max_num_considered, num_sims)
99 |
100 |
101 | if __name__ == "__main__":
102 | absltest.main()
103 |
--------------------------------------------------------------------------------
/mctx/_src/tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A data structure used to hold / inspect search data for a batch of inputs."""
16 |
17 | from __future__ import annotations
18 | from typing import Any, ClassVar, Generic, TypeVar
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 |
24 |
25 | T = TypeVar("T")
26 |
27 |
28 | @chex.dataclass(frozen=True)
29 | class Tree(Generic[T]):
30 | """State of a search tree.
31 |
32 | The `Tree` dataclass is used to hold and inspect search data for a batch of
33 | inputs. In the fields below `B` denotes the batch dimension, `N` represents
34 | the number of nodes in the tree, and `num_actions` is the number of discrete
35 | actions.
36 |
37 | node_visits: `[B, N]` the visit counts for each node.
38 | raw_values: `[B, N]` the raw value for each node.
39 | node_values: `[B, N]` the cumulative search value for each node.
40 | parents: `[B, N]` the node index for the parents for each node.
41 | action_from_parent: `[B, N]` action to take from the parent to reach each
42 | node.
43 | children_index: `[B, N, num_actions]` the node index of the children for each
44 | action.
45 | children_prior_logits: `[B, N, num_actions]` the action prior logits of each
46 | node.
47 | children_visits: `[B, N, num_actions]` the visit counts for children for
48 | each action.
49 | children_rewards: `[B, N, num_actions]` the immediate reward for each action.
50 | children_discounts: `[B, N, num_actions]` the discount between the
51 | `children_rewards` and the `children_values`.
52 | children_values: `[B, N, num_actions]` the value of the next node after the
53 | action.
54 | embeddings: `[B, N, ...]` the state embeddings of each node.
55 | root_invalid_actions: `[B, num_actions]` a mask with invalid actions at the
56 | root. In the mask, invalid actions have ones, and valid actions have zeros.
57 | extra_data: `[B, ...]` extra data passed to the search.
58 | """
59 | node_visits: chex.Array # [B, N]
60 | raw_values: chex.Array # [B, N]
61 | node_values: chex.Array # [B, N]
62 | parents: chex.Array # [B, N]
63 | action_from_parent: chex.Array # [B, N]
64 | children_index: chex.Array # [B, N, num_actions]
65 | children_prior_logits: chex.Array # [B, N, num_actions]
66 | children_visits: chex.Array # [B, N, num_actions]
67 | children_rewards: chex.Array # [B, N, num_actions]
68 | children_discounts: chex.Array # [B, N, num_actions]
69 | children_values: chex.Array # [B, N, num_actions]
70 | embeddings: Any # [B, N, ...]
71 | root_invalid_actions: chex.Array # [B, num_actions]
72 | extra_data: T # [B, ...]
73 |
74 | # The following attributes are class variables (and should not be set on
75 | # Tree instances).
76 | ROOT_INDEX: ClassVar[int] = 0
77 | NO_PARENT: ClassVar[int] = -1
78 | UNVISITED: ClassVar[int] = -1
79 |
80 | @property
81 | def num_actions(self):
82 | return self.children_index.shape[-1]
83 |
84 | @property
85 | def num_simulations(self):
86 | return self.node_visits.shape[-1] - 1
87 |
88 | def qvalues(self, indices):
89 | """Compute q-values for any node indices in the tree."""
90 | # pytype: disable=wrong-arg-types # jnp-type
91 | if jnp.asarray(indices).shape:
92 | return jax.vmap(_unbatched_qvalues)(self, indices)
93 | else:
94 | return _unbatched_qvalues(self, indices)
95 | # pytype: enable=wrong-arg-types
96 |
97 | def summary(self) -> SearchSummary:
98 | """Extract summary statistics for the root node."""
99 | # Get state and action values for the root nodes.
100 | chex.assert_rank(self.node_values, 2)
101 | value = self.node_values[:, Tree.ROOT_INDEX]
102 | batch_size, = value.shape
103 | root_indices = jnp.full((batch_size,), Tree.ROOT_INDEX)
104 | qvalues = self.qvalues(root_indices)
105 | # Extract visit counts and induced probabilities for the root nodes.
106 | visit_counts = self.children_visits[:, Tree.ROOT_INDEX].astype(value.dtype)
107 | total_counts = jnp.sum(visit_counts, axis=-1, keepdims=True)
108 | visit_probs = visit_counts / jnp.maximum(total_counts, 1)
109 | visit_probs = jnp.where(total_counts > 0, visit_probs, 1 / self.num_actions)
110 | # Return relevant stats.
111 | return SearchSummary( # pytype: disable=wrong-arg-types # numpy-scalars
112 | visit_counts=visit_counts,
113 | visit_probs=visit_probs,
114 | value=value,
115 | qvalues=qvalues)
116 |
117 |
118 | def infer_batch_size(tree: Tree) -> int:
119 | """Recovers batch size from `Tree` data structure."""
120 | if tree.node_values.ndim != 2:
121 | raise ValueError("Input tree is not batched.")
122 | chex.assert_equal_shape_prefix(jax.tree_util.tree_leaves(tree), 1)
123 | return tree.node_values.shape[0]
124 |
125 |
126 | # A number of aggregate statistics and predictions are extracted from the
127 | # search data and returned to the user for further processing.
128 | @chex.dataclass(frozen=True)
129 | class SearchSummary:
130 | """Stats from MCTS search."""
131 | visit_counts: chex.Array
132 | visit_probs: chex.Array
133 | value: chex.Array
134 | qvalues: chex.Array
135 |
136 |
137 | def _unbatched_qvalues(tree: Tree, index: int) -> int:
138 | chex.assert_rank(tree.children_discounts, 2)
139 | return ( # pytype: disable=bad-return-type # numpy-scalars
140 | tree.children_rewards[index]
141 | + tree.children_discounts[index] * tree.children_values[index]
142 | )
143 |
--------------------------------------------------------------------------------
/examples/policy_improvement_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A demonstration of the policy improvement by planning with Gumbel."""
16 |
17 | import functools
18 | from typing import Tuple
19 |
20 | from absl import app
21 | from absl import flags
22 | import chex
23 | import jax
24 | import jax.numpy as jnp
25 | import mctx
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_integer("seed", 42, "Random seed.")
29 | flags.DEFINE_integer("batch_size", 256, "Batch size.")
30 | flags.DEFINE_integer("num_actions", 82, "Number of actions.")
31 | flags.DEFINE_integer("num_simulations", 4, "Number of simulations.")
32 | flags.DEFINE_integer("max_num_considered_actions", 16,
33 | "The maximum number of actions expanded at the root.")
34 | flags.DEFINE_integer("num_runs", 1, "Number of runs on random data.")
35 |
36 |
37 | @chex.dataclass(frozen=True)
38 | class DemoOutput:
39 | prior_policy_value: chex.Array
40 | prior_policy_action_value: chex.Array
41 | selected_action_value: chex.Array
42 | action_weights_policy_value: chex.Array
43 |
44 |
45 | def _run_demo(rng_key: chex.PRNGKey) -> Tuple[chex.PRNGKey, DemoOutput]:
46 | """Runs a search algorithm on random data."""
47 | batch_size = FLAGS.batch_size
48 | rng_key, logits_rng, q_rng, search_rng = jax.random.split(rng_key, 4)
49 | # We will demonstrate the algorithm on random prior_logits.
50 | # Normally, the prior_logits would be produced by a policy network.
51 | prior_logits = jax.random.normal(
52 | logits_rng, shape=[batch_size, FLAGS.num_actions])
53 | # Defining a bandit with random Q-values. Only the Q-values of the visited
54 | # actions will be revealed to the search algorithm.
55 | qvalues = jax.random.uniform(q_rng, shape=prior_logits.shape)
56 | # If we know the value under the prior policy, we can use the value to
57 | # complete the missing Q-values. The completed Q-values will produce an
58 | # improved policy in `policy_output.action_weights`.
59 | raw_value = jnp.sum(jax.nn.softmax(prior_logits) * qvalues, axis=-1)
60 | use_mixed_value = False
61 |
62 | # The root output would be the output of MuZero representation network.
63 | root = mctx.RootFnOutput(
64 | prior_logits=prior_logits,
65 | value=raw_value,
66 | # The embedding is used only to implement the MuZero model.
67 | embedding=jnp.zeros([batch_size]),
68 | )
69 | # The recurrent_fn would be provided by MuZero dynamics network.
70 | recurrent_fn = _make_bandit_recurrent_fn(qvalues)
71 |
72 | # Running the search.
73 | policy_output = mctx.gumbel_muzero_policy(
74 | params=(),
75 | rng_key=search_rng,
76 | root=root,
77 | recurrent_fn=recurrent_fn,
78 | num_simulations=FLAGS.num_simulations,
79 | max_num_considered_actions=FLAGS.max_num_considered_actions,
80 | qtransform=functools.partial(
81 | mctx.qtransform_completed_by_mix_value,
82 | use_mixed_value=use_mixed_value),
83 | )
84 |
85 | # Collecting the Q-value of the selected action.
86 | selected_action_value = qvalues[jnp.arange(batch_size), policy_output.action]
87 |
88 | # We will compare the selected action to the action selected by the
89 | # prior policy, while using the same Gumbel random numbers.
90 | gumbel = policy_output.search_tree.extra_data.root_gumbel
91 | prior_policy_action = jnp.argmax(gumbel + prior_logits, axis=-1)
92 | prior_policy_action_value = qvalues[jnp.arange(batch_size),
93 | prior_policy_action]
94 |
95 | # Computing the policy value under the new action_weights.
96 | action_weights_policy_value = jnp.sum(
97 | policy_output.action_weights * qvalues, axis=-1)
98 |
99 | output = DemoOutput(
100 | prior_policy_value=raw_value,
101 | prior_policy_action_value=prior_policy_action_value,
102 | selected_action_value=selected_action_value,
103 | action_weights_policy_value=action_weights_policy_value,
104 | )
105 | return rng_key, output
106 |
107 |
108 | def _make_bandit_recurrent_fn(qvalues):
109 | """Returns a recurrent_fn for a determistic bandit."""
110 |
111 | def recurrent_fn(params, rng_key, action, embedding):
112 | del params, rng_key
113 | # For the bandit, the reward will be non-zero only at the root.
114 | reward = jnp.where(embedding == 0,
115 | qvalues[jnp.arange(action.shape[0]), action],
116 | 0.0)
117 | # On a single-player environment, use discount from [0, 1].
118 | # On a zero-sum self-play environment, use discount=-1.
119 | discount = jnp.ones_like(reward)
120 | recurrent_fn_output = mctx.RecurrentFnOutput(
121 | reward=reward,
122 | discount=discount,
123 | prior_logits=jnp.zeros_like(qvalues),
124 | value=jnp.zeros_like(reward))
125 | next_embedding = embedding + 1
126 | return recurrent_fn_output, next_embedding
127 |
128 | return recurrent_fn
129 |
130 |
131 | def main(_):
132 | rng_key = jax.random.PRNGKey(FLAGS.seed)
133 | jitted_run_demo = jax.jit(_run_demo)
134 | for _ in range(FLAGS.num_runs):
135 | rng_key, output = jitted_run_demo(rng_key)
136 | # Printing the obtained increase of the policy value.
137 | # The obtained increase should be non-negative.
138 | action_value_improvement = (
139 | output.selected_action_value - output.prior_policy_action_value)
140 | weights_value_improvement = (
141 | output.action_weights_policy_value - output.prior_policy_value)
142 | print("action value improvement: %.3f (min=%.3f)" %
143 | (action_value_improvement.mean(), action_value_improvement.min()))
144 | print("action_weights value improvement: %.3f (min=%.3f)" %
145 | (weights_value_improvement.mean(), weights_value_improvement.min()))
146 |
147 |
148 | if __name__ == "__main__":
149 | app.run(main)
150 |
--------------------------------------------------------------------------------
/mctx/_src/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Core types used in mctx."""
16 |
17 | from typing import Any, Callable, Generic, TypeVar, Tuple
18 |
19 | import chex
20 |
21 | from mctx._src import tree
22 |
23 |
24 | # Parameters are arbitrary nested structures of `chex.Array`.
25 | # A nested structure is either a single object, or a collection (list, tuple,
26 | # dictionary, etc.) of other nested structures.
27 | Params = chex.ArrayTree
28 |
29 |
30 | # The model used to search is expressed by a `RecurrentFn` function that takes
31 | # `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` and
32 | # the new state embedding.
33 | @chex.dataclass(frozen=True)
34 | class RecurrentFnOutput:
35 | """The output of a `RecurrentFn`.
36 |
37 | reward: `[B]` an approximate reward from the state-action transition.
38 | discount: `[B]` the discount between the `reward` and the `value`.
39 | prior_logits: `[B, num_actions]` the logits produced by a policy network.
40 | value: `[B]` an approximate value of the state after the state-action
41 | transition.
42 | """
43 | reward: chex.Array
44 | discount: chex.Array
45 | prior_logits: chex.Array
46 | value: chex.Array
47 |
48 |
49 | Action = chex.Array
50 | RecurrentState = Any
51 | RecurrentFn = Callable[
52 | [Params, chex.PRNGKey, Action, RecurrentState],
53 | Tuple[RecurrentFnOutput, RecurrentState]]
54 |
55 |
56 | @chex.dataclass(frozen=True)
57 | class RootFnOutput:
58 | """The output of a representation network.
59 |
60 | prior_logits: `[B, num_actions]` the logits produced by a policy network.
61 | value: `[B]` an approximate value of the current state.
62 | embedding: `[B, ...]` the inputs to the next `recurrent_fn` call.
63 | """
64 | prior_logits: chex.Array
65 | value: chex.Array
66 | embedding: RecurrentState
67 |
68 |
69 | # Action selection functions specify how to pick nodes to expand in the tree.
70 | NodeIndices = chex.Array
71 | Depth = chex.Array
72 | RootActionSelectionFn = Callable[
73 | [chex.PRNGKey, tree.Tree, NodeIndices], chex.Array]
74 | InteriorActionSelectionFn = Callable[
75 | [chex.PRNGKey, tree.Tree, NodeIndices, Depth],
76 | chex.Array]
77 | QTransform = Callable[[tree.Tree, chex.Array], chex.Array]
78 | # LoopFn has the same interface as jax.lax.fori_loop.
79 | LoopFn = Callable[
80 | [int, int, Callable[[Any, Any], Any], Tuple[chex.PRNGKey, tree.Tree]],
81 | Tuple[chex.PRNGKey, tree.Tree]]
82 |
83 | T = TypeVar("T")
84 |
85 |
86 | @chex.dataclass(frozen=True)
87 | class PolicyOutput(Generic[T]):
88 | """The output of a policy.
89 |
90 | action: `[B]` the proposed action.
91 | action_weights: `[B, num_actions]` the targets used to train a policy network.
92 | The action weights sum to one. Usually, the policy network is trained by
93 | cross-entropy:
94 | `cross_entropy(labels=stop_gradient(action_weights), logits=prior_logits)`.
95 | search_tree: `[B, ...]` the search tree of the finished search.
96 | """
97 | action: chex.Array
98 | action_weights: chex.Array
99 | search_tree: tree.Tree[T]
100 |
101 |
102 | @chex.dataclass(frozen=True)
103 | class DecisionRecurrentFnOutput:
104 | """Output of the function for expanding decision nodes.
105 |
106 | Expanding a decision node takes an action and a state embedding and produces
107 | an afterstate, which represents the state of the environment after an action
108 | is taken but before the environment has updated its state. Accordingly, there
109 | is no discount factor or reward for transitioning from state `s` to afterstate
110 | `sa`.
111 |
112 | Attributes:
113 | chance_logits: `[B, C]` logits of `C` chance outcomes at the afterstate.
114 | afterstate_value: `[B]` values of the afterstates `v(sa)`.
115 | """
116 | chance_logits: chex.Array # [B, C]
117 | afterstate_value: chex.Array # [B]
118 |
119 |
120 | @chex.dataclass(frozen=True)
121 | class ChanceRecurrentFnOutput:
122 | """Output of the function for expanding chance nodes.
123 |
124 | Expanding a chance node takes a chance outcome and an afterstate embedding
125 | and produces a state, which captures a potentially stochastic environment
126 | transition. When this transition occurs reward and discounts are produced as
127 | in a normal transition.
128 |
129 | Attributes:
130 | action_logits: `[B, A]` logits of different actions from the state.
131 | value: `[B]` values of the states `v(s)`.
132 | reward: `[B]` rewards at the states.
133 | discount: `[B]` discounts at the states.
134 | """
135 | action_logits: chex.Array # [B, A]
136 | value: chex.Array # [B]
137 | reward: chex.Array # [B]
138 | discount: chex.Array # [B]
139 |
140 |
141 | @chex.dataclass(frozen=True)
142 | class StochasticRecurrentState:
143 | """Wrapper that enables different treatment of decision and chance nodes.
144 |
145 | In Stochastic MuZero tree nodes can either be decision or chance nodes, these
146 | nodes are treated differently during expansion, search and backup, and a user
147 | could also pass differently structured embeddings for each type of node. This
148 | wrapper enables treating chance and decision nodes differently and supports
149 | potential differences between chance and decision node structures.
150 |
151 | Attributes:
152 | state_embedding: `[B ...]` an optionally meaningful state embedding.
153 | afterstate_embedding: `[B ...]` an optionally meaningful afterstate
154 | embedding.
155 | is_decision_node: `[B]` whether the node is a decision or chance node. If it
156 | is a decision node, `afterstate_embedding` is a dummy value. If it is a
157 | chance node, `state_embedding` is a dummy value.
158 | """
159 | state_embedding: chex.ArrayTree # [B, ...]
160 | afterstate_embedding: chex.ArrayTree # [B, ...]
161 | is_decision_node: chex.Array # [B]
162 |
163 |
164 | RecurrentState = chex.ArrayTree
165 |
166 | DecisionRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState],
167 | Tuple[DecisionRecurrentFnOutput, RecurrentState]]
168 |
169 | ChanceRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState],
170 | Tuple[ChanceRecurrentFnOutput, RecurrentState]]
171 |
--------------------------------------------------------------------------------
/mctx/_src/qtransforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Monotonic transformations for the Q-values."""
16 |
17 | import chex
18 | import jax
19 | import jax.numpy as jnp
20 |
21 | from mctx._src import tree as tree_lib
22 |
23 |
24 | def qtransform_by_min_max(
25 | tree: tree_lib.Tree,
26 | node_index: chex.Numeric,
27 | *,
28 | min_value: chex.Numeric,
29 | max_value: chex.Numeric,
30 | ) -> chex.Array:
31 | """Returns Q-values normalized by the given `min_value` and `max_value`.
32 |
33 | Args:
34 | tree: _unbatched_ MCTS tree state.
35 | node_index: scalar index of the parent node.
36 | min_value: given minimum value. Usually the `min_value` is minimum possible
37 | untransformed Q-value.
38 | max_value: given maximum value. Usually the `max_value` is maximum possible
39 | untransformed Q-value.
40 |
41 | Returns:
42 | Q-values normalized by `(qvalues - min_value) / (max_value - min_value)`.
43 | The unvisited actions will have zero Q-value. Shape `[num_actions]`.
44 | """
45 | chex.assert_shape(node_index, ())
46 | qvalues = tree.qvalues(node_index)
47 | visit_counts = tree.children_visits[node_index]
48 | value_score = jnp.where(visit_counts > 0, qvalues, min_value)
49 | value_score = (value_score - min_value) / ((max_value - min_value))
50 | return value_score
51 |
52 |
53 | def qtransform_by_parent_and_siblings(
54 | tree: tree_lib.Tree,
55 | node_index: chex.Numeric,
56 | *,
57 | epsilon: chex.Numeric = 1e-8,
58 | ) -> chex.Array:
59 | """Returns qvalues normalized by min, max over V(node) and qvalues.
60 |
61 | Args:
62 | tree: _unbatched_ MCTS tree state.
63 | node_index: scalar index of the parent node.
64 | epsilon: the minimum denominator for the normalization.
65 |
66 | Returns:
67 | Q-values normalized to be from the [0, 1] interval. The unvisited actions
68 | will have zero Q-value. Shape `[num_actions]`.
69 | """
70 | chex.assert_shape(node_index, ())
71 | qvalues = tree.qvalues(node_index)
72 | visit_counts = tree.children_visits[node_index]
73 | chex.assert_rank([qvalues, visit_counts, node_index], [1, 1, 0])
74 | node_value = tree.node_values[node_index]
75 | safe_qvalues = jnp.where(visit_counts > 0, qvalues, node_value)
76 | chex.assert_equal_shape([safe_qvalues, qvalues])
77 | min_value = jnp.minimum(node_value, jnp.min(safe_qvalues, axis=-1))
78 | max_value = jnp.maximum(node_value, jnp.max(safe_qvalues, axis=-1))
79 |
80 | completed_by_min = jnp.where(visit_counts > 0, qvalues, min_value)
81 | normalized = (completed_by_min - min_value) / (
82 | jnp.maximum(max_value - min_value, epsilon))
83 | chex.assert_equal_shape([normalized, qvalues])
84 | return normalized
85 |
86 |
87 | def qtransform_completed_by_mix_value(
88 | tree: tree_lib.Tree,
89 | node_index: chex.Numeric,
90 | *,
91 | value_scale: chex.Numeric = 0.1,
92 | maxvisit_init: chex.Numeric = 50.0,
93 | rescale_values: bool = True,
94 | use_mixed_value: bool = True,
95 | epsilon: chex.Numeric = 1e-8,
96 | ) -> chex.Array:
97 | """Returns completed qvalues.
98 |
99 | The missing Q-values of the unvisited actions are replaced by the
100 | mixed value, defined in Appendix D of
101 | "Policy improvement by planning with Gumbel":
102 | https://openreview.net/forum?id=bERaNdoegnO
103 |
104 | The Q-values are transformed by a linear transformation:
105 | `(maxvisit_init + max(visit_counts)) * value_scale * qvalues`.
106 |
107 | Args:
108 | tree: _unbatched_ MCTS tree state.
109 | node_index: scalar index of the parent node.
110 | value_scale: scale for the Q-values.
111 | maxvisit_init: offset to the `max(visit_counts)` in the scaling factor.
112 | rescale_values: if True, scale the qvalues by `1 / (max_q - min_q)`.
113 | use_mixed_value: if True, complete the Q-values with mixed value,
114 | otherwise complete the Q-values with the raw value.
115 | epsilon: the minimum denominator when using `rescale_values`.
116 |
117 | Returns:
118 | Completed Q-values. Shape `[num_actions]`.
119 | """
120 | chex.assert_shape(node_index, ())
121 | qvalues = tree.qvalues(node_index)
122 | visit_counts = tree.children_visits[node_index]
123 |
124 | # Computing the mixed value and producing completed_qvalues.
125 | raw_value = tree.raw_values[node_index]
126 | prior_probs = jax.nn.softmax(
127 | tree.children_prior_logits[node_index])
128 | if use_mixed_value:
129 | value = _compute_mixed_value(
130 | raw_value,
131 | qvalues=qvalues,
132 | visit_counts=visit_counts,
133 | prior_probs=prior_probs)
134 | else:
135 | value = raw_value
136 | completed_qvalues = _complete_qvalues(
137 | qvalues, visit_counts=visit_counts, value=value)
138 |
139 | # Scaling the Q-values.
140 | if rescale_values:
141 | completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon)
142 | maxvisit = jnp.max(visit_counts, axis=-1)
143 | visit_scale = maxvisit_init + maxvisit
144 | return visit_scale * value_scale * completed_qvalues
145 |
146 |
147 | def _rescale_qvalues(qvalues, epsilon):
148 | """Rescales the given completed Q-values to be from the [0, 1] interval."""
149 | min_value = jnp.min(qvalues, axis=-1, keepdims=True)
150 | max_value = jnp.max(qvalues, axis=-1, keepdims=True)
151 | return (qvalues - min_value) / jnp.maximum(max_value - min_value, epsilon)
152 |
153 |
154 | def _complete_qvalues(qvalues, *, visit_counts, value):
155 | """Returns completed Q-values, with the `value` for unvisited actions."""
156 | chex.assert_equal_shape([qvalues, visit_counts])
157 | chex.assert_shape(value, [])
158 |
159 | # The missing qvalues are replaced by the value.
160 | completed_qvalues = jnp.where(
161 | visit_counts > 0,
162 | qvalues,
163 | value)
164 | chex.assert_equal_shape([completed_qvalues, qvalues])
165 | return completed_qvalues
166 |
167 |
168 | def _compute_mixed_value(raw_value, qvalues, visit_counts, prior_probs):
169 | """Interpolates the raw_value and weighted qvalues.
170 |
171 | Args:
172 | raw_value: an approximate value of the state. Shape `[]`.
173 | qvalues: Q-values for all actions. Shape `[num_actions]`. The unvisited
174 | actions have undefined Q-value.
175 | visit_counts: the visit counts for all actions. Shape `[num_actions]`.
176 | prior_probs: the action probabilities, produced by the policy network for
177 | each action. Shape `[num_actions]`.
178 |
179 | Returns:
180 | An estimator of the state value. Shape `[]`.
181 | """
182 | sum_visit_counts = jnp.sum(visit_counts, axis=-1)
183 | # Ensuring non-nan weighted_q, even if the visited actions have zero
184 | # prior probability.
185 | prior_probs = jnp.maximum(jnp.finfo(prior_probs.dtype).tiny, prior_probs)
186 | # Summing the probabilities of the visited actions.
187 | sum_probs = jnp.sum(jnp.where(visit_counts > 0, prior_probs, 0.0),
188 | axis=-1)
189 | weighted_q = jnp.sum(jnp.where(
190 | visit_counts > 0,
191 | prior_probs * qvalues / jnp.where(visit_counts > 0, sum_probs, 1.0),
192 | 0.0), axis=-1)
193 | return (raw_value + sum_visit_counts * weighted_q) / (sum_visit_counts + 1)
194 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mctx: MCTS-in-JAX
2 |
3 | Mctx is a library with a [JAX](https://github.com/google/jax)-native
4 | implementation of Monte Carlo tree search (MCTS) algorithms such as
5 | [AlphaZero](https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go),
6 | [MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules), and
7 | [Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO). For computation
8 | speed up, the implementation fully supports JIT-compilation. Search algorithms
9 | in Mctx are defined for and operate on batches of inputs, in parallel. This
10 | allows to make the most of the accelerators and enables the algorithms to work
11 | with large learned environment models parameterized by deep neural networks.
12 |
13 | ## Installation
14 |
15 | You can install the latest released version of Mctx from PyPI via:
16 |
17 | ```sh
18 | pip install mctx
19 | ```
20 |
21 | or you can install the latest development version from GitHub:
22 |
23 | ```sh
24 | pip install git+https://github.com/google-deepmind/mctx.git
25 | ```
26 |
27 | ## Motivation
28 |
29 | Learning and search have been important topics since the early days of AI
30 | research. In the [words of Rich Sutton](http://www.incompleteideas.net/IncIdeas/BitterLesson.html):
31 |
32 | > One thing that should be learned [...] is the great power of general purpose
33 | > methods, of methods that continue to scale with increased computation even as
34 | > the available computation becomes very great. The two methods that seem to
35 | > scale arbitrarily in this way are *search* and *learning*.
36 |
37 | Recently, search algorithms have been successfully combined with learned models
38 | parameterized by deep neural networks, resulting in some of the most powerful
39 | and general reinforcement learning algorithms to date (e.g. MuZero).
40 | However, using search algorithms in combination with deep neural networks
41 | requires efficient implementations, typically written in fast compiled
42 | languages; this can come at the expense of usability and hackability,
43 | especially for researchers that are not familiar with C++. In turn, this limits
44 | adoption and further research on this critical topic.
45 |
46 | Through this library, we hope to help researchers everywhere to contribute to
47 | such an exciting area of research. We provide JAX-native implementations of core
48 | search algorithms such as MCTS, that we believe strike a good balance between
49 | performance and usability for researchers that want to investigate search-based
50 | algorithms in Python. The search methods provided by Mctx are
51 | heavily configurable to allow researchers to explore a variety of ideas in
52 | this space, and contribute to the next generation of search based agents.
53 |
54 | ## Search in Reinforcement Learning
55 |
56 | In Reinforcement Learning the *agent* must learn to interact with the
57 | *environment* in order to maximize a scalar *reward* signal. On each step the
58 | agent must select an action and receives in exchange an observation and a
59 | reward. We may call whatever mechanism the agent uses to select the action the
60 | agent's *policy*.
61 |
62 | Classically, policies are parameterized directly by a function approximator (as
63 | in REINFORCE), or policies are inferred by inspecting a set of learned estimates
64 | of the value of each action (as in Q-learning). Alternatively, search allows to
65 | select actions by constructing on the fly, in each state, a policy or a value
66 | function local to the current state, by *searching* using a learned *model* of
67 | the environment.
68 |
69 | Exhaustive search over all possible future courses of actions is computationally
70 | prohibitive in any non trivial environment, hence we need search algorithms
71 | that can make the best use of a finite computational budget. Typically priors
72 | are needed to guide which nodes in the search tree to expand (to reduce the
73 | *breadth* of the tree that we construct), and value functions are used to
74 | estimate the value of incomplete paths in the tree that don't reach an episode
75 | termination (to reduce the *depth* of the search tree).
76 |
77 | ## Quickstart
78 |
79 | Mctx provides a low-level generic `search` function and high-level concrete
80 | policies: `muzero_policy` and `gumbel_muzero_policy`.
81 |
82 | The user needs to provide several learned components to specify the
83 | representation, dynamics and prediction used by [MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules).
84 | In the context of the Mctx library, the representation of the *root* state is
85 | specified by a `RootFnOutput`. The `RootFnOutput` contains the `prior_logits`
86 | from a policy network, the estimated `value` of the root state, and any
87 | `embedding` suitable to represent the root state for the environment model.
88 |
89 | The dynamics environment model needs to be specified by a `recurrent_fn`.
90 | A `recurrent_fn(params, rng_key, action, embedding)` call takes an `action` and
91 | a state `embedding`. The call should return a tuple `(recurrent_fn_output,
92 | new_embedding)` with a `RecurrentFnOutput` and the embedding of the next state.
93 | The `RecurrentFnOutput` contains the `reward` and `discount` for the transition,
94 | and `prior_logits` and `value` for the new state.
95 |
96 | In [`examples/visualization_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/visualization_demo.py), you can
97 | see calls to a policy:
98 |
99 | ```python
100 | policy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn,
101 | num_simulations=32)
102 | ```
103 |
104 | The `policy_output.action` contains the action proposed by the search. That
105 | action can be passed to the environment. To improve the policy, the
106 | `policy_output.action_weights` contain targets usable to train the policy
107 | probabilities.
108 |
109 | We recommend to use the `gumbel_muzero_policy`.
110 | [Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO) guarantees a policy
111 | improvement if the action values are correctly evaluated. The policy improvement
112 | is demonstrated in
113 | [`examples/policy_improvement_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/policy_improvement_demo.py).
114 |
115 | ### Example projects
116 | The following projects demonstrate the Mctx usage:
117 |
118 | - [Pgx](https://github.com/sotetsuk/pgx) — A collection of 20+ vectorized
119 | JAX environments, including backgammon, chess, shogi, Go, and an AlphaZero
120 | example.
121 | - [Basic Learning Demo with Mctx](https://github.com/kenjyoung/mctx_learning_demo) —
122 | AlphaZero on random mazes.
123 | - [a0-jax](https://github.com/NTT123/a0-jax) — AlphaZero on Connect Four,
124 | Gomoku, and Go.
125 | - [muax](https://github.com/bwfbowen/muax) — MuZero on gym-style environments
126 | (CartPole, LunarLander).
127 | - [Classic MCTS](https://github.com/Carbon225/mctx-classic) — A simple example on Connect Four.
128 | - [mctx-az](https://github.com/lowrollr/mctx-az) — Mctx with AlphaZero subtree persistence.
129 |
130 | Tell us about your project.
131 |
132 | ## Citing Mctx
133 |
134 | This repository is part of the DeepMind JAX Ecosystem, to cite Mctx
135 | please use the citation:
136 |
137 | ```bibtex
138 | @software{deepmind2020jax,
139 | title = {The {D}eep{M}ind {JAX} {E}cosystem},
140 | author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
141 | url = {http://github.com/deepmind},
142 | year = {2020},
143 | }
144 | ```
145 |
--------------------------------------------------------------------------------
/examples/visualization_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A demo of Graphviz visualization of a search tree."""
16 |
17 | from typing import Optional, Sequence
18 |
19 | from absl import app
20 | from absl import flags
21 | import chex
22 | import jax
23 | import jax.numpy as jnp
24 | import mctx
25 | import pygraphviz
26 |
27 | FLAGS = flags.FLAGS
28 | flags.DEFINE_integer("seed", 42, "Random seed.")
29 | flags.DEFINE_integer("num_simulations", 32, "Number of simulations.")
30 | flags.DEFINE_integer("max_num_considered_actions", 16,
31 | "The maximum number of actions expanded at the root.")
32 | flags.DEFINE_integer("max_depth", None, "The maximum search depth.")
33 | flags.DEFINE_string("output_file", "/tmp/search_tree.png",
34 | "The output file for the visualization.")
35 |
36 |
37 | def convert_tree_to_graph(
38 | tree: mctx.Tree,
39 | action_labels: Optional[Sequence[str]] = None,
40 | batch_index: int = 0
41 | ) -> pygraphviz.AGraph:
42 | """Converts a search tree into a Graphviz graph.
43 |
44 | Args:
45 | tree: A `Tree` containing a batch of search data.
46 | action_labels: Optional labels for edges, defaults to the action index.
47 | batch_index: Index of the batch element to plot.
48 |
49 | Returns:
50 | A Graphviz graph representation of `tree`.
51 | """
52 | chex.assert_rank(tree.node_values, 2)
53 | batch_size = tree.node_values.shape[0]
54 | if action_labels is None:
55 | action_labels = range(tree.num_actions)
56 | elif len(action_labels) != tree.num_actions:
57 | raise ValueError(
58 | f"action_labels {action_labels} has the wrong number of actions "
59 | f"({len(action_labels)}). "
60 | f"Expecting {tree.num_actions}.")
61 |
62 | def node_to_str(node_i, reward=0, discount=1):
63 | return (f"{node_i}\n"
64 | f"Reward: {reward:.2f}\n"
65 | f"Discount: {discount:.2f}\n"
66 | f"Value: {tree.node_values[batch_index, node_i]:.2f}\n"
67 | f"Visits: {tree.node_visits[batch_index, node_i]}\n")
68 |
69 | def edge_to_str(node_i, a_i):
70 | node_index = jnp.full([batch_size], node_i)
71 | probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i])
72 | return (f"{action_labels[a_i]}\n"
73 | f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations
74 | f"p: {probs[a_i]:.2f}\n")
75 |
76 | graph = pygraphviz.AGraph(directed=True)
77 |
78 | # Add root
79 | graph.add_node(0, label=node_to_str(node_i=0), color="green")
80 | # Add all other nodes and connect them up.
81 | for node_i in range(tree.num_simulations):
82 | for a_i in range(tree.num_actions):
83 | # Index of children, or -1 if not expanded
84 | children_i = tree.children_index[batch_index, node_i, a_i]
85 | if children_i >= 0:
86 | graph.add_node(
87 | children_i,
88 | label=node_to_str(
89 | node_i=children_i,
90 | reward=tree.children_rewards[batch_index, node_i, a_i],
91 | discount=tree.children_discounts[batch_index, node_i, a_i]),
92 | color="red")
93 | graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i))
94 |
95 | return graph
96 |
97 |
98 | def _run_demo(rng_key: chex.PRNGKey):
99 | """Runs a search algorithm on a toy environment."""
100 | # We will define a deterministic toy environment.
101 | # The deterministic `transition_matrix` has shape `[num_states, num_actions]`.
102 | # The `transition_matrix[s, a]` holds the next state.
103 | transition_matrix = jnp.array([
104 | [1, 2, 3, 4],
105 | [0, 5, 0, 0],
106 | [0, 0, 0, 6],
107 | [0, 0, 0, 0],
108 | [0, 0, 0, 0],
109 | [0, 0, 0, 0],
110 | [0, 0, 0, 0],
111 | ], dtype=jnp.int32)
112 | # The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]`
113 | # holds the reward for that (s, a) pair.
114 | rewards = jnp.array([
115 | [1, -1, 0, 0],
116 | [0, 0, 0, 0],
117 | [0, 0, 0, 0],
118 | [0, 0, 0, 0],
119 | [0, 0, 0, 0],
120 | [0, 0, 0, 0],
121 | [10, 0, 20, 0],
122 | ], dtype=jnp.float32)
123 | num_states = rewards.shape[0]
124 | # The discount for each (s, a) pair.
125 | discounts = jnp.where(transition_matrix > 0, 1.0, 0.0)
126 | # Using optimistic initial values to encourage exploration.
127 | values = jnp.full([num_states], 15.0)
128 | # The prior policies for each state.
129 | all_prior_logits = jnp.zeros_like(rewards)
130 | root, recurrent_fn = _make_batched_env_model(
131 | # Using batch_size=2 to test the batched search.
132 | batch_size=2,
133 | transition_matrix=transition_matrix,
134 | rewards=rewards,
135 | discounts=discounts,
136 | values=values,
137 | prior_logits=all_prior_logits)
138 |
139 | # Running the search.
140 | policy_output = mctx.gumbel_muzero_policy(
141 | params=(),
142 | rng_key=rng_key,
143 | root=root,
144 | recurrent_fn=recurrent_fn,
145 | num_simulations=FLAGS.num_simulations,
146 | max_depth=FLAGS.max_depth,
147 | max_num_considered_actions=FLAGS.max_num_considered_actions,
148 | )
149 | return policy_output
150 |
151 |
152 | def _make_batched_env_model(
153 | batch_size: int,
154 | *,
155 | transition_matrix: chex.Array,
156 | rewards: chex.Array,
157 | discounts: chex.Array,
158 | values: chex.Array,
159 | prior_logits: chex.Array):
160 | """Returns a batched `(root, recurrent_fn)`."""
161 | chex.assert_equal_shape([transition_matrix, rewards, discounts,
162 | prior_logits])
163 | num_states, num_actions = transition_matrix.shape
164 | chex.assert_shape(values, [num_states])
165 | # We will start the search at state zero.
166 | root_state = 0
167 | root = mctx.RootFnOutput(
168 | prior_logits=jnp.full([batch_size, num_actions],
169 | prior_logits[root_state]),
170 | value=jnp.full([batch_size], values[root_state]),
171 | # The embedding will hold the state index.
172 | embedding=jnp.zeros([batch_size], dtype=jnp.int32),
173 | )
174 |
175 | def recurrent_fn(params, rng_key, action, embedding):
176 | del params, rng_key
177 | chex.assert_shape(action, [batch_size])
178 | chex.assert_shape(embedding, [batch_size])
179 | recurrent_fn_output = mctx.RecurrentFnOutput(
180 | reward=rewards[embedding, action],
181 | discount=discounts[embedding, action],
182 | prior_logits=prior_logits[embedding],
183 | value=values[embedding])
184 | next_embedding = transition_matrix[embedding, action]
185 | return recurrent_fn_output, next_embedding
186 |
187 | return root, recurrent_fn
188 |
189 |
190 | def main(_):
191 | rng_key = jax.random.PRNGKey(FLAGS.seed)
192 | jitted_run_demo = jax.jit(_run_demo)
193 | print("Starting search.")
194 | policy_output = jitted_run_demo(rng_key)
195 | batch_index = 0
196 | selected_action = policy_output.action[batch_index]
197 | q_value = policy_output.search_tree.summary().qvalues[
198 | batch_index, selected_action]
199 | print("Selected action:", selected_action)
200 | # To estimate the value of the root state, use the Q-value of the selected
201 | # action. The Q-value is not affected by the exploration at the root node.
202 | print("Selected action Q-value:", q_value)
203 | graph = convert_tree_to_graph(policy_output.search_tree)
204 | print("Saving tree diagram to:", FLAGS.output_file)
205 | graph.draw(FLAGS.output_file, prog="dot")
206 |
207 |
208 | if __name__ == "__main__":
209 | app.run(main)
210 |
--------------------------------------------------------------------------------
/mctx/_src/tests/tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A unit test comparing the search tree to an expected search tree."""
16 | # pylint: disable=use-dict-literal
17 | import functools
18 | import json
19 |
20 | from absl import logging
21 | from absl.testing import absltest
22 | from absl.testing import parameterized
23 | import chex
24 | import jax
25 | import jax.numpy as jnp
26 | import mctx
27 | import numpy as np
28 |
29 | jax.config.update("jax_threefry_partitionable", False)
30 |
31 |
32 | def _prepare_root(batch_size, num_actions):
33 | """Returns a root consistent with the stored expected trees."""
34 | rng_key = jax.random.PRNGKey(0)
35 | # Using a different rng_key inside each batch element.
36 | rng_keys = [rng_key]
37 | for i in range(1, batch_size):
38 | rng_keys.append(jax.random.fold_in(rng_key, i))
39 | embedding = jnp.stack(rng_keys)
40 | output = jax.vmap(
41 | functools.partial(_produce_prediction_output, num_actions=num_actions))(
42 | embedding)
43 | return mctx.RootFnOutput(
44 | prior_logits=output["policy_logits"],
45 | value=output["value"],
46 | embedding=embedding,
47 | )
48 |
49 |
50 | def _produce_prediction_output(rng_key, num_actions):
51 | """Producing the model output as in the stored expected trees."""
52 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3)
53 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3)
54 | del rng_key
55 | # Producing value from [-1, +1).
56 | value = jax.random.uniform(value_rng, shape=(), minval=-1.0, maxval=1.0)
57 | # Producing reward from [-1, +1).
58 | reward = jax.random.uniform(reward_rng, shape=(), minval=-1.0, maxval=1.0)
59 | return dict(
60 | policy_logits=jax.random.normal(policy_rng, shape=[num_actions]),
61 | value=value,
62 | reward=reward,
63 | )
64 |
65 |
66 | def _prepare_recurrent_fn(num_actions, *, discount, zero_reward):
67 | """Returns a dynamics function consistent with the expected trees."""
68 |
69 | def recurrent_fn(params, rng_key, action, embedding):
70 | del params, rng_key
71 | # The embeddings serve as rng_keys.
72 | embedding = jax.vmap(
73 | functools.partial(_fold_action_in, num_actions=num_actions))(embedding,
74 | action)
75 | output = jax.vmap(
76 | functools.partial(_produce_prediction_output, num_actions=num_actions))(
77 | embedding)
78 | reward = output["reward"]
79 | if zero_reward:
80 | reward = jnp.zeros_like(reward)
81 | return mctx.RecurrentFnOutput(
82 | reward=reward,
83 | discount=jnp.full_like(reward, discount),
84 | prior_logits=output["policy_logits"],
85 | value=output["value"],
86 | ), embedding
87 |
88 | return recurrent_fn
89 |
90 |
91 | def _fold_action_in(rng_key, action, num_actions):
92 | """Returns a new rng key, selected by the given action."""
93 | chex.assert_shape(action, ())
94 | chex.assert_type(action, jnp.int32)
95 | sub_rngs = jax.random.split(rng_key, num_actions)
96 | return sub_rngs[action]
97 |
98 |
99 | def tree_to_pytree(tree: mctx.Tree, batch_i: int = 0):
100 | """Converts the MCTS tree to nested dicts."""
101 | nodes = {}
102 | nodes[0] = _create_pynode(
103 | tree, batch_i, 0, prior=1.0, action=None, reward=None)
104 | children_prior_probs = jax.nn.softmax(tree.children_prior_logits, axis=-1)
105 | for node_i in range(tree.num_simulations + 1):
106 | for a_i in range(tree.num_actions):
107 | prior = children_prior_probs[batch_i, node_i, a_i]
108 | # Index of children, or -1 if not expanded
109 | child_i = int(tree.children_index[batch_i, node_i, a_i])
110 | if child_i >= 0:
111 | reward = tree.children_rewards[batch_i, node_i, a_i]
112 | child = _create_pynode(
113 | tree, batch_i, child_i, prior=prior, action=a_i, reward=reward)
114 | nodes[child_i] = child
115 | else:
116 | child = _create_bare_pynode(prior=prior, action=a_i)
117 | # pylint: disable=line-too-long
118 | nodes[node_i]["child_stats"].append(child) # pytype: disable=attribute-error
119 | # pylint: enable=line-too-long
120 | return nodes[0]
121 |
122 |
123 | def _create_pynode(tree, batch_i, node_i, prior, action, reward):
124 | """Returns a dict with extracted search statistics."""
125 | node = dict(
126 | prior=_round_float(prior),
127 | visit=int(tree.node_visits[batch_i, node_i]),
128 | value_view=_round_float(tree.node_values[batch_i, node_i]),
129 | raw_value_view=_round_float(tree.raw_values[batch_i, node_i]),
130 | child_stats=[],
131 | evaluation_index=node_i,
132 | )
133 | if action is not None:
134 | node["action"] = action
135 | if reward is not None:
136 | node["reward"] = _round_float(reward)
137 | return node
138 |
139 |
140 | def _create_bare_pynode(prior, action):
141 | return dict(
142 | prior=_round_float(prior),
143 | child_stats=[],
144 | action=action,
145 | )
146 |
147 |
148 | def _round_float(value, ndigits=10):
149 | return round(float(value), ndigits)
150 |
151 |
152 | class TreeTest(parameterized.TestCase):
153 |
154 | # Make sure to adjust the `shard_count` parameter in the build file to match
155 | # the number of parameter configurations passed to test_tree.
156 | # pylint: disable=line-too-long
157 | @parameterized.named_parameters(
158 | ("muzero_norescale",
159 | "../mctx/_src/tests/test_data/muzero_tree.json"),
160 | ("muzero_qtransform",
161 | "../mctx/_src/tests/test_data/muzero_qtransform_tree.json"),
162 | ("gumbel_muzero_norescale",
163 | "../mctx/_src/tests/test_data/gumbel_muzero_tree.json"),
164 | ("gumbel_muzero_reward",
165 | "../mctx/_src/tests/test_data/gumbel_muzero_reward_tree.json"))
166 | # pylint: enable=line-too-long
167 | def test_tree(self, tree_data_path):
168 | with open(tree_data_path, "rb") as fd:
169 | tree = json.load(fd)
170 | reproduced = self._reproduce_tree(tree)
171 | chex.assert_trees_all_close(tree["tree"], reproduced, atol=1e-3)
172 |
173 | def _reproduce_tree(self, tree):
174 | """Reproduces the given JSON tree by running a search."""
175 | policy_fn = dict(
176 | gumbel_muzero=mctx.gumbel_muzero_policy,
177 | muzero=mctx.muzero_policy,
178 | )[tree["algorithm"]]
179 |
180 | env_config = tree["env_config"]
181 | root = tree["tree"]
182 | num_actions = len(root["child_stats"])
183 | num_simulations = root["visit"] - 1
184 | qtransform = functools.partial(
185 | getattr(mctx, tree["algorithm_config"].pop("qtransform")),
186 | **tree["algorithm_config"].pop("qtransform_kwargs", {}))
187 |
188 | batch_size = 3
189 | # To test the independence of the batch computation, we use different
190 | # invalid actions for the other elements of the batch. The different batch
191 | # elements will then have different search tree depths.
192 | invalid_actions = np.zeros([batch_size, num_actions])
193 | invalid_actions[1, 1:] = 1
194 | invalid_actions[2, 2:] = 1
195 |
196 | def run_policy():
197 | return policy_fn(
198 | params=(),
199 | rng_key=jax.random.PRNGKey(1),
200 | root=_prepare_root(batch_size=batch_size, num_actions=num_actions),
201 | recurrent_fn=_prepare_recurrent_fn(num_actions, **env_config),
202 | num_simulations=num_simulations,
203 | qtransform=qtransform,
204 | invalid_actions=invalid_actions,
205 | **tree["algorithm_config"])
206 |
207 | policy_output = jax.jit(run_policy)() # pylint: disable=not-callable
208 | logging.info("Done search.")
209 |
210 | return tree_to_pytree(policy_output.search_tree)
211 |
212 |
213 | if __name__ == "__main__":
214 | jax.config.update("jax_numpy_rank_promotion", "raise")
215 | absltest.main()
216 |
--------------------------------------------------------------------------------
/mctx/_src/action_selection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A collection of action selection functions."""
16 | from typing import Optional, TypeVar
17 |
18 | import chex
19 | import jax
20 | import jax.numpy as jnp
21 |
22 | from mctx._src import base
23 | from mctx._src import qtransforms
24 | from mctx._src import seq_halving
25 | from mctx._src import tree as tree_lib
26 |
27 |
28 | def switching_action_selection_wrapper(
29 | root_action_selection_fn: base.RootActionSelectionFn,
30 | interior_action_selection_fn: base.InteriorActionSelectionFn
31 | ) -> base.InteriorActionSelectionFn:
32 | """Wraps root and interior action selection fns in a conditional statement."""
33 |
34 | def switching_action_selection_fn(
35 | rng_key: chex.PRNGKey,
36 | tree: tree_lib.Tree,
37 | node_index: base.NodeIndices,
38 | depth: base.Depth) -> chex.Array:
39 | return jax.lax.cond(
40 | depth == 0,
41 | lambda x: root_action_selection_fn(*x[:3]),
42 | lambda x: interior_action_selection_fn(*x),
43 | (rng_key, tree, node_index, depth))
44 |
45 | return switching_action_selection_fn
46 |
47 |
48 | def muzero_action_selection(
49 | rng_key: chex.PRNGKey,
50 | tree: tree_lib.Tree,
51 | node_index: chex.Numeric,
52 | depth: chex.Numeric,
53 | *,
54 | pb_c_init: float = 1.25,
55 | pb_c_base: float = 19652.0,
56 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
57 | ) -> chex.Array:
58 | """Returns the action selected for a node index.
59 |
60 | See Appendix B in https://arxiv.org/pdf/1911.08265.pdf for more details.
61 |
62 | Args:
63 | rng_key: random number generator state.
64 | tree: _unbatched_ MCTS tree state.
65 | node_index: scalar index of the node from which to select an action.
66 | depth: the scalar depth of the current node. The root has depth zero.
67 | pb_c_init: constant c_1 in the PUCT formula.
68 | pb_c_base: constant c_2 in the PUCT formula.
69 | qtransform: a monotonic transformation to convert the Q-values to [0, 1].
70 |
71 | Returns:
72 | action: the action selected from the given node.
73 | """
74 | visit_counts = tree.children_visits[node_index]
75 | node_visit = tree.node_visits[node_index]
76 | pb_c = pb_c_init + jnp.log((node_visit + pb_c_base + 1.) / pb_c_base)
77 | prior_logits = tree.children_prior_logits[node_index]
78 | prior_probs = jax.nn.softmax(prior_logits)
79 | policy_score = jnp.sqrt(node_visit) * pb_c * prior_probs / (visit_counts + 1)
80 | chex.assert_shape([node_index, node_visit], ())
81 | chex.assert_equal_shape([prior_probs, visit_counts, policy_score])
82 | value_score = qtransform(tree, node_index)
83 |
84 | # Add tiny bit of randomness for tie break
85 | node_noise_score = 1e-7 * jax.random.uniform(
86 | rng_key, (tree.num_actions,))
87 | to_argmax = value_score + policy_score + node_noise_score
88 |
89 | # Masking the invalid actions at the root.
90 | return masked_argmax(to_argmax, tree.root_invalid_actions * (depth == 0))
91 |
92 |
93 | @chex.dataclass(frozen=True)
94 | class GumbelMuZeroExtraData:
95 | """Extra data for Gumbel MuZero search."""
96 | root_gumbel: chex.Array
97 |
98 |
99 | GumbelMuZeroExtraDataType = TypeVar( # pylint: disable=invalid-name
100 | "GumbelMuZeroExtraDataType", bound=GumbelMuZeroExtraData)
101 |
102 |
103 | def gumbel_muzero_root_action_selection(
104 | rng_key: chex.PRNGKey,
105 | tree: tree_lib.Tree[GumbelMuZeroExtraDataType],
106 | node_index: chex.Numeric,
107 | *,
108 | num_simulations: chex.Numeric,
109 | max_num_considered_actions: chex.Numeric,
110 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
111 | ) -> chex.Array:
112 | """Returns the action selected by Sequential Halving with Gumbel.
113 |
114 | Initially, we sample `max_num_considered_actions` actions without replacement.
115 | From these, the actions with the highest `gumbel + logits + qvalues` are
116 | visited first.
117 |
118 | Args:
119 | rng_key: random number generator state.
120 | tree: _unbatched_ MCTS tree state.
121 | node_index: scalar index of the node from which to take an action.
122 | num_simulations: the simulation budget.
123 | max_num_considered_actions: the number of actions sampled without
124 | replacement.
125 | qtransform: a monotonic transformation for the Q-values.
126 |
127 | Returns:
128 | action: the action selected from the given node.
129 | """
130 | del rng_key
131 | chex.assert_shape([node_index], ())
132 | visit_counts = tree.children_visits[node_index]
133 | prior_logits = tree.children_prior_logits[node_index]
134 | chex.assert_equal_shape([visit_counts, prior_logits])
135 | completed_qvalues = qtransform(tree, node_index)
136 |
137 | table = jnp.array(seq_halving.get_table_of_considered_visits(
138 | max_num_considered_actions, num_simulations))
139 | num_valid_actions = jnp.sum(
140 | 1 - tree.root_invalid_actions, axis=-1).astype(jnp.int32)
141 | num_considered = jnp.minimum(
142 | max_num_considered_actions, num_valid_actions)
143 | chex.assert_shape(num_considered, ())
144 | # At the root, the simulation_index is equal to the sum of visit counts.
145 | simulation_index = jnp.sum(visit_counts, -1)
146 | chex.assert_shape(simulation_index, ())
147 | considered_visit = table[num_considered, simulation_index]
148 | chex.assert_shape(considered_visit, ())
149 | gumbel = tree.extra_data.root_gumbel
150 | to_argmax = seq_halving.score_considered(
151 | considered_visit, gumbel, prior_logits, completed_qvalues,
152 | visit_counts)
153 |
154 | # Masking the invalid actions at the root.
155 | return masked_argmax(to_argmax, tree.root_invalid_actions)
156 |
157 |
158 | def gumbel_muzero_interior_action_selection(
159 | rng_key: chex.PRNGKey,
160 | tree: tree_lib.Tree,
161 | node_index: chex.Numeric,
162 | depth: chex.Numeric,
163 | *,
164 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
165 | ) -> chex.Array:
166 | """Selects the action with a deterministic action selection.
167 |
168 | The action is selected based on the visit counts to produce visitation
169 | frequencies similar to softmax(prior_logits + qvalues).
170 |
171 | Args:
172 | rng_key: random number generator state.
173 | tree: _unbatched_ MCTS tree state.
174 | node_index: scalar index of the node from which to take an action.
175 | depth: the scalar depth of the current node. The root has depth zero.
176 | qtransform: function to obtain completed Q-values for a node.
177 |
178 | Returns:
179 | action: the action selected from the given node.
180 | """
181 | del rng_key, depth
182 | chex.assert_shape([node_index], ())
183 | visit_counts = tree.children_visits[node_index]
184 | prior_logits = tree.children_prior_logits[node_index]
185 | chex.assert_equal_shape([visit_counts, prior_logits])
186 | completed_qvalues = qtransform(tree, node_index)
187 |
188 | # The `prior_logits + completed_qvalues` provide an improved policy,
189 | # because the missing qvalues are replaced by v_{prior_logits}(node).
190 | to_argmax = _prepare_argmax_input(
191 | probs=jax.nn.softmax(prior_logits + completed_qvalues),
192 | visit_counts=visit_counts)
193 |
194 | chex.assert_rank(to_argmax, 1)
195 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32)
196 |
197 |
198 | def masked_argmax(
199 | to_argmax: chex.Array,
200 | invalid_actions: Optional[chex.Array]) -> chex.Array:
201 | """Returns a valid action with the highest `to_argmax`."""
202 | if invalid_actions is not None:
203 | chex.assert_equal_shape([to_argmax, invalid_actions])
204 | # The usage of the -inf inside the argmax does not lead to NaN.
205 | # Do not use -inf inside softmax, logsoftmax or cross-entropy.
206 | to_argmax = jnp.where(invalid_actions, -jnp.inf, to_argmax)
207 | # If all actions are invalid, the argmax returns action 0.
208 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32)
209 |
210 |
211 | def _prepare_argmax_input(probs, visit_counts):
212 | """Prepares the input for the deterministic selection.
213 |
214 | When calling argmax(_prepare_argmax_input(...)) multiple times
215 | with updated visit_counts, the produced visitation frequencies will
216 | approximate the probs.
217 |
218 | For the derivation, see Section 5 "Planning at non-root nodes" in
219 | "Policy improvement by planning with Gumbel":
220 | https://openreview.net/forum?id=bERaNdoegnO
221 |
222 | Args:
223 | probs: a policy or an improved policy. Shape `[num_actions]`.
224 | visit_counts: the existing visit counts. Shape `[num_actions]`.
225 |
226 | Returns:
227 | The input to an argmax. Shape `[num_actions]`.
228 | """
229 | chex.assert_equal_shape([probs, visit_counts])
230 | to_argmax = probs - visit_counts / (
231 | 1 + jnp.sum(visit_counts, keepdims=True, axis=-1))
232 | return to_argmax
233 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/mctx/_src/tests/policies_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Tests for `policies.py`."""
16 | import functools
17 |
18 | from absl.testing import absltest
19 | import jax
20 | import jax.numpy as jnp
21 | import mctx
22 | from mctx._src import policies
23 | import numpy as np
24 |
25 | jax.config.update("jax_threefry_partitionable", False)
26 |
27 |
28 | def _make_bandit_recurrent_fn(rewards, dummy_embedding=()):
29 | """Returns a recurrent_fn with discount=0."""
30 |
31 | def recurrent_fn(params, rng_key, action, embedding):
32 | del params, rng_key, embedding
33 | reward = rewards[jnp.arange(action.shape[0]), action]
34 | return mctx.RecurrentFnOutput(
35 | reward=reward,
36 | discount=jnp.zeros_like(reward),
37 | prior_logits=jnp.zeros_like(rewards),
38 | value=jnp.zeros_like(reward),
39 | ), dummy_embedding
40 |
41 | return recurrent_fn
42 |
43 |
44 | def _make_bandit_decision_and_chance_fns(rewards, num_chance_outcomes):
45 |
46 | def decision_recurrent_fn(params, rng_key, action, embedding):
47 | del params, rng_key
48 | batch_size = action.shape[0]
49 | reward = rewards[jnp.arange(batch_size), action]
50 | dummy_chance_logits = jnp.full([batch_size, num_chance_outcomes],
51 | -jnp.inf).at[:, 0].set(1.0)
52 | afterstate_embedding = (action, embedding)
53 | return mctx.DecisionRecurrentFnOutput(
54 | chance_logits=dummy_chance_logits,
55 | afterstate_value=jnp.zeros_like(reward)), afterstate_embedding
56 |
57 | def chance_recurrent_fn(params, rng_key, chance_outcome,
58 | afterstate_embedding):
59 | del params, rng_key, chance_outcome
60 | afterstate_action, embedding = afterstate_embedding
61 | batch_size = afterstate_action.shape[0]
62 |
63 | reward = rewards[jnp.arange(batch_size), afterstate_action]
64 | return mctx.ChanceRecurrentFnOutput(
65 | action_logits=jnp.zeros_like(rewards),
66 | value=jnp.zeros_like(reward),
67 | discount=jnp.zeros_like(reward),
68 | reward=reward), embedding
69 |
70 | return decision_recurrent_fn, chance_recurrent_fn
71 |
72 |
73 | def _get_deepest_leaf(tree, node_index):
74 | """Returns `(leaf, depth)` with maximum depth and visit count.
75 |
76 | Args:
77 | tree: _unbatched_ MCTS tree state.
78 | node_index: the node of the inspected subtree.
79 |
80 | Returns:
81 | `(leaf, depth)` of a deepest leaf. If multiple leaves have the same depth,
82 | the leaf with the highest visit count is returned.
83 | """
84 | np.testing.assert_equal(len(tree.children_index.shape), 2)
85 | leaf = node_index
86 | max_found_depth = 0
87 | for action in range(tree.children_index.shape[-1]):
88 | next_node_index = tree.children_index[node_index, action]
89 | if next_node_index != tree.UNVISITED:
90 | found_leaf, found_depth = _get_deepest_leaf(tree, next_node_index)
91 | if ((1 + found_depth, tree.node_visits[found_leaf]) >
92 | (max_found_depth, tree.node_visits[leaf])):
93 | leaf = found_leaf
94 | max_found_depth = 1 + found_depth
95 | return leaf, max_found_depth
96 |
97 |
98 | class PoliciesTest(absltest.TestCase):
99 |
100 | def test_apply_temperature_one(self):
101 | """Tests temperature=1."""
102 | logits = jnp.arange(6, dtype=jnp.float32)
103 | new_logits = policies._apply_temperature(logits, temperature=1.0)
104 | np.testing.assert_allclose(logits - logits.max(), new_logits)
105 |
106 | def test_apply_temperature_two(self):
107 | """Tests temperature=2."""
108 | logits = jnp.arange(6, dtype=jnp.float32)
109 | temperature = 2.0
110 | new_logits = policies._apply_temperature(logits, temperature)
111 | np.testing.assert_allclose((logits - logits.max()) / temperature,
112 | new_logits)
113 |
114 | def test_apply_temperature_zero(self):
115 | """Tests temperature=0."""
116 | logits = jnp.arange(4, dtype=jnp.float32)
117 | new_logits = policies._apply_temperature(logits, temperature=0.0)
118 | np.testing.assert_allclose(
119 | jnp.array([-2.552118e+38, -1.701412e+38, -8.507059e+37, 0.0]),
120 | new_logits,
121 | rtol=1e-3)
122 |
123 | def test_apply_temperature_zero_on_large_logits(self):
124 | """Tests temperature=0 on large logits."""
125 | logits = jnp.array([100.0, 3.4028235e+38, -jnp.inf, -3.4028235e+38])
126 | new_logits = policies._apply_temperature(logits, temperature=0.0)
127 | np.testing.assert_allclose(
128 | jnp.array([-jnp.inf, 0.0, -jnp.inf, -jnp.inf]), new_logits)
129 |
130 | def test_mask_invalid_actions(self):
131 | """Tests action masking."""
132 | logits = jnp.array([1e6, -jnp.inf, 1e6 + 1, -100.0])
133 | invalid_actions = jnp.array([0.0, 1.0, 0.0, 1.0])
134 | masked_logits = policies._mask_invalid_actions(
135 | logits, invalid_actions)
136 | valid_probs = jax.nn.softmax(jnp.array([0.0, 1.0]))
137 | np.testing.assert_allclose(
138 | jnp.array([valid_probs[0], 0.0, valid_probs[1], 0.0]),
139 | jax.nn.softmax(masked_logits))
140 |
141 | def test_mask_all_invalid_actions(self):
142 | """Tests a state with no valid action."""
143 | logits = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, -jnp.inf])
144 | invalid_actions = jnp.array([1.0, 1.0, 1.0, 1.0])
145 | masked_logits = policies._mask_invalid_actions(
146 | logits, invalid_actions)
147 | np.testing.assert_allclose(
148 | jnp.array([0.25, 0.25, 0.25, 0.25]),
149 | jax.nn.softmax(masked_logits))
150 |
151 | def test_muzero_policy(self):
152 | root = mctx.RootFnOutput(
153 | prior_logits=jnp.array([
154 | [-1.0, 0.0, 2.0, 3.0],
155 | ]),
156 | value=jnp.array([0.0]),
157 | embedding=(),
158 | )
159 | rewards = jnp.zeros_like(root.prior_logits)
160 | invalid_actions = jnp.array([
161 | [0.0, 0.0, 0.0, 1.0],
162 | ])
163 |
164 | policy_output = mctx.muzero_policy(
165 | params=(),
166 | rng_key=jax.random.PRNGKey(0),
167 | root=root,
168 | recurrent_fn=_make_bandit_recurrent_fn(rewards),
169 | num_simulations=1,
170 | invalid_actions=invalid_actions,
171 | dirichlet_fraction=0.0)
172 | expected_action = jnp.array([2], dtype=jnp.int32)
173 | np.testing.assert_array_equal(expected_action, policy_output.action)
174 | expected_action_weights = jnp.array([
175 | [0.0, 0.0, 1.0, 0.0],
176 | ])
177 | np.testing.assert_allclose(expected_action_weights,
178 | policy_output.action_weights)
179 |
180 | def test_gumbel_muzero_policy(self):
181 | root_value = jnp.array([-5.0])
182 | root = mctx.RootFnOutput(
183 | prior_logits=jnp.array([
184 | [0.0, -1.0, 2.0, 3.0],
185 | ]),
186 | value=root_value,
187 | embedding=(),
188 | )
189 | rewards = jnp.array([
190 | [20.0, 3.0, -1.0, 10.0],
191 | ])
192 | invalid_actions = jnp.array([
193 | [1.0, 0.0, 0.0, 1.0],
194 | ])
195 |
196 | value_scale = 0.05
197 | maxvisit_init = 60
198 | num_simulations = 17
199 | max_depth = 3
200 | qtransform = functools.partial(
201 | mctx.qtransform_completed_by_mix_value,
202 | value_scale=value_scale,
203 | maxvisit_init=maxvisit_init,
204 | rescale_values=True)
205 | policy_output = mctx.gumbel_muzero_policy(
206 | params=(),
207 | rng_key=jax.random.PRNGKey(0),
208 | root=root,
209 | recurrent_fn=_make_bandit_recurrent_fn(rewards),
210 | num_simulations=num_simulations,
211 | invalid_actions=invalid_actions,
212 | max_depth=max_depth,
213 | qtransform=qtransform,
214 | gumbel_scale=1.0)
215 | # Testing the action.
216 | expected_action = jnp.array([1], dtype=jnp.int32)
217 | np.testing.assert_array_equal(expected_action, policy_output.action)
218 |
219 | # Testing the action_weights.
220 | probs = jax.nn.softmax(jnp.where(
221 | invalid_actions, -jnp.inf, root.prior_logits))
222 | mix_value = 1.0 / (num_simulations + 1) * (root_value + num_simulations * (
223 | probs[:, 1] * rewards[:, 1] + probs[:, 2] * rewards[:, 2]))
224 |
225 | completed_qvalues = jnp.array([
226 | [mix_value[0], rewards[0, 1], rewards[0, 2], mix_value[0]],
227 | ])
228 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
229 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
230 | total_value_scale = (maxvisit_init + np.ceil(num_simulations / 2)
231 | ) * value_scale
232 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
233 | max_value - min_value)
234 | expected_action_weights = jax.nn.softmax(
235 | jnp.where(invalid_actions,
236 | -jnp.inf,
237 | root.prior_logits + rescaled_qvalues))
238 | np.testing.assert_allclose(expected_action_weights,
239 | policy_output.action_weights,
240 | atol=1e-6)
241 |
242 | # Testing the visit_counts.
243 | summary = policy_output.search_tree.summary()
244 | expected_visit_counts = jnp.array(
245 | [[0.0, np.ceil(num_simulations / 2), num_simulations // 2, 0.0]])
246 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)
247 |
248 | # Testing max_depth.
249 | leaf, max_found_depth = _get_deepest_leaf(
250 | jax.tree.map(lambda x: x[0], policy_output.search_tree),
251 | policy_output.search_tree.ROOT_INDEX)
252 | self.assertEqual(max_depth, max_found_depth)
253 | self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf])
254 |
255 | def test_gumbel_muzero_policy_without_invalid_actions(self):
256 | root_value = jnp.array([-5.0])
257 | root = mctx.RootFnOutput(
258 | prior_logits=jnp.array([
259 | [0.0, -1.0, 2.0, 3.0],
260 | ]),
261 | value=root_value,
262 | embedding=(),
263 | )
264 | rewards = jnp.array([
265 | [20.0, 3.0, -1.0, 10.0],
266 | ])
267 |
268 | value_scale = 0.05
269 | maxvisit_init = 60
270 | num_simulations = 17
271 | max_depth = 3
272 | qtransform = functools.partial(
273 | mctx.qtransform_completed_by_mix_value,
274 | value_scale=value_scale,
275 | maxvisit_init=maxvisit_init,
276 | rescale_values=True)
277 | policy_output = mctx.gumbel_muzero_policy(
278 | params=(),
279 | rng_key=jax.random.PRNGKey(0),
280 | root=root,
281 | recurrent_fn=_make_bandit_recurrent_fn(rewards),
282 | num_simulations=num_simulations,
283 | invalid_actions=None,
284 | max_depth=max_depth,
285 | qtransform=qtransform,
286 | gumbel_scale=1.0)
287 | # Testing the action.
288 | expected_action = jnp.array([3], dtype=jnp.int32)
289 | np.testing.assert_array_equal(expected_action, policy_output.action)
290 |
291 | # Testing the action_weights.
292 | summary = policy_output.search_tree.summary()
293 | completed_qvalues = rewards
294 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True)
295 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True)
296 | total_value_scale = (maxvisit_init + summary.visit_counts.max()
297 | ) * value_scale
298 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / (
299 | max_value - min_value)
300 | expected_action_weights = jax.nn.softmax(
301 | root.prior_logits + rescaled_qvalues)
302 | np.testing.assert_allclose(expected_action_weights,
303 | policy_output.action_weights,
304 | atol=1e-6)
305 |
306 | # Testing the visit_counts.
307 | expected_visit_counts = jnp.array(
308 | [[6, 2, 2, 7]])
309 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts)
310 |
311 | def test_stochastic_muzero_policy(self):
312 | """Tests that SMZ is equivalent to MZ with a dummy chance function."""
313 | root = mctx.RootFnOutput(
314 | prior_logits=jnp.array([
315 | [-1.0, 0.0, 2.0, 3.0],
316 | [0.0, 2.0, 5.0, -4.0],
317 | ]),
318 | value=jnp.array([1.0, 0.0]),
319 | embedding=jnp.zeros([2, 4])
320 | )
321 | rewards = jnp.zeros_like(root.prior_logits)
322 | invalid_actions = jnp.array([
323 | [0.0, 0.0, 0.0, 1.0],
324 | [1.0, 0.0, 1.0, 0.0],
325 | ])
326 |
327 | num_simulations = 10
328 |
329 | policy_output = mctx.muzero_policy(
330 | params=(),
331 | rng_key=jax.random.PRNGKey(0),
332 | root=root,
333 | recurrent_fn=_make_bandit_recurrent_fn(
334 | rewards,
335 | dummy_embedding=jnp.zeros_like(root.embedding)),
336 | num_simulations=num_simulations,
337 | invalid_actions=invalid_actions,
338 | dirichlet_fraction=0.0)
339 |
340 | num_chance_outcomes = 5
341 |
342 | decision_rec_fn, chance_rec_fn = _make_bandit_decision_and_chance_fns(
343 | rewards, num_chance_outcomes)
344 |
345 | stochastic_policy_output = mctx.stochastic_muzero_policy(
346 | params=(),
347 | rng_key=jax.random.PRNGKey(0),
348 | root=root,
349 | decision_recurrent_fn=decision_rec_fn,
350 | chance_recurrent_fn=chance_rec_fn,
351 | num_simulations=2 * num_simulations,
352 | invalid_actions=invalid_actions,
353 | dirichlet_fraction=0.0)
354 |
355 | np.testing.assert_array_equal(stochastic_policy_output.action,
356 | policy_output.action)
357 |
358 | np.testing.assert_allclose(stochastic_policy_output.action_weights,
359 | policy_output.action_weights)
360 |
361 |
362 | if __name__ == "__main__":
363 | absltest.main()
364 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | # This Pylint rcfile contains a best-effort configuration to uphold the
2 | # best-practices and style described in the Google Python style guide:
3 | # https://google.github.io/styleguide/pyguide.html
4 | #
5 | # Its canonical open-source location is:
6 | # https://google.github.io/styleguide/pylintrc
7 |
8 | [MASTER]
9 |
10 | # Files or directories to be skipped. They should be base names, not paths.
11 | ignore=third_party
12 |
13 | # Files or directories matching the regex patterns are skipped. The regex
14 | # matches against base names, not paths.
15 | ignore-patterns=
16 |
17 | # Pickle collected data for later comparisons.
18 | persistent=no
19 |
20 | # List of plugins (as comma separated values of python modules names) to load,
21 | # usually to register additional checkers.
22 | load-plugins=
23 |
24 | # Use multiple processes to speed up Pylint.
25 | jobs=4
26 |
27 | # Allow loading of arbitrary C extensions. Extensions are imported into the
28 | # active Python interpreter and may run arbitrary code.
29 | unsafe-load-any-extension=no
30 |
31 |
32 | [MESSAGES CONTROL]
33 |
34 | # Only show warnings with the listed confidence levels. Leave empty to show
35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
36 | confidence=
37 |
38 | # Enable the message, report, category or checker with the given id(s). You can
39 | # either give multiple identifier separated by comma (,) or put this option
40 | # multiple time (only on the command line, not in the configuration file where
41 | # it should appear only once). See also the "--disable" option for examples.
42 | #enable=
43 |
44 | # Disable the message, report, category or checker with the given id(s). You
45 | # can either give multiple identifiers separated by comma (,) or put this
46 | # option multiple times (only on the command line, not in the configuration
47 | # file where it should appear only once).You can also use "--disable=all" to
48 | # disable everything first and then reenable specific checks. For example, if
49 | # you want to run only the similarities checker, you can use "--disable=all
50 | # --enable=similarities". If you want to run only the classes checker, but have
51 | # no Warning level messages displayed, use"--disable=all --enable=classes
52 | # --disable=W"
53 | disable=abstract-method,
54 | apply-builtin,
55 | arguments-differ,
56 | attribute-defined-outside-init,
57 | backtick,
58 | bad-option-value,
59 | basestring-builtin,
60 | buffer-builtin,
61 | c-extension-no-member,
62 | consider-using-enumerate,
63 | cmp-builtin,
64 | cmp-method,
65 | coerce-builtin,
66 | coerce-method,
67 | delslice-method,
68 | div-method,
69 | duplicate-code,
70 | eq-without-hash,
71 | execfile-builtin,
72 | file-builtin,
73 | filter-builtin-not-iterating,
74 | fixme,
75 | getslice-method,
76 | global-statement,
77 | hex-method,
78 | idiv-method,
79 | implicit-str-concat,
80 | import-error,
81 | import-self,
82 | import-star-module-level,
83 | inconsistent-return-statements,
84 | input-builtin,
85 | intern-builtin,
86 | invalid-str-codec,
87 | locally-disabled,
88 | long-builtin,
89 | long-suffix,
90 | map-builtin-not-iterating,
91 | misplaced-comparison-constant,
92 | missing-function-docstring,
93 | metaclass-assignment,
94 | next-method-called,
95 | next-method-defined,
96 | no-absolute-import,
97 | no-else-break,
98 | no-else-continue,
99 | no-else-raise,
100 | no-else-return,
101 | no-init, # added
102 | no-member,
103 | no-name-in-module,
104 | no-self-use,
105 | nonzero-method,
106 | oct-method,
107 | old-division,
108 | old-ne-operator,
109 | old-octal-literal,
110 | old-raise-syntax,
111 | parameter-unpacking,
112 | print-statement,
113 | raising-string,
114 | range-builtin-not-iterating,
115 | raw_input-builtin,
116 | rdiv-method,
117 | reduce-builtin,
118 | relative-import,
119 | reload-builtin,
120 | round-builtin,
121 | setslice-method,
122 | signature-differs,
123 | standarderror-builtin,
124 | suppressed-message,
125 | sys-max-int,
126 | too-few-public-methods,
127 | too-many-ancestors,
128 | too-many-arguments,
129 | too-many-boolean-expressions,
130 | too-many-branches,
131 | too-many-instance-attributes,
132 | too-many-locals,
133 | too-many-nested-blocks,
134 | too-many-public-methods,
135 | too-many-return-statements,
136 | too-many-statements,
137 | trailing-newlines,
138 | unichr-builtin,
139 | unicode-builtin,
140 | unnecessary-pass,
141 | unpacking-in-except,
142 | useless-else-on-loop,
143 | useless-object-inheritance,
144 | useless-suppression,
145 | using-cmp-argument,
146 | wrong-import-order,
147 | xrange-builtin,
148 | zip-builtin-not-iterating,
149 |
150 |
151 | [REPORTS]
152 |
153 | # Set the output format. Available formats are text, parseable, colorized, msvs
154 | # (visual studio) and html. You can also give a reporter class, eg
155 | # mypackage.mymodule.MyReporterClass.
156 | output-format=text
157 |
158 | # Tells whether to display a full report or only the messages
159 | reports=no
160 |
161 | # Python expression which should return a note less than 10 (10 is the highest
162 | # note). You have access to the variables errors warning, statement which
163 | # respectively contain the number of errors / warnings messages and the total
164 | # number of statements analyzed. This is used by the global evaluation report
165 | # (RP0004).
166 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
167 |
168 | # Template used to display messages. This is a python new-style format string
169 | # used to format the message information. See doc for all details
170 | #msg-template=
171 |
172 |
173 | [BASIC]
174 |
175 | # Good variable names which should always be accepted, separated by a comma
176 | good-names=main,_
177 |
178 | # Bad variable names which should always be refused, separated by a comma
179 | bad-names=
180 |
181 | # Colon-delimited sets of names that determine each other's naming style when
182 | # the name regexes allow several styles.
183 | name-group=
184 |
185 | # Include a hint for the correct naming format with invalid-name
186 | include-naming-hint=no
187 |
188 | # List of decorators that produce properties, such as abc.abstractproperty. Add
189 | # to this list to register other decorators that produce valid properties.
190 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
191 |
192 | # Regular expression matching correct function names
193 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
194 |
195 | # Regular expression matching correct variable names
196 | variable-rgx=^[a-z][a-z0-9_]*$
197 |
198 | # Regular expression matching correct constant names
199 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
200 |
201 | # Regular expression matching correct attribute names
202 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
203 |
204 | # Regular expression matching correct argument names
205 | argument-rgx=^[a-z][a-z0-9_]*$
206 |
207 | # Regular expression matching correct class attribute names
208 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
209 |
210 | # Regular expression matching correct inline iteration names
211 | inlinevar-rgx=^[a-z][a-z0-9_]*$
212 |
213 | # Regular expression matching correct class names
214 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$
215 |
216 | # Regular expression matching correct module names
217 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
218 |
219 | # Regular expression matching correct method names
220 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
221 |
222 | # Regular expression which should only match function or class names that do
223 | # not require a docstring.
224 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
225 |
226 | # Minimum line length for functions/classes that require docstrings, shorter
227 | # ones are exempt.
228 | docstring-min-length=10
229 |
230 |
231 | [TYPECHECK]
232 |
233 | # List of decorators that produce context managers, such as
234 | # contextlib.contextmanager. Add to this list to register other decorators that
235 | # produce valid context managers.
236 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
237 |
238 | # Tells whether missing members accessed in mixin class should be ignored. A
239 | # mixin class is detected if its name ends with "mixin" (case insensitive).
240 | ignore-mixin-members=yes
241 |
242 | # List of module names for which member attributes should not be checked
243 | # (useful for modules/projects where namespaces are manipulated during runtime
244 | # and thus existing member attributes cannot be deduced by static analysis. It
245 | # supports qualified module names, as well as Unix pattern matching.
246 | ignored-modules=
247 |
248 | # List of class names for which member attributes should not be checked (useful
249 | # for classes with dynamically set attributes). This supports the use of
250 | # qualified names.
251 | ignored-classes=optparse.Values,thread._local,_thread._local
252 |
253 | # List of members which are set dynamically and missed by pylint inference
254 | # system, and so shouldn't trigger E1101 when accessed. Python regular
255 | # expressions are accepted.
256 | generated-members=
257 |
258 |
259 | [FORMAT]
260 |
261 | # Maximum number of characters on a single line.
262 | max-line-length=80
263 |
264 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
265 | # lines made too long by directives to pytype.
266 |
267 | # Regexp for a line that is allowed to be longer than the limit.
268 | ignore-long-lines=(?x)(
269 | ^\s*(\#\ )??$|
270 | ^\s*(from\s+\S+\s+)?import\s+.+$)
271 |
272 | # Allow the body of an if to be on the same line as the test if there is no
273 | # else.
274 | single-line-if-stmt=yes
275 |
276 | # Maximum number of lines in a module
277 | max-module-lines=99999
278 |
279 | # String used as indentation unit. The internal Google style guide mandates 2
280 | # spaces. Google's externaly-published style guide says 4, consistent with
281 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
282 | # projects (like TensorFlow).
283 | indent-string=' '
284 |
285 | # Number of spaces of indent required inside a hanging or continued line.
286 | indent-after-paren=4
287 |
288 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
289 | expected-line-ending-format=
290 |
291 |
292 | [MISCELLANEOUS]
293 |
294 | # List of note tags to take in consideration, separated by a comma.
295 | notes=TODO
296 |
297 |
298 | [STRING]
299 |
300 | # This flag controls whether inconsistent-quotes generates a warning when the
301 | # character used as a quote delimiter is used inconsistently within a module.
302 | check-quote-consistency=yes
303 |
304 |
305 | [VARIABLES]
306 |
307 | # Tells whether we should check for unused import in __init__ files.
308 | init-import=no
309 |
310 | # A regular expression matching the name of dummy variables (i.e. expectedly
311 | # not used).
312 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
313 |
314 | # List of additional names supposed to be defined in builtins. Remember that
315 | # you should avoid to define new builtins when possible.
316 | additional-builtins=
317 |
318 | # List of strings which can identify a callback function by name. A callback
319 | # name must start or end with one of those strings.
320 | callbacks=cb_,_cb
321 |
322 | # List of qualified module names which can have objects that can redefine
323 | # builtins.
324 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
325 |
326 |
327 | [LOGGING]
328 |
329 | # Logging modules to check that the string format arguments are in logging
330 | # function parameter format
331 | logging-modules=logging,absl.logging,tensorflow.io.logging
332 |
333 |
334 | [SIMILARITIES]
335 |
336 | # Minimum lines number of a similarity.
337 | min-similarity-lines=4
338 |
339 | # Ignore comments when computing similarities.
340 | ignore-comments=yes
341 |
342 | # Ignore docstrings when computing similarities.
343 | ignore-docstrings=yes
344 |
345 | # Ignore imports when computing similarities.
346 | ignore-imports=no
347 |
348 |
349 | [SPELLING]
350 |
351 | # Spelling dictionary name. Available dictionaries: none. To make it working
352 | # install python-enchant package.
353 | spelling-dict=
354 |
355 | # List of comma separated words that should not be checked.
356 | spelling-ignore-words=
357 |
358 | # A path to a file that contains private dictionary; one word per line.
359 | spelling-private-dict-file=
360 |
361 | # Tells whether to store unknown words to indicated private dictionary in
362 | # --spelling-private-dict-file option instead of raising a message.
363 | spelling-store-unknown-words=no
364 |
365 |
366 | [IMPORTS]
367 |
368 | # Deprecated modules which should not be used, separated by a comma
369 | deprecated-modules=regsub,
370 | TERMIOS,
371 | Bastion,
372 | rexec,
373 | sets
374 |
375 | # Create a graph of every (i.e. internal and external) dependencies in the
376 | # given file (report RP0402 must not be disabled)
377 | import-graph=
378 |
379 | # Create a graph of external dependencies in the given file (report RP0402 must
380 | # not be disabled)
381 | ext-import-graph=
382 |
383 | # Create a graph of internal dependencies in the given file (report RP0402 must
384 | # not be disabled)
385 | int-import-graph=
386 |
387 | # Force import order to recognize a module as part of the standard
388 | # compatibility libraries.
389 | known-standard-library=
390 |
391 | # Force import order to recognize a module as part of a third party library.
392 | known-third-party=enchant, absl
393 |
394 | # Analyse import fallback blocks. This can be used to support both Python 2 and
395 | # 3 compatible code, which means that the block might have code that exists
396 | # only in one or another interpreter, leading to false positives when analysed.
397 | analyse-fallback-blocks=no
398 |
399 |
400 | [CLASSES]
401 |
402 | # List of method names used to declare (i.e. assign) instance attributes.
403 | defining-attr-methods=__init__,
404 | __new__,
405 | setUp
406 |
407 | # List of member names, which should be excluded from the protected access
408 | # warning.
409 | exclude-protected=_asdict,
410 | _fields,
411 | _replace,
412 | _source,
413 | _make
414 |
415 | # List of valid names for the first argument in a class method.
416 | valid-classmethod-first-arg=cls,
417 | class_
418 |
419 | # List of valid names for the first argument in a metaclass class method.
420 | valid-metaclass-classmethod-first-arg=mcs
421 |
422 |
423 | [EXCEPTIONS]
424 |
425 | # Exceptions that will emit a warning when being caught. Defaults to
426 | # "Exception"
427 | overgeneral-exceptions=builtins.StandardError,
428 | builtins.Exception,
429 | builtins.BaseException
430 |
--------------------------------------------------------------------------------
/mctx/_src/search.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A JAX implementation of batched MCTS."""
16 | import functools
17 | from typing import Any, NamedTuple, Optional, Tuple, TypeVar
18 |
19 | import chex
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from mctx._src import action_selection
24 | from mctx._src import base
25 | from mctx._src import tree as tree_lib
26 |
27 | Tree = tree_lib.Tree
28 | T = TypeVar("T")
29 |
30 |
31 | def search(
32 | params: base.Params,
33 | rng_key: chex.PRNGKey,
34 | *,
35 | root: base.RootFnOutput,
36 | recurrent_fn: base.RecurrentFn,
37 | root_action_selection_fn: base.RootActionSelectionFn,
38 | interior_action_selection_fn: base.InteriorActionSelectionFn,
39 | num_simulations: int,
40 | max_depth: Optional[int] = None,
41 | invalid_actions: Optional[chex.Array] = None,
42 | extra_data: Any = None,
43 | loop_fn: base.LoopFn = jax.lax.fori_loop) -> Tree:
44 | """Performs a full search and returns sampled actions.
45 |
46 | In the shape descriptions, `B` denotes the batch dimension.
47 |
48 | Args:
49 | params: params to be forwarded to root and recurrent functions.
50 | rng_key: random number generator state, the key is consumed.
51 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
52 | `prior_logits` are from a policy network. The shapes are
53 | `([B, num_actions], [B], [B, ...])`, respectively.
54 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
55 | actions retrieved by the simulation step, which takes as args
56 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
57 | and the new state embedding. The `rng_key` argument is consumed.
58 | root_action_selection_fn: function used to select an action at the root.
59 | interior_action_selection_fn: function used to select an action during
60 | simulation.
61 | num_simulations: the number of simulations.
62 | max_depth: maximum search tree depth allowed during simulation, defined as
63 | the number of edges from the root to a leaf node.
64 | invalid_actions: a mask with invalid actions at the root. In the
65 | mask, invalid actions have ones, and valid actions have zeros.
66 | Shape `[B, num_actions]`.
67 | extra_data: extra data passed to `tree.extra_data`. Shape `[B, ...]`.
68 | loop_fn: Function used to run the simulations. It may be required to pass
69 | hk.fori_loop if using this function inside a Haiku module.
70 |
71 | Returns:
72 | `SearchResults` containing outcomes of the search, e.g. `visit_counts`
73 | `[B, num_actions]`.
74 | """
75 | action_selection_fn = action_selection.switching_action_selection_wrapper(
76 | root_action_selection_fn=root_action_selection_fn,
77 | interior_action_selection_fn=interior_action_selection_fn
78 | )
79 |
80 | # Do simulation, expansion, and backward steps.
81 | batch_size = root.value.shape[0]
82 | batch_range = jnp.arange(batch_size)
83 | if max_depth is None:
84 | max_depth = num_simulations
85 | if invalid_actions is None:
86 | invalid_actions = jnp.zeros_like(root.prior_logits)
87 |
88 | def body_fun(sim, loop_state):
89 | rng_key, tree = loop_state
90 | rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)
91 | # simulate is vmapped and expects batched rng keys.
92 | simulate_keys = jax.random.split(simulate_key, batch_size)
93 | parent_index, action = simulate(
94 | simulate_keys, tree, action_selection_fn, max_depth)
95 | # A node first expanded on simulation `i`, will have node index `i`.
96 | # Node 0 corresponds to the root node.
97 | next_node_index = tree.children_index[batch_range, parent_index, action]
98 | next_node_index = jnp.where(next_node_index == Tree.UNVISITED,
99 | sim + 1, next_node_index)
100 | tree = expand(
101 | params, expand_key, tree, recurrent_fn, parent_index,
102 | action, next_node_index)
103 | tree = backward(tree, next_node_index)
104 | loop_state = rng_key, tree
105 | return loop_state
106 |
107 | # Allocate all necessary storage.
108 | tree = instantiate_tree_from_root(root, num_simulations,
109 | root_invalid_actions=invalid_actions,
110 | extra_data=extra_data)
111 | _, tree = loop_fn(
112 | 0, num_simulations, body_fun, (rng_key, tree))
113 |
114 | return tree
115 |
116 |
117 | class _SimulationState(NamedTuple):
118 | """The state for the simulation while loop."""
119 | rng_key: chex.PRNGKey
120 | node_index: int
121 | action: int
122 | next_node_index: int
123 | depth: int
124 | is_continuing: bool
125 |
126 |
127 | @functools.partial(jax.vmap, in_axes=[0, 0, None, None], out_axes=0)
128 | def simulate(
129 | rng_key: chex.PRNGKey,
130 | tree: Tree,
131 | action_selection_fn: base.InteriorActionSelectionFn,
132 | max_depth: int) -> Tuple[chex.Array, chex.Array]:
133 | """Traverses the tree until reaching an unvisited action or `max_depth`.
134 |
135 | Each simulation starts from the root and keeps selecting actions traversing
136 | the tree until a leaf or `max_depth` is reached.
137 |
138 | Args:
139 | rng_key: random number generator state, the key is consumed.
140 | tree: _unbatched_ MCTS tree state.
141 | action_selection_fn: function used to select an action during simulation.
142 | max_depth: maximum search tree depth allowed during simulation.
143 |
144 | Returns:
145 | `(parent_index, action)` tuple, where `parent_index` is the index of the
146 | node reached at the end of the simulation, and the `action` is the action to
147 | evaluate from the `parent_index`.
148 | """
149 | def cond_fun(state):
150 | return state.is_continuing
151 |
152 | def body_fun(state):
153 | # Preparing the next simulation state.
154 | node_index = state.next_node_index
155 | rng_key, action_selection_key = jax.random.split(state.rng_key)
156 | action = action_selection_fn(action_selection_key, tree, node_index,
157 | state.depth)
158 | next_node_index = tree.children_index[node_index, action]
159 | # The returned action will be visited.
160 | depth = state.depth + 1
161 | is_before_depth_cutoff = depth < max_depth
162 | is_visited = next_node_index != Tree.UNVISITED
163 | is_continuing = jnp.logical_and(is_visited, is_before_depth_cutoff)
164 | return _SimulationState( # pytype: disable=wrong-arg-types # jax-types
165 | rng_key=rng_key,
166 | node_index=node_index,
167 | action=action,
168 | next_node_index=next_node_index,
169 | depth=depth,
170 | is_continuing=is_continuing)
171 |
172 | node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32)
173 | depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype)
174 | # pytype: disable=wrong-arg-types # jnp-type
175 | initial_state = _SimulationState(
176 | rng_key=rng_key,
177 | node_index=tree.NO_PARENT,
178 | action=tree.NO_PARENT,
179 | next_node_index=node_index,
180 | depth=depth,
181 | is_continuing=jnp.array(True))
182 | # pytype: enable=wrong-arg-types
183 | end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state)
184 |
185 | # Returning a node with a selected action.
186 | # The action can be already visited, if the max_depth is reached.
187 | return end_state.node_index, end_state.action
188 |
189 |
190 | def expand(
191 | params: chex.Array,
192 | rng_key: chex.PRNGKey,
193 | tree: Tree[T],
194 | recurrent_fn: base.RecurrentFn,
195 | parent_index: chex.Array,
196 | action: chex.Array,
197 | next_node_index: chex.Array) -> Tree[T]:
198 | """Create and evaluate child nodes from given nodes and unvisited actions.
199 |
200 | Args:
201 | params: params to be forwarded to recurrent function.
202 | rng_key: random number generator state.
203 | tree: the MCTS tree state to update.
204 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
205 | actions retrieved by the simulation step, which takes as args
206 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
207 | and the new state embedding. The `rng_key` argument is consumed.
208 | parent_index: the index of the parent node, from which the action will be
209 | expanded. Shape `[B]`.
210 | action: the action to expand. Shape `[B]`.
211 | next_node_index: the index of the newly expanded node. This can be the index
212 | of an existing node, if `max_depth` is reached. Shape `[B]`.
213 |
214 | Returns:
215 | tree: updated MCTS tree state.
216 | """
217 | batch_size = tree_lib.infer_batch_size(tree)
218 | batch_range = jnp.arange(batch_size)
219 | chex.assert_shape([parent_index, action, next_node_index], (batch_size,))
220 |
221 | # Retrieve states for nodes to be evaluated.
222 | embedding = jax.tree.map(
223 | lambda x: x[batch_range, parent_index], tree.embeddings)
224 |
225 | # Evaluate and create a new node.
226 | step, embedding = recurrent_fn(params, rng_key, action, embedding)
227 | chex.assert_shape(step.prior_logits, [batch_size, tree.num_actions])
228 | chex.assert_shape(step.reward, [batch_size])
229 | chex.assert_shape(step.discount, [batch_size])
230 | chex.assert_shape(step.value, [batch_size])
231 | tree = update_tree_node(
232 | tree, next_node_index, step.prior_logits, step.value, embedding)
233 |
234 | # Return updated tree topology.
235 | return tree.replace(
236 | children_index=batch_update(
237 | tree.children_index, next_node_index, parent_index, action),
238 | children_rewards=batch_update(
239 | tree.children_rewards, step.reward, parent_index, action),
240 | children_discounts=batch_update(
241 | tree.children_discounts, step.discount, parent_index, action),
242 | parents=batch_update(tree.parents, parent_index, next_node_index),
243 | action_from_parent=batch_update(
244 | tree.action_from_parent, action, next_node_index))
245 |
246 |
247 | @jax.vmap
248 | def backward(
249 | tree: Tree[T],
250 | leaf_index: chex.Numeric) -> Tree[T]:
251 | """Goes up and updates the tree until all nodes reached the root.
252 |
253 | Args:
254 | tree: the MCTS tree state to update, without the batch size.
255 | leaf_index: the node index from which to do the backward.
256 |
257 | Returns:
258 | Updated MCTS tree state.
259 | """
260 |
261 | def cond_fun(loop_state):
262 | _, _, index = loop_state
263 | return index != Tree.ROOT_INDEX
264 |
265 | def body_fun(loop_state):
266 | # Here we update the value of our parent, so we start by reversing.
267 | tree, leaf_value, index = loop_state
268 | parent = tree.parents[index]
269 | count = tree.node_visits[parent]
270 | action = tree.action_from_parent[index]
271 | reward = tree.children_rewards[parent, action]
272 | leaf_value = reward + tree.children_discounts[parent, action] * leaf_value
273 | parent_value = (
274 | tree.node_values[parent] * count + leaf_value) / (count + 1.0)
275 | children_values = tree.node_values[index]
276 | children_counts = tree.children_visits[parent, action] + 1
277 |
278 | tree = tree.replace(
279 | node_values=update(tree.node_values, parent_value, parent),
280 | node_visits=update(tree.node_visits, count + 1, parent),
281 | children_values=update(
282 | tree.children_values, children_values, parent, action),
283 | children_visits=update(
284 | tree.children_visits, children_counts, parent, action))
285 |
286 | return tree, leaf_value, parent
287 |
288 | leaf_index = jnp.asarray(leaf_index, dtype=jnp.int32)
289 | loop_state = (tree, tree.node_values[leaf_index], leaf_index)
290 | tree, _, _ = jax.lax.while_loop(cond_fun, body_fun, loop_state)
291 |
292 | return tree
293 |
294 |
295 | # Utility function to set the values of certain indices to prescribed values.
296 | # This is vmapped to operate seamlessly on batches.
297 | def update(x, vals, *indices):
298 | return x.at[indices].set(vals)
299 |
300 |
301 | batch_update = jax.vmap(update)
302 |
303 |
304 | def update_tree_node(
305 | tree: Tree[T],
306 | node_index: chex.Array,
307 | prior_logits: chex.Array,
308 | value: chex.Array,
309 | embedding: chex.Array) -> Tree[T]:
310 | """Updates the tree at node index.
311 |
312 | Args:
313 | tree: `Tree` to whose node is to be updated.
314 | node_index: the index of the expanded node. Shape `[B]`.
315 | prior_logits: the prior logits to fill in for the new node, of shape
316 | `[B, num_actions]`.
317 | value: the value to fill in for the new node. Shape `[B]`.
318 | embedding: the state embeddings for the node. Shape `[B, ...]`.
319 |
320 | Returns:
321 | The new tree with updated nodes.
322 | """
323 | batch_size = tree_lib.infer_batch_size(tree)
324 | batch_range = jnp.arange(batch_size)
325 | chex.assert_shape(prior_logits, (batch_size, tree.num_actions))
326 |
327 | # When using max_depth, a leaf can be expanded multiple times.
328 | new_visit = tree.node_visits[batch_range, node_index] + 1
329 | updates = dict( # pylint: disable=use-dict-literal
330 | children_prior_logits=batch_update(
331 | tree.children_prior_logits, prior_logits, node_index),
332 | raw_values=batch_update(
333 | tree.raw_values, value, node_index),
334 | node_values=batch_update(
335 | tree.node_values, value, node_index),
336 | node_visits=batch_update(
337 | tree.node_visits, new_visit, node_index),
338 | embeddings=jax.tree.map(
339 | lambda t, s: batch_update(t, s, node_index),
340 | tree.embeddings, embedding))
341 |
342 | return tree.replace(**updates)
343 |
344 |
345 | def instantiate_tree_from_root(
346 | root: base.RootFnOutput,
347 | num_simulations: int,
348 | root_invalid_actions: chex.Array,
349 | extra_data: Any) -> Tree:
350 | """Initializes tree state at search root."""
351 | chex.assert_rank(root.prior_logits, 2)
352 | batch_size, num_actions = root.prior_logits.shape
353 | chex.assert_shape(root.value, [batch_size])
354 | num_nodes = num_simulations + 1
355 | data_dtype = root.value.dtype
356 | batch_node = (batch_size, num_nodes)
357 | batch_node_action = (batch_size, num_nodes, num_actions)
358 |
359 | def _zeros(x):
360 | return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype)
361 |
362 | # Create a new empty tree state and fill its root.
363 | tree = Tree(
364 | node_visits=jnp.zeros(batch_node, dtype=jnp.int32),
365 | raw_values=jnp.zeros(batch_node, dtype=data_dtype),
366 | node_values=jnp.zeros(batch_node, dtype=data_dtype),
367 | parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32),
368 | action_from_parent=jnp.full(
369 | batch_node, Tree.NO_PARENT, dtype=jnp.int32),
370 | children_index=jnp.full(
371 | batch_node_action, Tree.UNVISITED, dtype=jnp.int32),
372 | children_prior_logits=jnp.zeros(
373 | batch_node_action, dtype=root.prior_logits.dtype),
374 | children_values=jnp.zeros(batch_node_action, dtype=data_dtype),
375 | children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
376 | children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
377 | children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
378 | embeddings=jax.tree.map(_zeros, root.embedding),
379 | root_invalid_actions=root_invalid_actions,
380 | extra_data=extra_data)
381 |
382 | root_index = jnp.full([batch_size], Tree.ROOT_INDEX)
383 | tree = update_tree_node(
384 | tree, root_index, root.prior_logits, root.value, root.embedding)
385 | return tree
386 |
--------------------------------------------------------------------------------
/mctx/_src/policies.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Search policies."""
16 | import functools
17 | from typing import Optional, Tuple
18 |
19 | import chex
20 | import jax
21 | import jax.numpy as jnp
22 |
23 | from mctx._src import action_selection
24 | from mctx._src import base
25 | from mctx._src import qtransforms
26 | from mctx._src import search
27 | from mctx._src import seq_halving
28 |
29 |
30 | def muzero_policy(
31 | params: base.Params,
32 | rng_key: chex.PRNGKey,
33 | root: base.RootFnOutput,
34 | recurrent_fn: base.RecurrentFn,
35 | num_simulations: int,
36 | invalid_actions: Optional[chex.Array] = None,
37 | max_depth: Optional[int] = None,
38 | loop_fn: base.LoopFn = jax.lax.fori_loop,
39 | *,
40 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
41 | dirichlet_fraction: chex.Numeric = 0.25,
42 | dirichlet_alpha: chex.Numeric = 0.3,
43 | pb_c_init: chex.Numeric = 1.25,
44 | pb_c_base: chex.Numeric = 19652,
45 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]:
46 | """Runs MuZero search and returns the `PolicyOutput`.
47 |
48 | In the shape descriptions, `B` denotes the batch dimension.
49 |
50 | Args:
51 | params: params to be forwarded to root and recurrent functions.
52 | rng_key: random number generator state, the key is consumed.
53 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
54 | `prior_logits` are from a policy network. The shapes are
55 | `([B, num_actions], [B], [B, ...])`, respectively.
56 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
57 | actions retrieved by the simulation step, which takes as args
58 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
59 | and the new state embedding. The `rng_key` argument is consumed.
60 | num_simulations: the number of simulations.
61 | invalid_actions: a mask with invalid actions. Invalid actions
62 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
63 | max_depth: maximum search tree depth allowed during simulation.
64 | loop_fn: Function used to run the simulations. It may be required to pass
65 | hk.fori_loop if using this function inside a Haiku module.
66 | qtransform: function to obtain completed Q-values for a node.
67 | dirichlet_fraction: float from 0 to 1 interpolating between using only the
68 | prior policy or just the Dirichlet noise.
69 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet
70 | distribution.
71 | pb_c_init: constant c_1 in the PUCT formula.
72 | pb_c_base: constant c_2 in the PUCT formula.
73 | temperature: temperature for acting proportionally to
74 | `visit_counts**(1 / temperature)`.
75 |
76 | Returns:
77 | `PolicyOutput` containing the proposed action, action_weights and the used
78 | search tree.
79 | """
80 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)
81 |
82 | # Adding Dirichlet noise.
83 | noisy_logits = _get_logits_from_probs(
84 | _add_dirichlet_noise(
85 | dirichlet_rng_key,
86 | jax.nn.softmax(root.prior_logits),
87 | dirichlet_fraction=dirichlet_fraction,
88 | dirichlet_alpha=dirichlet_alpha))
89 | root = root.replace(
90 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions))
91 |
92 | # Running the search.
93 | interior_action_selection_fn = functools.partial(
94 | action_selection.muzero_action_selection,
95 | pb_c_base=pb_c_base,
96 | pb_c_init=pb_c_init,
97 | qtransform=qtransform)
98 | root_action_selection_fn = functools.partial(
99 | interior_action_selection_fn,
100 | depth=0)
101 | search_tree = search.search(
102 | params=params,
103 | rng_key=search_rng_key,
104 | root=root,
105 | recurrent_fn=recurrent_fn,
106 | root_action_selection_fn=root_action_selection_fn,
107 | interior_action_selection_fn=interior_action_selection_fn,
108 | num_simulations=num_simulations,
109 | max_depth=max_depth,
110 | invalid_actions=invalid_actions,
111 | loop_fn=loop_fn)
112 |
113 | # Sampling the proposed action proportionally to the visit counts.
114 | summary = search_tree.summary()
115 | action_weights = summary.visit_probs
116 | action_logits = _apply_temperature(
117 | _get_logits_from_probs(action_weights), temperature)
118 | action = jax.random.categorical(rng_key, action_logits)
119 | return base.PolicyOutput(
120 | action=action,
121 | action_weights=action_weights,
122 | search_tree=search_tree)
123 |
124 |
125 | def gumbel_muzero_policy(
126 | params: base.Params,
127 | rng_key: chex.PRNGKey,
128 | root: base.RootFnOutput,
129 | recurrent_fn: base.RecurrentFn,
130 | num_simulations: int,
131 | invalid_actions: Optional[chex.Array] = None,
132 | max_depth: Optional[int] = None,
133 | loop_fn: base.LoopFn = jax.lax.fori_loop,
134 | *,
135 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value,
136 | max_num_considered_actions: int = 16,
137 | gumbel_scale: chex.Numeric = 1.,
138 | ) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]:
139 | """Runs Gumbel MuZero search and returns the `PolicyOutput`.
140 |
141 | This policy implements Full Gumbel MuZero from
142 | "Policy improvement by planning with Gumbel".
143 | https://openreview.net/forum?id=bERaNdoegnO
144 |
145 | At the root of the search tree, actions are selected by Sequential Halving
146 | with Gumbel. At non-root nodes (aka interior nodes), actions are selected by
147 | the Full Gumbel MuZero deterministic action selection.
148 |
149 | In the shape descriptions, `B` denotes the batch dimension.
150 |
151 | Args:
152 | params: params to be forwarded to root and recurrent functions.
153 | rng_key: random number generator state, the key is consumed.
154 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
155 | `prior_logits` are from a policy network. The shapes are
156 | `([B, num_actions], [B], [B, ...])`, respectively.
157 | recurrent_fn: a callable to be called on the leaf nodes and unvisited
158 | actions retrieved by the simulation step, which takes as args
159 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
160 | and the new state embedding. The `rng_key` argument is consumed.
161 | num_simulations: the number of simulations.
162 | invalid_actions: a mask with invalid actions. Invalid actions
163 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
164 | max_depth: maximum search tree depth allowed during simulation.
165 | loop_fn: Function used to run the simulations. It may be required to pass
166 | hk.fori_loop if using this function inside a Haiku module.
167 | qtransform: function to obtain completed Q-values for a node.
168 | max_num_considered_actions: the maximum number of actions expanded at the
169 | root node. A smaller number of actions will be expanded if the number of
170 | valid actions is smaller.
171 | gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information
172 | games can use gumbel_scale=0.0.
173 |
174 | Returns:
175 | `PolicyOutput` containing the proposed action, action_weights and the used
176 | search tree.
177 | """
178 | # Masking invalid actions.
179 | root = root.replace(
180 | prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions))
181 |
182 | # Generating Gumbel.
183 | rng_key, gumbel_rng = jax.random.split(rng_key)
184 | gumbel = gumbel_scale * jax.random.gumbel(
185 | gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype)
186 |
187 | # Searching.
188 | extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)
189 | search_tree = search.search(
190 | params=params,
191 | rng_key=rng_key,
192 | root=root,
193 | recurrent_fn=recurrent_fn,
194 | root_action_selection_fn=functools.partial(
195 | action_selection.gumbel_muzero_root_action_selection,
196 | num_simulations=num_simulations,
197 | max_num_considered_actions=max_num_considered_actions,
198 | qtransform=qtransform,
199 | ),
200 | interior_action_selection_fn=functools.partial(
201 | action_selection.gumbel_muzero_interior_action_selection,
202 | qtransform=qtransform,
203 | ),
204 | num_simulations=num_simulations,
205 | max_depth=max_depth,
206 | invalid_actions=invalid_actions,
207 | extra_data=extra_data,
208 | loop_fn=loop_fn)
209 | summary = search_tree.summary()
210 |
211 | # Acting with the best action from the most visited actions.
212 | # The "best" action has the highest `gumbel + logits + q`.
213 | # Inside the minibatch, the considered_visit can be different on states with
214 | # a smaller number of valid actions.
215 | considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
216 | # The completed_qvalues include imputed values for unvisited actions.
217 | completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])( # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long
218 | search_tree, search_tree.ROOT_INDEX)
219 | to_argmax = seq_halving.score_considered(
220 | considered_visit, gumbel, root.prior_logits, completed_qvalues,
221 | summary.visit_counts)
222 | action = action_selection.masked_argmax(to_argmax, invalid_actions)
223 |
224 | # Producing action_weights usable to train the policy network.
225 | completed_search_logits = _mask_invalid_actions(
226 | root.prior_logits + completed_qvalues, invalid_actions)
227 | action_weights = jax.nn.softmax(completed_search_logits)
228 | return base.PolicyOutput(
229 | action=action,
230 | action_weights=action_weights,
231 | search_tree=search_tree)
232 |
233 |
234 | def stochastic_muzero_policy(
235 | params: chex.ArrayTree,
236 | rng_key: chex.PRNGKey,
237 | root: base.RootFnOutput,
238 | decision_recurrent_fn: base.DecisionRecurrentFn,
239 | chance_recurrent_fn: base.ChanceRecurrentFn,
240 | num_simulations: int,
241 | invalid_actions: Optional[chex.Array] = None,
242 | max_depth: Optional[int] = None,
243 | loop_fn: base.LoopFn = jax.lax.fori_loop,
244 | *,
245 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings,
246 | dirichlet_fraction: chex.Numeric = 0.25,
247 | dirichlet_alpha: chex.Numeric = 0.3,
248 | pb_c_init: chex.Numeric = 1.25,
249 | pb_c_base: chex.Numeric = 19652,
250 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]:
251 | """Runs Stochastic MuZero search.
252 |
253 | Implements search as described in the Stochastic MuZero paper:
254 | (https://openreview.net/forum?id=X6D9bAHhBQ1).
255 |
256 | In the shape descriptions, `B` denotes the batch dimension.
257 | Args:
258 | params: params to be forwarded to root and recurrent functions.
259 | rng_key: random number generator state, the key is consumed.
260 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
261 | `prior_logits` are from a policy network. The shapes are `([B,
262 | num_actions], [B], [B, ...])`, respectively.
263 | decision_recurrent_fn: a callable to be called on the leaf decision nodes
264 | and unvisited actions retrieved by the simulation step, which takes as
265 | args `(params, rng_key, action, state_embedding)` and returns a
266 | `(DecisionRecurrentFnOutput, afterstate_embedding)`.
267 | chance_recurrent_fn: a callable to be called on the leaf chance nodes and
268 | unvisited actions retrieved by the simulation step, which takes as args
269 | `(params, rng_key, chance_outcome, afterstate_embedding)` and returns a
270 | `(ChanceRecurrentFnOutput, state_embedding)`.
271 | num_simulations: the number of simulations.
272 | invalid_actions: a mask with invalid actions. Invalid actions have ones,
273 | valid actions have zeros in the mask. Shape `[B, num_actions]`.
274 | max_depth: maximum search tree depth allowed during simulation.
275 | loop_fn: Function used to run the simulations. It may be required to pass
276 | hk.fori_loop if using this function inside a Haiku module.
277 | qtransform: function to obtain completed Q-values for a node.
278 | dirichlet_fraction: float from 0 to 1 interpolating between using only the
279 | prior policy or just the Dirichlet noise.
280 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet
281 | distribution.
282 | pb_c_init: constant c_1 in the PUCT formula.
283 | pb_c_base: constant c_2 in the PUCT formula.
284 | temperature: temperature for acting proportionally to `visit_counts**(1 /
285 | temperature)`.
286 |
287 | Returns:
288 | `PolicyOutput` containing the proposed action, action_weights and the used
289 | search tree.
290 | """
291 |
292 | num_actions = root.prior_logits.shape[-1]
293 |
294 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)
295 |
296 | # Adding Dirichlet noise.
297 | noisy_logits = _get_logits_from_probs(
298 | _add_dirichlet_noise(
299 | dirichlet_rng_key,
300 | jax.nn.softmax(root.prior_logits),
301 | dirichlet_fraction=dirichlet_fraction,
302 | dirichlet_alpha=dirichlet_alpha))
303 |
304 | root = root.replace(
305 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions))
306 |
307 | # construct a dummy afterstate embedding
308 | batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0]
309 | dummy_action = jnp.zeros([batch_size], dtype=jnp.int32)
310 | dummy_output, dummy_afterstate_embedding = decision_recurrent_fn(
311 | params, rng_key, dummy_action, root.embedding)
312 | num_chance_outcomes = dummy_output.chance_logits.shape[-1]
313 |
314 | root = root.replace(
315 | # pad action logits with num_chance_outcomes so dim is A + C
316 | prior_logits=jnp.concatenate([
317 | root.prior_logits,
318 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf)
319 | ], axis=-1),
320 | # replace embedding with wrapper.
321 | embedding=base.StochasticRecurrentState(
322 | state_embedding=root.embedding,
323 | afterstate_embedding=dummy_afterstate_embedding,
324 | is_decision_node=jnp.ones([batch_size], dtype=bool)))
325 |
326 | # Stochastic MuZero Change: We need to be able to tell if different nodes are
327 | # decision or chance. This is accomplished by imposing a special structure
328 | # on the embeddings stored in each node. Each embedding is an instance of
329 | # StochasticRecurrentState which maintains this information.
330 | recurrent_fn = _make_stochastic_recurrent_fn(
331 | decision_node_fn=decision_recurrent_fn,
332 | chance_node_fn=chance_recurrent_fn,
333 | num_actions=num_actions,
334 | num_chance_outcomes=num_chance_outcomes,
335 | )
336 |
337 | # Running the search.
338 |
339 | interior_decision_node_selection_fn = functools.partial(
340 | action_selection.muzero_action_selection,
341 | pb_c_base=pb_c_base,
342 | pb_c_init=pb_c_init,
343 | qtransform=qtransform)
344 |
345 | interior_action_selection_fn = _make_stochastic_action_selection_fn(
346 | interior_decision_node_selection_fn, num_actions)
347 |
348 | root_action_selection_fn = functools.partial(
349 | interior_action_selection_fn, depth=0)
350 |
351 | search_tree = search.search(
352 | params=params,
353 | rng_key=search_rng_key,
354 | root=root,
355 | recurrent_fn=recurrent_fn,
356 | root_action_selection_fn=root_action_selection_fn,
357 | interior_action_selection_fn=interior_action_selection_fn,
358 | num_simulations=num_simulations,
359 | max_depth=max_depth,
360 | invalid_actions=invalid_actions,
361 | loop_fn=loop_fn)
362 |
363 | # Sampling the proposed action proportionally to the visit counts.
364 | search_tree = _mask_tree(search_tree, num_actions, 'decision')
365 | summary = search_tree.summary()
366 | action_weights = summary.visit_probs
367 | action_logits = _apply_temperature(
368 | _get_logits_from_probs(action_weights), temperature)
369 | action = jax.random.categorical(rng_key, action_logits)
370 | return base.PolicyOutput(
371 | action=action, action_weights=action_weights, search_tree=search_tree)
372 |
373 |
374 | def _mask_invalid_actions(logits, invalid_actions):
375 | """Returns logits with zero mass to invalid actions."""
376 | if invalid_actions is None:
377 | return logits
378 | chex.assert_equal_shape([logits, invalid_actions])
379 | logits = logits - jnp.max(logits, axis=-1, keepdims=True)
380 | # At the end of an episode, all actions can be invalid. A softmax would then
381 | # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
382 | # a finite `min_logit` for the invalid actions.
383 | min_logit = jnp.finfo(logits.dtype).min
384 | return jnp.where(invalid_actions, min_logit, logits)
385 |
386 |
387 | def _get_logits_from_probs(probs):
388 | tiny = jnp.finfo(probs.dtype).tiny
389 | return jnp.log(jnp.maximum(probs, tiny))
390 |
391 |
392 | def _add_dirichlet_noise(rng_key, probs, *, dirichlet_alpha,
393 | dirichlet_fraction):
394 | """Mixes the probs with Dirichlet noise."""
395 | chex.assert_rank(probs, 2)
396 | chex.assert_type([dirichlet_alpha, dirichlet_fraction], float)
397 |
398 | batch_size, num_actions = probs.shape
399 | noise = jax.random.dirichlet(
400 | rng_key,
401 | alpha=jnp.full([num_actions], fill_value=dirichlet_alpha),
402 | shape=(batch_size,))
403 | noisy_probs = (1 - dirichlet_fraction) * probs + dirichlet_fraction * noise
404 | return noisy_probs
405 |
406 |
407 | def _apply_temperature(logits, temperature):
408 | """Returns `logits / temperature`, supporting also temperature=0."""
409 | # The max subtraction prevents +inf after dividing by a small temperature.
410 | logits = logits - jnp.max(logits, keepdims=True, axis=-1)
411 | tiny = jnp.finfo(logits.dtype).tiny
412 | return logits / jnp.maximum(tiny, temperature)
413 |
414 |
415 | def _make_stochastic_recurrent_fn(
416 | decision_node_fn: base.DecisionRecurrentFn,
417 | chance_node_fn: base.ChanceRecurrentFn,
418 | num_actions: int,
419 | num_chance_outcomes: int,
420 | ) -> base.RecurrentFn:
421 | """Make Stochastic Recurrent Fn."""
422 |
423 | def stochastic_recurrent_fn(
424 | params: base.Params,
425 | rng: chex.PRNGKey,
426 | action_or_chance: base.Action, # [B]
427 | state: base.StochasticRecurrentState
428 | ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentState]:
429 | batch_size = jax.tree_util.tree_leaves(state.state_embedding)[0].shape[0]
430 | # Internally we assume that there are `A' = A + C` "actions";
431 | # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,.
432 | # To interpret it as an action we can leave it as is:
433 | action = action_or_chance - 0
434 | # To interpret it as a chance outcome we subtract num_actions:
435 | chance_outcome = action_or_chance - num_actions
436 |
437 | decision_output, afterstate_embedding = decision_node_fn(
438 | params, rng, action, state.state_embedding)
439 | # Outputs from DecisionRecurrentFunction produce chance logits with
440 | # dim `C`, to respect our internal convention that there are `A' = A + C`
441 | # "actions" we pad with `A` dummy logits which are ultimately ignored:
442 | # see `_mask_tree`.
443 | output_if_decision_node = base.RecurrentFnOutput(
444 | prior_logits=jnp.concatenate([
445 | jnp.full([batch_size, num_actions], fill_value=-jnp.inf),
446 | decision_output.chance_logits], axis=-1),
447 | value=decision_output.afterstate_value,
448 | reward=jnp.zeros_like(decision_output.afterstate_value),
449 | discount=jnp.ones_like(decision_output.afterstate_value))
450 |
451 | chance_output, state_embedding = chance_node_fn(params, rng, chance_outcome,
452 | state.afterstate_embedding)
453 | # Outputs from ChanceRecurrentFunction produce action logits with dim `A`,
454 | # to respect our internal convention that there are `A' = A + C` "actions"
455 | # we pad with `C` dummy logits which are ultimately ignored: see
456 | # `_mask_tree`.
457 | output_if_chance_node = base.RecurrentFnOutput(
458 | prior_logits=jnp.concatenate([
459 | chance_output.action_logits,
460 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf)
461 | ], axis=-1),
462 | value=chance_output.value,
463 | reward=chance_output.reward,
464 | discount=chance_output.discount)
465 |
466 | new_state = base.StochasticRecurrentState(
467 | state_embedding=state_embedding,
468 | afterstate_embedding=afterstate_embedding,
469 | is_decision_node=jnp.logical_not(state.is_decision_node))
470 |
471 | def _broadcast_where(decision_leaf, chance_leaf):
472 | extra_dims = [1] * (len(decision_leaf.shape) - 1)
473 | expanded_is_decision = jnp.reshape(state.is_decision_node,
474 | [-1] + extra_dims)
475 | return jnp.where(
476 | # ensure state.is_decision node has appropriate shape.
477 | expanded_is_decision,
478 | decision_leaf, chance_leaf)
479 |
480 | output = jax.tree.map(_broadcast_where,
481 | output_if_decision_node,
482 | output_if_chance_node)
483 | return output, new_state
484 |
485 | return stochastic_recurrent_fn
486 |
487 |
488 | def _mask_tree(tree: search.Tree, num_actions: int, mode: str) -> search.Tree:
489 | """Masks out parts of the tree based upon node type.
490 |
491 | "Actions" in our tree can either be action or chance values: A' = A + C. This
492 | utility function masks the parts of the tree containing dimensions of shape
493 | A' to be either A or C depending upon `mode`.
494 |
495 | Args:
496 | tree: The tree to be masked.
497 | num_actions: The number of environment actions A.
498 | mode: Either "decision" or "chance".
499 |
500 | Returns:
501 | An appropriately masked tree.
502 | """
503 |
504 | def _take_slice(x):
505 | if mode == 'decision':
506 | return x[..., :num_actions]
507 | elif mode == 'chance':
508 | return x[..., num_actions:]
509 | else:
510 | raise ValueError(f'Unknown mode: {mode}.')
511 |
512 | return tree.replace(
513 | children_index=_take_slice(tree.children_index),
514 | children_prior_logits=_take_slice(tree.children_prior_logits),
515 | children_visits=_take_slice(tree.children_visits),
516 | children_rewards=_take_slice(tree.children_rewards),
517 | children_discounts=_take_slice(tree.children_discounts),
518 | children_values=_take_slice(tree.children_values),
519 | root_invalid_actions=_take_slice(tree.root_invalid_actions))
520 |
521 |
522 | def _make_stochastic_action_selection_fn(
523 | decision_node_selection_fn: base.InteriorActionSelectionFn,
524 | num_actions: int,
525 | ) -> base.InteriorActionSelectionFn:
526 | """Make Stochastic Action Selection Fn."""
527 |
528 | # NOTE: trees are unbatched here.
529 |
530 | def _chance_node_selection_fn(
531 | tree: search.Tree,
532 | node_index: chex.Array,
533 | ) -> chex.Array:
534 | num_chance = tree.children_visits[node_index]
535 | chance_logits = tree.children_prior_logits[node_index]
536 | prob_chance = jax.nn.softmax(chance_logits)
537 | argmax_chance = jnp.argmax(prob_chance / (num_chance + 1), axis=-1).astype(
538 | jnp.int32
539 | )
540 | return argmax_chance
541 |
542 | def _action_selection_fn(key: chex.PRNGKey, tree: search.Tree,
543 | node_index: chex.Array,
544 | depth: chex.Array) -> chex.Array:
545 | is_decision = tree.embeddings.is_decision_node[node_index]
546 | chance_selection = _chance_node_selection_fn(
547 | tree=_mask_tree(tree, num_actions, 'chance'),
548 | node_index=node_index) + num_actions
549 | decision_selection = decision_node_selection_fn(
550 | key, _mask_tree(tree, num_actions, 'decision'), node_index, depth)
551 | return jax.lax.cond(is_decision, lambda: decision_selection,
552 | lambda: chance_selection)
553 |
554 | return _action_selection_fn
555 |
--------------------------------------------------------------------------------