├── .github └── workflows │ ├── ci.yml │ └── pypi-publish.yml ├── .gitignore ├── .pylintrc ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── policy_improvement_demo.py └── visualization_demo.py ├── mctx ├── __init__.py ├── _src │ ├── __init__.py │ ├── action_selection.py │ ├── base.py │ ├── policies.py │ ├── qtransforms.py │ ├── search.py │ ├── seq_halving.py │ ├── tests │ │ ├── mctx_test.py │ │ ├── policies_test.py │ │ ├── qtransforms_test.py │ │ ├── seq_halving_test.py │ │ ├── test_data │ │ │ ├── gumbel_muzero_reward_tree.json │ │ │ ├── gumbel_muzero_tree.json │ │ │ ├── muzero_qtransform_tree.json │ │ │ └── muzero_tree.json │ │ └── tree_test.py │ └── tree.py └── py.typed ├── requirements ├── requirements-test.txt ├── requirements.txt └── requirements_examples.txt ├── setup.py └── test.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build-and-test: 11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 12 | runs-on: "${{ matrix.os }}" 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.11", "3.12", "3.13"] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: "actions/checkout@v5" 21 | - uses: "actions/setup-python@v5.3.0" 22 | with: 23 | python-version: "${{ matrix.python-version }}" 24 | - name: Run CI tests 25 | run: bash test.sh 26 | shell: bash 27 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | id-token: write 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel twine 22 | - name: Check consistency between the package version and release tag 23 | run: | 24 | RELEASE_VER=${GITHUB_REF#refs/*/} 25 | PACKAGE_VER="v`python setup.py --version`" 26 | if [ $RELEASE_VER != $PACKAGE_VER ] 27 | then 28 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 29 | fi 30 | - name: Build 31 | run: | 32 | python setup.py sdist bdist_wheel 33 | - name: Publish package distributions to PyPI 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | 9 | # Mac OS 10 | .DS_Store 11 | 12 | # Python tools 13 | .mypy_cache/ 14 | .pytype/ 15 | .ipynb_checkpoints 16 | 17 | # Editors 18 | .idea 19 | .vscode 20 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=abstract-method, 54 | apply-builtin, 55 | arguments-differ, 56 | attribute-defined-outside-init, 57 | backtick, 58 | bad-option-value, 59 | basestring-builtin, 60 | buffer-builtin, 61 | c-extension-no-member, 62 | consider-using-enumerate, 63 | cmp-builtin, 64 | cmp-method, 65 | coerce-builtin, 66 | coerce-method, 67 | delslice-method, 68 | div-method, 69 | duplicate-code, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | inconsistent-return-statements, 84 | input-builtin, 85 | intern-builtin, 86 | invalid-str-codec, 87 | locally-disabled, 88 | long-builtin, 89 | long-suffix, 90 | map-builtin-not-iterating, 91 | misplaced-comparison-constant, 92 | missing-function-docstring, 93 | metaclass-assignment, 94 | next-method-called, 95 | next-method-defined, 96 | no-absolute-import, 97 | no-else-break, 98 | no-else-continue, 99 | no-else-raise, 100 | no-else-return, 101 | no-init, # added 102 | no-member, 103 | no-name-in-module, 104 | no-self-use, 105 | nonzero-method, 106 | oct-method, 107 | old-division, 108 | old-ne-operator, 109 | old-octal-literal, 110 | old-raise-syntax, 111 | parameter-unpacking, 112 | print-statement, 113 | raising-string, 114 | range-builtin-not-iterating, 115 | raw_input-builtin, 116 | rdiv-method, 117 | reduce-builtin, 118 | relative-import, 119 | reload-builtin, 120 | round-builtin, 121 | setslice-method, 122 | signature-differs, 123 | standarderror-builtin, 124 | suppressed-message, 125 | sys-max-int, 126 | too-few-public-methods, 127 | too-many-ancestors, 128 | too-many-arguments, 129 | too-many-boolean-expressions, 130 | too-many-branches, 131 | too-many-instance-attributes, 132 | too-many-locals, 133 | too-many-nested-blocks, 134 | too-many-public-methods, 135 | too-many-return-statements, 136 | too-many-statements, 137 | trailing-newlines, 138 | unichr-builtin, 139 | unicode-builtin, 140 | unnecessary-pass, 141 | unpacking-in-except, 142 | useless-else-on-loop, 143 | useless-object-inheritance, 144 | useless-suppression, 145 | using-cmp-argument, 146 | wrong-import-order, 147 | xrange-builtin, 148 | zip-builtin-not-iterating, 149 | 150 | 151 | [REPORTS] 152 | 153 | # Set the output format. Available formats are text, parseable, colorized, msvs 154 | # (visual studio) and html. You can also give a reporter class, eg 155 | # mypackage.mymodule.MyReporterClass. 156 | output-format=text 157 | 158 | # Tells whether to display a full report or only the messages 159 | reports=no 160 | 161 | # Python expression which should return a note less than 10 (10 is the highest 162 | # note). You have access to the variables errors warning, statement which 163 | # respectively contain the number of errors / warnings messages and the total 164 | # number of statements analyzed. This is used by the global evaluation report 165 | # (RP0004). 166 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 167 | 168 | # Template used to display messages. This is a python new-style format string 169 | # used to format the message information. See doc for all details 170 | #msg-template= 171 | 172 | 173 | [BASIC] 174 | 175 | # Good variable names which should always be accepted, separated by a comma 176 | good-names=main,_ 177 | 178 | # Bad variable names which should always be refused, separated by a comma 179 | bad-names= 180 | 181 | # Colon-delimited sets of names that determine each other's naming style when 182 | # the name regexes allow several styles. 183 | name-group= 184 | 185 | # Include a hint for the correct naming format with invalid-name 186 | include-naming-hint=no 187 | 188 | # List of decorators that produce properties, such as abc.abstractproperty. Add 189 | # to this list to register other decorators that produce valid properties. 190 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 191 | 192 | # Regular expression matching correct function names 193 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 194 | 195 | # Regular expression matching correct variable names 196 | variable-rgx=^[a-z][a-z0-9_]*$ 197 | 198 | # Regular expression matching correct constant names 199 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 200 | 201 | # Regular expression matching correct attribute names 202 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 203 | 204 | # Regular expression matching correct argument names 205 | argument-rgx=^[a-z][a-z0-9_]*$ 206 | 207 | # Regular expression matching correct class attribute names 208 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 209 | 210 | # Regular expression matching correct inline iteration names 211 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 212 | 213 | # Regular expression matching correct class names 214 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 215 | 216 | # Regular expression matching correct module names 217 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 218 | 219 | # Regular expression matching correct method names 220 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 221 | 222 | # Regular expression which should only match function or class names that do 223 | # not require a docstring. 224 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 225 | 226 | # Minimum line length for functions/classes that require docstrings, shorter 227 | # ones are exempt. 228 | docstring-min-length=10 229 | 230 | 231 | [TYPECHECK] 232 | 233 | # List of decorators that produce context managers, such as 234 | # contextlib.contextmanager. Add to this list to register other decorators that 235 | # produce valid context managers. 236 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 237 | 238 | # Tells whether missing members accessed in mixin class should be ignored. A 239 | # mixin class is detected if its name ends with "mixin" (case insensitive). 240 | ignore-mixin-members=yes 241 | 242 | # List of module names for which member attributes should not be checked 243 | # (useful for modules/projects where namespaces are manipulated during runtime 244 | # and thus existing member attributes cannot be deduced by static analysis. It 245 | # supports qualified module names, as well as Unix pattern matching. 246 | ignored-modules= 247 | 248 | # List of class names for which member attributes should not be checked (useful 249 | # for classes with dynamically set attributes). This supports the use of 250 | # qualified names. 251 | ignored-classes=optparse.Values,thread._local,_thread._local 252 | 253 | # List of members which are set dynamically and missed by pylint inference 254 | # system, and so shouldn't trigger E1101 when accessed. Python regular 255 | # expressions are accepted. 256 | generated-members= 257 | 258 | 259 | [FORMAT] 260 | 261 | # Maximum number of characters on a single line. 262 | max-line-length=80 263 | 264 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 265 | # lines made too long by directives to pytype. 266 | 267 | # Regexp for a line that is allowed to be longer than the limit. 268 | ignore-long-lines=(?x)( 269 | ^\s*(\#\ )??$| 270 | ^\s*(from\s+\S+\s+)?import\s+.+$) 271 | 272 | # Allow the body of an if to be on the same line as the test if there is no 273 | # else. 274 | single-line-if-stmt=yes 275 | 276 | # Maximum number of lines in a module 277 | max-module-lines=99999 278 | 279 | # String used as indentation unit. The internal Google style guide mandates 2 280 | # spaces. Google's externaly-published style guide says 4, consistent with 281 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 282 | # projects (like TensorFlow). 283 | indent-string=' ' 284 | 285 | # Number of spaces of indent required inside a hanging or continued line. 286 | indent-after-paren=4 287 | 288 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 289 | expected-line-ending-format= 290 | 291 | 292 | [MISCELLANEOUS] 293 | 294 | # List of note tags to take in consideration, separated by a comma. 295 | notes=TODO 296 | 297 | 298 | [STRING] 299 | 300 | # This flag controls whether inconsistent-quotes generates a warning when the 301 | # character used as a quote delimiter is used inconsistently within a module. 302 | check-quote-consistency=yes 303 | 304 | 305 | [VARIABLES] 306 | 307 | # Tells whether we should check for unused import in __init__ files. 308 | init-import=no 309 | 310 | # A regular expression matching the name of dummy variables (i.e. expectedly 311 | # not used). 312 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 313 | 314 | # List of additional names supposed to be defined in builtins. Remember that 315 | # you should avoid to define new builtins when possible. 316 | additional-builtins= 317 | 318 | # List of strings which can identify a callback function by name. A callback 319 | # name must start or end with one of those strings. 320 | callbacks=cb_,_cb 321 | 322 | # List of qualified module names which can have objects that can redefine 323 | # builtins. 324 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 325 | 326 | 327 | [LOGGING] 328 | 329 | # Logging modules to check that the string format arguments are in logging 330 | # function parameter format 331 | logging-modules=logging,absl.logging,tensorflow.io.logging 332 | 333 | 334 | [SIMILARITIES] 335 | 336 | # Minimum lines number of a similarity. 337 | min-similarity-lines=4 338 | 339 | # Ignore comments when computing similarities. 340 | ignore-comments=yes 341 | 342 | # Ignore docstrings when computing similarities. 343 | ignore-docstrings=yes 344 | 345 | # Ignore imports when computing similarities. 346 | ignore-imports=no 347 | 348 | 349 | [SPELLING] 350 | 351 | # Spelling dictionary name. Available dictionaries: none. To make it working 352 | # install python-enchant package. 353 | spelling-dict= 354 | 355 | # List of comma separated words that should not be checked. 356 | spelling-ignore-words= 357 | 358 | # A path to a file that contains private dictionary; one word per line. 359 | spelling-private-dict-file= 360 | 361 | # Tells whether to store unknown words to indicated private dictionary in 362 | # --spelling-private-dict-file option instead of raising a message. 363 | spelling-store-unknown-words=no 364 | 365 | 366 | [IMPORTS] 367 | 368 | # Deprecated modules which should not be used, separated by a comma 369 | deprecated-modules=regsub, 370 | TERMIOS, 371 | Bastion, 372 | rexec, 373 | sets 374 | 375 | # Create a graph of every (i.e. internal and external) dependencies in the 376 | # given file (report RP0402 must not be disabled) 377 | import-graph= 378 | 379 | # Create a graph of external dependencies in the given file (report RP0402 must 380 | # not be disabled) 381 | ext-import-graph= 382 | 383 | # Create a graph of internal dependencies in the given file (report RP0402 must 384 | # not be disabled) 385 | int-import-graph= 386 | 387 | # Force import order to recognize a module as part of the standard 388 | # compatibility libraries. 389 | known-standard-library= 390 | 391 | # Force import order to recognize a module as part of a third party library. 392 | known-third-party=enchant, absl 393 | 394 | # Analyse import fallback blocks. This can be used to support both Python 2 and 395 | # 3 compatible code, which means that the block might have code that exists 396 | # only in one or another interpreter, leading to false positives when analysed. 397 | analyse-fallback-blocks=no 398 | 399 | 400 | [CLASSES] 401 | 402 | # List of method names used to declare (i.e. assign) instance attributes. 403 | defining-attr-methods=__init__, 404 | __new__, 405 | setUp 406 | 407 | # List of member names, which should be excluded from the protected access 408 | # warning. 409 | exclude-protected=_asdict, 410 | _fields, 411 | _replace, 412 | _source, 413 | _make 414 | 415 | # List of valid names for the first argument in a class method. 416 | valid-classmethod-first-arg=cls, 417 | class_ 418 | 419 | # List of valid names for the first argument in a metaclass class method. 420 | valid-metaclass-classmethod-first-arg=mcs 421 | 422 | 423 | [EXCEPTIONS] 424 | 425 | # Exceptions that will emit a warning when being caught. Defaults to 426 | # "Exception" 427 | overgeneral-exceptions=builtins.StandardError, 428 | builtins.Exception, 429 | builtins.BaseException 430 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Testing 26 | 27 | Please make sure that your PR passes all tests by running `bash test.sh` on your 28 | local machine. Also, you can run only tests that are affected by your code 29 | changes, but you will need to select them manually. 30 | 31 | ## Community Guidelines 32 | 33 | This project follows [Google's Open Source Community 34 | Guidelines](https://opensource.google.com/conduct/). 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements/* 4 | include mctx/py.typed 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mctx: MCTS-in-JAX 2 | 3 | Mctx is a library with a [JAX](https://github.com/google/jax)-native 4 | implementation of Monte Carlo tree search (MCTS) algorithms such as 5 | [AlphaZero](https://deepmind.com/blog/article/alphazero-shedding-new-light-grand-games-chess-shogi-and-go), 6 | [MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules), and 7 | [Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO). For computation 8 | speed up, the implementation fully supports JIT-compilation. Search algorithms 9 | in Mctx are defined for and operate on batches of inputs, in parallel. This 10 | allows to make the most of the accelerators and enables the algorithms to work 11 | with large learned environment models parameterized by deep neural networks. 12 | 13 | ## Installation 14 | 15 | You can install the latest released version of Mctx from PyPI via: 16 | 17 | ```sh 18 | pip install mctx 19 | ``` 20 | 21 | or you can install the latest development version from GitHub: 22 | 23 | ```sh 24 | pip install git+https://github.com/google-deepmind/mctx.git 25 | ``` 26 | 27 | ## Motivation 28 | 29 | Learning and search have been important topics since the early days of AI 30 | research. In the [words of Rich Sutton](http://www.incompleteideas.net/IncIdeas/BitterLesson.html): 31 | 32 | > One thing that should be learned [...] is the great power of general purpose 33 | > methods, of methods that continue to scale with increased computation even as 34 | > the available computation becomes very great. The two methods that seem to 35 | > scale arbitrarily in this way are *search* and *learning*. 36 | 37 | Recently, search algorithms have been successfully combined with learned models 38 | parameterized by deep neural networks, resulting in some of the most powerful 39 | and general reinforcement learning algorithms to date (e.g. MuZero). 40 | However, using search algorithms in combination with deep neural networks 41 | requires efficient implementations, typically written in fast compiled 42 | languages; this can come at the expense of usability and hackability, 43 | especially for researchers that are not familiar with C++. In turn, this limits 44 | adoption and further research on this critical topic. 45 | 46 | Through this library, we hope to help researchers everywhere to contribute to 47 | such an exciting area of research. We provide JAX-native implementations of core 48 | search algorithms such as MCTS, that we believe strike a good balance between 49 | performance and usability for researchers that want to investigate search-based 50 | algorithms in Python. The search methods provided by Mctx are 51 | heavily configurable to allow researchers to explore a variety of ideas in 52 | this space, and contribute to the next generation of search based agents. 53 | 54 | ## Search in Reinforcement Learning 55 | 56 | In Reinforcement Learning the *agent* must learn to interact with the 57 | *environment* in order to maximize a scalar *reward* signal. On each step the 58 | agent must select an action and receives in exchange an observation and a 59 | reward. We may call whatever mechanism the agent uses to select the action the 60 | agent's *policy*. 61 | 62 | Classically, policies are parameterized directly by a function approximator (as 63 | in REINFORCE), or policies are inferred by inspecting a set of learned estimates 64 | of the value of each action (as in Q-learning). Alternatively, search allows to 65 | select actions by constructing on the fly, in each state, a policy or a value 66 | function local to the current state, by *searching* using a learned *model* of 67 | the environment. 68 | 69 | Exhaustive search over all possible future courses of actions is computationally 70 | prohibitive in any non trivial environment, hence we need search algorithms 71 | that can make the best use of a finite computational budget. Typically priors 72 | are needed to guide which nodes in the search tree to expand (to reduce the 73 | *breadth* of the tree that we construct), and value functions are used to 74 | estimate the value of incomplete paths in the tree that don't reach an episode 75 | termination (to reduce the *depth* of the search tree). 76 | 77 | ## Quickstart 78 | 79 | Mctx provides a low-level generic `search` function and high-level concrete 80 | policies: `muzero_policy` and `gumbel_muzero_policy`. 81 | 82 | The user needs to provide several learned components to specify the 83 | representation, dynamics and prediction used by [MuZero](https://deepmind.com/blog/article/muzero-mastering-go-chess-shogi-and-atari-without-rules). 84 | In the context of the Mctx library, the representation of the *root* state is 85 | specified by a `RootFnOutput`. The `RootFnOutput` contains the `prior_logits` 86 | from a policy network, the estimated `value` of the root state, and any 87 | `embedding` suitable to represent the root state for the environment model. 88 | 89 | The dynamics environment model needs to be specified by a `recurrent_fn`. 90 | A `recurrent_fn(params, rng_key, action, embedding)` call takes an `action` and 91 | a state `embedding`. The call should return a tuple `(recurrent_fn_output, 92 | new_embedding)` with a `RecurrentFnOutput` and the embedding of the next state. 93 | The `RecurrentFnOutput` contains the `reward` and `discount` for the transition, 94 | and `prior_logits` and `value` for the new state. 95 | 96 | In [`examples/visualization_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/visualization_demo.py), you can 97 | see calls to a policy: 98 | 99 | ```python 100 | policy_output = mctx.gumbel_muzero_policy(params, rng_key, root, recurrent_fn, 101 | num_simulations=32) 102 | ``` 103 | 104 | The `policy_output.action` contains the action proposed by the search. That 105 | action can be passed to the environment. To improve the policy, the 106 | `policy_output.action_weights` contain targets usable to train the policy 107 | probabilities. 108 | 109 | We recommend to use the `gumbel_muzero_policy`. 110 | [Gumbel MuZero](https://openreview.net/forum?id=bERaNdoegnO) guarantees a policy 111 | improvement if the action values are correctly evaluated. The policy improvement 112 | is demonstrated in 113 | [`examples/policy_improvement_demo.py`](https://github.com/google-deepmind/mctx/blob/main/examples/policy_improvement_demo.py). 114 | 115 | ### Example projects 116 | The following projects demonstrate the Mctx usage: 117 | 118 | - [Pgx](https://github.com/sotetsuk/pgx) — A collection of 20+ vectorized 119 | JAX environments, including backgammon, chess, shogi, Go, and an AlphaZero 120 | example. 121 | - [Basic Learning Demo with Mctx](https://github.com/kenjyoung/mctx_learning_demo) — 122 | AlphaZero on random mazes. 123 | - [a0-jax](https://github.com/NTT123/a0-jax) — AlphaZero on Connect Four, 124 | Gomoku, and Go. 125 | - [muax](https://github.com/bwfbowen/muax) — MuZero on gym-style environments 126 | (CartPole, LunarLander). 127 | - [Classic MCTS](https://github.com/Carbon225/mctx-classic) — A simple example on Connect Four. 128 | - [mctx-az](https://github.com/lowrollr/mctx-az) — Mctx with AlphaZero subtree persistence. 129 | 130 | Tell us about your project. 131 | 132 | ## Citing Mctx 133 | 134 | This repository is part of the DeepMind JAX Ecosystem, to cite Mctx 135 | please use the citation: 136 | 137 | ```bibtex 138 | @software{deepmind2020jax, 139 | title = {The {D}eep{M}ind {JAX} {E}cosystem}, 140 | author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio}, 141 | url = {http://github.com/deepmind}, 142 | year = {2020}, 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /examples/policy_improvement_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A demonstration of the policy improvement by planning with Gumbel.""" 16 | 17 | import functools 18 | from typing import Tuple 19 | 20 | from absl import app 21 | from absl import flags 22 | import chex 23 | import jax 24 | import jax.numpy as jnp 25 | import mctx 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_integer("seed", 42, "Random seed.") 29 | flags.DEFINE_integer("batch_size", 256, "Batch size.") 30 | flags.DEFINE_integer("num_actions", 82, "Number of actions.") 31 | flags.DEFINE_integer("num_simulations", 4, "Number of simulations.") 32 | flags.DEFINE_integer("max_num_considered_actions", 16, 33 | "The maximum number of actions expanded at the root.") 34 | flags.DEFINE_integer("num_runs", 1, "Number of runs on random data.") 35 | 36 | 37 | @chex.dataclass(frozen=True) 38 | class DemoOutput: 39 | prior_policy_value: chex.Array 40 | prior_policy_action_value: chex.Array 41 | selected_action_value: chex.Array 42 | action_weights_policy_value: chex.Array 43 | 44 | 45 | def _run_demo(rng_key: chex.PRNGKey) -> Tuple[chex.PRNGKey, DemoOutput]: 46 | """Runs a search algorithm on random data.""" 47 | batch_size = FLAGS.batch_size 48 | rng_key, logits_rng, q_rng, search_rng = jax.random.split(rng_key, 4) 49 | # We will demonstrate the algorithm on random prior_logits. 50 | # Normally, the prior_logits would be produced by a policy network. 51 | prior_logits = jax.random.normal( 52 | logits_rng, shape=[batch_size, FLAGS.num_actions]) 53 | # Defining a bandit with random Q-values. Only the Q-values of the visited 54 | # actions will be revealed to the search algorithm. 55 | qvalues = jax.random.uniform(q_rng, shape=prior_logits.shape) 56 | # If we know the value under the prior policy, we can use the value to 57 | # complete the missing Q-values. The completed Q-values will produce an 58 | # improved policy in `policy_output.action_weights`. 59 | raw_value = jnp.sum(jax.nn.softmax(prior_logits) * qvalues, axis=-1) 60 | use_mixed_value = False 61 | 62 | # The root output would be the output of MuZero representation network. 63 | root = mctx.RootFnOutput( 64 | prior_logits=prior_logits, 65 | value=raw_value, 66 | # The embedding is used only to implement the MuZero model. 67 | embedding=jnp.zeros([batch_size]), 68 | ) 69 | # The recurrent_fn would be provided by MuZero dynamics network. 70 | recurrent_fn = _make_bandit_recurrent_fn(qvalues) 71 | 72 | # Running the search. 73 | policy_output = mctx.gumbel_muzero_policy( 74 | params=(), 75 | rng_key=search_rng, 76 | root=root, 77 | recurrent_fn=recurrent_fn, 78 | num_simulations=FLAGS.num_simulations, 79 | max_num_considered_actions=FLAGS.max_num_considered_actions, 80 | qtransform=functools.partial( 81 | mctx.qtransform_completed_by_mix_value, 82 | use_mixed_value=use_mixed_value), 83 | ) 84 | 85 | # Collecting the Q-value of the selected action. 86 | selected_action_value = qvalues[jnp.arange(batch_size), policy_output.action] 87 | 88 | # We will compare the selected action to the action selected by the 89 | # prior policy, while using the same Gumbel random numbers. 90 | gumbel = policy_output.search_tree.extra_data.root_gumbel 91 | prior_policy_action = jnp.argmax(gumbel + prior_logits, axis=-1) 92 | prior_policy_action_value = qvalues[jnp.arange(batch_size), 93 | prior_policy_action] 94 | 95 | # Computing the policy value under the new action_weights. 96 | action_weights_policy_value = jnp.sum( 97 | policy_output.action_weights * qvalues, axis=-1) 98 | 99 | output = DemoOutput( 100 | prior_policy_value=raw_value, 101 | prior_policy_action_value=prior_policy_action_value, 102 | selected_action_value=selected_action_value, 103 | action_weights_policy_value=action_weights_policy_value, 104 | ) 105 | return rng_key, output 106 | 107 | 108 | def _make_bandit_recurrent_fn(qvalues): 109 | """Returns a recurrent_fn for a determistic bandit.""" 110 | 111 | def recurrent_fn(params, rng_key, action, embedding): 112 | del params, rng_key 113 | # For the bandit, the reward will be non-zero only at the root. 114 | reward = jnp.where(embedding == 0, 115 | qvalues[jnp.arange(action.shape[0]), action], 116 | 0.0) 117 | # On a single-player environment, use discount from [0, 1]. 118 | # On a zero-sum self-play environment, use discount=-1. 119 | discount = jnp.ones_like(reward) 120 | recurrent_fn_output = mctx.RecurrentFnOutput( 121 | reward=reward, 122 | discount=discount, 123 | prior_logits=jnp.zeros_like(qvalues), 124 | value=jnp.zeros_like(reward)) 125 | next_embedding = embedding + 1 126 | return recurrent_fn_output, next_embedding 127 | 128 | return recurrent_fn 129 | 130 | 131 | def main(_): 132 | rng_key = jax.random.PRNGKey(FLAGS.seed) 133 | jitted_run_demo = jax.jit(_run_demo) 134 | for _ in range(FLAGS.num_runs): 135 | rng_key, output = jitted_run_demo(rng_key) 136 | # Printing the obtained increase of the policy value. 137 | # The obtained increase should be non-negative. 138 | action_value_improvement = ( 139 | output.selected_action_value - output.prior_policy_action_value) 140 | weights_value_improvement = ( 141 | output.action_weights_policy_value - output.prior_policy_value) 142 | print("action value improvement: %.3f (min=%.3f)" % 143 | (action_value_improvement.mean(), action_value_improvement.min())) 144 | print("action_weights value improvement: %.3f (min=%.3f)" % 145 | (weights_value_improvement.mean(), weights_value_improvement.min())) 146 | 147 | 148 | if __name__ == "__main__": 149 | app.run(main) 150 | -------------------------------------------------------------------------------- /examples/visualization_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A demo of Graphviz visualization of a search tree.""" 16 | 17 | from typing import Optional, Sequence 18 | 19 | from absl import app 20 | from absl import flags 21 | import chex 22 | import jax 23 | import jax.numpy as jnp 24 | import mctx 25 | import pygraphviz 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_integer("seed", 42, "Random seed.") 29 | flags.DEFINE_integer("num_simulations", 32, "Number of simulations.") 30 | flags.DEFINE_integer("max_num_considered_actions", 16, 31 | "The maximum number of actions expanded at the root.") 32 | flags.DEFINE_integer("max_depth", None, "The maximum search depth.") 33 | flags.DEFINE_string("output_file", "/tmp/search_tree.png", 34 | "The output file for the visualization.") 35 | 36 | 37 | def convert_tree_to_graph( 38 | tree: mctx.Tree, 39 | action_labels: Optional[Sequence[str]] = None, 40 | batch_index: int = 0 41 | ) -> pygraphviz.AGraph: 42 | """Converts a search tree into a Graphviz graph. 43 | 44 | Args: 45 | tree: A `Tree` containing a batch of search data. 46 | action_labels: Optional labels for edges, defaults to the action index. 47 | batch_index: Index of the batch element to plot. 48 | 49 | Returns: 50 | A Graphviz graph representation of `tree`. 51 | """ 52 | chex.assert_rank(tree.node_values, 2) 53 | batch_size = tree.node_values.shape[0] 54 | if action_labels is None: 55 | action_labels = range(tree.num_actions) 56 | elif len(action_labels) != tree.num_actions: 57 | raise ValueError( 58 | f"action_labels {action_labels} has the wrong number of actions " 59 | f"({len(action_labels)}). " 60 | f"Expecting {tree.num_actions}.") 61 | 62 | def node_to_str(node_i, reward=0, discount=1): 63 | return (f"{node_i}\n" 64 | f"Reward: {reward:.2f}\n" 65 | f"Discount: {discount:.2f}\n" 66 | f"Value: {tree.node_values[batch_index, node_i]:.2f}\n" 67 | f"Visits: {tree.node_visits[batch_index, node_i]}\n") 68 | 69 | def edge_to_str(node_i, a_i): 70 | node_index = jnp.full([batch_size], node_i) 71 | probs = jax.nn.softmax(tree.children_prior_logits[batch_index, node_i]) 72 | return (f"{action_labels[a_i]}\n" 73 | f"Q: {tree.qvalues(node_index)[batch_index, a_i]:.2f}\n" # pytype: disable=unsupported-operands # always-use-return-annotations 74 | f"p: {probs[a_i]:.2f}\n") 75 | 76 | graph = pygraphviz.AGraph(directed=True) 77 | 78 | # Add root 79 | graph.add_node(0, label=node_to_str(node_i=0), color="green") 80 | # Add all other nodes and connect them up. 81 | for node_i in range(tree.num_simulations): 82 | for a_i in range(tree.num_actions): 83 | # Index of children, or -1 if not expanded 84 | children_i = tree.children_index[batch_index, node_i, a_i] 85 | if children_i >= 0: 86 | graph.add_node( 87 | children_i, 88 | label=node_to_str( 89 | node_i=children_i, 90 | reward=tree.children_rewards[batch_index, node_i, a_i], 91 | discount=tree.children_discounts[batch_index, node_i, a_i]), 92 | color="red") 93 | graph.add_edge(node_i, children_i, label=edge_to_str(node_i, a_i)) 94 | 95 | return graph 96 | 97 | 98 | def _run_demo(rng_key: chex.PRNGKey): 99 | """Runs a search algorithm on a toy environment.""" 100 | # We will define a deterministic toy environment. 101 | # The deterministic `transition_matrix` has shape `[num_states, num_actions]`. 102 | # The `transition_matrix[s, a]` holds the next state. 103 | transition_matrix = jnp.array([ 104 | [1, 2, 3, 4], 105 | [0, 5, 0, 0], 106 | [0, 0, 0, 6], 107 | [0, 0, 0, 0], 108 | [0, 0, 0, 0], 109 | [0, 0, 0, 0], 110 | [0, 0, 0, 0], 111 | ], dtype=jnp.int32) 112 | # The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]` 113 | # holds the reward for that (s, a) pair. 114 | rewards = jnp.array([ 115 | [1, -1, 0, 0], 116 | [0, 0, 0, 0], 117 | [0, 0, 0, 0], 118 | [0, 0, 0, 0], 119 | [0, 0, 0, 0], 120 | [0, 0, 0, 0], 121 | [10, 0, 20, 0], 122 | ], dtype=jnp.float32) 123 | num_states = rewards.shape[0] 124 | # The discount for each (s, a) pair. 125 | discounts = jnp.where(transition_matrix > 0, 1.0, 0.0) 126 | # Using optimistic initial values to encourage exploration. 127 | values = jnp.full([num_states], 15.0) 128 | # The prior policies for each state. 129 | all_prior_logits = jnp.zeros_like(rewards) 130 | root, recurrent_fn = _make_batched_env_model( 131 | # Using batch_size=2 to test the batched search. 132 | batch_size=2, 133 | transition_matrix=transition_matrix, 134 | rewards=rewards, 135 | discounts=discounts, 136 | values=values, 137 | prior_logits=all_prior_logits) 138 | 139 | # Running the search. 140 | policy_output = mctx.gumbel_muzero_policy( 141 | params=(), 142 | rng_key=rng_key, 143 | root=root, 144 | recurrent_fn=recurrent_fn, 145 | num_simulations=FLAGS.num_simulations, 146 | max_depth=FLAGS.max_depth, 147 | max_num_considered_actions=FLAGS.max_num_considered_actions, 148 | ) 149 | return policy_output 150 | 151 | 152 | def _make_batched_env_model( 153 | batch_size: int, 154 | *, 155 | transition_matrix: chex.Array, 156 | rewards: chex.Array, 157 | discounts: chex.Array, 158 | values: chex.Array, 159 | prior_logits: chex.Array): 160 | """Returns a batched `(root, recurrent_fn)`.""" 161 | chex.assert_equal_shape([transition_matrix, rewards, discounts, 162 | prior_logits]) 163 | num_states, num_actions = transition_matrix.shape 164 | chex.assert_shape(values, [num_states]) 165 | # We will start the search at state zero. 166 | root_state = 0 167 | root = mctx.RootFnOutput( 168 | prior_logits=jnp.full([batch_size, num_actions], 169 | prior_logits[root_state]), 170 | value=jnp.full([batch_size], values[root_state]), 171 | # The embedding will hold the state index. 172 | embedding=jnp.zeros([batch_size], dtype=jnp.int32), 173 | ) 174 | 175 | def recurrent_fn(params, rng_key, action, embedding): 176 | del params, rng_key 177 | chex.assert_shape(action, [batch_size]) 178 | chex.assert_shape(embedding, [batch_size]) 179 | recurrent_fn_output = mctx.RecurrentFnOutput( 180 | reward=rewards[embedding, action], 181 | discount=discounts[embedding, action], 182 | prior_logits=prior_logits[embedding], 183 | value=values[embedding]) 184 | next_embedding = transition_matrix[embedding, action] 185 | return recurrent_fn_output, next_embedding 186 | 187 | return root, recurrent_fn 188 | 189 | 190 | def main(_): 191 | rng_key = jax.random.PRNGKey(FLAGS.seed) 192 | jitted_run_demo = jax.jit(_run_demo) 193 | print("Starting search.") 194 | policy_output = jitted_run_demo(rng_key) 195 | batch_index = 0 196 | selected_action = policy_output.action[batch_index] 197 | q_value = policy_output.search_tree.summary().qvalues[ 198 | batch_index, selected_action] 199 | print("Selected action:", selected_action) 200 | # To estimate the value of the root state, use the Q-value of the selected 201 | # action. The Q-value is not affected by the exploration at the root node. 202 | print("Selected action Q-value:", q_value) 203 | graph = convert_tree_to_graph(policy_output.search_tree) 204 | print("Saving tree diagram to:", FLAGS.output_file) 205 | graph.draw(FLAGS.output_file, prog="dot") 206 | 207 | 208 | if __name__ == "__main__": 209 | app.run(main) 210 | -------------------------------------------------------------------------------- /mctx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Mctx: Monte Carlo tree search in JAX.""" 16 | 17 | from mctx._src.action_selection import gumbel_muzero_interior_action_selection 18 | from mctx._src.action_selection import gumbel_muzero_root_action_selection 19 | from mctx._src.action_selection import GumbelMuZeroExtraData 20 | from mctx._src.action_selection import muzero_action_selection 21 | from mctx._src.base import ChanceRecurrentFnOutput 22 | from mctx._src.base import DecisionRecurrentFnOutput 23 | from mctx._src.base import InteriorActionSelectionFn 24 | from mctx._src.base import LoopFn 25 | from mctx._src.base import PolicyOutput 26 | from mctx._src.base import RecurrentFn 27 | from mctx._src.base import RecurrentFnOutput 28 | from mctx._src.base import RecurrentState 29 | from mctx._src.base import RootActionSelectionFn 30 | from mctx._src.base import RootFnOutput 31 | from mctx._src.policies import gumbel_muzero_policy 32 | from mctx._src.policies import muzero_policy 33 | from mctx._src.policies import stochastic_muzero_policy 34 | from mctx._src.qtransforms import qtransform_by_min_max 35 | from mctx._src.qtransforms import qtransform_by_parent_and_siblings 36 | from mctx._src.qtransforms import qtransform_completed_by_mix_value 37 | from mctx._src.search import search 38 | from mctx._src.tree import Tree 39 | 40 | __version__ = "0.0.6" 41 | 42 | __all__ = ( 43 | "ChanceRecurrentFnOutput", 44 | "DecisionRecurrentFnOutput", 45 | "GumbelMuZeroExtraData", 46 | "InteriorActionSelectionFn", 47 | "LoopFn", 48 | "PolicyOutput", 49 | "RecurrentFn", 50 | "RecurrentFnOutput", 51 | "RecurrentState", 52 | "RootActionSelectionFn", 53 | "RootFnOutput", 54 | "Tree", 55 | "gumbel_muzero_interior_action_selection", 56 | "gumbel_muzero_policy", 57 | "gumbel_muzero_root_action_selection", 58 | "muzero_action_selection", 59 | "muzero_policy", 60 | "qtransform_by_min_max", 61 | "qtransform_by_parent_and_siblings", 62 | "qtransform_completed_by_mix_value", 63 | "search", 64 | "stochastic_muzero_policy", 65 | ) 66 | 67 | # _________________________________________ 68 | # / Please don't use symbols in `_src` they \ 69 | # \ are not part of the Mctx public API. / 70 | # ----------------------------------------- 71 | # \ ^__^ 72 | # \ (oo)\_______ 73 | # (__)\ )\/\ 74 | # ||----w | 75 | # || || 76 | # 77 | -------------------------------------------------------------------------------- /mctx/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /mctx/_src/action_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A collection of action selection functions.""" 16 | from typing import Optional, TypeVar 17 | 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | from mctx._src import base 23 | from mctx._src import qtransforms 24 | from mctx._src import seq_halving 25 | from mctx._src import tree as tree_lib 26 | 27 | 28 | def switching_action_selection_wrapper( 29 | root_action_selection_fn: base.RootActionSelectionFn, 30 | interior_action_selection_fn: base.InteriorActionSelectionFn 31 | ) -> base.InteriorActionSelectionFn: 32 | """Wraps root and interior action selection fns in a conditional statement.""" 33 | 34 | def switching_action_selection_fn( 35 | rng_key: chex.PRNGKey, 36 | tree: tree_lib.Tree, 37 | node_index: base.NodeIndices, 38 | depth: base.Depth) -> chex.Array: 39 | return jax.lax.cond( 40 | depth == 0, 41 | lambda x: root_action_selection_fn(*x[:3]), 42 | lambda x: interior_action_selection_fn(*x), 43 | (rng_key, tree, node_index, depth)) 44 | 45 | return switching_action_selection_fn 46 | 47 | 48 | def muzero_action_selection( 49 | rng_key: chex.PRNGKey, 50 | tree: tree_lib.Tree, 51 | node_index: chex.Numeric, 52 | depth: chex.Numeric, 53 | *, 54 | pb_c_init: float = 1.25, 55 | pb_c_base: float = 19652.0, 56 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 57 | ) -> chex.Array: 58 | """Returns the action selected for a node index. 59 | 60 | See Appendix B in https://arxiv.org/pdf/1911.08265.pdf for more details. 61 | 62 | Args: 63 | rng_key: random number generator state. 64 | tree: _unbatched_ MCTS tree state. 65 | node_index: scalar index of the node from which to select an action. 66 | depth: the scalar depth of the current node. The root has depth zero. 67 | pb_c_init: constant c_1 in the PUCT formula. 68 | pb_c_base: constant c_2 in the PUCT formula. 69 | qtransform: a monotonic transformation to convert the Q-values to [0, 1]. 70 | 71 | Returns: 72 | action: the action selected from the given node. 73 | """ 74 | visit_counts = tree.children_visits[node_index] 75 | node_visit = tree.node_visits[node_index] 76 | pb_c = pb_c_init + jnp.log((node_visit + pb_c_base + 1.) / pb_c_base) 77 | prior_logits = tree.children_prior_logits[node_index] 78 | prior_probs = jax.nn.softmax(prior_logits) 79 | policy_score = jnp.sqrt(node_visit) * pb_c * prior_probs / (visit_counts + 1) 80 | chex.assert_shape([node_index, node_visit], ()) 81 | chex.assert_equal_shape([prior_probs, visit_counts, policy_score]) 82 | value_score = qtransform(tree, node_index) 83 | 84 | # Add tiny bit of randomness for tie break 85 | node_noise_score = 1e-7 * jax.random.uniform( 86 | rng_key, (tree.num_actions,)) 87 | to_argmax = value_score + policy_score + node_noise_score 88 | 89 | # Masking the invalid actions at the root. 90 | return masked_argmax(to_argmax, tree.root_invalid_actions * (depth == 0)) 91 | 92 | 93 | @chex.dataclass(frozen=True) 94 | class GumbelMuZeroExtraData: 95 | """Extra data for Gumbel MuZero search.""" 96 | root_gumbel: chex.Array 97 | 98 | 99 | GumbelMuZeroExtraDataType = TypeVar( # pylint: disable=invalid-name 100 | "GumbelMuZeroExtraDataType", bound=GumbelMuZeroExtraData) 101 | 102 | 103 | def gumbel_muzero_root_action_selection( 104 | rng_key: chex.PRNGKey, 105 | tree: tree_lib.Tree[GumbelMuZeroExtraDataType], 106 | node_index: chex.Numeric, 107 | *, 108 | num_simulations: chex.Numeric, 109 | max_num_considered_actions: chex.Numeric, 110 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value, 111 | ) -> chex.Array: 112 | """Returns the action selected by Sequential Halving with Gumbel. 113 | 114 | Initially, we sample `max_num_considered_actions` actions without replacement. 115 | From these, the actions with the highest `gumbel + logits + qvalues` are 116 | visited first. 117 | 118 | Args: 119 | rng_key: random number generator state. 120 | tree: _unbatched_ MCTS tree state. 121 | node_index: scalar index of the node from which to take an action. 122 | num_simulations: the simulation budget. 123 | max_num_considered_actions: the number of actions sampled without 124 | replacement. 125 | qtransform: a monotonic transformation for the Q-values. 126 | 127 | Returns: 128 | action: the action selected from the given node. 129 | """ 130 | del rng_key 131 | chex.assert_shape([node_index], ()) 132 | visit_counts = tree.children_visits[node_index] 133 | prior_logits = tree.children_prior_logits[node_index] 134 | chex.assert_equal_shape([visit_counts, prior_logits]) 135 | completed_qvalues = qtransform(tree, node_index) 136 | 137 | table = jnp.array(seq_halving.get_table_of_considered_visits( 138 | max_num_considered_actions, num_simulations)) 139 | num_valid_actions = jnp.sum( 140 | 1 - tree.root_invalid_actions, axis=-1).astype(jnp.int32) 141 | num_considered = jnp.minimum( 142 | max_num_considered_actions, num_valid_actions) 143 | chex.assert_shape(num_considered, ()) 144 | # At the root, the simulation_index is equal to the sum of visit counts. 145 | simulation_index = jnp.sum(visit_counts, -1) 146 | chex.assert_shape(simulation_index, ()) 147 | considered_visit = table[num_considered, simulation_index] 148 | chex.assert_shape(considered_visit, ()) 149 | gumbel = tree.extra_data.root_gumbel 150 | to_argmax = seq_halving.score_considered( 151 | considered_visit, gumbel, prior_logits, completed_qvalues, 152 | visit_counts) 153 | 154 | # Masking the invalid actions at the root. 155 | return masked_argmax(to_argmax, tree.root_invalid_actions) 156 | 157 | 158 | def gumbel_muzero_interior_action_selection( 159 | rng_key: chex.PRNGKey, 160 | tree: tree_lib.Tree, 161 | node_index: chex.Numeric, 162 | depth: chex.Numeric, 163 | *, 164 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value, 165 | ) -> chex.Array: 166 | """Selects the action with a deterministic action selection. 167 | 168 | The action is selected based on the visit counts to produce visitation 169 | frequencies similar to softmax(prior_logits + qvalues). 170 | 171 | Args: 172 | rng_key: random number generator state. 173 | tree: _unbatched_ MCTS tree state. 174 | node_index: scalar index of the node from which to take an action. 175 | depth: the scalar depth of the current node. The root has depth zero. 176 | qtransform: function to obtain completed Q-values for a node. 177 | 178 | Returns: 179 | action: the action selected from the given node. 180 | """ 181 | del rng_key, depth 182 | chex.assert_shape([node_index], ()) 183 | visit_counts = tree.children_visits[node_index] 184 | prior_logits = tree.children_prior_logits[node_index] 185 | chex.assert_equal_shape([visit_counts, prior_logits]) 186 | completed_qvalues = qtransform(tree, node_index) 187 | 188 | # The `prior_logits + completed_qvalues` provide an improved policy, 189 | # because the missing qvalues are replaced by v_{prior_logits}(node). 190 | to_argmax = _prepare_argmax_input( 191 | probs=jax.nn.softmax(prior_logits + completed_qvalues), 192 | visit_counts=visit_counts) 193 | 194 | chex.assert_rank(to_argmax, 1) 195 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32) 196 | 197 | 198 | def masked_argmax( 199 | to_argmax: chex.Array, 200 | invalid_actions: Optional[chex.Array]) -> chex.Array: 201 | """Returns a valid action with the highest `to_argmax`.""" 202 | if invalid_actions is not None: 203 | chex.assert_equal_shape([to_argmax, invalid_actions]) 204 | # The usage of the -inf inside the argmax does not lead to NaN. 205 | # Do not use -inf inside softmax, logsoftmax or cross-entropy. 206 | to_argmax = jnp.where(invalid_actions, -jnp.inf, to_argmax) 207 | # If all actions are invalid, the argmax returns action 0. 208 | return jnp.argmax(to_argmax, axis=-1).astype(jnp.int32) 209 | 210 | 211 | def _prepare_argmax_input(probs, visit_counts): 212 | """Prepares the input for the deterministic selection. 213 | 214 | When calling argmax(_prepare_argmax_input(...)) multiple times 215 | with updated visit_counts, the produced visitation frequencies will 216 | approximate the probs. 217 | 218 | For the derivation, see Section 5 "Planning at non-root nodes" in 219 | "Policy improvement by planning with Gumbel": 220 | https://openreview.net/forum?id=bERaNdoegnO 221 | 222 | Args: 223 | probs: a policy or an improved policy. Shape `[num_actions]`. 224 | visit_counts: the existing visit counts. Shape `[num_actions]`. 225 | 226 | Returns: 227 | The input to an argmax. Shape `[num_actions]`. 228 | """ 229 | chex.assert_equal_shape([probs, visit_counts]) 230 | to_argmax = probs - visit_counts / ( 231 | 1 + jnp.sum(visit_counts, keepdims=True, axis=-1)) 232 | return to_argmax 233 | -------------------------------------------------------------------------------- /mctx/_src/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Core types used in mctx.""" 16 | 17 | from typing import Any, Callable, Generic, TypeVar, Tuple 18 | 19 | import chex 20 | 21 | from mctx._src import tree 22 | 23 | 24 | # Parameters are arbitrary nested structures of `chex.Array`. 25 | # A nested structure is either a single object, or a collection (list, tuple, 26 | # dictionary, etc.) of other nested structures. 27 | Params = chex.ArrayTree 28 | 29 | 30 | # The model used to search is expressed by a `RecurrentFn` function that takes 31 | # `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` and 32 | # the new state embedding. 33 | @chex.dataclass(frozen=True) 34 | class RecurrentFnOutput: 35 | """The output of a `RecurrentFn`. 36 | 37 | reward: `[B]` an approximate reward from the state-action transition. 38 | discount: `[B]` the discount between the `reward` and the `value`. 39 | prior_logits: `[B, num_actions]` the logits produced by a policy network. 40 | value: `[B]` an approximate value of the state after the state-action 41 | transition. 42 | """ 43 | reward: chex.Array 44 | discount: chex.Array 45 | prior_logits: chex.Array 46 | value: chex.Array 47 | 48 | 49 | Action = chex.Array 50 | RecurrentState = Any 51 | RecurrentFn = Callable[ 52 | [Params, chex.PRNGKey, Action, RecurrentState], 53 | Tuple[RecurrentFnOutput, RecurrentState]] 54 | 55 | 56 | @chex.dataclass(frozen=True) 57 | class RootFnOutput: 58 | """The output of a representation network. 59 | 60 | prior_logits: `[B, num_actions]` the logits produced by a policy network. 61 | value: `[B]` an approximate value of the current state. 62 | embedding: `[B, ...]` the inputs to the next `recurrent_fn` call. 63 | """ 64 | prior_logits: chex.Array 65 | value: chex.Array 66 | embedding: RecurrentState 67 | 68 | 69 | # Action selection functions specify how to pick nodes to expand in the tree. 70 | NodeIndices = chex.Array 71 | Depth = chex.Array 72 | RootActionSelectionFn = Callable[ 73 | [chex.PRNGKey, tree.Tree, NodeIndices], chex.Array] 74 | InteriorActionSelectionFn = Callable[ 75 | [chex.PRNGKey, tree.Tree, NodeIndices, Depth], 76 | chex.Array] 77 | QTransform = Callable[[tree.Tree, chex.Array], chex.Array] 78 | # LoopFn has the same interface as jax.lax.fori_loop. 79 | LoopFn = Callable[ 80 | [int, int, Callable[[Any, Any], Any], Tuple[chex.PRNGKey, tree.Tree]], 81 | Tuple[chex.PRNGKey, tree.Tree]] 82 | 83 | T = TypeVar("T") 84 | 85 | 86 | @chex.dataclass(frozen=True) 87 | class PolicyOutput(Generic[T]): 88 | """The output of a policy. 89 | 90 | action: `[B]` the proposed action. 91 | action_weights: `[B, num_actions]` the targets used to train a policy network. 92 | The action weights sum to one. Usually, the policy network is trained by 93 | cross-entropy: 94 | `cross_entropy(labels=stop_gradient(action_weights), logits=prior_logits)`. 95 | search_tree: `[B, ...]` the search tree of the finished search. 96 | """ 97 | action: chex.Array 98 | action_weights: chex.Array 99 | search_tree: tree.Tree[T] 100 | 101 | 102 | @chex.dataclass(frozen=True) 103 | class DecisionRecurrentFnOutput: 104 | """Output of the function for expanding decision nodes. 105 | 106 | Expanding a decision node takes an action and a state embedding and produces 107 | an afterstate, which represents the state of the environment after an action 108 | is taken but before the environment has updated its state. Accordingly, there 109 | is no discount factor or reward for transitioning from state `s` to afterstate 110 | `sa`. 111 | 112 | Attributes: 113 | chance_logits: `[B, C]` logits of `C` chance outcomes at the afterstate. 114 | afterstate_value: `[B]` values of the afterstates `v(sa)`. 115 | """ 116 | chance_logits: chex.Array # [B, C] 117 | afterstate_value: chex.Array # [B] 118 | 119 | 120 | @chex.dataclass(frozen=True) 121 | class ChanceRecurrentFnOutput: 122 | """Output of the function for expanding chance nodes. 123 | 124 | Expanding a chance node takes a chance outcome and an afterstate embedding 125 | and produces a state, which captures a potentially stochastic environment 126 | transition. When this transition occurs reward and discounts are produced as 127 | in a normal transition. 128 | 129 | Attributes: 130 | action_logits: `[B, A]` logits of different actions from the state. 131 | value: `[B]` values of the states `v(s)`. 132 | reward: `[B]` rewards at the states. 133 | discount: `[B]` discounts at the states. 134 | """ 135 | action_logits: chex.Array # [B, A] 136 | value: chex.Array # [B] 137 | reward: chex.Array # [B] 138 | discount: chex.Array # [B] 139 | 140 | 141 | @chex.dataclass(frozen=True) 142 | class StochasticRecurrentState: 143 | """Wrapper that enables different treatment of decision and chance nodes. 144 | 145 | In Stochastic MuZero tree nodes can either be decision or chance nodes, these 146 | nodes are treated differently during expansion, search and backup, and a user 147 | could also pass differently structured embeddings for each type of node. This 148 | wrapper enables treating chance and decision nodes differently and supports 149 | potential differences between chance and decision node structures. 150 | 151 | Attributes: 152 | state_embedding: `[B ...]` an optionally meaningful state embedding. 153 | afterstate_embedding: `[B ...]` an optionally meaningful afterstate 154 | embedding. 155 | is_decision_node: `[B]` whether the node is a decision or chance node. If it 156 | is a decision node, `afterstate_embedding` is a dummy value. If it is a 157 | chance node, `state_embedding` is a dummy value. 158 | """ 159 | state_embedding: chex.ArrayTree # [B, ...] 160 | afterstate_embedding: chex.ArrayTree # [B, ...] 161 | is_decision_node: chex.Array # [B] 162 | 163 | 164 | RecurrentState = chex.ArrayTree 165 | 166 | DecisionRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState], 167 | Tuple[DecisionRecurrentFnOutput, RecurrentState]] 168 | 169 | ChanceRecurrentFn = Callable[[Params, chex.PRNGKey, Action, RecurrentState], 170 | Tuple[ChanceRecurrentFnOutput, RecurrentState]] 171 | -------------------------------------------------------------------------------- /mctx/_src/policies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Search policies.""" 16 | import functools 17 | from typing import Optional, Tuple 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from mctx._src import action_selection 24 | from mctx._src import base 25 | from mctx._src import qtransforms 26 | from mctx._src import search 27 | from mctx._src import seq_halving 28 | 29 | 30 | def muzero_policy( 31 | params: base.Params, 32 | rng_key: chex.PRNGKey, 33 | root: base.RootFnOutput, 34 | recurrent_fn: base.RecurrentFn, 35 | num_simulations: int, 36 | invalid_actions: Optional[chex.Array] = None, 37 | max_depth: Optional[int] = None, 38 | loop_fn: base.LoopFn = jax.lax.fori_loop, 39 | *, 40 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 41 | dirichlet_fraction: chex.Numeric = 0.25, 42 | dirichlet_alpha: chex.Numeric = 0.3, 43 | pb_c_init: chex.Numeric = 1.25, 44 | pb_c_base: chex.Numeric = 19652, 45 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]: 46 | """Runs MuZero search and returns the `PolicyOutput`. 47 | 48 | In the shape descriptions, `B` denotes the batch dimension. 49 | 50 | Args: 51 | params: params to be forwarded to root and recurrent functions. 52 | rng_key: random number generator state, the key is consumed. 53 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 54 | `prior_logits` are from a policy network. The shapes are 55 | `([B, num_actions], [B], [B, ...])`, respectively. 56 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 57 | actions retrieved by the simulation step, which takes as args 58 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 59 | and the new state embedding. The `rng_key` argument is consumed. 60 | num_simulations: the number of simulations. 61 | invalid_actions: a mask with invalid actions. Invalid actions 62 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. 63 | max_depth: maximum search tree depth allowed during simulation. 64 | loop_fn: Function used to run the simulations. It may be required to pass 65 | hk.fori_loop if using this function inside a Haiku module. 66 | qtransform: function to obtain completed Q-values for a node. 67 | dirichlet_fraction: float from 0 to 1 interpolating between using only the 68 | prior policy or just the Dirichlet noise. 69 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet 70 | distribution. 71 | pb_c_init: constant c_1 in the PUCT formula. 72 | pb_c_base: constant c_2 in the PUCT formula. 73 | temperature: temperature for acting proportionally to 74 | `visit_counts**(1 / temperature)`. 75 | 76 | Returns: 77 | `PolicyOutput` containing the proposed action, action_weights and the used 78 | search tree. 79 | """ 80 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3) 81 | 82 | # Adding Dirichlet noise. 83 | noisy_logits = _get_logits_from_probs( 84 | _add_dirichlet_noise( 85 | dirichlet_rng_key, 86 | jax.nn.softmax(root.prior_logits), 87 | dirichlet_fraction=dirichlet_fraction, 88 | dirichlet_alpha=dirichlet_alpha)) 89 | root = root.replace( 90 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)) 91 | 92 | # Running the search. 93 | interior_action_selection_fn = functools.partial( 94 | action_selection.muzero_action_selection, 95 | pb_c_base=pb_c_base, 96 | pb_c_init=pb_c_init, 97 | qtransform=qtransform) 98 | root_action_selection_fn = functools.partial( 99 | interior_action_selection_fn, 100 | depth=0) 101 | search_tree = search.search( 102 | params=params, 103 | rng_key=search_rng_key, 104 | root=root, 105 | recurrent_fn=recurrent_fn, 106 | root_action_selection_fn=root_action_selection_fn, 107 | interior_action_selection_fn=interior_action_selection_fn, 108 | num_simulations=num_simulations, 109 | max_depth=max_depth, 110 | invalid_actions=invalid_actions, 111 | loop_fn=loop_fn) 112 | 113 | # Sampling the proposed action proportionally to the visit counts. 114 | summary = search_tree.summary() 115 | action_weights = summary.visit_probs 116 | action_logits = _apply_temperature( 117 | _get_logits_from_probs(action_weights), temperature) 118 | action = jax.random.categorical(rng_key, action_logits) 119 | return base.PolicyOutput( 120 | action=action, 121 | action_weights=action_weights, 122 | search_tree=search_tree) 123 | 124 | 125 | def gumbel_muzero_policy( 126 | params: base.Params, 127 | rng_key: chex.PRNGKey, 128 | root: base.RootFnOutput, 129 | recurrent_fn: base.RecurrentFn, 130 | num_simulations: int, 131 | invalid_actions: Optional[chex.Array] = None, 132 | max_depth: Optional[int] = None, 133 | loop_fn: base.LoopFn = jax.lax.fori_loop, 134 | *, 135 | qtransform: base.QTransform = qtransforms.qtransform_completed_by_mix_value, 136 | max_num_considered_actions: int = 16, 137 | gumbel_scale: chex.Numeric = 1., 138 | ) -> base.PolicyOutput[action_selection.GumbelMuZeroExtraData]: 139 | """Runs Gumbel MuZero search and returns the `PolicyOutput`. 140 | 141 | This policy implements Full Gumbel MuZero from 142 | "Policy improvement by planning with Gumbel". 143 | https://openreview.net/forum?id=bERaNdoegnO 144 | 145 | At the root of the search tree, actions are selected by Sequential Halving 146 | with Gumbel. At non-root nodes (aka interior nodes), actions are selected by 147 | the Full Gumbel MuZero deterministic action selection. 148 | 149 | In the shape descriptions, `B` denotes the batch dimension. 150 | 151 | Args: 152 | params: params to be forwarded to root and recurrent functions. 153 | rng_key: random number generator state, the key is consumed. 154 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 155 | `prior_logits` are from a policy network. The shapes are 156 | `([B, num_actions], [B], [B, ...])`, respectively. 157 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 158 | actions retrieved by the simulation step, which takes as args 159 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 160 | and the new state embedding. The `rng_key` argument is consumed. 161 | num_simulations: the number of simulations. 162 | invalid_actions: a mask with invalid actions. Invalid actions 163 | have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. 164 | max_depth: maximum search tree depth allowed during simulation. 165 | loop_fn: Function used to run the simulations. It may be required to pass 166 | hk.fori_loop if using this function inside a Haiku module. 167 | qtransform: function to obtain completed Q-values for a node. 168 | max_num_considered_actions: the maximum number of actions expanded at the 169 | root node. A smaller number of actions will be expanded if the number of 170 | valid actions is smaller. 171 | gumbel_scale: scale for the Gumbel noise. Evalution on perfect-information 172 | games can use gumbel_scale=0.0. 173 | 174 | Returns: 175 | `PolicyOutput` containing the proposed action, action_weights and the used 176 | search tree. 177 | """ 178 | # Masking invalid actions. 179 | root = root.replace( 180 | prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions)) 181 | 182 | # Generating Gumbel. 183 | rng_key, gumbel_rng = jax.random.split(rng_key) 184 | gumbel = gumbel_scale * jax.random.gumbel( 185 | gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype) 186 | 187 | # Searching. 188 | extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel) 189 | search_tree = search.search( 190 | params=params, 191 | rng_key=rng_key, 192 | root=root, 193 | recurrent_fn=recurrent_fn, 194 | root_action_selection_fn=functools.partial( 195 | action_selection.gumbel_muzero_root_action_selection, 196 | num_simulations=num_simulations, 197 | max_num_considered_actions=max_num_considered_actions, 198 | qtransform=qtransform, 199 | ), 200 | interior_action_selection_fn=functools.partial( 201 | action_selection.gumbel_muzero_interior_action_selection, 202 | qtransform=qtransform, 203 | ), 204 | num_simulations=num_simulations, 205 | max_depth=max_depth, 206 | invalid_actions=invalid_actions, 207 | extra_data=extra_data, 208 | loop_fn=loop_fn) 209 | summary = search_tree.summary() 210 | 211 | # Acting with the best action from the most visited actions. 212 | # The "best" action has the highest `gumbel + logits + q`. 213 | # Inside the minibatch, the considered_visit can be different on states with 214 | # a smaller number of valid actions. 215 | considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True) 216 | # The completed_qvalues include imputed values for unvisited actions. 217 | completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])( # pytype: disable=wrong-arg-types # numpy-scalars # pylint: disable=line-too-long 218 | search_tree, search_tree.ROOT_INDEX) 219 | to_argmax = seq_halving.score_considered( 220 | considered_visit, gumbel, root.prior_logits, completed_qvalues, 221 | summary.visit_counts) 222 | action = action_selection.masked_argmax(to_argmax, invalid_actions) 223 | 224 | # Producing action_weights usable to train the policy network. 225 | completed_search_logits = _mask_invalid_actions( 226 | root.prior_logits + completed_qvalues, invalid_actions) 227 | action_weights = jax.nn.softmax(completed_search_logits) 228 | return base.PolicyOutput( 229 | action=action, 230 | action_weights=action_weights, 231 | search_tree=search_tree) 232 | 233 | 234 | def stochastic_muzero_policy( 235 | params: chex.ArrayTree, 236 | rng_key: chex.PRNGKey, 237 | root: base.RootFnOutput, 238 | decision_recurrent_fn: base.DecisionRecurrentFn, 239 | chance_recurrent_fn: base.ChanceRecurrentFn, 240 | num_simulations: int, 241 | invalid_actions: Optional[chex.Array] = None, 242 | max_depth: Optional[int] = None, 243 | loop_fn: base.LoopFn = jax.lax.fori_loop, 244 | *, 245 | qtransform: base.QTransform = qtransforms.qtransform_by_parent_and_siblings, 246 | dirichlet_fraction: chex.Numeric = 0.25, 247 | dirichlet_alpha: chex.Numeric = 0.3, 248 | pb_c_init: chex.Numeric = 1.25, 249 | pb_c_base: chex.Numeric = 19652, 250 | temperature: chex.Numeric = 1.0) -> base.PolicyOutput[None]: 251 | """Runs Stochastic MuZero search. 252 | 253 | Implements search as described in the Stochastic MuZero paper: 254 | (https://openreview.net/forum?id=X6D9bAHhBQ1). 255 | 256 | In the shape descriptions, `B` denotes the batch dimension. 257 | Args: 258 | params: params to be forwarded to root and recurrent functions. 259 | rng_key: random number generator state, the key is consumed. 260 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 261 | `prior_logits` are from a policy network. The shapes are `([B, 262 | num_actions], [B], [B, ...])`, respectively. 263 | decision_recurrent_fn: a callable to be called on the leaf decision nodes 264 | and unvisited actions retrieved by the simulation step, which takes as 265 | args `(params, rng_key, action, state_embedding)` and returns a 266 | `(DecisionRecurrentFnOutput, afterstate_embedding)`. 267 | chance_recurrent_fn: a callable to be called on the leaf chance nodes and 268 | unvisited actions retrieved by the simulation step, which takes as args 269 | `(params, rng_key, chance_outcome, afterstate_embedding)` and returns a 270 | `(ChanceRecurrentFnOutput, state_embedding)`. 271 | num_simulations: the number of simulations. 272 | invalid_actions: a mask with invalid actions. Invalid actions have ones, 273 | valid actions have zeros in the mask. Shape `[B, num_actions]`. 274 | max_depth: maximum search tree depth allowed during simulation. 275 | loop_fn: Function used to run the simulations. It may be required to pass 276 | hk.fori_loop if using this function inside a Haiku module. 277 | qtransform: function to obtain completed Q-values for a node. 278 | dirichlet_fraction: float from 0 to 1 interpolating between using only the 279 | prior policy or just the Dirichlet noise. 280 | dirichlet_alpha: concentration parameter to parametrize the Dirichlet 281 | distribution. 282 | pb_c_init: constant c_1 in the PUCT formula. 283 | pb_c_base: constant c_2 in the PUCT formula. 284 | temperature: temperature for acting proportionally to `visit_counts**(1 / 285 | temperature)`. 286 | 287 | Returns: 288 | `PolicyOutput` containing the proposed action, action_weights and the used 289 | search tree. 290 | """ 291 | 292 | num_actions = root.prior_logits.shape[-1] 293 | 294 | rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3) 295 | 296 | # Adding Dirichlet noise. 297 | noisy_logits = _get_logits_from_probs( 298 | _add_dirichlet_noise( 299 | dirichlet_rng_key, 300 | jax.nn.softmax(root.prior_logits), 301 | dirichlet_fraction=dirichlet_fraction, 302 | dirichlet_alpha=dirichlet_alpha)) 303 | 304 | root = root.replace( 305 | prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions)) 306 | 307 | # construct a dummy afterstate embedding 308 | batch_size = jax.tree_util.tree_leaves(root.embedding)[0].shape[0] 309 | dummy_action = jnp.zeros([batch_size], dtype=jnp.int32) 310 | dummy_output, dummy_afterstate_embedding = decision_recurrent_fn( 311 | params, rng_key, dummy_action, root.embedding) 312 | num_chance_outcomes = dummy_output.chance_logits.shape[-1] 313 | 314 | root = root.replace( 315 | # pad action logits with num_chance_outcomes so dim is A + C 316 | prior_logits=jnp.concatenate([ 317 | root.prior_logits, 318 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf) 319 | ], axis=-1), 320 | # replace embedding with wrapper. 321 | embedding=base.StochasticRecurrentState( 322 | state_embedding=root.embedding, 323 | afterstate_embedding=dummy_afterstate_embedding, 324 | is_decision_node=jnp.ones([batch_size], dtype=bool))) 325 | 326 | # Stochastic MuZero Change: We need to be able to tell if different nodes are 327 | # decision or chance. This is accomplished by imposing a special structure 328 | # on the embeddings stored in each node. Each embedding is an instance of 329 | # StochasticRecurrentState which maintains this information. 330 | recurrent_fn = _make_stochastic_recurrent_fn( 331 | decision_node_fn=decision_recurrent_fn, 332 | chance_node_fn=chance_recurrent_fn, 333 | num_actions=num_actions, 334 | num_chance_outcomes=num_chance_outcomes, 335 | ) 336 | 337 | # Running the search. 338 | 339 | interior_decision_node_selection_fn = functools.partial( 340 | action_selection.muzero_action_selection, 341 | pb_c_base=pb_c_base, 342 | pb_c_init=pb_c_init, 343 | qtransform=qtransform) 344 | 345 | interior_action_selection_fn = _make_stochastic_action_selection_fn( 346 | interior_decision_node_selection_fn, num_actions) 347 | 348 | root_action_selection_fn = functools.partial( 349 | interior_action_selection_fn, depth=0) 350 | 351 | search_tree = search.search( 352 | params=params, 353 | rng_key=search_rng_key, 354 | root=root, 355 | recurrent_fn=recurrent_fn, 356 | root_action_selection_fn=root_action_selection_fn, 357 | interior_action_selection_fn=interior_action_selection_fn, 358 | num_simulations=num_simulations, 359 | max_depth=max_depth, 360 | invalid_actions=invalid_actions, 361 | loop_fn=loop_fn) 362 | 363 | # Sampling the proposed action proportionally to the visit counts. 364 | search_tree = _mask_tree(search_tree, num_actions, 'decision') 365 | summary = search_tree.summary() 366 | action_weights = summary.visit_probs 367 | action_logits = _apply_temperature( 368 | _get_logits_from_probs(action_weights), temperature) 369 | action = jax.random.categorical(rng_key, action_logits) 370 | return base.PolicyOutput( 371 | action=action, action_weights=action_weights, search_tree=search_tree) 372 | 373 | 374 | def _mask_invalid_actions(logits, invalid_actions): 375 | """Returns logits with zero mass to invalid actions.""" 376 | if invalid_actions is None: 377 | return logits 378 | chex.assert_equal_shape([logits, invalid_actions]) 379 | logits = logits - jnp.max(logits, axis=-1, keepdims=True) 380 | # At the end of an episode, all actions can be invalid. A softmax would then 381 | # produce NaNs, if using -inf for the logits. We avoid the NaNs by using 382 | # a finite `min_logit` for the invalid actions. 383 | min_logit = jnp.finfo(logits.dtype).min 384 | return jnp.where(invalid_actions, min_logit, logits) 385 | 386 | 387 | def _get_logits_from_probs(probs): 388 | tiny = jnp.finfo(probs.dtype).tiny 389 | return jnp.log(jnp.maximum(probs, tiny)) 390 | 391 | 392 | def _add_dirichlet_noise(rng_key, probs, *, dirichlet_alpha, 393 | dirichlet_fraction): 394 | """Mixes the probs with Dirichlet noise.""" 395 | chex.assert_rank(probs, 2) 396 | chex.assert_type([dirichlet_alpha, dirichlet_fraction], float) 397 | 398 | batch_size, num_actions = probs.shape 399 | noise = jax.random.dirichlet( 400 | rng_key, 401 | alpha=jnp.full([num_actions], fill_value=dirichlet_alpha), 402 | shape=(batch_size,)) 403 | noisy_probs = (1 - dirichlet_fraction) * probs + dirichlet_fraction * noise 404 | return noisy_probs 405 | 406 | 407 | def _apply_temperature(logits, temperature): 408 | """Returns `logits / temperature`, supporting also temperature=0.""" 409 | # The max subtraction prevents +inf after dividing by a small temperature. 410 | logits = logits - jnp.max(logits, keepdims=True, axis=-1) 411 | tiny = jnp.finfo(logits.dtype).tiny 412 | return logits / jnp.maximum(tiny, temperature) 413 | 414 | 415 | def _make_stochastic_recurrent_fn( 416 | decision_node_fn: base.DecisionRecurrentFn, 417 | chance_node_fn: base.ChanceRecurrentFn, 418 | num_actions: int, 419 | num_chance_outcomes: int, 420 | ) -> base.RecurrentFn: 421 | """Make Stochastic Recurrent Fn.""" 422 | 423 | def stochastic_recurrent_fn( 424 | params: base.Params, 425 | rng: chex.PRNGKey, 426 | action_or_chance: base.Action, # [B] 427 | state: base.StochasticRecurrentState 428 | ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentState]: 429 | batch_size = jax.tree_util.tree_leaves(state.state_embedding)[0].shape[0] 430 | # Internally we assume that there are `A' = A + C` "actions"; 431 | # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,. 432 | # To interpret it as an action we can leave it as is: 433 | action = action_or_chance - 0 434 | # To interpret it as a chance outcome we subtract num_actions: 435 | chance_outcome = action_or_chance - num_actions 436 | 437 | decision_output, afterstate_embedding = decision_node_fn( 438 | params, rng, action, state.state_embedding) 439 | # Outputs from DecisionRecurrentFunction produce chance logits with 440 | # dim `C`, to respect our internal convention that there are `A' = A + C` 441 | # "actions" we pad with `A` dummy logits which are ultimately ignored: 442 | # see `_mask_tree`. 443 | output_if_decision_node = base.RecurrentFnOutput( 444 | prior_logits=jnp.concatenate([ 445 | jnp.full([batch_size, num_actions], fill_value=-jnp.inf), 446 | decision_output.chance_logits], axis=-1), 447 | value=decision_output.afterstate_value, 448 | reward=jnp.zeros_like(decision_output.afterstate_value), 449 | discount=jnp.ones_like(decision_output.afterstate_value)) 450 | 451 | chance_output, state_embedding = chance_node_fn(params, rng, chance_outcome, 452 | state.afterstate_embedding) 453 | # Outputs from ChanceRecurrentFunction produce action logits with dim `A`, 454 | # to respect our internal convention that there are `A' = A + C` "actions" 455 | # we pad with `C` dummy logits which are ultimately ignored: see 456 | # `_mask_tree`. 457 | output_if_chance_node = base.RecurrentFnOutput( 458 | prior_logits=jnp.concatenate([ 459 | chance_output.action_logits, 460 | jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf) 461 | ], axis=-1), 462 | value=chance_output.value, 463 | reward=chance_output.reward, 464 | discount=chance_output.discount) 465 | 466 | new_state = base.StochasticRecurrentState( 467 | state_embedding=state_embedding, 468 | afterstate_embedding=afterstate_embedding, 469 | is_decision_node=jnp.logical_not(state.is_decision_node)) 470 | 471 | def _broadcast_where(decision_leaf, chance_leaf): 472 | extra_dims = [1] * (len(decision_leaf.shape) - 1) 473 | expanded_is_decision = jnp.reshape(state.is_decision_node, 474 | [-1] + extra_dims) 475 | return jnp.where( 476 | # ensure state.is_decision node has appropriate shape. 477 | expanded_is_decision, 478 | decision_leaf, chance_leaf) 479 | 480 | output = jax.tree.map(_broadcast_where, 481 | output_if_decision_node, 482 | output_if_chance_node) 483 | return output, new_state 484 | 485 | return stochastic_recurrent_fn 486 | 487 | 488 | def _mask_tree(tree: search.Tree, num_actions: int, mode: str) -> search.Tree: 489 | """Masks out parts of the tree based upon node type. 490 | 491 | "Actions" in our tree can either be action or chance values: A' = A + C. This 492 | utility function masks the parts of the tree containing dimensions of shape 493 | A' to be either A or C depending upon `mode`. 494 | 495 | Args: 496 | tree: The tree to be masked. 497 | num_actions: The number of environment actions A. 498 | mode: Either "decision" or "chance". 499 | 500 | Returns: 501 | An appropriately masked tree. 502 | """ 503 | 504 | def _take_slice(x): 505 | if mode == 'decision': 506 | return x[..., :num_actions] 507 | elif mode == 'chance': 508 | return x[..., num_actions:] 509 | else: 510 | raise ValueError(f'Unknown mode: {mode}.') 511 | 512 | return tree.replace( 513 | children_index=_take_slice(tree.children_index), 514 | children_prior_logits=_take_slice(tree.children_prior_logits), 515 | children_visits=_take_slice(tree.children_visits), 516 | children_rewards=_take_slice(tree.children_rewards), 517 | children_discounts=_take_slice(tree.children_discounts), 518 | children_values=_take_slice(tree.children_values), 519 | root_invalid_actions=_take_slice(tree.root_invalid_actions)) 520 | 521 | 522 | def _make_stochastic_action_selection_fn( 523 | decision_node_selection_fn: base.InteriorActionSelectionFn, 524 | num_actions: int, 525 | ) -> base.InteriorActionSelectionFn: 526 | """Make Stochastic Action Selection Fn.""" 527 | 528 | # NOTE: trees are unbatched here. 529 | 530 | def _chance_node_selection_fn( 531 | tree: search.Tree, 532 | node_index: chex.Array, 533 | ) -> chex.Array: 534 | num_chance = tree.children_visits[node_index] 535 | chance_logits = tree.children_prior_logits[node_index] 536 | prob_chance = jax.nn.softmax(chance_logits) 537 | argmax_chance = jnp.argmax(prob_chance / (num_chance + 1), axis=-1).astype( 538 | jnp.int32 539 | ) 540 | return argmax_chance 541 | 542 | def _action_selection_fn(key: chex.PRNGKey, tree: search.Tree, 543 | node_index: chex.Array, 544 | depth: chex.Array) -> chex.Array: 545 | is_decision = tree.embeddings.is_decision_node[node_index] 546 | chance_selection = _chance_node_selection_fn( 547 | tree=_mask_tree(tree, num_actions, 'chance'), 548 | node_index=node_index) + num_actions 549 | decision_selection = decision_node_selection_fn( 550 | key, _mask_tree(tree, num_actions, 'decision'), node_index, depth) 551 | return jax.lax.cond(is_decision, lambda: decision_selection, 552 | lambda: chance_selection) 553 | 554 | return _action_selection_fn 555 | -------------------------------------------------------------------------------- /mctx/_src/qtransforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Monotonic transformations for the Q-values.""" 16 | 17 | import chex 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from mctx._src import tree as tree_lib 22 | 23 | 24 | def qtransform_by_min_max( 25 | tree: tree_lib.Tree, 26 | node_index: chex.Numeric, 27 | *, 28 | min_value: chex.Numeric, 29 | max_value: chex.Numeric, 30 | ) -> chex.Array: 31 | """Returns Q-values normalized by the given `min_value` and `max_value`. 32 | 33 | Args: 34 | tree: _unbatched_ MCTS tree state. 35 | node_index: scalar index of the parent node. 36 | min_value: given minimum value. Usually the `min_value` is minimum possible 37 | untransformed Q-value. 38 | max_value: given maximum value. Usually the `max_value` is maximum possible 39 | untransformed Q-value. 40 | 41 | Returns: 42 | Q-values normalized by `(qvalues - min_value) / (max_value - min_value)`. 43 | The unvisited actions will have zero Q-value. Shape `[num_actions]`. 44 | """ 45 | chex.assert_shape(node_index, ()) 46 | qvalues = tree.qvalues(node_index) 47 | visit_counts = tree.children_visits[node_index] 48 | value_score = jnp.where(visit_counts > 0, qvalues, min_value) 49 | value_score = (value_score - min_value) / ((max_value - min_value)) 50 | return value_score 51 | 52 | 53 | def qtransform_by_parent_and_siblings( 54 | tree: tree_lib.Tree, 55 | node_index: chex.Numeric, 56 | *, 57 | epsilon: chex.Numeric = 1e-8, 58 | ) -> chex.Array: 59 | """Returns qvalues normalized by min, max over V(node) and qvalues. 60 | 61 | Args: 62 | tree: _unbatched_ MCTS tree state. 63 | node_index: scalar index of the parent node. 64 | epsilon: the minimum denominator for the normalization. 65 | 66 | Returns: 67 | Q-values normalized to be from the [0, 1] interval. The unvisited actions 68 | will have zero Q-value. Shape `[num_actions]`. 69 | """ 70 | chex.assert_shape(node_index, ()) 71 | qvalues = tree.qvalues(node_index) 72 | visit_counts = tree.children_visits[node_index] 73 | chex.assert_rank([qvalues, visit_counts, node_index], [1, 1, 0]) 74 | node_value = tree.node_values[node_index] 75 | safe_qvalues = jnp.where(visit_counts > 0, qvalues, node_value) 76 | chex.assert_equal_shape([safe_qvalues, qvalues]) 77 | min_value = jnp.minimum(node_value, jnp.min(safe_qvalues, axis=-1)) 78 | max_value = jnp.maximum(node_value, jnp.max(safe_qvalues, axis=-1)) 79 | 80 | completed_by_min = jnp.where(visit_counts > 0, qvalues, min_value) 81 | normalized = (completed_by_min - min_value) / ( 82 | jnp.maximum(max_value - min_value, epsilon)) 83 | chex.assert_equal_shape([normalized, qvalues]) 84 | return normalized 85 | 86 | 87 | def qtransform_completed_by_mix_value( 88 | tree: tree_lib.Tree, 89 | node_index: chex.Numeric, 90 | *, 91 | value_scale: chex.Numeric = 0.1, 92 | maxvisit_init: chex.Numeric = 50.0, 93 | rescale_values: bool = True, 94 | use_mixed_value: bool = True, 95 | epsilon: chex.Numeric = 1e-8, 96 | ) -> chex.Array: 97 | """Returns completed qvalues. 98 | 99 | The missing Q-values of the unvisited actions are replaced by the 100 | mixed value, defined in Appendix D of 101 | "Policy improvement by planning with Gumbel": 102 | https://openreview.net/forum?id=bERaNdoegnO 103 | 104 | The Q-values are transformed by a linear transformation: 105 | `(maxvisit_init + max(visit_counts)) * value_scale * qvalues`. 106 | 107 | Args: 108 | tree: _unbatched_ MCTS tree state. 109 | node_index: scalar index of the parent node. 110 | value_scale: scale for the Q-values. 111 | maxvisit_init: offset to the `max(visit_counts)` in the scaling factor. 112 | rescale_values: if True, scale the qvalues by `1 / (max_q - min_q)`. 113 | use_mixed_value: if True, complete the Q-values with mixed value, 114 | otherwise complete the Q-values with the raw value. 115 | epsilon: the minimum denominator when using `rescale_values`. 116 | 117 | Returns: 118 | Completed Q-values. Shape `[num_actions]`. 119 | """ 120 | chex.assert_shape(node_index, ()) 121 | qvalues = tree.qvalues(node_index) 122 | visit_counts = tree.children_visits[node_index] 123 | 124 | # Computing the mixed value and producing completed_qvalues. 125 | raw_value = tree.raw_values[node_index] 126 | prior_probs = jax.nn.softmax( 127 | tree.children_prior_logits[node_index]) 128 | if use_mixed_value: 129 | value = _compute_mixed_value( 130 | raw_value, 131 | qvalues=qvalues, 132 | visit_counts=visit_counts, 133 | prior_probs=prior_probs) 134 | else: 135 | value = raw_value 136 | completed_qvalues = _complete_qvalues( 137 | qvalues, visit_counts=visit_counts, value=value) 138 | 139 | # Scaling the Q-values. 140 | if rescale_values: 141 | completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon) 142 | maxvisit = jnp.max(visit_counts, axis=-1) 143 | visit_scale = maxvisit_init + maxvisit 144 | return visit_scale * value_scale * completed_qvalues 145 | 146 | 147 | def _rescale_qvalues(qvalues, epsilon): 148 | """Rescales the given completed Q-values to be from the [0, 1] interval.""" 149 | min_value = jnp.min(qvalues, axis=-1, keepdims=True) 150 | max_value = jnp.max(qvalues, axis=-1, keepdims=True) 151 | return (qvalues - min_value) / jnp.maximum(max_value - min_value, epsilon) 152 | 153 | 154 | def _complete_qvalues(qvalues, *, visit_counts, value): 155 | """Returns completed Q-values, with the `value` for unvisited actions.""" 156 | chex.assert_equal_shape([qvalues, visit_counts]) 157 | chex.assert_shape(value, []) 158 | 159 | # The missing qvalues are replaced by the value. 160 | completed_qvalues = jnp.where( 161 | visit_counts > 0, 162 | qvalues, 163 | value) 164 | chex.assert_equal_shape([completed_qvalues, qvalues]) 165 | return completed_qvalues 166 | 167 | 168 | def _compute_mixed_value(raw_value, qvalues, visit_counts, prior_probs): 169 | """Interpolates the raw_value and weighted qvalues. 170 | 171 | Args: 172 | raw_value: an approximate value of the state. Shape `[]`. 173 | qvalues: Q-values for all actions. Shape `[num_actions]`. The unvisited 174 | actions have undefined Q-value. 175 | visit_counts: the visit counts for all actions. Shape `[num_actions]`. 176 | prior_probs: the action probabilities, produced by the policy network for 177 | each action. Shape `[num_actions]`. 178 | 179 | Returns: 180 | An estimator of the state value. Shape `[]`. 181 | """ 182 | sum_visit_counts = jnp.sum(visit_counts, axis=-1) 183 | # Ensuring non-nan weighted_q, even if the visited actions have zero 184 | # prior probability. 185 | prior_probs = jnp.maximum(jnp.finfo(prior_probs.dtype).tiny, prior_probs) 186 | # Summing the probabilities of the visited actions. 187 | sum_probs = jnp.sum(jnp.where(visit_counts > 0, prior_probs, 0.0), 188 | axis=-1) 189 | weighted_q = jnp.sum(jnp.where( 190 | visit_counts > 0, 191 | prior_probs * qvalues / jnp.where(visit_counts > 0, sum_probs, 1.0), 192 | 0.0), axis=-1) 193 | return (raw_value + sum_visit_counts * weighted_q) / (sum_visit_counts + 1) 194 | -------------------------------------------------------------------------------- /mctx/_src/search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A JAX implementation of batched MCTS.""" 16 | import functools 17 | from typing import Any, NamedTuple, Optional, Tuple, TypeVar 18 | 19 | import chex 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | from mctx._src import action_selection 24 | from mctx._src import base 25 | from mctx._src import tree as tree_lib 26 | 27 | Tree = tree_lib.Tree 28 | T = TypeVar("T") 29 | 30 | 31 | def search( 32 | params: base.Params, 33 | rng_key: chex.PRNGKey, 34 | *, 35 | root: base.RootFnOutput, 36 | recurrent_fn: base.RecurrentFn, 37 | root_action_selection_fn: base.RootActionSelectionFn, 38 | interior_action_selection_fn: base.InteriorActionSelectionFn, 39 | num_simulations: int, 40 | max_depth: Optional[int] = None, 41 | invalid_actions: Optional[chex.Array] = None, 42 | extra_data: Any = None, 43 | loop_fn: base.LoopFn = jax.lax.fori_loop) -> Tree: 44 | """Performs a full search and returns sampled actions. 45 | 46 | In the shape descriptions, `B` denotes the batch dimension. 47 | 48 | Args: 49 | params: params to be forwarded to root and recurrent functions. 50 | rng_key: random number generator state, the key is consumed. 51 | root: a `(prior_logits, value, embedding)` `RootFnOutput`. The 52 | `prior_logits` are from a policy network. The shapes are 53 | `([B, num_actions], [B], [B, ...])`, respectively. 54 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 55 | actions retrieved by the simulation step, which takes as args 56 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 57 | and the new state embedding. The `rng_key` argument is consumed. 58 | root_action_selection_fn: function used to select an action at the root. 59 | interior_action_selection_fn: function used to select an action during 60 | simulation. 61 | num_simulations: the number of simulations. 62 | max_depth: maximum search tree depth allowed during simulation, defined as 63 | the number of edges from the root to a leaf node. 64 | invalid_actions: a mask with invalid actions at the root. In the 65 | mask, invalid actions have ones, and valid actions have zeros. 66 | Shape `[B, num_actions]`. 67 | extra_data: extra data passed to `tree.extra_data`. Shape `[B, ...]`. 68 | loop_fn: Function used to run the simulations. It may be required to pass 69 | hk.fori_loop if using this function inside a Haiku module. 70 | 71 | Returns: 72 | `SearchResults` containing outcomes of the search, e.g. `visit_counts` 73 | `[B, num_actions]`. 74 | """ 75 | action_selection_fn = action_selection.switching_action_selection_wrapper( 76 | root_action_selection_fn=root_action_selection_fn, 77 | interior_action_selection_fn=interior_action_selection_fn 78 | ) 79 | 80 | # Do simulation, expansion, and backward steps. 81 | batch_size = root.value.shape[0] 82 | batch_range = jnp.arange(batch_size) 83 | if max_depth is None: 84 | max_depth = num_simulations 85 | if invalid_actions is None: 86 | invalid_actions = jnp.zeros_like(root.prior_logits) 87 | 88 | def body_fun(sim, loop_state): 89 | rng_key, tree = loop_state 90 | rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3) 91 | # simulate is vmapped and expects batched rng keys. 92 | simulate_keys = jax.random.split(simulate_key, batch_size) 93 | parent_index, action = simulate( 94 | simulate_keys, tree, action_selection_fn, max_depth) 95 | # A node first expanded on simulation `i`, will have node index `i`. 96 | # Node 0 corresponds to the root node. 97 | next_node_index = tree.children_index[batch_range, parent_index, action] 98 | next_node_index = jnp.where(next_node_index == Tree.UNVISITED, 99 | sim + 1, next_node_index) 100 | tree = expand( 101 | params, expand_key, tree, recurrent_fn, parent_index, 102 | action, next_node_index) 103 | tree = backward(tree, next_node_index) 104 | loop_state = rng_key, tree 105 | return loop_state 106 | 107 | # Allocate all necessary storage. 108 | tree = instantiate_tree_from_root(root, num_simulations, 109 | root_invalid_actions=invalid_actions, 110 | extra_data=extra_data) 111 | _, tree = loop_fn( 112 | 0, num_simulations, body_fun, (rng_key, tree)) 113 | 114 | return tree 115 | 116 | 117 | class _SimulationState(NamedTuple): 118 | """The state for the simulation while loop.""" 119 | rng_key: chex.PRNGKey 120 | node_index: int 121 | action: int 122 | next_node_index: int 123 | depth: int 124 | is_continuing: bool 125 | 126 | 127 | @functools.partial(jax.vmap, in_axes=[0, 0, None, None], out_axes=0) 128 | def simulate( 129 | rng_key: chex.PRNGKey, 130 | tree: Tree, 131 | action_selection_fn: base.InteriorActionSelectionFn, 132 | max_depth: int) -> Tuple[chex.Array, chex.Array]: 133 | """Traverses the tree until reaching an unvisited action or `max_depth`. 134 | 135 | Each simulation starts from the root and keeps selecting actions traversing 136 | the tree until a leaf or `max_depth` is reached. 137 | 138 | Args: 139 | rng_key: random number generator state, the key is consumed. 140 | tree: _unbatched_ MCTS tree state. 141 | action_selection_fn: function used to select an action during simulation. 142 | max_depth: maximum search tree depth allowed during simulation. 143 | 144 | Returns: 145 | `(parent_index, action)` tuple, where `parent_index` is the index of the 146 | node reached at the end of the simulation, and the `action` is the action to 147 | evaluate from the `parent_index`. 148 | """ 149 | def cond_fun(state): 150 | return state.is_continuing 151 | 152 | def body_fun(state): 153 | # Preparing the next simulation state. 154 | node_index = state.next_node_index 155 | rng_key, action_selection_key = jax.random.split(state.rng_key) 156 | action = action_selection_fn(action_selection_key, tree, node_index, 157 | state.depth) 158 | next_node_index = tree.children_index[node_index, action] 159 | # The returned action will be visited. 160 | depth = state.depth + 1 161 | is_before_depth_cutoff = depth < max_depth 162 | is_visited = next_node_index != Tree.UNVISITED 163 | is_continuing = jnp.logical_and(is_visited, is_before_depth_cutoff) 164 | return _SimulationState( # pytype: disable=wrong-arg-types # jax-types 165 | rng_key=rng_key, 166 | node_index=node_index, 167 | action=action, 168 | next_node_index=next_node_index, 169 | depth=depth, 170 | is_continuing=is_continuing) 171 | 172 | node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32) 173 | depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype) 174 | # pytype: disable=wrong-arg-types # jnp-type 175 | initial_state = _SimulationState( 176 | rng_key=rng_key, 177 | node_index=tree.NO_PARENT, 178 | action=tree.NO_PARENT, 179 | next_node_index=node_index, 180 | depth=depth, 181 | is_continuing=jnp.array(True)) 182 | # pytype: enable=wrong-arg-types 183 | end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state) 184 | 185 | # Returning a node with a selected action. 186 | # The action can be already visited, if the max_depth is reached. 187 | return end_state.node_index, end_state.action 188 | 189 | 190 | def expand( 191 | params: chex.Array, 192 | rng_key: chex.PRNGKey, 193 | tree: Tree[T], 194 | recurrent_fn: base.RecurrentFn, 195 | parent_index: chex.Array, 196 | action: chex.Array, 197 | next_node_index: chex.Array) -> Tree[T]: 198 | """Create and evaluate child nodes from given nodes and unvisited actions. 199 | 200 | Args: 201 | params: params to be forwarded to recurrent function. 202 | rng_key: random number generator state. 203 | tree: the MCTS tree state to update. 204 | recurrent_fn: a callable to be called on the leaf nodes and unvisited 205 | actions retrieved by the simulation step, which takes as args 206 | `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput` 207 | and the new state embedding. The `rng_key` argument is consumed. 208 | parent_index: the index of the parent node, from which the action will be 209 | expanded. Shape `[B]`. 210 | action: the action to expand. Shape `[B]`. 211 | next_node_index: the index of the newly expanded node. This can be the index 212 | of an existing node, if `max_depth` is reached. Shape `[B]`. 213 | 214 | Returns: 215 | tree: updated MCTS tree state. 216 | """ 217 | batch_size = tree_lib.infer_batch_size(tree) 218 | batch_range = jnp.arange(batch_size) 219 | chex.assert_shape([parent_index, action, next_node_index], (batch_size,)) 220 | 221 | # Retrieve states for nodes to be evaluated. 222 | embedding = jax.tree.map( 223 | lambda x: x[batch_range, parent_index], tree.embeddings) 224 | 225 | # Evaluate and create a new node. 226 | step, embedding = recurrent_fn(params, rng_key, action, embedding) 227 | chex.assert_shape(step.prior_logits, [batch_size, tree.num_actions]) 228 | chex.assert_shape(step.reward, [batch_size]) 229 | chex.assert_shape(step.discount, [batch_size]) 230 | chex.assert_shape(step.value, [batch_size]) 231 | tree = update_tree_node( 232 | tree, next_node_index, step.prior_logits, step.value, embedding) 233 | 234 | # Return updated tree topology. 235 | return tree.replace( 236 | children_index=batch_update( 237 | tree.children_index, next_node_index, parent_index, action), 238 | children_rewards=batch_update( 239 | tree.children_rewards, step.reward, parent_index, action), 240 | children_discounts=batch_update( 241 | tree.children_discounts, step.discount, parent_index, action), 242 | parents=batch_update(tree.parents, parent_index, next_node_index), 243 | action_from_parent=batch_update( 244 | tree.action_from_parent, action, next_node_index)) 245 | 246 | 247 | @jax.vmap 248 | def backward( 249 | tree: Tree[T], 250 | leaf_index: chex.Numeric) -> Tree[T]: 251 | """Goes up and updates the tree until all nodes reached the root. 252 | 253 | Args: 254 | tree: the MCTS tree state to update, without the batch size. 255 | leaf_index: the node index from which to do the backward. 256 | 257 | Returns: 258 | Updated MCTS tree state. 259 | """ 260 | 261 | def cond_fun(loop_state): 262 | _, _, index = loop_state 263 | return index != Tree.ROOT_INDEX 264 | 265 | def body_fun(loop_state): 266 | # Here we update the value of our parent, so we start by reversing. 267 | tree, leaf_value, index = loop_state 268 | parent = tree.parents[index] 269 | count = tree.node_visits[parent] 270 | action = tree.action_from_parent[index] 271 | reward = tree.children_rewards[parent, action] 272 | leaf_value = reward + tree.children_discounts[parent, action] * leaf_value 273 | parent_value = ( 274 | tree.node_values[parent] * count + leaf_value) / (count + 1.0) 275 | children_values = tree.node_values[index] 276 | children_counts = tree.children_visits[parent, action] + 1 277 | 278 | tree = tree.replace( 279 | node_values=update(tree.node_values, parent_value, parent), 280 | node_visits=update(tree.node_visits, count + 1, parent), 281 | children_values=update( 282 | tree.children_values, children_values, parent, action), 283 | children_visits=update( 284 | tree.children_visits, children_counts, parent, action)) 285 | 286 | return tree, leaf_value, parent 287 | 288 | leaf_index = jnp.asarray(leaf_index, dtype=jnp.int32) 289 | loop_state = (tree, tree.node_values[leaf_index], leaf_index) 290 | tree, _, _ = jax.lax.while_loop(cond_fun, body_fun, loop_state) 291 | 292 | return tree 293 | 294 | 295 | # Utility function to set the values of certain indices to prescribed values. 296 | # This is vmapped to operate seamlessly on batches. 297 | def update(x, vals, *indices): 298 | return x.at[indices].set(vals) 299 | 300 | 301 | batch_update = jax.vmap(update) 302 | 303 | 304 | def update_tree_node( 305 | tree: Tree[T], 306 | node_index: chex.Array, 307 | prior_logits: chex.Array, 308 | value: chex.Array, 309 | embedding: chex.Array) -> Tree[T]: 310 | """Updates the tree at node index. 311 | 312 | Args: 313 | tree: `Tree` to whose node is to be updated. 314 | node_index: the index of the expanded node. Shape `[B]`. 315 | prior_logits: the prior logits to fill in for the new node, of shape 316 | `[B, num_actions]`. 317 | value: the value to fill in for the new node. Shape `[B]`. 318 | embedding: the state embeddings for the node. Shape `[B, ...]`. 319 | 320 | Returns: 321 | The new tree with updated nodes. 322 | """ 323 | batch_size = tree_lib.infer_batch_size(tree) 324 | batch_range = jnp.arange(batch_size) 325 | chex.assert_shape(prior_logits, (batch_size, tree.num_actions)) 326 | 327 | # When using max_depth, a leaf can be expanded multiple times. 328 | new_visit = tree.node_visits[batch_range, node_index] + 1 329 | updates = dict( # pylint: disable=use-dict-literal 330 | children_prior_logits=batch_update( 331 | tree.children_prior_logits, prior_logits, node_index), 332 | raw_values=batch_update( 333 | tree.raw_values, value, node_index), 334 | node_values=batch_update( 335 | tree.node_values, value, node_index), 336 | node_visits=batch_update( 337 | tree.node_visits, new_visit, node_index), 338 | embeddings=jax.tree.map( 339 | lambda t, s: batch_update(t, s, node_index), 340 | tree.embeddings, embedding)) 341 | 342 | return tree.replace(**updates) 343 | 344 | 345 | def instantiate_tree_from_root( 346 | root: base.RootFnOutput, 347 | num_simulations: int, 348 | root_invalid_actions: chex.Array, 349 | extra_data: Any) -> Tree: 350 | """Initializes tree state at search root.""" 351 | chex.assert_rank(root.prior_logits, 2) 352 | batch_size, num_actions = root.prior_logits.shape 353 | chex.assert_shape(root.value, [batch_size]) 354 | num_nodes = num_simulations + 1 355 | data_dtype = root.value.dtype 356 | batch_node = (batch_size, num_nodes) 357 | batch_node_action = (batch_size, num_nodes, num_actions) 358 | 359 | def _zeros(x): 360 | return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype) 361 | 362 | # Create a new empty tree state and fill its root. 363 | tree = Tree( 364 | node_visits=jnp.zeros(batch_node, dtype=jnp.int32), 365 | raw_values=jnp.zeros(batch_node, dtype=data_dtype), 366 | node_values=jnp.zeros(batch_node, dtype=data_dtype), 367 | parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32), 368 | action_from_parent=jnp.full( 369 | batch_node, Tree.NO_PARENT, dtype=jnp.int32), 370 | children_index=jnp.full( 371 | batch_node_action, Tree.UNVISITED, dtype=jnp.int32), 372 | children_prior_logits=jnp.zeros( 373 | batch_node_action, dtype=root.prior_logits.dtype), 374 | children_values=jnp.zeros(batch_node_action, dtype=data_dtype), 375 | children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32), 376 | children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype), 377 | children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype), 378 | embeddings=jax.tree.map(_zeros, root.embedding), 379 | root_invalid_actions=root_invalid_actions, 380 | extra_data=extra_data) 381 | 382 | root_index = jnp.full([batch_size], Tree.ROOT_INDEX) 383 | tree = update_tree_node( 384 | tree, root_index, root.prior_logits, root.value, root.embedding) 385 | return tree 386 | -------------------------------------------------------------------------------- /mctx/_src/seq_halving.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions for Sequential Halving.""" 16 | 17 | import math 18 | 19 | import chex 20 | import jax.numpy as jnp 21 | 22 | 23 | def score_considered(considered_visit, gumbel, logits, normalized_qvalues, 24 | visit_counts): 25 | """Returns a score usable for an argmax.""" 26 | # We allow to visit a child, if it is the only considered child. 27 | low_logit = -1e9 28 | logits = logits - jnp.max(logits, keepdims=True, axis=-1) 29 | penalty = jnp.where( 30 | visit_counts == considered_visit, 31 | 0, -jnp.inf) 32 | chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty]) 33 | return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty 34 | 35 | 36 | def get_sequence_of_considered_visits(max_num_considered_actions, 37 | num_simulations): 38 | """Returns a sequence of visit counts considered by Sequential Halving. 39 | 40 | Sequential Halving is a "pure exploration" algorithm for bandits, introduced 41 | in "Almost Optimal Exploration in Multi-Armed Bandits": 42 | http://proceedings.mlr.press/v28/karnin13.pdf 43 | 44 | The visit counts allows to implement Sequential Halving by selecting the best 45 | action from the actions with the currently considered visit count. 46 | 47 | Args: 48 | max_num_considered_actions: The maximum number of considered actions. 49 | The `max_num_considered_actions` can be smaller than the number of 50 | actions. 51 | num_simulations: The total simulation budget. 52 | 53 | Returns: 54 | A tuple with visit counts. Length `num_simulations`. 55 | """ 56 | if max_num_considered_actions <= 1: 57 | return tuple(range(num_simulations)) 58 | log2max = int(math.ceil(math.log2(max_num_considered_actions))) 59 | sequence = [] 60 | visits = [0] * max_num_considered_actions 61 | num_considered = max_num_considered_actions 62 | while len(sequence) < num_simulations: 63 | num_extra_visits = max(1, int(num_simulations / (log2max * num_considered))) 64 | for _ in range(num_extra_visits): 65 | sequence.extend(visits[:num_considered]) 66 | for i in range(num_considered): 67 | visits[i] += 1 68 | # Halving the number of considered actions. 69 | num_considered = max(2, num_considered // 2) 70 | return tuple(sequence[:num_simulations]) 71 | 72 | 73 | def get_table_of_considered_visits(max_num_considered_actions, num_simulations): 74 | """Returns a table of sequences of visit counts. 75 | 76 | Args: 77 | max_num_considered_actions: The maximum number of considered actions. 78 | The `max_num_considered_actions` can be smaller than the number of 79 | actions. 80 | num_simulations: The total simulation budget. 81 | 82 | Returns: 83 | A tuple of sequences of visit counts. 84 | Shape [max_num_considered_actions + 1, num_simulations]. 85 | """ 86 | return tuple( 87 | get_sequence_of_considered_visits(m, num_simulations) 88 | for m in range(max_num_considered_actions + 1)) 89 | 90 | -------------------------------------------------------------------------------- /mctx/_src/tests/mctx_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for Mctx.""" 16 | 17 | from absl.testing import absltest 18 | import mctx 19 | 20 | 21 | class MctxTest(absltest.TestCase): 22 | """Test mctx can be imported correctly.""" 23 | 24 | def test_import(self): 25 | self.assertTrue(hasattr(mctx, "gumbel_muzero_policy")) 26 | self.assertTrue(hasattr(mctx, "muzero_policy")) 27 | self.assertTrue(hasattr(mctx, "qtransform_by_min_max")) 28 | self.assertTrue(hasattr(mctx, "qtransform_by_parent_and_siblings")) 29 | self.assertTrue(hasattr(mctx, "qtransform_completed_by_mix_value")) 30 | self.assertTrue(hasattr(mctx, "PolicyOutput")) 31 | self.assertTrue(hasattr(mctx, "RootFnOutput")) 32 | self.assertTrue(hasattr(mctx, "RecurrentFnOutput")) 33 | 34 | 35 | if __name__ == "__main__": 36 | absltest.main() 37 | -------------------------------------------------------------------------------- /mctx/_src/tests/policies_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `policies.py`.""" 16 | import functools 17 | 18 | from absl.testing import absltest 19 | import jax 20 | import jax.numpy as jnp 21 | import mctx 22 | from mctx._src import policies 23 | import numpy as np 24 | 25 | jax.config.update("jax_threefry_partitionable", False) 26 | 27 | 28 | def _make_bandit_recurrent_fn(rewards, dummy_embedding=()): 29 | """Returns a recurrent_fn with discount=0.""" 30 | 31 | def recurrent_fn(params, rng_key, action, embedding): 32 | del params, rng_key, embedding 33 | reward = rewards[jnp.arange(action.shape[0]), action] 34 | return mctx.RecurrentFnOutput( 35 | reward=reward, 36 | discount=jnp.zeros_like(reward), 37 | prior_logits=jnp.zeros_like(rewards), 38 | value=jnp.zeros_like(reward), 39 | ), dummy_embedding 40 | 41 | return recurrent_fn 42 | 43 | 44 | def _make_bandit_decision_and_chance_fns(rewards, num_chance_outcomes): 45 | 46 | def decision_recurrent_fn(params, rng_key, action, embedding): 47 | del params, rng_key 48 | batch_size = action.shape[0] 49 | reward = rewards[jnp.arange(batch_size), action] 50 | dummy_chance_logits = jnp.full([batch_size, num_chance_outcomes], 51 | -jnp.inf).at[:, 0].set(1.0) 52 | afterstate_embedding = (action, embedding) 53 | return mctx.DecisionRecurrentFnOutput( 54 | chance_logits=dummy_chance_logits, 55 | afterstate_value=jnp.zeros_like(reward)), afterstate_embedding 56 | 57 | def chance_recurrent_fn(params, rng_key, chance_outcome, 58 | afterstate_embedding): 59 | del params, rng_key, chance_outcome 60 | afterstate_action, embedding = afterstate_embedding 61 | batch_size = afterstate_action.shape[0] 62 | 63 | reward = rewards[jnp.arange(batch_size), afterstate_action] 64 | return mctx.ChanceRecurrentFnOutput( 65 | action_logits=jnp.zeros_like(rewards), 66 | value=jnp.zeros_like(reward), 67 | discount=jnp.zeros_like(reward), 68 | reward=reward), embedding 69 | 70 | return decision_recurrent_fn, chance_recurrent_fn 71 | 72 | 73 | def _get_deepest_leaf(tree, node_index): 74 | """Returns `(leaf, depth)` with maximum depth and visit count. 75 | 76 | Args: 77 | tree: _unbatched_ MCTS tree state. 78 | node_index: the node of the inspected subtree. 79 | 80 | Returns: 81 | `(leaf, depth)` of a deepest leaf. If multiple leaves have the same depth, 82 | the leaf with the highest visit count is returned. 83 | """ 84 | np.testing.assert_equal(len(tree.children_index.shape), 2) 85 | leaf = node_index 86 | max_found_depth = 0 87 | for action in range(tree.children_index.shape[-1]): 88 | next_node_index = tree.children_index[node_index, action] 89 | if next_node_index != tree.UNVISITED: 90 | found_leaf, found_depth = _get_deepest_leaf(tree, next_node_index) 91 | if ((1 + found_depth, tree.node_visits[found_leaf]) > 92 | (max_found_depth, tree.node_visits[leaf])): 93 | leaf = found_leaf 94 | max_found_depth = 1 + found_depth 95 | return leaf, max_found_depth 96 | 97 | 98 | class PoliciesTest(absltest.TestCase): 99 | 100 | def test_apply_temperature_one(self): 101 | """Tests temperature=1.""" 102 | logits = jnp.arange(6, dtype=jnp.float32) 103 | new_logits = policies._apply_temperature(logits, temperature=1.0) 104 | np.testing.assert_allclose(logits - logits.max(), new_logits) 105 | 106 | def test_apply_temperature_two(self): 107 | """Tests temperature=2.""" 108 | logits = jnp.arange(6, dtype=jnp.float32) 109 | temperature = 2.0 110 | new_logits = policies._apply_temperature(logits, temperature) 111 | np.testing.assert_allclose((logits - logits.max()) / temperature, 112 | new_logits) 113 | 114 | def test_apply_temperature_zero(self): 115 | """Tests temperature=0.""" 116 | logits = jnp.arange(4, dtype=jnp.float32) 117 | new_logits = policies._apply_temperature(logits, temperature=0.0) 118 | np.testing.assert_allclose( 119 | jnp.array([-2.552118e+38, -1.701412e+38, -8.507059e+37, 0.0]), 120 | new_logits, 121 | rtol=1e-3) 122 | 123 | def test_apply_temperature_zero_on_large_logits(self): 124 | """Tests temperature=0 on large logits.""" 125 | logits = jnp.array([100.0, 3.4028235e+38, -jnp.inf, -3.4028235e+38]) 126 | new_logits = policies._apply_temperature(logits, temperature=0.0) 127 | np.testing.assert_allclose( 128 | jnp.array([-jnp.inf, 0.0, -jnp.inf, -jnp.inf]), new_logits) 129 | 130 | def test_mask_invalid_actions(self): 131 | """Tests action masking.""" 132 | logits = jnp.array([1e6, -jnp.inf, 1e6 + 1, -100.0]) 133 | invalid_actions = jnp.array([0.0, 1.0, 0.0, 1.0]) 134 | masked_logits = policies._mask_invalid_actions( 135 | logits, invalid_actions) 136 | valid_probs = jax.nn.softmax(jnp.array([0.0, 1.0])) 137 | np.testing.assert_allclose( 138 | jnp.array([valid_probs[0], 0.0, valid_probs[1], 0.0]), 139 | jax.nn.softmax(masked_logits)) 140 | 141 | def test_mask_all_invalid_actions(self): 142 | """Tests a state with no valid action.""" 143 | logits = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, -jnp.inf]) 144 | invalid_actions = jnp.array([1.0, 1.0, 1.0, 1.0]) 145 | masked_logits = policies._mask_invalid_actions( 146 | logits, invalid_actions) 147 | np.testing.assert_allclose( 148 | jnp.array([0.25, 0.25, 0.25, 0.25]), 149 | jax.nn.softmax(masked_logits)) 150 | 151 | def test_muzero_policy(self): 152 | root = mctx.RootFnOutput( 153 | prior_logits=jnp.array([ 154 | [-1.0, 0.0, 2.0, 3.0], 155 | ]), 156 | value=jnp.array([0.0]), 157 | embedding=(), 158 | ) 159 | rewards = jnp.zeros_like(root.prior_logits) 160 | invalid_actions = jnp.array([ 161 | [0.0, 0.0, 0.0, 1.0], 162 | ]) 163 | 164 | policy_output = mctx.muzero_policy( 165 | params=(), 166 | rng_key=jax.random.PRNGKey(0), 167 | root=root, 168 | recurrent_fn=_make_bandit_recurrent_fn(rewards), 169 | num_simulations=1, 170 | invalid_actions=invalid_actions, 171 | dirichlet_fraction=0.0) 172 | expected_action = jnp.array([2], dtype=jnp.int32) 173 | np.testing.assert_array_equal(expected_action, policy_output.action) 174 | expected_action_weights = jnp.array([ 175 | [0.0, 0.0, 1.0, 0.0], 176 | ]) 177 | np.testing.assert_allclose(expected_action_weights, 178 | policy_output.action_weights) 179 | 180 | def test_gumbel_muzero_policy(self): 181 | root_value = jnp.array([-5.0]) 182 | root = mctx.RootFnOutput( 183 | prior_logits=jnp.array([ 184 | [0.0, -1.0, 2.0, 3.0], 185 | ]), 186 | value=root_value, 187 | embedding=(), 188 | ) 189 | rewards = jnp.array([ 190 | [20.0, 3.0, -1.0, 10.0], 191 | ]) 192 | invalid_actions = jnp.array([ 193 | [1.0, 0.0, 0.0, 1.0], 194 | ]) 195 | 196 | value_scale = 0.05 197 | maxvisit_init = 60 198 | num_simulations = 17 199 | max_depth = 3 200 | qtransform = functools.partial( 201 | mctx.qtransform_completed_by_mix_value, 202 | value_scale=value_scale, 203 | maxvisit_init=maxvisit_init, 204 | rescale_values=True) 205 | policy_output = mctx.gumbel_muzero_policy( 206 | params=(), 207 | rng_key=jax.random.PRNGKey(0), 208 | root=root, 209 | recurrent_fn=_make_bandit_recurrent_fn(rewards), 210 | num_simulations=num_simulations, 211 | invalid_actions=invalid_actions, 212 | max_depth=max_depth, 213 | qtransform=qtransform, 214 | gumbel_scale=1.0) 215 | # Testing the action. 216 | expected_action = jnp.array([1], dtype=jnp.int32) 217 | np.testing.assert_array_equal(expected_action, policy_output.action) 218 | 219 | # Testing the action_weights. 220 | probs = jax.nn.softmax(jnp.where( 221 | invalid_actions, -jnp.inf, root.prior_logits)) 222 | mix_value = 1.0 / (num_simulations + 1) * (root_value + num_simulations * ( 223 | probs[:, 1] * rewards[:, 1] + probs[:, 2] * rewards[:, 2])) 224 | 225 | completed_qvalues = jnp.array([ 226 | [mix_value[0], rewards[0, 1], rewards[0, 2], mix_value[0]], 227 | ]) 228 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True) 229 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True) 230 | total_value_scale = (maxvisit_init + np.ceil(num_simulations / 2) 231 | ) * value_scale 232 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / ( 233 | max_value - min_value) 234 | expected_action_weights = jax.nn.softmax( 235 | jnp.where(invalid_actions, 236 | -jnp.inf, 237 | root.prior_logits + rescaled_qvalues)) 238 | np.testing.assert_allclose(expected_action_weights, 239 | policy_output.action_weights, 240 | atol=1e-6) 241 | 242 | # Testing the visit_counts. 243 | summary = policy_output.search_tree.summary() 244 | expected_visit_counts = jnp.array( 245 | [[0.0, np.ceil(num_simulations / 2), num_simulations // 2, 0.0]]) 246 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts) 247 | 248 | # Testing max_depth. 249 | leaf, max_found_depth = _get_deepest_leaf( 250 | jax.tree.map(lambda x: x[0], policy_output.search_tree), 251 | policy_output.search_tree.ROOT_INDEX) 252 | self.assertEqual(max_depth, max_found_depth) 253 | self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf]) 254 | 255 | def test_gumbel_muzero_policy_without_invalid_actions(self): 256 | root_value = jnp.array([-5.0]) 257 | root = mctx.RootFnOutput( 258 | prior_logits=jnp.array([ 259 | [0.0, -1.0, 2.0, 3.0], 260 | ]), 261 | value=root_value, 262 | embedding=(), 263 | ) 264 | rewards = jnp.array([ 265 | [20.0, 3.0, -1.0, 10.0], 266 | ]) 267 | 268 | value_scale = 0.05 269 | maxvisit_init = 60 270 | num_simulations = 17 271 | max_depth = 3 272 | qtransform = functools.partial( 273 | mctx.qtransform_completed_by_mix_value, 274 | value_scale=value_scale, 275 | maxvisit_init=maxvisit_init, 276 | rescale_values=True) 277 | policy_output = mctx.gumbel_muzero_policy( 278 | params=(), 279 | rng_key=jax.random.PRNGKey(0), 280 | root=root, 281 | recurrent_fn=_make_bandit_recurrent_fn(rewards), 282 | num_simulations=num_simulations, 283 | invalid_actions=None, 284 | max_depth=max_depth, 285 | qtransform=qtransform, 286 | gumbel_scale=1.0) 287 | # Testing the action. 288 | expected_action = jnp.array([3], dtype=jnp.int32) 289 | np.testing.assert_array_equal(expected_action, policy_output.action) 290 | 291 | # Testing the action_weights. 292 | summary = policy_output.search_tree.summary() 293 | completed_qvalues = rewards 294 | max_value = jnp.max(completed_qvalues, axis=-1, keepdims=True) 295 | min_value = jnp.min(completed_qvalues, axis=-1, keepdims=True) 296 | total_value_scale = (maxvisit_init + summary.visit_counts.max() 297 | ) * value_scale 298 | rescaled_qvalues = total_value_scale * (completed_qvalues - min_value) / ( 299 | max_value - min_value) 300 | expected_action_weights = jax.nn.softmax( 301 | root.prior_logits + rescaled_qvalues) 302 | np.testing.assert_allclose(expected_action_weights, 303 | policy_output.action_weights, 304 | atol=1e-6) 305 | 306 | # Testing the visit_counts. 307 | expected_visit_counts = jnp.array( 308 | [[6, 2, 2, 7]]) 309 | np.testing.assert_array_equal(expected_visit_counts, summary.visit_counts) 310 | 311 | def test_stochastic_muzero_policy(self): 312 | """Tests that SMZ is equivalent to MZ with a dummy chance function.""" 313 | root = mctx.RootFnOutput( 314 | prior_logits=jnp.array([ 315 | [-1.0, 0.0, 2.0, 3.0], 316 | [0.0, 2.0, 5.0, -4.0], 317 | ]), 318 | value=jnp.array([1.0, 0.0]), 319 | embedding=jnp.zeros([2, 4]) 320 | ) 321 | rewards = jnp.zeros_like(root.prior_logits) 322 | invalid_actions = jnp.array([ 323 | [0.0, 0.0, 0.0, 1.0], 324 | [1.0, 0.0, 1.0, 0.0], 325 | ]) 326 | 327 | num_simulations = 10 328 | 329 | policy_output = mctx.muzero_policy( 330 | params=(), 331 | rng_key=jax.random.PRNGKey(0), 332 | root=root, 333 | recurrent_fn=_make_bandit_recurrent_fn( 334 | rewards, 335 | dummy_embedding=jnp.zeros_like(root.embedding)), 336 | num_simulations=num_simulations, 337 | invalid_actions=invalid_actions, 338 | dirichlet_fraction=0.0) 339 | 340 | num_chance_outcomes = 5 341 | 342 | decision_rec_fn, chance_rec_fn = _make_bandit_decision_and_chance_fns( 343 | rewards, num_chance_outcomes) 344 | 345 | stochastic_policy_output = mctx.stochastic_muzero_policy( 346 | params=(), 347 | rng_key=jax.random.PRNGKey(0), 348 | root=root, 349 | decision_recurrent_fn=decision_rec_fn, 350 | chance_recurrent_fn=chance_rec_fn, 351 | num_simulations=2 * num_simulations, 352 | invalid_actions=invalid_actions, 353 | dirichlet_fraction=0.0) 354 | 355 | np.testing.assert_array_equal(stochastic_policy_output.action, 356 | policy_output.action) 357 | 358 | np.testing.assert_allclose(stochastic_policy_output.action_weights, 359 | policy_output.action_weights) 360 | 361 | 362 | if __name__ == "__main__": 363 | absltest.main() 364 | -------------------------------------------------------------------------------- /mctx/_src/tests/qtransforms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `qtransforms.py`.""" 16 | from absl.testing import absltest 17 | import jax 18 | import jax.numpy as jnp 19 | from mctx._src import qtransforms 20 | import numpy as np 21 | 22 | 23 | class QtransformsTest(absltest.TestCase): 24 | 25 | def test_mix_value(self): 26 | """Tests the output of _compute_mixed_value().""" 27 | raw_value = jnp.array(-0.8) 28 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf]) 29 | probs = jax.nn.softmax(prior_logits) 30 | visit_counts = jnp.array([0, 4.0, 4.0, 0]) 31 | qvalues = 10.0 / 54 * jnp.array([20.0, 3.0, -1.0, 10.0]) 32 | mix_value = qtransforms._compute_mixed_value( 33 | raw_value, qvalues, visit_counts, probs) 34 | 35 | num_simulations = jnp.sum(visit_counts) 36 | expected_mix_value = 1.0 / (num_simulations + 1) * ( 37 | raw_value + num_simulations * 38 | (probs[1] * qvalues[1] + probs[2] * qvalues[2])) 39 | np.testing.assert_allclose(expected_mix_value, mix_value) 40 | 41 | def test_mix_value_with_zero_visits(self): 42 | """Tests that zero visit counts do not divide by zero.""" 43 | raw_value = jnp.array(-0.8) 44 | prior_logits = jnp.array([-jnp.inf, -1.0, 2.0, -jnp.inf]) 45 | probs = jax.nn.softmax(prior_logits) 46 | visit_counts = jnp.array([0, 0, 0, 0]) 47 | qvalues = jnp.zeros_like(probs) 48 | with jax.debug_nans(): 49 | mix_value = qtransforms._compute_mixed_value( 50 | raw_value, qvalues, visit_counts, probs) 51 | 52 | np.testing.assert_allclose(raw_value, mix_value) 53 | 54 | 55 | if __name__ == "__main__": 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /mctx/_src/tests/seq_halving_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for `seq_halving.py`.""" 16 | from absl.testing import absltest 17 | from mctx._src import seq_halving 18 | 19 | 20 | class SeqHalvingTest(absltest.TestCase): 21 | 22 | def _check_visits(self, expected_results, max_num_considered_actions, 23 | num_simulations): 24 | """Compares the expected results to the returned considered visits.""" 25 | self.assertLen(expected_results, num_simulations) 26 | results = seq_halving.get_sequence_of_considered_visits( 27 | max_num_considered_actions, num_simulations) 28 | self.assertEqual(tuple(expected_results), results) 29 | 30 | def test_considered_min_sims(self): 31 | # Using exactly `num_simulations = max_num_considered_actions * 32 | # log2(max_num_considered_actions)`. 33 | num_sims = 24 34 | max_num_considered = 8 35 | expected_results = [ 36 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions. 37 | 1, 1, 1, 1, # Considering 4 actions. 38 | 2, 2, 2, 2, # Considering 4 actions, round 2. 39 | 3, 3, 4, 4, 5, 5, 6, 6, # Considering 2 actions. 40 | ] # pyformat: disable 41 | self._check_visits(expected_results, max_num_considered, num_sims) 42 | 43 | def test_considered_extra_sims(self): 44 | # Using more simulations than `max_num_considered_actions * 45 | # log2(max_num_considered_actions)`. 46 | num_sims = 47 47 | max_num_considered = 8 48 | expected_results = [ 49 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions. 50 | 1, 1, 1, 1, # Considering 4 actions. 51 | 2, 2, 2, 2, # Considering 4 actions, round 2. 52 | 3, 3, 3, 3, # Considering 4 actions, round 3. 53 | 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 54 | 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 55 | ] # pyformat: disable 56 | self._check_visits(expected_results, max_num_considered, num_sims) 57 | 58 | def test_considered_less_sims(self): 59 | # Using a very small number of simulations. 60 | num_sims = 2 61 | max_num_considered = 8 62 | expected_results = [0, 0] 63 | self._check_visits(expected_results, max_num_considered, num_sims) 64 | 65 | def test_considered_less_sims2(self): 66 | # Using `num_simulations < max_num_considered_actions * 67 | # log2(max_num_considered_actions)`. 68 | num_sims = 13 69 | max_num_considered = 8 70 | expected_results = [ 71 | 0, 0, 0, 0, 0, 0, 0, 0, # Considering 8 actions. 72 | 1, 1, 1, 1, # Considering 4 actions. 73 | 2, 74 | ] # pyformat: disable 75 | self._check_visits(expected_results, max_num_considered, num_sims) 76 | 77 | def test_considered_not_power_of_2(self): 78 | # Using max_num_considered_actions that is not a power of 2. 79 | num_sims = 24 80 | max_num_considered = 7 81 | expected_results = [ 82 | 0, 0, 0, 0, 0, 0, 0, # Considering 7 actions. 83 | 1, 1, 1, 2, 2, 2, # Considering 3 actions. 84 | 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 85 | ] # pyformat: disable 86 | self._check_visits(expected_results, max_num_considered, num_sims) 87 | 88 | def test_considered_action0(self): 89 | num_sims = 16 90 | max_num_considered = 0 91 | expected_results = range(num_sims) 92 | self._check_visits(expected_results, max_num_considered, num_sims) 93 | 94 | def test_considered_action1(self): 95 | num_sims = 16 96 | max_num_considered = 1 97 | expected_results = range(num_sims) 98 | self._check_visits(expected_results, max_num_considered, num_sims) 99 | 100 | 101 | if __name__ == "__main__": 102 | absltest.main() 103 | -------------------------------------------------------------------------------- /mctx/_src/tests/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A unit test comparing the search tree to an expected search tree.""" 16 | # pylint: disable=use-dict-literal 17 | import functools 18 | import json 19 | 20 | from absl import logging 21 | from absl.testing import absltest 22 | from absl.testing import parameterized 23 | import chex 24 | import jax 25 | import jax.numpy as jnp 26 | import mctx 27 | import numpy as np 28 | 29 | jax.config.update("jax_threefry_partitionable", False) 30 | 31 | 32 | def _prepare_root(batch_size, num_actions): 33 | """Returns a root consistent with the stored expected trees.""" 34 | rng_key = jax.random.PRNGKey(0) 35 | # Using a different rng_key inside each batch element. 36 | rng_keys = [rng_key] 37 | for i in range(1, batch_size): 38 | rng_keys.append(jax.random.fold_in(rng_key, i)) 39 | embedding = jnp.stack(rng_keys) 40 | output = jax.vmap( 41 | functools.partial(_produce_prediction_output, num_actions=num_actions))( 42 | embedding) 43 | return mctx.RootFnOutput( 44 | prior_logits=output["policy_logits"], 45 | value=output["value"], 46 | embedding=embedding, 47 | ) 48 | 49 | 50 | def _produce_prediction_output(rng_key, num_actions): 51 | """Producing the model output as in the stored expected trees.""" 52 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3) 53 | policy_rng, value_rng, reward_rng = jax.random.split(rng_key, 3) 54 | del rng_key 55 | # Producing value from [-1, +1). 56 | value = jax.random.uniform(value_rng, shape=(), minval=-1.0, maxval=1.0) 57 | # Producing reward from [-1, +1). 58 | reward = jax.random.uniform(reward_rng, shape=(), minval=-1.0, maxval=1.0) 59 | return dict( 60 | policy_logits=jax.random.normal(policy_rng, shape=[num_actions]), 61 | value=value, 62 | reward=reward, 63 | ) 64 | 65 | 66 | def _prepare_recurrent_fn(num_actions, *, discount, zero_reward): 67 | """Returns a dynamics function consistent with the expected trees.""" 68 | 69 | def recurrent_fn(params, rng_key, action, embedding): 70 | del params, rng_key 71 | # The embeddings serve as rng_keys. 72 | embedding = jax.vmap( 73 | functools.partial(_fold_action_in, num_actions=num_actions))(embedding, 74 | action) 75 | output = jax.vmap( 76 | functools.partial(_produce_prediction_output, num_actions=num_actions))( 77 | embedding) 78 | reward = output["reward"] 79 | if zero_reward: 80 | reward = jnp.zeros_like(reward) 81 | return mctx.RecurrentFnOutput( 82 | reward=reward, 83 | discount=jnp.full_like(reward, discount), 84 | prior_logits=output["policy_logits"], 85 | value=output["value"], 86 | ), embedding 87 | 88 | return recurrent_fn 89 | 90 | 91 | def _fold_action_in(rng_key, action, num_actions): 92 | """Returns a new rng key, selected by the given action.""" 93 | chex.assert_shape(action, ()) 94 | chex.assert_type(action, jnp.int32) 95 | sub_rngs = jax.random.split(rng_key, num_actions) 96 | return sub_rngs[action] 97 | 98 | 99 | def tree_to_pytree(tree: mctx.Tree, batch_i: int = 0): 100 | """Converts the MCTS tree to nested dicts.""" 101 | nodes = {} 102 | nodes[0] = _create_pynode( 103 | tree, batch_i, 0, prior=1.0, action=None, reward=None) 104 | children_prior_probs = jax.nn.softmax(tree.children_prior_logits, axis=-1) 105 | for node_i in range(tree.num_simulations + 1): 106 | for a_i in range(tree.num_actions): 107 | prior = children_prior_probs[batch_i, node_i, a_i] 108 | # Index of children, or -1 if not expanded 109 | child_i = int(tree.children_index[batch_i, node_i, a_i]) 110 | if child_i >= 0: 111 | reward = tree.children_rewards[batch_i, node_i, a_i] 112 | child = _create_pynode( 113 | tree, batch_i, child_i, prior=prior, action=a_i, reward=reward) 114 | nodes[child_i] = child 115 | else: 116 | child = _create_bare_pynode(prior=prior, action=a_i) 117 | # pylint: disable=line-too-long 118 | nodes[node_i]["child_stats"].append(child) # pytype: disable=attribute-error 119 | # pylint: enable=line-too-long 120 | return nodes[0] 121 | 122 | 123 | def _create_pynode(tree, batch_i, node_i, prior, action, reward): 124 | """Returns a dict with extracted search statistics.""" 125 | node = dict( 126 | prior=_round_float(prior), 127 | visit=int(tree.node_visits[batch_i, node_i]), 128 | value_view=_round_float(tree.node_values[batch_i, node_i]), 129 | raw_value_view=_round_float(tree.raw_values[batch_i, node_i]), 130 | child_stats=[], 131 | evaluation_index=node_i, 132 | ) 133 | if action is not None: 134 | node["action"] = action 135 | if reward is not None: 136 | node["reward"] = _round_float(reward) 137 | return node 138 | 139 | 140 | def _create_bare_pynode(prior, action): 141 | return dict( 142 | prior=_round_float(prior), 143 | child_stats=[], 144 | action=action, 145 | ) 146 | 147 | 148 | def _round_float(value, ndigits=10): 149 | return round(float(value), ndigits) 150 | 151 | 152 | class TreeTest(parameterized.TestCase): 153 | 154 | # Make sure to adjust the `shard_count` parameter in the build file to match 155 | # the number of parameter configurations passed to test_tree. 156 | # pylint: disable=line-too-long 157 | @parameterized.named_parameters( 158 | ("muzero_norescale", 159 | "../mctx/_src/tests/test_data/muzero_tree.json"), 160 | ("muzero_qtransform", 161 | "../mctx/_src/tests/test_data/muzero_qtransform_tree.json"), 162 | ("gumbel_muzero_norescale", 163 | "../mctx/_src/tests/test_data/gumbel_muzero_tree.json"), 164 | ("gumbel_muzero_reward", 165 | "../mctx/_src/tests/test_data/gumbel_muzero_reward_tree.json")) 166 | # pylint: enable=line-too-long 167 | def test_tree(self, tree_data_path): 168 | with open(tree_data_path, "rb") as fd: 169 | tree = json.load(fd) 170 | reproduced = self._reproduce_tree(tree) 171 | chex.assert_trees_all_close(tree["tree"], reproduced, atol=1e-3) 172 | 173 | def _reproduce_tree(self, tree): 174 | """Reproduces the given JSON tree by running a search.""" 175 | policy_fn = dict( 176 | gumbel_muzero=mctx.gumbel_muzero_policy, 177 | muzero=mctx.muzero_policy, 178 | )[tree["algorithm"]] 179 | 180 | env_config = tree["env_config"] 181 | root = tree["tree"] 182 | num_actions = len(root["child_stats"]) 183 | num_simulations = root["visit"] - 1 184 | qtransform = functools.partial( 185 | getattr(mctx, tree["algorithm_config"].pop("qtransform")), 186 | **tree["algorithm_config"].pop("qtransform_kwargs", {})) 187 | 188 | batch_size = 3 189 | # To test the independence of the batch computation, we use different 190 | # invalid actions for the other elements of the batch. The different batch 191 | # elements will then have different search tree depths. 192 | invalid_actions = np.zeros([batch_size, num_actions]) 193 | invalid_actions[1, 1:] = 1 194 | invalid_actions[2, 2:] = 1 195 | 196 | def run_policy(): 197 | return policy_fn( 198 | params=(), 199 | rng_key=jax.random.PRNGKey(1), 200 | root=_prepare_root(batch_size=batch_size, num_actions=num_actions), 201 | recurrent_fn=_prepare_recurrent_fn(num_actions, **env_config), 202 | num_simulations=num_simulations, 203 | qtransform=qtransform, 204 | invalid_actions=invalid_actions, 205 | **tree["algorithm_config"]) 206 | 207 | policy_output = jax.jit(run_policy)() # pylint: disable=not-callable 208 | logging.info("Done search.") 209 | 210 | return tree_to_pytree(policy_output.search_tree) 211 | 212 | 213 | if __name__ == "__main__": 214 | jax.config.update("jax_numpy_rank_promotion", "raise") 215 | absltest.main() 216 | -------------------------------------------------------------------------------- /mctx/_src/tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A data structure used to hold / inspect search data for a batch of inputs.""" 16 | 17 | from __future__ import annotations 18 | from typing import Any, ClassVar, Generic, TypeVar 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | @chex.dataclass(frozen=True) 29 | class Tree(Generic[T]): 30 | """State of a search tree. 31 | 32 | The `Tree` dataclass is used to hold and inspect search data for a batch of 33 | inputs. In the fields below `B` denotes the batch dimension, `N` represents 34 | the number of nodes in the tree, and `num_actions` is the number of discrete 35 | actions. 36 | 37 | node_visits: `[B, N]` the visit counts for each node. 38 | raw_values: `[B, N]` the raw value for each node. 39 | node_values: `[B, N]` the cumulative search value for each node. 40 | parents: `[B, N]` the node index for the parents for each node. 41 | action_from_parent: `[B, N]` action to take from the parent to reach each 42 | node. 43 | children_index: `[B, N, num_actions]` the node index of the children for each 44 | action. 45 | children_prior_logits: `[B, N, num_actions]` the action prior logits of each 46 | node. 47 | children_visits: `[B, N, num_actions]` the visit counts for children for 48 | each action. 49 | children_rewards: `[B, N, num_actions]` the immediate reward for each action. 50 | children_discounts: `[B, N, num_actions]` the discount between the 51 | `children_rewards` and the `children_values`. 52 | children_values: `[B, N, num_actions]` the value of the next node after the 53 | action. 54 | embeddings: `[B, N, ...]` the state embeddings of each node. 55 | root_invalid_actions: `[B, num_actions]` a mask with invalid actions at the 56 | root. In the mask, invalid actions have ones, and valid actions have zeros. 57 | extra_data: `[B, ...]` extra data passed to the search. 58 | """ 59 | node_visits: chex.Array # [B, N] 60 | raw_values: chex.Array # [B, N] 61 | node_values: chex.Array # [B, N] 62 | parents: chex.Array # [B, N] 63 | action_from_parent: chex.Array # [B, N] 64 | children_index: chex.Array # [B, N, num_actions] 65 | children_prior_logits: chex.Array # [B, N, num_actions] 66 | children_visits: chex.Array # [B, N, num_actions] 67 | children_rewards: chex.Array # [B, N, num_actions] 68 | children_discounts: chex.Array # [B, N, num_actions] 69 | children_values: chex.Array # [B, N, num_actions] 70 | embeddings: Any # [B, N, ...] 71 | root_invalid_actions: chex.Array # [B, num_actions] 72 | extra_data: T # [B, ...] 73 | 74 | # The following attributes are class variables (and should not be set on 75 | # Tree instances). 76 | ROOT_INDEX: ClassVar[int] = 0 77 | NO_PARENT: ClassVar[int] = -1 78 | UNVISITED: ClassVar[int] = -1 79 | 80 | @property 81 | def num_actions(self): 82 | return self.children_index.shape[-1] 83 | 84 | @property 85 | def num_simulations(self): 86 | return self.node_visits.shape[-1] - 1 87 | 88 | def qvalues(self, indices): 89 | """Compute q-values for any node indices in the tree.""" 90 | # pytype: disable=wrong-arg-types # jnp-type 91 | if jnp.asarray(indices).shape: 92 | return jax.vmap(_unbatched_qvalues)(self, indices) 93 | else: 94 | return _unbatched_qvalues(self, indices) 95 | # pytype: enable=wrong-arg-types 96 | 97 | def summary(self) -> SearchSummary: 98 | """Extract summary statistics for the root node.""" 99 | # Get state and action values for the root nodes. 100 | chex.assert_rank(self.node_values, 2) 101 | value = self.node_values[:, Tree.ROOT_INDEX] 102 | batch_size, = value.shape 103 | root_indices = jnp.full((batch_size,), Tree.ROOT_INDEX) 104 | qvalues = self.qvalues(root_indices) 105 | # Extract visit counts and induced probabilities for the root nodes. 106 | visit_counts = self.children_visits[:, Tree.ROOT_INDEX].astype(value.dtype) 107 | total_counts = jnp.sum(visit_counts, axis=-1, keepdims=True) 108 | visit_probs = visit_counts / jnp.maximum(total_counts, 1) 109 | visit_probs = jnp.where(total_counts > 0, visit_probs, 1 / self.num_actions) 110 | # Return relevant stats. 111 | return SearchSummary( # pytype: disable=wrong-arg-types # numpy-scalars 112 | visit_counts=visit_counts, 113 | visit_probs=visit_probs, 114 | value=value, 115 | qvalues=qvalues) 116 | 117 | 118 | def infer_batch_size(tree: Tree) -> int: 119 | """Recovers batch size from `Tree` data structure.""" 120 | if tree.node_values.ndim != 2: 121 | raise ValueError("Input tree is not batched.") 122 | chex.assert_equal_shape_prefix(jax.tree_util.tree_leaves(tree), 1) 123 | return tree.node_values.shape[0] 124 | 125 | 126 | # A number of aggregate statistics and predictions are extracted from the 127 | # search data and returned to the user for further processing. 128 | @chex.dataclass(frozen=True) 129 | class SearchSummary: 130 | """Stats from MCTS search.""" 131 | visit_counts: chex.Array 132 | visit_probs: chex.Array 133 | value: chex.Array 134 | qvalues: chex.Array 135 | 136 | 137 | def _unbatched_qvalues(tree: Tree, index: int) -> int: 138 | chex.assert_rank(tree.children_discounts, 2) 139 | return ( # pytype: disable=bad-return-type # numpy-scalars 140 | tree.children_rewards[index] 141 | + tree.children_discounts[index] * tree.children_values[index] 142 | ) 143 | -------------------------------------------------------------------------------- /mctx/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/mctx/f8cd07bcc5d7ff736ae4c1e4217d2001508f8353/mctx/py.typed -------------------------------------------------------------------------------- /requirements/requirements-test.txt: -------------------------------------------------------------------------------- 1 | absl-py>=2.3.1 2 | numpy>=1.24.1 3 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | chex>=0.1.91 2 | jax>=0.7.0 3 | jaxlib>=0.7.0 4 | -------------------------------------------------------------------------------- /requirements/requirements_examples.txt: -------------------------------------------------------------------------------- 1 | absl-py>=2.3.1 2 | pygraphviz>=1.7 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Install script for setuptools.""" 16 | 17 | import os 18 | from setuptools import find_namespace_packages 19 | from setuptools import setup 20 | 21 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | 23 | 24 | def _get_version(): 25 | with open('mctx/__init__.py') as fp: 26 | for line in fp: 27 | if line.startswith('__version__') and '=' in line: 28 | version = line[line.find('=') + 1:].strip(' \'"\n') 29 | if version: 30 | return version 31 | raise ValueError('`__version__` not defined in `mctx/__init__.py`') 32 | 33 | 34 | def _parse_requirements(path): 35 | 36 | with open(os.path.join(_CURRENT_DIR, path)) as f: 37 | return [ 38 | line.rstrip() 39 | for line in f 40 | if not (line.isspace() or line.startswith('#')) 41 | ] 42 | 43 | 44 | setup( 45 | name='mctx', 46 | version=_get_version(), 47 | url='https://github.com/google-deepmind/mctx', 48 | license='Apache 2.0', 49 | author='DeepMind', 50 | description=('Monte Carlo tree search in JAX.'), 51 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(), 52 | long_description_content_type='text/markdown', 53 | author_email='mctx-dev@google.com', 54 | keywords='jax planning reinforcement-learning python machine learning', 55 | packages=find_namespace_packages(exclude=['*_test.py']), 56 | install_requires=_parse_requirements( 57 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')), 58 | tests_require=_parse_requirements( 59 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements-test.txt')), 60 | zip_safe=False, # Required for full installation. 61 | python_requires='>=3.11', 62 | classifiers=[ 63 | 'Development Status :: 4 - Beta', 64 | 'Intended Audience :: Science/Research', 65 | 'License :: OSI Approved :: Apache Software License', 66 | 'Programming Language :: Python', 67 | 'Programming Language :: Python :: 3', 68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 69 | 'Topic :: Software Development :: Libraries :: Python Modules', 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Runs CI tests on a local machine. 17 | set -xeuo pipefail 18 | 19 | # Install deps in a virtual env. 20 | readonly VENV_DIR=/tmp/mctx-env 21 | rm -rf "${VENV_DIR}" 22 | python3 -m venv "${VENV_DIR}" 23 | source "${VENV_DIR}/bin/activate" 24 | python --version 25 | 26 | # Install dependencies. 27 | pip install --upgrade pip setuptools wheel 28 | pip install flake8 pytest-xdist pylint pylint-exit 29 | pip install -r requirements/requirements.txt 30 | pip install -r requirements/requirements-test.txt 31 | 32 | # Lint with flake8. 33 | flake8 `find mctx -name '*.py' | xargs` --count --select=E9,F63,F7,F82,E225,E251 --show-source --statistics 34 | 35 | # Lint with pylint. 36 | # Fail on errors, warning, and conventions. 37 | PYLINT_ARGS="-efail -wfail -cfail" 38 | # Lint modules and tests separately. 39 | pylint --rcfile=.pylintrc `find mctx -name '*.py' | grep -v 'test.py' | xargs` || pylint-exit $PYLINT_ARGS $? 40 | # Disable `protected-access` warnings for tests. 41 | pylint --rcfile=.pylintrc `find mctx -name '*_test.py' | xargs` -d W0212 || pylint-exit $PYLINT_ARGS $? 42 | 43 | # Build the package. 44 | python setup.py sdist 45 | pip wheel --verbose --no-deps --no-clean dist/mctx*.tar.gz 46 | pip install mctx*.whl 47 | 48 | # Check types with pytype. 49 | # Note: pytype does not support 3.12 as of 23.11.23 50 | # See https://github.com/google/pytype/issues/1308 51 | if [ `python -c 'import sys; print(sys.version_info.minor)'` -lt 12 ]; 52 | then 53 | pip install pytype 54 | pytype `find mctx/_src/ -name "*py" | xargs` -k 55 | fi; 56 | 57 | # Run tests using pytest. 58 | # Change directory to avoid importing the package from repo root. 59 | mkdir _testing && cd _testing 60 | 61 | # Run tests using pytest. 62 | pytest -n "$(grep -c ^processor /proc/cpuinfo)" --pyargs mctx 63 | cd .. 64 | 65 | set +u 66 | deactivate 67 | echo "All tests passed. Congrats!" 68 | --------------------------------------------------------------------------------