├── cfrx ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── cfr │ │ ├── __init__.py │ │ └── cfr.py │ └── mccfr │ │ ├── __init__.py │ │ ├── test_outcome_sampling.py │ │ └── outcome_sampling.py ├── trainers │ ├── __init__.py │ ├── cfr.py │ └── mccfr.py ├── envs │ ├── nlhe_poker │ │ ├── __init__.py │ │ ├── showdown.py │ │ ├── gui.py │ │ ├── env.py │ │ └── test_showdown.py │ ├── kuhn_poker │ │ ├── data │ │ │ ├── __init__.py │ │ │ └── info_states.npz │ │ ├── __init__.py │ │ ├── constants.py │ │ └── env.py │ ├── leduc_poker │ │ ├── data │ │ │ ├── __init__.py │ │ │ └── info_states.npz │ │ ├── __init__.py │ │ ├── constants.py │ │ └── env.py │ ├── __init__.py │ └── base.py ├── tree │ ├── __init__.py │ ├── tree.py │ └── traverse.py ├── metrics │ ├── __init__.py │ ├── exploitability.py │ ├── exploitability_test.py │ └── best_response.py ├── episode.py ├── utils.py └── policy.py ├── imgs └── bench_open_spiel.png ├── MANIFEST.in ├── mypy.ini ├── CONTRIBUTING.md ├── .pre-commit-config.yaml ├── LICENSE ├── pyproject.toml ├── .gitignore ├── README.md └── examples └── cfr.ipynb /cfrx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfrx/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfrx/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfrx/envs/nlhe_poker/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfrx/envs/kuhn_poker/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfrx/envs/leduc_poker/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/bench_open_spiel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egiob/cfrx/HEAD/imgs/bench_open_spiel.png -------------------------------------------------------------------------------- /cfrx/tree/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.tree.tree import Root, Tree 2 | 3 | __all__ = ["Root", "Tree"] 4 | -------------------------------------------------------------------------------- /cfrx/envs/kuhn_poker/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.envs.kuhn_poker.env import KuhnPoker 2 | 3 | __all__ = ["KuhnPoker"] 4 | -------------------------------------------------------------------------------- /cfrx/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.envs.base import Env, InfoState, State 2 | 3 | __all__ = ["Env", "State", "InfoState"] 4 | -------------------------------------------------------------------------------- /cfrx/envs/leduc_poker/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.envs.leduc_poker.env import LeducPoker 2 | 3 | __all__ = ["LeducPoker"] 4 | -------------------------------------------------------------------------------- /cfrx/envs/kuhn_poker/data/info_states.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egiob/cfrx/HEAD/cfrx/envs/kuhn_poker/data/info_states.npz -------------------------------------------------------------------------------- /cfrx/envs/leduc_poker/data/info_states.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egiob/cfrx/HEAD/cfrx/envs/leduc_poker/data/info_states.npz -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | recursive-include * *.npz 3 | 4 | # remove the test specific files 5 | recursive-exclude * *_test.py 6 | -------------------------------------------------------------------------------- /cfrx/algorithms/cfr/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.algorithms.cfr.cfr import CFRState, do_iteration 2 | 3 | __all__ = ["CFRState", "do_iteration"] 4 | -------------------------------------------------------------------------------- /cfrx/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.metrics.best_response import compute_best_response_value 2 | from cfrx.metrics.exploitability import exploitability 3 | -------------------------------------------------------------------------------- /cfrx/algorithms/mccfr/__init__.py: -------------------------------------------------------------------------------- 1 | from cfrx.algorithms.mccfr.outcome_sampling import MCCFRState, do_iteration 2 | 3 | __all__ = ["MCCFRState", "do_iteration"] 4 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | # Global options: 2 | 3 | [mypy] 4 | warn_return_any = True 5 | warn_unused_configs = True 6 | ignore_missing_imports = True 7 | follow_imports = skip 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contribution Guidelines 2 | 3 | All contributions are welcome, don't hesitate to send a PR or open an issue. 4 | 5 | ## License 6 | 7 | By contributing, you agree that your contributions will be licensed under MIT License. 8 | -------------------------------------------------------------------------------- /cfrx/envs/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import pgx.core 6 | from jaxtyping import Array, Int, PyTree 7 | 8 | InfoState = PyTree 9 | State = PyTree 10 | 11 | 12 | class BaseEnv(ABC): 13 | @abstractmethod 14 | def action_to_string(cls, action: Int[Array, ""]) -> str: 15 | pass 16 | 17 | 18 | class Env(BaseEnv, pgx.core.Env): 19 | pass 20 | -------------------------------------------------------------------------------- /cfrx/episode.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | from jaxtyping import Array, Bool, Float, Int 4 | 5 | from cfrx.envs.base import InfoState 6 | 7 | 8 | class Episode(NamedTuple): 9 | info_state: InfoState 10 | action: Float[Array, "..."] 11 | reward: Float[Array, "..."] 12 | action_mask: Bool[Array, "..."] 13 | current_player: Int[Array, "..."] 14 | behavior_prob: Float[Array, "..."] 15 | mask: Bool[Array, "..."] 16 | chance_node: Bool[Array, "..."] 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - id: requirements-txt-fixer 10 | - id: pretty-format-json 11 | exclude: \.ipynb 12 | - repo: https://github.com/psf/black 13 | rev: 23.11.0 # Use the sha / tag you want to point at 14 | hooks: 15 | - id: black 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | 21 | 22 | - repo: https://github.com/kynan/nbstripout 23 | rev: 0.7.1 24 | hooks: 25 | - id: nbstripout 26 | -------------------------------------------------------------------------------- /cfrx/envs/leduc_poker/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | current_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | INFO_SETS = dict( 8 | np.load( 9 | os.path.join( 10 | current_path, 11 | "data", 12 | "info_states.npz", 13 | ) 14 | ) 15 | ) 16 | 17 | 18 | # This is a bit hacky, it allows to transform an infostate into an index, and to 19 | # construct an array to efficiently lookup in Jax. 20 | multiplier = np.array([3**k for k in range(12)]) 21 | tr = {k: np.sum((v + 1) * multiplier) % 1235 for k, v in INFO_SETS.items()} 22 | max_value = max(list(tr.values())) 23 | REVERSE_INFO_SETS_LOOKUP = np.zeros(max_value + 1, dtype=int) 24 | REVERSE_INFO_SETS_LOOKUP[list(tr.values())] = np.arange(len(tr)) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Raphaël Boige 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cfrx/metrics/exploitability.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import pgx 4 | from jaxtyping import Array 5 | 6 | from cfrx.metrics.best_response import compute_best_response_value 7 | from cfrx.policy import Float, Policy 8 | from cfrx.tree import Root 9 | from cfrx.tree.traverse import instantiate_tree_from_root, traverse_tree_cfr 10 | 11 | 12 | def exploitability( 13 | env: pgx.Env, 14 | policy: Policy, 15 | policy_params: Array, 16 | n_players: int, 17 | n_max_nodes: int = 100, 18 | ) -> Float[Array, ""]: 19 | random_key = jax.random.PRNGKey(0) 20 | state = jax.jit(env.init)(random_key) 21 | 22 | values_list = [] 23 | root = Root( 24 | prior_logits=state.legal_action_mask / state.legal_action_mask.sum(), 25 | value=jnp.array(0.0), 26 | state=state, 27 | ) 28 | for traverser in range(n_players): 29 | tree = instantiate_tree_from_root( 30 | root=root, 31 | n_max_nodes=n_max_nodes, 32 | n_players=n_players, 33 | running_probabilities=True, 34 | ) 35 | 36 | tree = traverse_tree_cfr( 37 | tree, 38 | policy=policy, 39 | env=env, 40 | traverser=traverser, 41 | policy_params=policy_params, 42 | ) 43 | 44 | value = compute_best_response_value( 45 | tree, br_player=traverser, info_state_fn=env.info_state_idx 46 | ) 47 | values_list.append(value) 48 | 49 | values = jnp.stack(values_list) 50 | exploitability = jnp.array(values).sum() / n_players 51 | return exploitability 52 | -------------------------------------------------------------------------------- /cfrx/envs/kuhn_poker/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jax 4 | import numpy as np 5 | 6 | current_path = os.path.dirname(os.path.abspath(__file__)) 7 | 8 | INFO_SETS = dict( 9 | np.load( 10 | os.path.join( 11 | current_path, 12 | "data", 13 | "info_states.npz", 14 | ) 15 | ) 16 | ) 17 | 18 | INFO_SET_ACTION_MASK = { 19 | "0": [False, True, False, True], 20 | "0p": [False, True, False, True], 21 | "0b": [True, False, True, False], 22 | "0pb": [True, False, True, False], 23 | "1": [False, True, False, True], 24 | "1p": [False, True, False, True], 25 | "1b": [True, False, True, False], 26 | "1pb": [True, False, True, False], 27 | "2": [False, True, False, True], 28 | "2p": [False, True, False, True], 29 | "2b": [True, False, True, False], 30 | "2pb": [True, False, True, False], 31 | } 32 | 33 | 34 | def get_kuhn_optimal_policy(alpha: float) -> dict: 35 | assert 0 <= alpha <= 1 / 3 36 | optimal_probs = { 37 | "0": [0.0, alpha, 0.0, 1 - alpha], 38 | "0p": [0.0, 1 / 3, 0.0, 2 / 3], 39 | "0b": [0.0, 0.0, 1.0, 0.0], 40 | "0pb": [0.0, 0.0, 1.0, 0.0], 41 | "1": [0.0, 0.0, 0.0, 1.0], 42 | "1p": [0.0, 0.0, 0.0, 1.0], 43 | "1b": [1 / 3, 0.0, 2 / 3, 0.0], 44 | "1pb": [alpha + 1 / 3, 0.0, 2 / 3 - alpha, 0.0], 45 | "2": [0.0, 3 * alpha, 0, 1 - 3 * alpha], 46 | "2p": [0.0, 1.0, 0.0, 0.0], 47 | "2b": [1.0, 0.0, 0.0, 0.0], 48 | "2pb": [1.0, 0.0, 0.0, 0.0], 49 | } 50 | return optimal_probs 51 | 52 | 53 | KUHN_UNIFORM_POLICY = jax.tree_map( 54 | lambda x: x / x.sum(), {k: np.array(x) for k, x in INFO_SET_ACTION_MASK.items()} 55 | ) 56 | -------------------------------------------------------------------------------- /cfrx/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional 4 | 5 | import jax 6 | import jax.flatten_util 7 | import jax.numpy as jnp 8 | from jaxtyping import Array, Bool, Int, PyTree 9 | 10 | 11 | def get_action_mask(state: PyTree) -> Bool[Array, "..."]: 12 | chance_action_mask = state.chance_prior > 0 13 | decision_action_mask = state.legal_action_mask 14 | max_action_mask_size = max( 15 | chance_action_mask.shape[-1], decision_action_mask.shape[-1] 16 | ) 17 | chance_action_mask = jnp.pad( 18 | chance_action_mask, 19 | (0, max_action_mask_size - chance_action_mask.shape[-1]), 20 | constant_values=(False,), 21 | ) 22 | decision_action_mask = jnp.pad( 23 | decision_action_mask, 24 | (0, max_action_mask_size - decision_action_mask.shape[-1]), 25 | constant_values=(False,), 26 | ) 27 | action_mask = jnp.where( 28 | state.chance_node[..., None], 29 | chance_action_mask, 30 | decision_action_mask, 31 | ) 32 | return action_mask 33 | 34 | 35 | def reverse_array_lookup(x: Array, lookup_table: Array) -> Int[Array, ""]: 36 | return (lookup_table == x).all(axis=1).argmax() 37 | 38 | 39 | def tree_unstack(tree: PyTree) -> list[PyTree]: 40 | leaves, treedef = jax.tree_util.tree_flatten(tree) 41 | n_trees = leaves[0].shape[0] 42 | new_leaves: list = [[] for _ in range(n_trees)] 43 | for leaf in leaves: 44 | for i in range(n_trees): 45 | new_leaves[i].append(leaf[i]) 46 | new_trees = [treedef.unflatten(leaves) for leaves in new_leaves] 47 | return new_trees 48 | 49 | 50 | def ravel(tree: PyTree) -> Array: 51 | return jax.flatten_util.ravel_pytree(tree)[0] 52 | 53 | 54 | def regret_matching(regrets: Array) -> Array: 55 | positive_regrets = jnp.maximum(regrets, 0) 56 | n_actions = positive_regrets.shape[-1] 57 | sum_pos_regret = positive_regrets.sum(axis=-1, keepdims=True) 58 | dist = jnp.where( 59 | sum_pos_regret == 0, 1 / n_actions, positive_regrets / sum_pos_regret 60 | ) 61 | return dist 62 | 63 | 64 | def log_array(x: Array, name: Optional[str] = None) -> None: 65 | if name is None: 66 | name = "array" 67 | jax.debug.print(name + ": {x}", x=x) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "cfrx" 8 | description = "Counterfactual Regret Minimization in Jax" 9 | readme = "README.md" 10 | license = {file = "LICENSE"} 11 | urls = {repository = "https://github.com/Egiob/cfrx" } 12 | authors = [ 13 | {name="Raphaël Boige"}, 14 | ] 15 | requires-python = ">=3.9.0" 16 | version = "0.0.2" 17 | dependencies = [ 18 | "jaxtyping>=0.2.19", 19 | "jax>=0.4.0", 20 | "flax>=0.7.0", 21 | "pgx= =2.0.1", 22 | "tqdm~=4.66.0" 23 | ] 24 | keywords = ["jax", "game-theory", "reinforcement-learning", "cfr", "poker"] 25 | classifiers = [ 26 | "Development Status :: 3 - Alpha", 27 | "Intended Audience :: Developers", 28 | "Intended Audience :: Information Technology", 29 | "Intended Audience :: Science/Research", 30 | "License :: OSI Approved :: MIT License", 31 | "Natural Language :: English", 32 | "Programming Language :: Python :: 3", 33 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 34 | "Topic :: Scientific/Engineering :: Information Analysis", 35 | "Topic :: Scientific/Engineering :: Mathematics", 36 | ] 37 | 38 | [tool.setuptools] 39 | include-package-data = true 40 | 41 | [tool.setuptools.packages.find] 42 | where = ["."] 43 | include = ["*"] 44 | exclude = ["imgs"] 45 | 46 | [tool.isort] 47 | profile = "black" 48 | 49 | [tool.mypy] 50 | python_version = "3.10" 51 | namespace_packages = true 52 | incremental = false 53 | cache_dir = "" 54 | warn_redundant_casts = true 55 | warn_return_any = true 56 | warn_unused_configs = true 57 | warn_unused_ignores = false 58 | allow_redefinition = true 59 | disallow_untyped_calls = true 60 | disallow_untyped_defs = true 61 | disallow_incomplete_defs = true 62 | check_untyped_defs = true 63 | disallow_untyped_decorators = false 64 | strict_optional = true 65 | strict_equality = true 66 | explicit_package_bases = true 67 | follow_imports = "skip" 68 | ignore_missing_imports = true 69 | 70 | [tool.black] 71 | line-length = 89 72 | 73 | [tool.bumpver] 74 | current_version = "0.0.2" 75 | version_pattern = "MAJOR.MINOR.PATCH" 76 | commit_message = "build: bump version {old_version} -> {new_version}" 77 | commit = true 78 | tag = true 79 | push = false 80 | 81 | [tool.bumpver.file_patterns] 82 | "pyproject.toml" = ['current_version = "{version}"', 'version = "{version}"'] 83 | -------------------------------------------------------------------------------- /cfrx/metrics/exploitability_test.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pyspiel 7 | import pytest 8 | from jaxtyping import Array 9 | from open_spiel.python.algorithms.exploitability import ( 10 | exploitability as open_spiel_exploitability_fn, 11 | ) 12 | from open_spiel.python.algorithms.mccfr import AveragePolicy 13 | 14 | from cfrx.envs.kuhn_poker.constants import ( 15 | INFO_SETS, 16 | KUHN_UNIFORM_POLICY, 17 | get_kuhn_optimal_policy, 18 | ) 19 | from cfrx.envs.kuhn_poker.env import KuhnPoker 20 | from cfrx.metrics.exploitability import exploitability as cfrx_exploitability_fn 21 | from cfrx.policy import TabularPolicy 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "policy_dict", 26 | [ 27 | get_kuhn_optimal_policy(alpha=0.0), 28 | KUHN_UNIFORM_POLICY, 29 | get_kuhn_optimal_policy(alpha=0.15), 30 | ], 31 | ) 32 | def test_kuhn_exploitability_vs_open_spiel(policy_dict: Dict[str, Array]) -> None: 33 | info_states_dict = INFO_SETS 34 | default = np.ones_like(next(iter(policy_dict.values()))) 35 | policy = np.stack([policy_dict.get(x, default) for x in info_states_dict]) 36 | 37 | # cfrx 38 | policy_jax = jnp.asarray(policy) 39 | n_max_nodes = 200 40 | env = KuhnPoker() 41 | policy_obj = TabularPolicy( 42 | n_actions=policy.shape[1], info_state_idx_fn=env.info_state_idx 43 | ) 44 | n_players = 2 45 | cfrx_exploitability = cfrx_exploitability_fn( 46 | env, 47 | policy_params=policy_jax, 48 | n_players=n_players, 49 | n_max_nodes=n_max_nodes, 50 | policy=policy_obj, 51 | ) 52 | 53 | # open_spiel 54 | game = pyspiel.load_game("kuhn_poker") 55 | avg_probs = np.array(policy) 56 | avg_probs = np.concatenate( 57 | [ 58 | avg_probs[:, 2:4].sum(axis=1, keepdims=True), 59 | avg_probs[:, :2].sum(axis=1, keepdims=True), 60 | ], 61 | axis=1, 62 | ) 63 | 64 | open_spiel_info_states = [[None, x] for x in avg_probs] 65 | cfrx_to_open_spiel_info_states = dict(zip(INFO_SETS.keys(), open_spiel_info_states)) 66 | 67 | average_policy = AveragePolicy( 68 | game=game, player_ids=[0, 1], infostates=cfrx_to_open_spiel_info_states 69 | ) 70 | open_spiel_exploitability = open_spiel_exploitability_fn(game, average_policy) 71 | 72 | assert np.allclose(cfrx_exploitability, open_spiel_exploitability, atol=1e-5) 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cfrx: Counterfactual Regret Minimization in Jax. 2 | 3 | cfrx is an open-source library designed for efficient implementation of counterfactual regret minimization (CFR) algorithms using JAX. It focuses on computational speed and easy parallelization on hardware accelerators like GPUs and TPUs. 4 | 5 | Key Features: 6 | 7 | - **JIT Compilation for Speed:** cfrx makes the most out of JAX's just-in-time (JIT) compilation to minimize runtime overhead and maximize computational speed. 8 | 9 | - **Hardware Accelerator Support:** It supports parallelization on GPUs and TPUs, enabling efficient scaling of computations for large-scale problems. 10 | 11 | - **Python/JAX Ease of Use:** cfrx provides a Pythonic interface built on JAX, offering simplicity and accessibility compared to traditional C++ implementations or prohibitively slow pure-Python code. 12 | 13 | ## Installation 14 | 15 | pip install cfrx 16 | 17 | ## Getting started 18 | 19 | An example notebook is available [here](examples/mccfr.ipynb). 20 | 21 | Snippet for training a MCCFR-outcome sampling on the Kuhn Poker game. 22 | ```python3 23 | import jax 24 | 25 | from cfrx.envs.kuhn_poker.env import KuhnPoker 26 | from cfrx.policy import TabularPolicy 27 | from cfrx.trainers.mccfr import MCCFRTrainer 28 | 29 | env = KuhnPoker() 30 | 31 | policy = TabularPolicy( 32 | n_actions=env.n_actions, 33 | exploration_factor=0.6, 34 | info_state_idx_fn=env.info_state_idx, 35 | ) 36 | 37 | random_key = jax.random.PRNGKey(0) 38 | 39 | trainer = MCCFRTrainer(env=env, policy=policy) 40 | 41 | training_state, metrics = trainer.train( 42 | random_key=random_key, n_iterations=100_000, metrics_period=5_000 43 | ) 44 | ``` 45 | 46 | 47 | ## Implemented features and upcoming features 48 | 49 | | Algorithms | | 50 | |---|---| 51 | | MCCFR (outcome-sampling) | :white_check_mark: | 52 | | MCCFR (other variants) | :x: | 53 | | Vanilla CFR | :white_check_mark: | 54 | | Deep CFR | :x: | 55 | 56 | | Metrics | | 57 | |---|---| 58 | | Exploitability | :white_check_mark: | 59 | | Local Best Response | :x: | 60 | 61 | | Environments | | 62 | |---|---| 63 | | Kuhn Poker | :white_check_mark: | 64 | | Leduc Poker | :white_check_mark: | 65 | | Larger games | :x: | 66 | 67 | 68 | ## Performance 69 | 70 | Below is a small benchmark against `open_spiel` for MCCFR-outcome-sampling on Kuhn Poker and Leduc Poker. Compared to the Python API of `open_spiel`, `cfrx` has faster runtime and demonstrates similar convergence. 71 | 72 | ![benchmarck_against_open_spiel_img](imgs/bench_open_spiel.png) 73 | 74 | ## See also 75 | 76 | cfrx is heavily inspired by the amazing [google-deepmind/open_spiel](https://github.com/google-deepmind/open_spiel) library as well as by many projects from the Jax ecosystem and especially [sotetsuk/pgx](https://github.com/sotetsuk/pgx) and [google-deepmind/mctx](https://github.com/google-deepmind/mctx). 77 | 78 | 79 | ## Contributing 80 | 81 | Contributions are welcome, refer to the [contributions guidelines](CONTRIBUTING.md). 82 | -------------------------------------------------------------------------------- /cfrx/policy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Callable 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray 7 | 8 | from cfrx.envs.base import InfoState 9 | 10 | 11 | class Policy(ABC): 12 | _n_actions: int 13 | 14 | @abstractmethod 15 | def prob_distribution( 16 | self, 17 | params: Float[Array, "... a"], 18 | info_state: InfoState[Array, "..."], 19 | action_mask: Bool[Array, "... a"], 20 | use_behavior_policy: Bool[Array, "..."], 21 | ) -> Array: 22 | pass 23 | 24 | @abstractmethod 25 | def sample( 26 | self, 27 | params: Float[Array, "... a"], 28 | info_state: InfoState[Array, "..."], 29 | action_mask: Bool[Array, "... a"], 30 | random_key: PRNGKeyArray, 31 | use_behavior_policy: Bool[Array, "..."], 32 | ) -> Array: 33 | pass 34 | 35 | 36 | class TabularPolicy(Policy): 37 | def __init__( 38 | self, 39 | n_actions: int, 40 | info_state_idx_fn: Callable[[InfoState], Int[Array, ""]], 41 | exploration_factor: float = 0.6, 42 | ): 43 | self._n_actions = n_actions 44 | self._exploration_factor = exploration_factor 45 | self._info_state_idx_fn = info_state_idx_fn 46 | 47 | def prob_distribution( 48 | self, 49 | params: Float[Array, "... a"], 50 | info_state: InfoState[Array, "..."], 51 | action_mask: Bool[Array, "... a"], 52 | use_behavior_policy: Bool[Array, "..."], 53 | ) -> Array: 54 | info_state_idx = self._info_state_idx_fn(info_state) 55 | probs = params[info_state_idx] 56 | 57 | behavior_probabilities = ( 58 | probs * (1 - self._exploration_factor) 59 | + self._exploration_factor * jnp.ones_like(probs) / self._n_actions 60 | ) 61 | 62 | probs = jnp.where(use_behavior_policy, behavior_probabilities, probs) 63 | probs = probs * action_mask 64 | probs /= probs.sum(axis=-1, keepdims=True) 65 | return probs 66 | 67 | def sample( 68 | self, 69 | params: Float[Array, "... a"], 70 | info_state: InfoState[Array, "..."], 71 | action_mask: Bool[Array, "... a"], 72 | random_key: PRNGKeyArray, 73 | use_behavior_policy: Bool[Array, "..."], 74 | ) -> Array: 75 | info_state_idx = self._info_state_idx_fn(info_state) 76 | probs = params[info_state_idx] 77 | 78 | behavior_probabilities = ( 79 | probs * (1 - self._exploration_factor) 80 | + self._exploration_factor * jnp.ones_like(probs) / self._n_actions 81 | ) 82 | 83 | probs = jnp.where(use_behavior_policy, behavior_probabilities, probs) 84 | 85 | probs = probs * action_mask 86 | probs /= probs.sum(axis=-1, keepdims=True) 87 | action = jax.random.choice(random_key, jnp.arange(self._n_actions), p=probs) 88 | return action 89 | -------------------------------------------------------------------------------- /examples/cfr.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import matplotlib.pyplot as plt\n", 11 | "\n", 12 | "from cfrx.algorithms.cfr import CFRState\n", 13 | "from cfrx.policy import TabularPolicy\n", 14 | "from cfrx.trainers.cfr import CFRTrainer" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "1", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "ENV_NAME = \"Kuhn Poker\"" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "2", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "if ENV_NAME == \"Kuhn Poker\":\n", 35 | " from cfrx.envs.kuhn_poker.env import KuhnPoker\n", 36 | "\n", 37 | " env_cls = KuhnPoker\n", 38 | "\n", 39 | "\n", 40 | "elif ENV_NAME == \"Leduc Poker\":\n", 41 | " from cfrx.envs.leduc_poker.env import LeducPoker\n", 42 | "\n", 43 | " env_cls = LeducPoker" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "3", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "env = env_cls()" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "4", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "training_state = CFRState.init(n_states=env.n_info_states, n_actions=env.n_actions)\n", 64 | "policy = TabularPolicy(n_actions=env.n_actions, info_state_idx_fn=env.info_state_idx)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "5", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "trainer = CFRTrainer(env=env, policy=policy, device=\"cpu\")\n", 75 | "training_state, metrics = trainer.train(n_iterations=10000, metrics_period=10)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "6", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "plt.plot(metrics[\"step\"], metrics[\"exploitability\"])\n", 86 | "plt.yscale(\"log\")\n", 87 | "plt.xlabel(\"Iterations\")\n", 88 | "plt.title(f\"CFR on {ENV_NAME}\")\n", 89 | "plt.ylabel(\"Exploitability\")" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "7", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "Python 3 (ipykernel)", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.10.13" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 5 122 | } 123 | -------------------------------------------------------------------------------- /cfrx/tree/tree.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jaxtyping import Array, Float, Int, PyTree 6 | 7 | 8 | class classproperty: 9 | def __init__(self, func): 10 | self.fget = func 11 | 12 | def __get__(self, instance, owner): 13 | return self.fget(owner) 14 | 15 | 16 | class Root(NamedTuple): 17 | """ 18 | Base class to hold the root of a search tree. 19 | 20 | Args: 21 | prior_logits: `[n_actions]` the action prior logits. 22 | value: `[]` the value of the root node. 23 | state: `[...]` the state of the root node. 24 | """ 25 | 26 | prior_logits: Array 27 | value: Array 28 | state: Array 29 | 30 | 31 | UNVISITED = -1 32 | NO_PARENT = -1 33 | ROOT_INDEX = 0 34 | 35 | 36 | class Tree(NamedTuple): 37 | """ 38 | Adapted with minor modifications from https://github.com/google-deepmind/mctx. 39 | 40 | State of a search tree. 41 | 42 | The `Tree` dataclass is used to hold and inspect search data for a batch of 43 | inputs. In the fields below `N` represents the number of nodes in the tree, 44 | and `n_actions` is the number of discrete actions. 45 | 46 | node_visits: `[N]` the visit counts for each node. 47 | raw_values: `[N, n_players]` the raw value for each node. 48 | node_values: `[N, n_players]` the cumulative search value for each node. 49 | parents: `[N]` the node index for the parents for each node. 50 | action_from_parent: `[N]` action to take from the parent to reach each 51 | node. 52 | children_index: `[N, n_actions]` the node index of the children for each 53 | action. 54 | children_prior_logits: `[N, n_actions` the action prior logits of each 55 | node. 56 | children_visits: `[N, n_actions]` the visit counts for children for 57 | each action. 58 | children_rewards: `[N, n_actions, n_players]` the immediate reward for each 59 | action. 60 | children_values: `[N, n_actions, n_players]` the value of the next node after 61 | the action. 62 | states: `[N, ...]` the state embeddings of each node. 63 | depth: `[N]` the depth of each node in the tree. 64 | extra_data: `[...]` extra data passed to the tree. 65 | 66 | """ 67 | 68 | node_visits: Int[Array, "..."] 69 | raw_values: Float[Array, "... n_players"] 70 | node_values: Float[Array, "... n_players"] 71 | parents: Int[Array, "..."] 72 | action_from_parent: Int[Array, "..."] 73 | children_index: Int[Array, "... n_actions"] 74 | children_prior_logits: Float[Array, "... n_actions"] 75 | children_visits: Int[Array, "... n_actions"] 76 | children_rewards: Float[Array, "... n_actions n_players"] 77 | children_values: Float[Array, "... n_actions n_players"] 78 | states: PyTree 79 | depth: Int[Array, "..."] 80 | extra_data: dict[str, Array] 81 | to_visit: jax.Array 82 | 83 | @classproperty 84 | def ROOT_INDEX(cls) -> Int[Array, ""]: 85 | return jnp.asarray(ROOT_INDEX) 86 | 87 | @classproperty 88 | def NO_PARENT(cls) -> Int[Array, ""]: 89 | return jnp.asarray(NO_PARENT) 90 | 91 | @classproperty 92 | def UNVISITED(cls) -> Int[Array, ""]: 93 | return jnp.asarray(UNVISITED) 94 | 95 | @property 96 | def n_actions(self) -> int: 97 | return self.children_index.shape[-1] 98 | -------------------------------------------------------------------------------- /cfrx/trainers/cfr.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jaxtyping import Array, Int 3 | from tqdm import tqdm 4 | 5 | from cfrx.algorithms.cfr.cfr import CFRState, do_iteration 6 | from cfrx.envs import Env 7 | from cfrx.metrics import exploitability 8 | from cfrx.policy import TabularPolicy 9 | 10 | 11 | class CFRTrainer: 12 | def __init__(self, env: Env, policy: TabularPolicy, device: str): 13 | self._env = env 14 | self._policy = policy 15 | self._device = jax.devices(device)[0] 16 | self._exploitability_fn = jax.jit( 17 | lambda policy_params: exploitability( 18 | policy_params=policy_params, 19 | env=env, 20 | n_players=env.n_players, 21 | n_max_nodes=env.max_nodes, 22 | policy=policy, 23 | ), 24 | device=self._device, 25 | ) 26 | 27 | self._do_iteration_fn = jax.jit( 28 | lambda training_state, update_player: do_iteration( 29 | training_state=training_state, 30 | env=env, 31 | policy=policy, 32 | update_player=update_player, 33 | ), 34 | device=self._device, 35 | ) 36 | 37 | def do_n_iterations( 38 | self, 39 | training_state: CFRState, 40 | update_player: Int[Array, ""], 41 | n: int, 42 | ) -> tuple[CFRState, Int[Array, ""]]: 43 | def _scan_fn(carry, unused): 44 | training_state, update_player = carry 45 | 46 | update_player = (update_player + 1) % 2 47 | training_state = self._do_iteration_fn( 48 | training_state, 49 | update_player=update_player, 50 | ) 51 | 52 | return (training_state, update_player), None 53 | 54 | (new_training_state, last_update_player), _ = jax.lax.scan( 55 | _scan_fn, 56 | (training_state, update_player), 57 | None, 58 | length=n, 59 | ) 60 | 61 | return new_training_state, last_update_player 62 | 63 | def train( 64 | self, n_iterations: int, metrics_period: int 65 | ) -> tuple[CFRState, dict[str, Array]]: 66 | training_state = CFRState.init(self._env.n_info_states, self._env.n_actions) 67 | 68 | assert n_iterations % metrics_period == 0 69 | 70 | n_loops = n_iterations // metrics_period 71 | update_player = 0 72 | _do_n_iterations = jax.jit( 73 | lambda training_state, update_player: self.do_n_iterations( 74 | training_state=training_state, 75 | update_player=update_player, 76 | n=2 * metrics_period, 77 | ), 78 | device=self._device, 79 | ) 80 | metrics = [] 81 | 82 | pbar = tqdm(total=n_iterations, desc="Training", unit_scale=True) 83 | for k in range(n_loops): 84 | if k == 0: 85 | current_policy = training_state.avg_probs 86 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True) 87 | exp = self._exploitability_fn(policy_params=current_policy) 88 | metrics.append({"exploitability": exp, "step": 0}) 89 | pbar.set_postfix(exploitability=f"{exp:.1e}") 90 | 91 | # Do n iterations 92 | training_state, update_player = _do_n_iterations( 93 | training_state, update_player 94 | ) 95 | 96 | # Evaluate exploitability 97 | current_policy = training_state.avg_probs 98 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True) 99 | 100 | exp = self._exploitability_fn(policy_params=current_policy) 101 | metrics.append({"exploitability": exp, "step": k * metrics_period}) 102 | pbar.set_postfix(exploitability=f"{exp:.1e}") 103 | pbar.update(metrics_period) 104 | 105 | metrics = jax.tree_map(lambda *x: jax.numpy.stack(x), *metrics) 106 | return training_state, metrics 107 | -------------------------------------------------------------------------------- /cfrx/trainers/mccfr.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jaxtyping import Array, Int, PRNGKeyArray 3 | from tqdm import tqdm 4 | 5 | from cfrx.algorithms.mccfr.outcome_sampling import MCCFRState, do_iteration 6 | from cfrx.envs import Env 7 | from cfrx.metrics import exploitability 8 | from cfrx.policy import TabularPolicy 9 | 10 | 11 | class MCCFRTrainer: 12 | def __init__(self, env: Env, policy: TabularPolicy): 13 | self._env = env 14 | self._policy = policy 15 | 16 | self._exploitability_fn = jax.jit( 17 | lambda policy_params: exploitability( 18 | policy_params=policy_params, 19 | env=env, 20 | n_players=env.n_players, 21 | n_max_nodes=env.max_nodes, 22 | policy=policy, 23 | ) 24 | ) 25 | 26 | self._do_iteration_fn = jax.jit( 27 | lambda training_state, random_key, update_player: do_iteration( 28 | training_state=training_state, 29 | random_key=random_key, 30 | env=env, 31 | policy=policy, 32 | update_player=update_player, 33 | ) 34 | ) 35 | 36 | def do_n_iterations( 37 | self, 38 | training_state: MCCFRState, 39 | update_player: Int[Array, ""], 40 | random_key: PRNGKeyArray, 41 | n: int, 42 | ) -> tuple[MCCFRState, Int[Array, ""]]: 43 | def _scan_fn(carry, unused): 44 | training_state, random_key, update_player = carry 45 | 46 | random_key, subkey = jax.random.split(random_key) 47 | update_player = (update_player + 1) % 2 48 | training_state = self._do_iteration_fn( 49 | training_state, 50 | subkey, 51 | update_player=update_player, 52 | ) 53 | 54 | return (training_state, random_key, update_player), None 55 | 56 | (new_training_state, _, last_update_player), _ = jax.lax.scan( 57 | _scan_fn, 58 | (training_state, random_key, update_player), 59 | None, 60 | length=n, 61 | ) 62 | 63 | return new_training_state, last_update_player 64 | 65 | def train( 66 | self, n_iterations: int, metrics_period: int, random_key: PRNGKeyArray 67 | ) -> tuple[MCCFRState, dict[str, Array]]: 68 | training_state = MCCFRState.init(self._env.n_info_states, self._env.n_actions) 69 | 70 | assert n_iterations % metrics_period == 0 71 | 72 | n_loops = n_iterations // metrics_period 73 | update_player = 0 74 | _do_n_iterations = jax.jit( 75 | lambda training_state, update_player, random_key: self.do_n_iterations( 76 | training_state=training_state, 77 | update_player=update_player, 78 | random_key=random_key, 79 | n=2 * metrics_period, 80 | ) 81 | ) 82 | 83 | metrics = [] 84 | pbar = tqdm(total=n_iterations, desc="Training", unit_scale=True) 85 | for k in range(n_loops): 86 | if k == 0: 87 | current_policy = training_state.avg_probs 88 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True) 89 | exp = self._exploitability_fn(policy_params=current_policy) 90 | metrics.append({"exploitability": exp, "step": 0}) 91 | pbar.set_postfix(exploitability=f"{exp:.1e}") 92 | 93 | random_key, subkey = jax.random.split(random_key) 94 | 95 | # Do n iterations 96 | training_state, update_player = _do_n_iterations( 97 | training_state, update_player, subkey 98 | ) 99 | 100 | # Evaluate exploitability 101 | current_policy = training_state.avg_probs 102 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True) 103 | 104 | exp = self._exploitability_fn(policy_params=current_policy) 105 | metrics.append({"exploitability": exp, "step": (k + 1) * metrics_period}) 106 | pbar.set_postfix(exploitability=f"{exp:.1e}") 107 | pbar.update(metrics_period) 108 | 109 | metrics = jax.tree_map(lambda *x: jax.numpy.stack(x), *metrics) 110 | return training_state, metrics 111 | -------------------------------------------------------------------------------- /cfrx/metrics/best_response.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jaxtyping import Array, Float, Int 6 | 7 | from cfrx.tree import Tree 8 | 9 | 10 | def backward_one_infoset( 11 | tree: Tree, 12 | info_states: Array, 13 | current_infoset: Int[Array, ""], 14 | br_player: int, 15 | depth: Int[Array, ""], 16 | ) -> Tree: 17 | # Select all nodes in this infoset 18 | infoset_mask = ( 19 | (tree.depth == depth) 20 | & (~tree.states.terminated) 21 | & (info_states == current_infoset) 22 | ) 23 | 24 | is_br_player = ( 25 | (tree.states.current_player == br_player) 26 | & (~tree.states.chance_node) 27 | & infoset_mask 28 | ).any() 29 | 30 | p_opponent = tree.extra_data["p_opponent"] 31 | p_chance = tree.extra_data["p_chance"] 32 | p_self = tree.extra_data["p_self"] 33 | 34 | # Get expected values for each of the node in the infoset 35 | cf_reach_prob = (p_opponent * p_chance * p_self)[..., None] 36 | 37 | legal_action_mask = (tree.states.legal_action_mask & infoset_mask[..., None]).any( 38 | axis=0 39 | ) 40 | 41 | best_response_values = tree.children_values[..., br_player] * cf_reach_prob # (T, a) 42 | best_response_values = jnp.sum( 43 | best_response_values, axis=0, where=infoset_mask[..., None] # (a,) 44 | ) 45 | 46 | best_action = jnp.where(legal_action_mask, best_response_values, -jnp.inf).argmax() 47 | 48 | br_value = tree.children_values[:, best_action, br_player] 49 | 50 | expected_current_value = ( 51 | tree.children_values[..., br_player] * tree.children_prior_logits 52 | ).sum(axis=1) 53 | 54 | current_value = jnp.where(is_br_player, br_value, expected_current_value) 55 | 56 | new_node_values = jnp.where( 57 | infoset_mask, current_value, tree.node_values[..., br_player] 58 | ) 59 | new_children_values = jnp.where( 60 | tree.children_index != -1, new_node_values[tree.children_index], 0 61 | ) 62 | 63 | tree = tree._replace( 64 | node_values=tree.node_values.at[..., br_player].set(new_node_values), 65 | children_values=tree.children_values.at[..., br_player].set(new_children_values), 66 | ) 67 | 68 | return tree 69 | 70 | 71 | def backward_one_depth_level( 72 | tree: Tree, 73 | depth: Int[Array, ""], 74 | br_player: int, 75 | info_state_fn: Callable, 76 | ) -> Tree: 77 | info_states = jax.vmap(info_state_fn)(tree.states.info_state) 78 | 79 | def cond_fn(val: tuple[Tree, Array]) -> Array: 80 | tree, visited = val 81 | nodes_to_visit_idx = ( 82 | (tree.depth == depth) & (~tree.states.terminated) & (~visited) 83 | ) 84 | return nodes_to_visit_idx.any() 85 | 86 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]: 87 | tree, visited = val 88 | nodes_to_visit_idx = ( 89 | (tree.depth == depth) & (~tree.states.terminated) & (~visited) 90 | ) 91 | 92 | # Select an infoset to resolve 93 | selected_infoset_idx = nodes_to_visit_idx.argmax() 94 | selected_infoset = info_states[selected_infoset_idx] 95 | 96 | tree = backward_one_infoset( 97 | tree=tree, 98 | depth=depth, 99 | info_states=info_states, 100 | current_infoset=selected_infoset, 101 | br_player=br_player, 102 | ) 103 | 104 | visited = jnp.where(info_states == selected_infoset, True, visited) 105 | return tree, visited 106 | 107 | visited = jnp.zeros(tree.node_values.shape[0], dtype=bool) 108 | tree, visited = jax.lax.while_loop(cond_fn, loop_fn, (tree, visited)) 109 | 110 | return tree 111 | 112 | 113 | def compute_best_response_value( 114 | tree: Tree, 115 | br_player: int, 116 | info_state_fn: Callable, 117 | ) -> Float[Array, " num_players"]: 118 | depth = tree.depth.max() 119 | 120 | def cond_fn(val: tuple[Tree, Array]) -> Array: 121 | _, depth = val 122 | return depth >= 0 123 | 124 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]: 125 | tree, depth = val 126 | tree = backward_one_depth_level( 127 | tree=tree, depth=depth, br_player=br_player, info_state_fn=info_state_fn 128 | ) 129 | depth -= 1 130 | return tree, depth 131 | 132 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, depth)) 133 | return tree.node_values[0] 134 | -------------------------------------------------------------------------------- /cfrx/algorithms/mccfr/test_outcome_sampling.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import pytest 4 | 5 | from cfrx.algorithms.mccfr.outcome_sampling import ( 6 | MCCFRState, 7 | compute_regrets_and_strategy_profile, 8 | unroll, 9 | ) 10 | from cfrx.metrics import exploitability 11 | from cfrx.policy import TabularPolicy 12 | from cfrx.utils import regret_matching 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "env_name, num_iterations, target_exploitability", 17 | [ 18 | ("Kuhn Poker", 20000, 15e-3), 19 | ("Leduc Poker", 50000, 5e-1), 20 | ], 21 | ) 22 | def test_perf(env_name: str, num_iterations: int, target_exploitability: float): 23 | device = jax.devices("cpu")[0] 24 | 25 | random_key = jax.random.PRNGKey(0) 26 | if env_name == "Kuhn Poker": 27 | from cfrx.envs.kuhn_poker.constants import INFO_SETS 28 | from cfrx.envs.kuhn_poker.env import KuhnPoker 29 | 30 | env_cls = KuhnPoker 31 | EPISODE_LEN = 8 32 | NUM_MAX_NODES = 100 33 | 34 | elif env_name == "Leduc Poker": 35 | from cfrx.envs.leduc_poker.constants import INFO_SETS 36 | from cfrx.envs.leduc_poker.env import LeducPoker 37 | 38 | env_cls = LeducPoker 39 | EPISODE_LEN = 20 40 | NUM_MAX_NODES = 2000 41 | 42 | env = env_cls() 43 | 44 | n_states = len(INFO_SETS) 45 | n_actions = env.num_actions 46 | 47 | training_state = MCCFRState.init(n_states, n_actions) 48 | 49 | policy = TabularPolicy( 50 | n_actions=n_actions, 51 | exploration_factor=0.6, 52 | info_state_idx_fn=env.info_state_idx, 53 | ) 54 | 55 | def do_iteration(training_state, random_key, env, policy, update_player): 56 | """ 57 | Do one iteration of MCCFR: traverse the game tree once and 58 | compute counterfactual regrets and strategy profiles 59 | """ 60 | 61 | random_key, subkey = jax.random.split(random_key) 62 | 63 | # Sample one path in the game tree 64 | random_key, subkey = jax.random.split(random_key) 65 | episode, states = unroll( 66 | init_state=env.init(subkey), 67 | training_state=training_state, 68 | random_key=subkey, 69 | update_player=update_player, 70 | env=env, 71 | policy=policy, 72 | n_max_steps=EPISODE_LEN, 73 | ) 74 | 75 | # Compute counterfactual values and strategy profile 76 | ( 77 | info_states, 78 | sampled_regrets, 79 | sampled_avg_probs, 80 | ) = compute_regrets_and_strategy_profile( 81 | episode=episode, 82 | training_state=training_state, 83 | policy=policy, 84 | update_player=update_player, 85 | ) 86 | info_states_idx = jax.vmap(env.info_state_idx)(info_states) 87 | 88 | # Store regret and strategy profile values 89 | regrets = training_state.regrets.at[info_states_idx].add(sampled_regrets) 90 | avg_probs = training_state.avg_probs.at[info_states_idx].add(sampled_avg_probs) 91 | 92 | return regrets, avg_probs, episode 93 | 94 | do_iteration = jax.jit( 95 | do_iteration, static_argnames=("env", "policy"), device=device 96 | ) 97 | 98 | # This function measures the exploitability of a strategy 99 | exploitability_fn = jax.jit( 100 | lambda policy_params: exploitability( 101 | policy_params=policy_params, 102 | env=env, 103 | n_players=2, 104 | n_max_nodes=NUM_MAX_NODES, 105 | policy=policy, 106 | ), 107 | device=device, 108 | ) 109 | 110 | # One iteration consists in updating the policy for both players 111 | n_loops = 2 * num_iterations 112 | 113 | for k in range(n_loops): 114 | random_key, subkey = jax.random.split(random_key) 115 | 116 | # Update players alternatively 117 | update_player = k % 2 118 | new_regrets, new_avg_probs, episode = do_iteration( 119 | training_state, 120 | random_key, 121 | env=env, 122 | policy=policy, 123 | update_player=update_player, 124 | ) 125 | 126 | # Accumulate regrets, compute new strategy and avg strategy 127 | new_probs = regret_matching(new_regrets) 128 | new_probs /= new_probs.sum(axis=-1, keepdims=True) 129 | 130 | training_state = training_state._replace( 131 | regrets=new_regrets, 132 | probs=new_probs, 133 | avg_probs=new_avg_probs, 134 | step=training_state.step + 1, 135 | ) 136 | 137 | current_policy = training_state.avg_probs 138 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True) 139 | exp = exploitability_fn(policy_params=current_policy) 140 | 141 | assert np.allclose(exp, target_exploitability, rtol=1e-1, atol=1e-2) 142 | -------------------------------------------------------------------------------- /cfrx/envs/nlhe_poker/showdown.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def straight_flush_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 6 | suit_counts = jnp.bincount(suits, length=4) 7 | color = suit_counts.argmax() 8 | ranks = jnp.where(suits == color, ranks, -1) 9 | append_ace = jnp.where((ranks == 0).any(), 13, -1) 10 | values = jnp.unique(jnp.sort(ranks), size=7, fill_value=append_ace) 11 | diff = jnp.diff(values, append=append_ace) 12 | conv = jnp.convolve((diff == 1), jnp.ones(4), mode="same") 13 | straigth_idx = jnp.where(conv == 4, jnp.arange(7), -1).argmax() 14 | straight_rank = values[straigth_idx] 15 | return 8_000_000 + straight_rank 16 | 17 | 18 | def four_of_a_kind_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 19 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 20 | rank_counts = jnp.bincount(ranks, length=13) 21 | active_rank = rank_counts.argmax() 22 | remaining_ranks = jnp.where(ranks != active_rank, ranks, -1) 23 | remaining_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=1) 24 | return 7_000_000 + active_rank * 13 + remaining_score 25 | 26 | 27 | def full_house_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 28 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 29 | rank_counts = jnp.bincount(ranks, length=13) 30 | three_of_a_kind_mask = rank_counts >= 3 31 | active_rank_three = jnp.where(three_of_a_kind_mask, jnp.arange(13), -1).argmax() 32 | rank_counts = rank_counts.at[active_rank_three].set(0) 33 | pair_mask = rank_counts >= 2 34 | active_rank_pair = jnp.where(pair_mask, jnp.arange(13), -1).argmax() 35 | return 6_000_000 + active_rank_three * 13 + active_rank_pair 36 | 37 | 38 | def flush_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 39 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 40 | suit_counts = jnp.bincount(suits, length=4) 41 | color = suit_counts.argmax() 42 | colored_ranks = jnp.where(suits == color, ranks, -1) 43 | colored_ranks = colored_ranks.sort() 44 | score = high_card_score(colored_ranks, suits) 45 | return 5_000_000 + score 46 | 47 | 48 | def straight_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 49 | append_ace = jnp.where((ranks == 0).any(), 13, -1) 50 | values = jnp.unique(jnp.sort(ranks), size=7, fill_value=append_ace) 51 | diff = jnp.diff(values, append=append_ace) 52 | conv = jnp.convolve((diff == 1), jnp.ones(4), mode="same") 53 | straigth_idx = jnp.where(conv == 4, jnp.arange(7), -1).argmax() 54 | straight_rank = values[straigth_idx] 55 | return 4_000_000 + straight_rank 56 | 57 | 58 | def three_of_a_kind_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 59 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 60 | rank_counts = jnp.bincount(ranks, length=13) 61 | three_of_a_kind_mask = rank_counts >= 3 62 | active_rank = jnp.where(three_of_a_kind_mask, jnp.arange(13), -1).argmax() 63 | remaining_ranks = jnp.where(ranks != active_rank, ranks, -1) 64 | kicker_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=2) 65 | return 3_000_000 + active_rank * 13**2 + kicker_score 66 | 67 | 68 | def two_pair_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 69 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 70 | rank_counts = jnp.bincount(ranks, length=13) 71 | pair_mask = rank_counts >= 2 72 | pair_ranks = jnp.where(pair_mask, jnp.arange(13), -1) 73 | pair_ranks = jnp.argsort(pair_ranks)[-2:] 74 | pair_ranks = jnp.sort(pair_ranks)[::-1] 75 | active_rank_first_pair, active_rank_second_pair = ( 76 | pair_ranks[0], 77 | pair_ranks[1], 78 | ) 79 | rank_counts = rank_counts.at[active_rank_first_pair].set(0) 80 | rank_counts = rank_counts.at[active_rank_second_pair].set(0) 81 | remaining_ranks = jnp.where( 82 | (ranks != active_rank_first_pair) & (ranks != active_rank_second_pair), 83 | ranks, 84 | -1, 85 | ) 86 | kicker_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=1) 87 | return ( 88 | 2_000_000 89 | + active_rank_first_pair * 13**2 90 | + active_rank_second_pair * 13 91 | + kicker_score 92 | ) 93 | 94 | 95 | def one_pair_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 96 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 97 | rank_counts = jnp.bincount(ranks, length=13) 98 | pair_mask = rank_counts >= 2 99 | active_rank_pair = jnp.where(pair_mask, jnp.arange(13), -1).argmax() 100 | rank_counts = rank_counts.at[active_rank_pair].set(0) 101 | remaining_ranks = jnp.where(ranks != active_rank_pair, ranks, -1) 102 | kicker_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=3) 103 | return 1_000_000 + active_rank_pair * 13**3 + kicker_score 104 | 105 | 106 | def high_card_hand_score(ranks: jax.Array, suits: jax.Array) -> jax.Array: 107 | ranks = jnp.where(ranks == 0, 12, ranks - 1) 108 | return high_card_score(ranks, suits, n_start=0, n_end=5) 109 | 110 | 111 | def high_card_score( 112 | ranks: jax.Array, suits: jax.Array, n_start: int = 0, n_end: int = 5 113 | ) -> jax.Array: 114 | n = n_end - n_start 115 | ranks = ranks.sort() 116 | rank_scores = jnp.array([13**k for k in range(n_start, n_end)]) 117 | ranks = ranks[-n:] 118 | return (ranks * rank_scores).sum() 119 | 120 | 121 | def is_straight_fn(ranks): 122 | append_ace = jnp.where((ranks == 0).any(), 13, -1) 123 | values = jnp.unique(jnp.sort(ranks), size=7, fill_value=append_ace) 124 | diff = jnp.diff(values, append=append_ace) 125 | cond = (jnp.convolve((diff == 1), jnp.ones(4), mode="valid") == 4).any() 126 | return cond 127 | 128 | 129 | def is_straight_flush_fn(ranks, suits): 130 | suit_counts = jnp.bincount(suits, length=4) 131 | color = suit_counts.argmax() 132 | ranks = jnp.where(suits == color, ranks, -1) 133 | return is_straight_fn(ranks) 134 | 135 | 136 | def get_hand_type(ranks: jax.Array, suits: jax.Array) -> jax.Array: 137 | rank_counts = jnp.bincount(ranks, length=13) 138 | suit_counts = jnp.bincount(suits, length=4) 139 | 140 | is_straight_flush = is_straight_flush_fn(ranks, suits) 141 | higher = is_straight_flush 142 | index = 8 * is_straight_flush 143 | 144 | is_four_of_a_kind = ~higher & (rank_counts == 4).any() 145 | higher |= is_four_of_a_kind 146 | index += 7 * is_four_of_a_kind 147 | 148 | is_full = ~higher & (rank_counts == 3).any() & (rank_counts == 2).any() 149 | higher |= is_full 150 | index += 6 * is_full 151 | 152 | is_flush = ~higher & (suit_counts >= 5).any() 153 | higher |= is_flush 154 | index += 5 * is_flush 155 | 156 | is_straight = ~higher & is_straight_fn(ranks) 157 | higher |= is_straight 158 | index += 4 * is_straight 159 | 160 | is_three_of_a_kind = ~higher & (rank_counts == 3).any() 161 | higher |= is_three_of_a_kind 162 | index += 3 * is_three_of_a_kind 163 | 164 | is_two_pairs = ~higher & ((rank_counts == 2).sum() >= 2) 165 | higher |= is_two_pairs 166 | index += 2 * is_two_pairs 167 | 168 | is_one_pair = ~higher & (rank_counts == 2).any() 169 | higher |= is_one_pair 170 | index += 1 * is_one_pair 171 | 172 | return index 173 | 174 | 175 | def get_showdown_score(hand: jax.Array) -> jax.Array: 176 | hand = hand.astype(jnp.int32) 177 | ranks = hand % 13 178 | suits = hand // 13 179 | 180 | hand_type = get_hand_type(ranks, suits) 181 | 182 | score = jax.lax.switch( 183 | hand_type, 184 | [ 185 | high_card_hand_score, 186 | one_pair_score, 187 | two_pair_score, 188 | three_of_a_kind_score, 189 | straight_score, 190 | flush_score, 191 | full_house_score, 192 | four_of_a_kind_score, 193 | straight_flush_score, 194 | ], 195 | ranks, 196 | suits, 197 | ) 198 | 199 | return score 200 | -------------------------------------------------------------------------------- /cfrx/envs/nlhe_poker/gui.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from nicegui import app, ui 6 | 7 | from cfrx.envs.nlhe_poker.env import State, TexasHoldem 8 | 9 | SCALE = 0.5 10 | X0 = 400 11 | X1 = 80 12 | Y1 = 700 13 | OFF = 180 14 | DEALER_COORDS = [(3 * X0, 0), (3 * X0, 2500)] 15 | STACK_COORDS = [(120, 50), (120, 400)] 16 | BET_COORDS = [(220, 140), (220, 320)] 17 | POT_COORDS = (500, 220) 18 | 19 | COLORS = ["heart", "diamond", "club", "spade"] 20 | VALUES = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "jack", "queen", "king"] 21 | 22 | 23 | def create_svg( 24 | hands: list[tuple[int, int]], board: list[int], board_mask: int, dealer: int 25 | ) -> str: 26 | svg_open = """""" 32 | 33 | svg_close = "" 34 | 35 | svg_hands = "" 36 | for i, hand in enumerate(hands): 37 | for j, card in enumerate(hand): 38 | x = X0 + j * 40 39 | y = i * Y1 40 | svg_hands += f""" 41 | 43 | """ 44 | 45 | svg_board = "" 46 | for i, card in enumerate(board): 47 | x = X1 + i * OFF 48 | y = 350 49 | if i >= board_mask: 50 | svg_board += f""" 51 | 53 | """ 54 | else: 55 | svg_board += f""" 56 | 58 | """ 59 | 60 | dealer_coords = DEALER_COORDS[dealer] 61 | svg_dealer = f""" 62 | 64 | """ 65 | 66 | return f"{svg_open}{svg_hands}{svg_dealer}{svg_board}{svg_close}" 67 | 68 | 69 | def create_stack(player: int, stack: int) -> str: 70 | stack_html = f"""

72 | Player {player}
Stack: {stack}

""" 73 | 74 | return stack_html 75 | 76 | 77 | def create_pot(pot: int) -> str: 78 | pot_html = f"""

80 | Pot: {pot}

""" 81 | 82 | return pot_html 83 | 84 | 85 | def create_bet(player: int, bet: str, current_player: int) -> str: 86 | color = "red" if player == current_player else None 87 | 88 | bet_html = f"""
93 | {bet}
""" 94 | 95 | return bet_html 96 | 97 | 98 | def create_terminated(done: bool, reward: tuple[float, float]) -> str: 99 | if done: 100 | end_html = f"""
Terminated
P0: {reward[0]} 102 |
P1: {reward[1]} 103 |
""" 104 | else: 105 | end_html = "" 106 | 107 | return end_html 108 | 109 | 110 | def create_min_bet(min_bet: int, min_raise: int) -> str: 111 | min_bet_html = f"""
Min bet:
{min_bet} 113 |
Min raise:
{min_raise } 114 |
""" 115 | 116 | return min_bet_html 117 | 118 | 119 | def visualize(state: State, container: Any) -> None: 120 | with container: 121 | hands = [tuple(state.hands[0]), tuple(state.hands[1])] 122 | board = state.board.tolist() 123 | 124 | svg = create_svg( 125 | hands=hands, 126 | board=board, 127 | board_mask=int(state.board_mask), 128 | dealer=int(state.dealer), 129 | ) 130 | ui.html(svg) 131 | 132 | stack_0 = int(state.stacks[0]) 133 | 134 | stack = create_stack(0, stack_0) 135 | ui.html(stack) 136 | 137 | stack_1 = int(state.stacks[1]) 138 | 139 | stack = create_stack(1, stack_1) 140 | ui.html(stack) 141 | 142 | pot_value = int(state.bets.sum()) 143 | 144 | pot = create_pot(pot_value) 145 | ui.html(pot) 146 | 147 | current_bets = state.bets[state.current_round] 148 | current_mask = state.bets_mask[state.current_round] 149 | 150 | idx_0 = current_mask[0].sum() 151 | 152 | if idx_0 == 0: 153 | current_bet_0 = "" 154 | else: 155 | current_bet_0 = str(current_bets[0, :idx_0].sum()) 156 | 157 | idx_1 = current_mask[1].sum() 158 | if idx_1 == 0: 159 | current_bet_1 = "" 160 | else: 161 | current_bet_1 = str(current_bets[1, :idx_1].sum()) 162 | 163 | cr = int(state.current_player) 164 | 165 | bet = create_bet(0, current_bet_0, current_player=cr) 166 | ui.html(bet) 167 | 168 | bet = create_bet(1, current_bet_1, current_player=cr) 169 | ui.html(bet) 170 | 171 | done = create_terminated(bool(state.terminated), reward=tuple(state.rewards)) 172 | ui.html(done) 173 | 174 | min_bet = create_min_bet(int(state.min_bet), int(state.min_raise)) 175 | ui.html(min_bet) 176 | 177 | 178 | def visualize_previous(): 179 | global states 180 | global current_state_idx 181 | global container 182 | 183 | current_state_idx = max(current_state_idx - 1, 0) 184 | visualize(states[current_state_idx], container=container) 185 | 186 | 187 | def visualize_next(): 188 | global states 189 | global current_state_idx 190 | global container 191 | 192 | current_state_idx = min(current_state_idx + 1, len(states) - 1) 193 | visualize(states[current_state_idx], container=container) 194 | 195 | 196 | def delete_last(): 197 | global states 198 | global current_state_idx 199 | global container 200 | if len(states) > 1: 201 | states.pop(-1) 202 | current_state_idx = max(len(states) - 1, 0) 203 | visualize(states[current_state_idx], container=container) 204 | 205 | 206 | def ui_bet(): 207 | global states 208 | global current_state_idx 209 | global env 210 | global slider 211 | global container 212 | 213 | state = env.step(states[-1], action=jnp.array(slider.value)) 214 | states.append(state) 215 | current_state_idx = len(states) - 1 216 | visualize(states[current_state_idx], container=container) 217 | 218 | 219 | if __name__ in {"__main__", "__mp_main__"}: 220 | app.add_static_files("/static", "static") 221 | env = TexasHoldem() 222 | 223 | container = ui.card().tight() 224 | 225 | with ui.row().style("position: absolute; top:600px"): 226 | with ui.button_group(): 227 | ui.button("<", on_click=visualize_previous) 228 | ui.button(">", on_click=visualize_next) 229 | ui.button("DEL", on_click=delete_last) 230 | 231 | slider = ui.slider(min=-1, max=100, value=0).style("width: 200px") 232 | ui.label().bind_text_from(slider, "value") 233 | ui.button("BET", on_click=ui_bet) 234 | 235 | state = env.init(jax.random.PRNGKey(1), initial_stacks=jnp.array([100, 50])) 236 | 237 | states = [state] 238 | 239 | state = env.step(state, action=jnp.array(1.0)) 240 | states.append(state) 241 | 242 | state = env.step(state, action=jnp.array(10.0)) 243 | states.append(state) 244 | visualize(state, container=container) 245 | current_state_idx = len(states) - 1 246 | 247 | ui.run(show=False) 248 | -------------------------------------------------------------------------------- /cfrx/envs/kuhn_poker/env.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pgx 7 | import pgx.kuhn_poker 8 | from jaxtyping import Array, Bool, Int, PRNGKeyArray 9 | from pgx._src.dwg.kuhn_poker import CARD 10 | from pgx._src.struct import dataclass 11 | 12 | import cfrx 13 | from cfrx.envs.kuhn_poker.constants import INFO_SETS 14 | from cfrx.utils import ravel, reverse_array_lookup 15 | 16 | CARD.append("?") 17 | NUM_DIFFERENT_CARDS = 3 18 | NUM_REPEAT_CARDS = 1 19 | NUM_TOTAL_CARDS = NUM_DIFFERENT_CARDS * NUM_REPEAT_CARDS 20 | INFO_SETS_VALUES = np.stack(list(INFO_SETS.values())) 21 | 22 | 23 | class InfoState(NamedTuple): 24 | private_card: Int[Array, "..."] 25 | action_sequence: Int[Array, "..."] 26 | chance_round: Int[Array, "..."] 27 | chance_node: Bool[Array, "..."] 28 | 29 | 30 | @dataclass 31 | class State(pgx.kuhn_poker.State): 32 | info_state: InfoState = InfoState( 33 | private_card=jnp.int8(-1), 34 | action_sequence=jnp.ones(2, dtype=jnp.int8) * -1, 35 | chance_round=jnp.int8(0), 36 | chance_node=jnp.bool_(True), 37 | ) 38 | chance_node: Bool[Array, "..."] = jnp.bool_(False) 39 | chance_prior: Int[Array, "..."] = ( 40 | jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS 41 | ) 42 | 43 | 44 | class KuhnPoker(pgx.kuhn_poker.KuhnPoker, cfrx.envs.Env): 45 | @classmethod 46 | def action_to_string(cls, action: Int[Array, ""]) -> str: 47 | strings = ["b", "p"] 48 | if action != -1: 49 | a = int(action) // 2 50 | rep = strings[a] 51 | else: 52 | rep = "?" 53 | return rep 54 | 55 | @property 56 | def max_episode_length(self) -> int: 57 | return 6 58 | 59 | @property 60 | def max_nodes(self) -> int: 61 | return 60 62 | 63 | @property 64 | def n_info_states(self) -> int: 65 | return len(INFO_SETS) 66 | 67 | @property 68 | def n_actions(self) -> int: 69 | return self.num_actions 70 | 71 | @property 72 | def n_players(self) -> int: 73 | return self.num_players 74 | 75 | def update_info_state( 76 | self, state: State, next_state: State, action: Int[Array, ""] 77 | ) -> InfoState: 78 | info_state = next_state.info_state 79 | 80 | private_card = jnp.where( 81 | next_state.chance_node, -1, next_state._cards[next_state.current_player] 82 | ) 83 | current_position = (info_state.action_sequence != -1).sum() 84 | action_sequence = info_state.action_sequence.at[current_position].set( 85 | jnp.int8(action) 86 | ) 87 | 88 | action_sequence = jnp.where(state.chance_node, -1, action_sequence) 89 | chance_round = info_state.chance_round + state.chance_node 90 | info_state = info_state._replace( 91 | private_card=private_card, 92 | action_sequence=action_sequence, 93 | chance_round=chance_round, 94 | chance_node=next_state.chance_node, 95 | ) 96 | return info_state 97 | 98 | def info_state_to_str(self, info_state: InfoState) -> str: 99 | if info_state.chance_node: 100 | rep = f"chance{info_state.chance_round}:" 101 | 102 | else: 103 | rep = "" 104 | 105 | strings = ["b", "p"] 106 | rep += f"{info_state.private_card}" 107 | 108 | for action in np.array(info_state.action_sequence): 109 | action = action // 2 110 | if action != -1: 111 | rep += strings[action] 112 | 113 | return rep 114 | 115 | def get_action_mask(self, state: State) -> jax.Array: 116 | return state.legal_action_mask 117 | 118 | def get_chance_mask(self, state: State) -> jax.Array: 119 | return state.chance_prior > 0 120 | 121 | def get_info_state(self, state: State) -> jax.Array: 122 | return state.info_state 123 | 124 | def info_state_idx(self, info_state: InfoState) -> Int[Array, ""]: 125 | info_state_ravel = ravel(info_state) 126 | return reverse_array_lookup(info_state_ravel, jnp.asarray(INFO_SETS_VALUES)) 127 | 128 | def _init(self, rng: PRNGKeyArray) -> State: 129 | env_state = super()._init(rng) 130 | 131 | info_state = InfoState( 132 | private_card=jnp.int8(-1), 133 | action_sequence=jnp.ones(2, dtype=jnp.int8) * -1, 134 | chance_round=jnp.int8(0), 135 | chance_node=jnp.bool_(True), 136 | ) 137 | 138 | return State( 139 | current_player=env_state.current_player.astype(jnp.int8), 140 | observation=env_state.observation, 141 | rewards=env_state.rewards, 142 | terminated=env_state.terminated, 143 | truncated=env_state.truncated, 144 | _step_count=env_state._step_count, 145 | _last_action=env_state._last_action, 146 | _cards=jnp.int8([-1, -1]), 147 | legal_action_mask=jnp.bool_([1, 1, 1, 1]), 148 | _pot=env_state._pot, 149 | info_state=info_state, 150 | chance_node=jnp.bool_(True), 151 | chance_prior=jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS, 152 | ) 153 | 154 | def _resolve_chance_node( 155 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray 156 | ) -> State: 157 | draw_player = NUM_TOTAL_CARDS - state.chance_prior.sum() 158 | 159 | cards = state._cards.at[draw_player].set(action.astype(jnp.int8)) 160 | chance_prior = state.chance_prior.at[action].add(-1) 161 | chance_node = (cards == -1).any() 162 | 163 | legal_action_mask = jnp.where( 164 | chance_node, state.legal_action_mask, jnp.bool_([0, 1, 0, 1]) 165 | ) 166 | 167 | return State( 168 | current_player=state.current_player.astype(jnp.int8), 169 | observation=state.observation, 170 | rewards=state.rewards, 171 | terminated=state.terminated, 172 | truncated=state.truncated, 173 | _step_count=state._step_count, 174 | _last_action=state._last_action, 175 | _cards=cards, 176 | legal_action_mask=legal_action_mask, 177 | _pot=state._pot, 178 | info_state=state.info_state, 179 | chance_node=chance_node, 180 | chance_prior=chance_prior, 181 | ) 182 | 183 | def _resolve_decision_node( 184 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray 185 | ) -> State: 186 | env_state = super()._step(state=state, action=action, key=random_key) 187 | print(env_state.legal_action_mask.shape) 188 | return State( 189 | current_player=env_state.current_player.astype(jnp.int8), 190 | observation=env_state.observation, 191 | rewards=env_state.rewards, 192 | terminated=env_state.terminated, 193 | truncated=env_state.truncated, 194 | _step_count=env_state._step_count, 195 | _last_action=env_state._last_action, 196 | _cards=env_state._cards, 197 | legal_action_mask=env_state.legal_action_mask, 198 | _pot=env_state._pot, 199 | info_state=env_state.info_state, 200 | chance_node=jnp.bool_(False), 201 | chance_prior=env_state.chance_prior, 202 | ) 203 | 204 | def _step( 205 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray 206 | ) -> State: 207 | new_state = jax.lax.cond( 208 | state.chance_node, 209 | lambda: self._resolve_chance_node( 210 | state=state, action=action, random_key=random_key 211 | ), 212 | lambda: self._resolve_decision_node( 213 | state=state, action=action, random_key=random_key 214 | ), 215 | ) 216 | info_state = self.update_info_state( 217 | state=state, next_state=new_state, action=action 218 | ) 219 | return new_state.replace(info_state=info_state) 220 | -------------------------------------------------------------------------------- /cfrx/algorithms/cfr/cfr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Callable 4 | from typing import NamedTuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jaxtyping import Array, Float, Int 9 | 10 | from cfrx.envs import Env 11 | from cfrx.policy import TabularPolicy 12 | from cfrx.tree import Tree 13 | from cfrx.tree.traverse_old import instantiate_tree_from_root, traverse_tree_cfr 14 | from cfrx.tree.tree_old import Root 15 | from cfrx.utils import regret_matching 16 | 17 | 18 | class CFRState(NamedTuple): 19 | regrets: Float[Array, "*batch a"] 20 | probs: Float[Array, "... a"] 21 | avg_probs: Float[Array, "... a"] 22 | step: Int[Array, "..."] 23 | 24 | @classmethod 25 | def init(cls, n_states: int, n_actions: int) -> CFRState: 26 | return CFRState( 27 | regrets=jnp.zeros((n_states, n_actions)), 28 | probs=jnp.ones((n_states, n_actions)) 29 | / jnp.ones((n_states, n_actions)).sum(axis=-1, keepdims=True), 30 | avg_probs=jnp.zeros((n_states, n_actions)) + 1e-6, 31 | step=jnp.array(1, dtype=int), 32 | ) 33 | 34 | 35 | def backward_one_infoset( 36 | tree: Tree, 37 | info_states: Array, 38 | current_infoset: Int[Array, ""], 39 | cfr_player: int, 40 | depth: Int[Array, ""], 41 | ) -> Tree: 42 | # Select all nodes in this infoset 43 | infoset_mask = ( 44 | (tree.depth == depth) 45 | & (~tree.states.terminated) 46 | & (info_states == current_infoset) 47 | ) 48 | 49 | is_cfr_player = ( 50 | (tree.states.current_player == cfr_player) 51 | & (~tree.states.chance_node) 52 | & infoset_mask 53 | ).any() 54 | 55 | p_opponent = tree.extra_data["p_opponent"] 56 | p_chance = tree.extra_data["p_chance"] 57 | p_self = tree.extra_data["p_self"] 58 | 59 | # Get expected values for each of the node in the infoset 60 | cf_reach_prob = p_opponent * p_chance 61 | 62 | legal_action_mask = (tree.states.legal_action_mask & infoset_mask[..., None]).any( 63 | axis=0 64 | ) 65 | 66 | cf_state_action_values = ( 67 | tree.children_values[..., cfr_player] 68 | * cf_reach_prob[..., None] 69 | * legal_action_mask 70 | ) # (T, a) 71 | cf_state_action_values = jnp.sum( 72 | cf_state_action_values, axis=0, where=infoset_mask[..., None] # (a,) 73 | ) 74 | 75 | expected_values = ( 76 | tree.children_values[..., cfr_player] * tree.children_prior_logits 77 | ).sum(axis=1) 78 | 79 | cf_state_values = jnp.sum( 80 | expected_values * cf_reach_prob, axis=0, where=infoset_mask 81 | ) # (,) 82 | 83 | new_node_values = jnp.where( 84 | infoset_mask, expected_values, tree.node_values[..., cfr_player] 85 | ) 86 | 87 | new_children_values = jnp.where( 88 | tree.children_index != -1, new_node_values[tree.children_index], 0 89 | ) 90 | 91 | regrets = ( 92 | cf_state_action_values - cf_state_values[..., None] 93 | ) * legal_action_mask # (a,) 94 | 95 | strategy_profile = jnp.sum( 96 | p_self[..., None] * tree.children_prior_logits, 97 | axis=0, 98 | where=infoset_mask[..., None], 99 | ) # (a,) 100 | 101 | tree = tree._replace( 102 | node_values=tree.node_values.at[..., cfr_player].set(new_node_values), 103 | children_values=tree.children_values.at[..., cfr_player].set( 104 | new_children_values 105 | ), 106 | extra_data={ 107 | **tree.extra_data, 108 | "regrets": jnp.where( 109 | infoset_mask[..., None] & is_cfr_player, 110 | tree.extra_data["regrets"] + regrets, 111 | tree.extra_data["regrets"], 112 | ), 113 | "strategy_profile": jnp.where( 114 | infoset_mask[..., None] & is_cfr_player, 115 | tree.extra_data["strategy_profile"] + strategy_profile, 116 | tree.extra_data["strategy_profile"], 117 | ), 118 | }, 119 | ) 120 | 121 | return tree 122 | 123 | 124 | def backward_one_depth_level( 125 | tree: Tree, 126 | depth: Int[Array, ""], 127 | cfr_player: int, 128 | info_state_fn: Callable, 129 | ) -> Tree: 130 | info_states = jax.vmap(info_state_fn)(tree.states.info_state) 131 | 132 | def cond_fn(val: tuple[Tree, Array]) -> Array: 133 | tree, visited = val 134 | nodes_to_visit_idx = ( 135 | (tree.depth == depth) & (~tree.states.terminated) & (~visited) 136 | ) 137 | return nodes_to_visit_idx.any() 138 | 139 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]: 140 | tree, visited = val 141 | nodes_to_visit_idx = ( 142 | (tree.depth == depth) & (~tree.states.terminated) & (~visited) 143 | ) 144 | 145 | # Select an infoset to resolve 146 | selected_infoset_idx = nodes_to_visit_idx.argmax() 147 | selected_infoset = info_states[selected_infoset_idx] 148 | 149 | tree = backward_one_infoset( 150 | tree=tree, 151 | depth=depth, 152 | info_states=info_states, 153 | current_infoset=selected_infoset, 154 | cfr_player=cfr_player, 155 | ) 156 | 157 | visited = jnp.where(info_states == selected_infoset, True, visited) 158 | return tree, visited 159 | 160 | visited = jnp.zeros(tree.node_values.shape[0], dtype=bool) 161 | tree, visited = jax.lax.while_loop(cond_fn, loop_fn, (tree, visited)) 162 | 163 | return tree 164 | 165 | 166 | def backward_cfr( 167 | tree: Tree, 168 | cfr_player: int, 169 | info_state_fn: Callable, 170 | ) -> Tree: 171 | depth = tree.depth.max() 172 | 173 | def cond_fn(val: tuple[Tree, Array]) -> Array: 174 | _, depth = val 175 | return depth >= 0 176 | 177 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]: 178 | tree, depth = val 179 | tree = backward_one_depth_level( 180 | tree=tree, depth=depth, cfr_player=cfr_player, info_state_fn=info_state_fn 181 | ) 182 | depth -= 1 183 | return tree, depth 184 | 185 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, depth)) 186 | return tree 187 | 188 | 189 | def do_iteration( 190 | training_state: CFRState, 191 | update_player: Int[Array, ""], 192 | env: Env, 193 | policy: TabularPolicy, 194 | ) -> CFRState: 195 | s0 = env.init(jax.random.PRNGKey(0)) 196 | root = Root( 197 | prior_logits=s0.legal_action_mask * 1.0, 198 | value=jnp.zeros(1, dtype=float), 199 | state=s0, 200 | ) 201 | 202 | tree = instantiate_tree_from_root( 203 | root, 204 | n_max_nodes=env.max_nodes, 205 | n_players=env.n_players, 206 | running_probabilities=True, 207 | ) 208 | 209 | tree = traverse_tree_cfr( 210 | tree, 211 | policy=policy, 212 | policy_params=training_state.probs, 213 | env=env, 214 | traverser=update_player, 215 | ) 216 | 217 | infoset_idx = jax.vmap(env.info_state_idx)(tree.states.info_state) 218 | squash_idx = jnp.unique(infoset_idx, size=env.n_info_states, return_index=True)[1] 219 | 220 | tree_regrets = training_state.regrets.take(infoset_idx, axis=0) 221 | tree_strategies = training_state.avg_probs.take(infoset_idx, axis=0) 222 | 223 | tree = tree._replace( 224 | extra_data={ 225 | **tree.extra_data, 226 | "regrets": tree_regrets, 227 | "strategy_profile": tree_strategies, 228 | } 229 | ) 230 | 231 | tree = backward_cfr( 232 | tree=tree, cfr_player=update_player, info_state_fn=env.info_state_idx 233 | ) 234 | 235 | regrets = tree.extra_data["regrets"][squash_idx] 236 | strategy_profile = tree.extra_data["strategy_profile"][squash_idx] 237 | probs = regret_matching(regrets) 238 | new_state = training_state._replace( 239 | regrets=regrets, 240 | probs=probs, 241 | avg_probs=strategy_profile, 242 | step=training_state.step + 1, 243 | ) 244 | return new_state 245 | -------------------------------------------------------------------------------- /cfrx/envs/nlhe_poker/env.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import cfrx 7 | from cfrx.envs.nlhe_poker.showdown import get_showdown_score 8 | 9 | NUM_PLAYERS = 2 10 | 11 | 12 | class State(NamedTuple): 13 | board: jax.Array 14 | board_mask: jax.Array 15 | hands: jax.Array 16 | stacks: jax.Array 17 | bets: jax.Array 18 | bets_mask: jax.Array 19 | current_player: jax.Array 20 | min_bet: jax.Array 21 | min_raise: jax.Array 22 | current_round: jax.Array 23 | terminated: jax.Array 24 | small_blind: jax.Array 25 | showdown_result: jax.Array 26 | dealer: jax.Array 27 | rewards: jax.Array 28 | 29 | 30 | class InfoState(NamedTuple): 31 | hand: jax.Array 32 | board: jax.Array 33 | board_mask: jax.Array 34 | bets: jax.Array 35 | bets_mask: jax.Array 36 | stacks: jax.Array 37 | 38 | 39 | class TexasHoldem: 40 | def __init__(self, num_max_bets: int = 4): 41 | self.num_max_bets = num_max_bets 42 | self._num_visibles = jnp.array([0, 3, 4, 5], dtype=jnp.uint8) 43 | 44 | @classmethod 45 | def action_to_string(cls, action: jax.Array) -> str: 46 | raise NotImplementedError 47 | 48 | @property 49 | def max_episode_length(self) -> int: 50 | return 4 * self.num_max_bets 51 | 52 | @property 53 | def max_nodes(self) -> int: 54 | raise NotImplementedError 55 | 56 | @property 57 | def n_info_states(self) -> int: 58 | raise NotImplementedError 59 | 60 | def init( 61 | self, 62 | random_key: jax.Array, 63 | dealer_player: int = 0, 64 | initial_stacks: jax.Array | None = None, 65 | small_blind: float = 1.0, 66 | ): 67 | deck = jax.random.permutation(random_key, jnp.arange(52, dtype=jnp.uint8)) 68 | 69 | hands = jnp.zeros((NUM_PLAYERS, 2), dtype=jnp.uint8) 70 | for i in range(NUM_PLAYERS): 71 | cards = deck[i * 2 : (i + 1) * 2] 72 | hands = hands.at[i].set(cards) 73 | 74 | board = deck[NUM_PLAYERS * 2 : NUM_PLAYERS * 2 + 5].astype(jnp.uint8) 75 | 76 | board_mask = jnp.array(0) 77 | 78 | if initial_stacks is None: 79 | stacks = jnp.ones(NUM_PLAYERS, dtype=float) * 100 * small_blind 80 | else: 81 | stacks = initial_stacks 82 | 83 | sb_player = (dealer_player + 1) % NUM_PLAYERS 84 | bb_player = (dealer_player + 2) % NUM_PLAYERS 85 | first_player = (dealer_player + 3) % NUM_PLAYERS 86 | 87 | bets = jnp.zeros((4, NUM_PLAYERS, self.num_max_bets), dtype=float) 88 | 89 | sb_bet = jnp.minimum(small_blind, stacks[sb_player]) 90 | bb_bet = jnp.minimum(small_blind * 2, stacks[bb_player]) 91 | bets = bets.at[0, sb_player, 0].set(sb_bet) 92 | bets = bets.at[0, bb_player, 0].set(bb_bet) 93 | stacks = stacks.at[sb_player].set(stacks[sb_player] - sb_bet) 94 | stacks = stacks.at[bb_player].set(stacks[bb_player] - bb_bet) 95 | 96 | bets_mask = bets > 0.0 97 | 98 | showdown_result = self._resolve_showdown(board, hands) 99 | 100 | state = State( 101 | board=board, 102 | board_mask=board_mask, 103 | hands=hands, 104 | stacks=stacks, 105 | bets=bets, 106 | bets_mask=bets_mask, 107 | current_player=jnp.array(first_player), 108 | min_bet=jnp.array(small_blind), 109 | min_raise=2 * jnp.array(small_blind), 110 | current_round=jnp.array(0), 111 | terminated=jnp.array(False), 112 | small_blind=jnp.array(small_blind), 113 | showdown_result=showdown_result, 114 | dealer=jnp.array(dealer_player), 115 | rewards=jnp.array([0.0, 0.0]), 116 | ) 117 | 118 | return state 119 | 120 | def _is_end_round( 121 | self, 122 | current_round: jax.Array, 123 | bets_mask: jax.Array, 124 | bets: jax.Array, 125 | ) -> jax.Array: 126 | bets_mask = jnp.where( 127 | current_round == 0, bets_mask.at[0, :, 0].set(False), bets_mask 128 | ) 129 | has_everyone_spoken = bets_mask[current_round].any(axis=1).all() 130 | total_bets = bets[current_round].sum(axis=-1) 131 | same_bets = (total_bets.max() == total_bets.min()).all() 132 | 133 | end_round = has_everyone_spoken & same_bets 134 | return end_round 135 | 136 | def _resolve_showdown(self, board: jax.Array, hands: jax.Array) -> jax.Array: 137 | """ 138 | Return 0 if p0 wins, 1 if p1 wins, -1 if tie 139 | """ 140 | 141 | p0_score = get_showdown_score(jnp.concatenate([board, hands[0]])) 142 | p1_score = get_showdown_score(jnp.concatenate([board, hands[1]])) 143 | 144 | return jnp.where(p0_score == p1_score, jnp.array(-1), p0_score < p1_score) 145 | 146 | def _current_round_to_board_mask(self, current_round: jax.Array) -> jax.Array: 147 | return self._num_visibles[current_round] 148 | 149 | def step(self, state: State, action: jax.Array) -> State: 150 | fold = action < 0.0 151 | 152 | # check = action == 0.0 153 | # bet = action > 0.0 154 | 155 | # clit bet 156 | action = jnp.clip(action, a_min=0, a_max=state.stacks[state.current_player]) 157 | 158 | cp = state.current_player 159 | cr = state.current_round 160 | 161 | # store bet, update bet mask, update stack 162 | 163 | current_idx = state.bets_mask[cr, cp].sum() 164 | new_bets_mask = state.bets_mask.at[cr, cp, current_idx].set(True) 165 | new_bets = state.bets.at[cr, cp, current_idx].set(action) 166 | new_stacks = state.stacks.at[cp].set(state.stacks[cp] - action) 167 | 168 | reward = jnp.ones(NUM_PLAYERS) * new_bets[:, cp].sum() 169 | # compute hypothetic fold reward 170 | fold_reward = jnp.where(jnp.arange(NUM_PLAYERS) == cp, -reward, reward) 171 | # jax.debug.print("fold_reward: {x}", x=fold_reward) 172 | 173 | # compute hypothetic showdown reward 174 | showdown_result = state.showdown_result 175 | showdown_reward = jnp.where( 176 | jnp.arange(NUM_PLAYERS) == showdown_result, reward, -reward 177 | ) 178 | showdown_reward = jnp.where(showdown_result == -1, 0, showdown_reward) 179 | 180 | end_round = self._is_end_round( 181 | current_round=state.current_round, 182 | bets_mask=new_bets_mask, 183 | bets=new_bets, 184 | ) 185 | 186 | # jax.debug.print("end_round: {x}", x=end_round) 187 | 188 | is_allin = (state.stacks == 0).any() & ~fold 189 | 190 | is_showdown = (end_round & (state.current_round == 3)) | is_allin 191 | 192 | done = state.terminated | is_showdown | fold 193 | 194 | # jax.debug.print("is showdown: {x}", x=is_showdown) 195 | 196 | reward = jnp.where(is_showdown, showdown_reward, fold_reward) 197 | reward = jnp.where(done, reward, 0.0) 198 | 199 | new_cr = jnp.where(end_round, cr + 1, cr) 200 | new_cr = jnp.where(is_showdown, 4, new_cr) 201 | 202 | new_current_player = jnp.where( 203 | end_round, (state.dealer + 1) % NUM_PLAYERS, (cp + 1) % NUM_PLAYERS 204 | ) 205 | 206 | bet_diff = jnp.abs(jnp.diff(new_bets[cr].sum(axis=-1)))[0] 207 | my_total_bet = new_bets[cr, cp].sum(axis=-1) 208 | 209 | min_bet = jnp.maximum(bet_diff, state.small_blind) 210 | min_bet = jnp.where(end_round, state.small_blind, min_bet) 211 | 212 | min_raise = jnp.where(end_round, min_bet, my_total_bet + min_bet) 213 | 214 | new_state = State( 215 | board=state.board, 216 | board_mask=self._num_visibles[new_cr], 217 | hands=state.hands, 218 | stacks=new_stacks, 219 | bets=new_bets, 220 | bets_mask=new_bets_mask, 221 | current_player=new_current_player, 222 | min_bet=min_bet, 223 | min_raise=min_raise, 224 | current_round=new_cr, 225 | terminated=done, 226 | small_blind=state.small_blind, 227 | showdown_result=state.showdown_result, 228 | dealer=state.dealer, 229 | rewards=reward, 230 | ) 231 | 232 | return new_state 233 | 234 | def observe(self, state: State) -> InfoState: 235 | board_mask = self._num_visibles[state.current_round] 236 | board_mask = ~( 237 | jnp.zeros(5, dtype=bool).at[board_mask].set(True).cumsum().astype(bool) 238 | ) 239 | infostate = InfoState( 240 | hand=state.hands[state.current_player], 241 | board=jnp.where(board_mask, state.board, -1), 242 | board_mask=board_mask, 243 | bets=state.bets, 244 | bets_mask=state.bets_mask, 245 | stacks=state.stacks, 246 | ) 247 | 248 | return infostate 249 | -------------------------------------------------------------------------------- /cfrx/envs/nlhe_poker/test_showdown.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from cfrx.envs.nlhe_poker.showdown import ( 4 | flush_score, 5 | four_of_a_kind_score, 6 | full_house_score, 7 | get_hand_type, 8 | high_card_hand_score, 9 | one_pair_score, 10 | straight_flush_score, 11 | straight_score, 12 | three_of_a_kind_score, 13 | two_pair_score, 14 | ) 15 | 16 | 17 | def test_straight_flush_order(): 18 | hands = [ 19 | ( 20 | jnp.array([0, 1, 2, 3, 4, 8, 12]), 21 | jnp.array([0, 0, 0, 0, 0, 1, 2]), 22 | ), # 5-high straight flush 23 | ( 24 | jnp.array([4, 5, 6, 7, 8, 1, 11]), 25 | jnp.array([2, 2, 2, 2, 2, 3, 0]), 26 | ), # 9-high straight flush 27 | ( 28 | jnp.array([8, 9, 10, 11, 12, 0, 2]), 29 | jnp.array([1, 1, 1, 1, 1, 2, 3]), 30 | ), # King-high straight flush 31 | ( 32 | jnp.array([9, 10, 11, 12, 0, 8, 7]), 33 | jnp.array([3, 3, 3, 3, 3, 0, 1]), 34 | ), # Ace-high straight flush 35 | ] 36 | 37 | scores = [straight_flush_score(ranks, suits) for ranks, suits in hands] 38 | 39 | print("straight_flush_score", scores) 40 | 41 | # Expected order of scores from lowest to highest 42 | expected_order = sorted(scores) 43 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 44 | 45 | 46 | def test_four_of_a_kind_order(): 47 | hands = [ 48 | jnp.array([10, 10, 10, 10, 0, 1, 2]), # Four of Jacks with kicker A 49 | jnp.array([11, 11, 11, 11, 0, 1, 2]), # Four of Queens with kicker A 50 | jnp.array([12, 12, 12, 12, 0, 2, 3]), # Four of Kings with kicker A 51 | jnp.array([0, 0, 0, 0, 1, 1, 1]), # Four of Aces with kicker 2 52 | jnp.array([0, 0, 0, 0, 1, 1, 2]), # Four of Aces with kicker 3 53 | jnp.array([0, 0, 0, 0, 1, 1, 12]), # Four of Aces with kicker King 54 | jnp.array([0, 0, 0, 0, 11, 12, 12]), # Four of Aces with kicker King 55 | ] 56 | 57 | suits = (jnp.array([3, 3, 3, 3, 0, 1, 2]),) 58 | scores = [four_of_a_kind_score(ranks, suits) for ranks in hands] 59 | 60 | print("four_of_a_kind", scores) 61 | 62 | expected_order = sorted(scores) 63 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 64 | 65 | 66 | def test_three_of_a_kind_order(): 67 | hands = [ 68 | (jnp.array([1, 1, 1, 2, 3, 4, 6])), # Three 2s with kickers (5,7) 69 | (jnp.array([1, 1, 1, 2, 3, 5, 6])), # Three 2s with kickers (6,7) 70 | (jnp.array([1, 1, 1, 2, 3, 5, 0])), # Three 2s with kickers (6,A) 71 | (jnp.array([1, 1, 1, 3, 3, 3, 2])), # Three 2s, three 4s with kickers (2,3) 72 | (jnp.array([1, 1, 1, 3, 3, 3, 0])), # Three 2s, three 4s with kickers (2,A) 73 | (jnp.array([10, 10, 10, 3, 3, 3, 5])), # Three Js, three 4s with kickers (4,6) 74 | (jnp.array([10, 10, 10, 3, 3, 3, 5])), # Three Js, three 4s with kickers (4,A) 75 | (jnp.array([10, 10, 10, 0, 0, 0, 5])), # Three Js, three As 76 | ] 77 | 78 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3]) 79 | scores = [three_of_a_kind_score(ranks, suits) for ranks in hands] 80 | 81 | print("three_of_a_kind", scores) 82 | 83 | expected_order = sorted(scores) 84 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 85 | 86 | 87 | def test_flush_order(): 88 | hands = [ 89 | ( 90 | jnp.array([2, 4, 5, 6, 10, 4, 5]), 91 | jnp.array([0, 0, 0, 0, 0, 1, 1]), 92 | ), 93 | ( 94 | jnp.array([2, 4, 5, 8, 10, 4, 5]), 95 | jnp.array([0, 0, 0, 0, 0, 1, 1]), 96 | ), 97 | ( 98 | jnp.array([2, 4, 5, 8, 10, 7, 5]), 99 | jnp.array([0, 0, 0, 0, 0, 0, 1]), 100 | ), 101 | ( 102 | jnp.array([0, 4, 5, 8, 10, 7, 5]), 103 | jnp.array([0, 0, 0, 0, 0, 0, 1]), 104 | ), 105 | ] 106 | 107 | scores = [flush_score(ranks, suits) for ranks, suits in hands] 108 | 109 | print("flush", scores) 110 | 111 | expected_order = sorted(scores) 112 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 113 | 114 | 115 | def test_straight_order(): 116 | hands = [ 117 | (jnp.array([0, 1, 2, 3, 4, 7, 8])), 118 | (jnp.array([0, 1, 2, 3, 4, 7, 0])), 119 | (jnp.array([3, 4, 5, 6, 7, 10, 11])), 120 | (jnp.array([2, 9, 8, 9, 10, 11, 12])), 121 | (jnp.array([2, 9, 9, 10, 11, 12, 0])), 122 | ] 123 | 124 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3]) 125 | scores = [straight_score(ranks, suits) for ranks in hands] 126 | 127 | print("straight", scores) 128 | 129 | expected_order = sorted(scores) 130 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 131 | 132 | 133 | def test_full_house_order(): 134 | hands = [ 135 | (jnp.array([1, 1, 2, 2, 2, 3, 4])), 136 | (jnp.array([4, 4, 3, 3, 3, 2, 1])), 137 | (jnp.array([4, 4, 4, 3, 3, 3, 2])), 138 | (jnp.array([2, 2, 10, 10, 10, 4, 5])), 139 | (jnp.array([12, 12, 10, 10, 10, 4, 5])), 140 | (jnp.array([12, 12, 0, 0, 0, 4, 5])), 141 | ] 142 | 143 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3]) 144 | scores = [full_house_score(ranks, suits) for ranks in hands] 145 | 146 | print("full_house", scores) 147 | 148 | expected_order = sorted(scores) 149 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 150 | 151 | 152 | def test_two_pair_order(): 153 | hands = [ 154 | (jnp.array([1, 1, 2, 2, 3, 5, 7])), 155 | (jnp.array([1, 1, 2, 2, 4, 5, 0])), 156 | (jnp.array([1, 1, 3, 4, 5, 5, 7])), 157 | (jnp.array([1, 2, 3, 3, 5, 5, 7])), 158 | (jnp.array([1, 1, 3, 3, 5, 5, 7])), 159 | (jnp.array([11, 11, 12, 12, 2, 3, 6])), 160 | (jnp.array([12, 12, 0, 0, 2, 3, 4])), 161 | (jnp.array([12, 12, 0, 0, 11, 3, 4])), 162 | ] 163 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3]) 164 | scores = [two_pair_score(ranks, suits) for ranks in hands] 165 | 166 | print("two_pair", scores) 167 | 168 | expected_order = sorted(scores) 169 | 170 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 171 | 172 | 173 | def test_one_pair_order(): 174 | hands = [ 175 | (jnp.array([1, 1, 2, 3, 4, 7, 8])), 176 | (jnp.array([1, 1, 2, 3, 9, 7, 8])), 177 | (jnp.array([1, 1, 4, 5, 0, 7, 8])), 178 | (jnp.array([12, 12, 11, 10, 9, 1, 2])), 179 | (jnp.array([0, 0, 2, 3, 4, 5, 7])), 180 | ] 181 | 182 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3]) 183 | scores = [one_pair_score(ranks, suits) for ranks in hands] 184 | 185 | print("one_pair", scores) 186 | 187 | expected_order = sorted(scores) 188 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 189 | 190 | 191 | def test_high_card_order(): 192 | hands = [ 193 | (jnp.array([1, 2, 3, 4, 6, 7, 8])), 194 | (jnp.array([1, 2, 3, 4, 6, 7, 10])), 195 | (jnp.array([1, 2, 3, 4, 6, 8, 10])), 196 | (jnp.array([1, 2, 12, 4, 6, 8, 10])), 197 | (jnp.array([0, 2, 12, 4, 6, 8, 10])), 198 | ] 199 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3]) 200 | scores = [high_card_hand_score(ranks, suits) for ranks in hands] 201 | 202 | print("high_card", scores) 203 | 204 | expected_order = sorted(scores) 205 | assert scores == expected_order, f"Scores are not in expected order: {scores}" 206 | 207 | 208 | def test_get_hand_type(): 209 | # Royal Flush (Ace-high straight flush) 210 | ranks = jnp.asarray([0, 9, 10, 11, 12, 3, 5]) # A, 10, J, Q, K, 4, 6 211 | suits = jnp.asarray([1, 1, 1, 1, 1, 2, 3]) 212 | assert get_hand_type(ranks, suits) == 8 213 | 214 | # Straight Flush 215 | ranks = jnp.asarray([8, 9, 10, 11, 12, 3, 5]) # 9, 10, J, Q, K, 4, 6 216 | suits = jnp.asarray([2, 2, 2, 2, 2, 3, 1]) 217 | assert get_hand_type(ranks, suits) == 8 218 | 219 | # Four of a Kind 220 | ranks = jnp.asarray([1, 1, 1, 1, 2, 3, 4]) # 2, 2, 2, 2, 3, 4, 5 221 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2]) 222 | assert get_hand_type(ranks, suits) == 7 223 | 224 | # Full House 225 | ranks = jnp.asarray([2, 2, 2, 3, 3, 4, 5]) # 3, 3, 3, 4, 4, 5, 6 226 | suits = jnp.asarray([0, 1, 2, 0, 1, 2, 3]) 227 | assert get_hand_type(ranks, suits) == 6 228 | 229 | # Flush 230 | ranks = jnp.asarray([1, 4, 6, 8, 10, 2, 3]) # 2, 5, 7, 9, J, 3, 4 231 | suits = jnp.asarray([1, 1, 1, 1, 1, 2, 0]) 232 | assert get_hand_type(ranks, suits) == 5 233 | 234 | # Straight 235 | ranks = jnp.asarray([0, 9, 10, 11, 12, 1, 2]) # 5, 6, 7, 8, 9, 10, J 236 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2]) 237 | assert get_hand_type(ranks, suits) == 4 238 | 239 | # Three of a Kind 240 | ranks = jnp.asarray([3, 3, 3, 5, 6, 7, 8]) # 4, 4, 4, 6, 7, 8, 9 241 | suits = jnp.asarray([0, 1, 2, 0, 1, 2, 3]) 242 | assert get_hand_type(ranks, suits) == 3 243 | 244 | # Two Pairs 245 | ranks = jnp.asarray([4, 4, 7, 7, 10, 1, 2]) # 5, 5, 8, 8, J, 2, 3 246 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2]) 247 | assert get_hand_type(ranks, suits) == 2 248 | 249 | # One Pair 250 | ranks = jnp.asarray([6, 6, 8, 9, 11, 1, 3]) # 7, 7, 9, 10, Q, 2, 4 251 | suits = jnp.asarray([0, 1, 2, 0, 1, 2, 3]) 252 | assert get_hand_type(ranks, suits) == 1 253 | 254 | ranks = jnp.asarray([1, 10, 5, 7, 9, 2, 4]) # 2, 4, 6, 8, 10, 3, 5 255 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2]) 256 | assert get_hand_type(ranks, suits) == 0 257 | -------------------------------------------------------------------------------- /cfrx/envs/leduc_poker/env.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import pgx 7 | import pgx.leduc_holdem 8 | from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, Shaped 9 | from pgx._src.dwg.leduc_holdem import CARD 10 | from pgx._src.struct import dataclass 11 | 12 | import cfrx.envs 13 | from cfrx.envs.leduc_poker.constants import INFO_SETS, REVERSE_INFO_SETS_LOOKUP 14 | from cfrx.utils import ravel 15 | 16 | CARD.append("?") 17 | INFO_SETS_VALUES = np.stack(list(INFO_SETS.values())) 18 | NUM_DIFFERENT_CARDS = 3 19 | NUM_REPEAT_CARDS = 2 20 | NUM_TOTAL_CARDS = NUM_DIFFERENT_CARDS * NUM_REPEAT_CARDS 21 | 22 | 23 | class InfoState(NamedTuple): 24 | private_card: Int[Array, ""] 25 | public_card: Int[Array, ""] 26 | action_sequence: Int[Array, "..."] 27 | chance_round: Int[Array, "..."] 28 | chance_node: Bool[Array, "..."] 29 | 30 | 31 | @dataclass 32 | class State(pgx.leduc_holdem.State): 33 | info_state: InfoState = InfoState( 34 | private_card=jnp.int8(-1), 35 | public_card=jnp.int8(-1), 36 | action_sequence=jnp.ones((2, 4), dtype=jnp.int8) * -1, 37 | chance_round=jnp.int8(0), 38 | chance_node=jnp.bool_(True), 39 | ) 40 | chance_node: Bool[Array, ""] = jnp.bool_(False) 41 | chance_prior: Float[Array, "..."] = ( 42 | jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS 43 | ) 44 | 45 | 46 | def convert_info_state_to_idx(info_state: InfoState) -> jnp.ndarray: 47 | """ 48 | This is a bit hacky, it allows to transform an infostate into an index, and to 49 | construct an array to efficiently lookup in Jax. 50 | """ 51 | 52 | info_state_ravel = ravel(info_state) 53 | multiplier = jnp.array([3**k for k in range(info_state_ravel.shape[-1])]) 54 | idx = jnp.sum((info_state_ravel + 1) * multiplier) % 1235 55 | return idx 56 | 57 | 58 | class LeducPoker(pgx.leduc_holdem.LeducHoldem, cfrx.envs.Env): 59 | @classmethod 60 | def action_to_string(cls, action: Int[Array, ""]) -> str: 61 | strings = ["c", "r", "f"] 62 | if action != -1: 63 | a = int(action) 64 | rep = strings[a] 65 | else: 66 | rep = "?" 67 | return rep 68 | 69 | @property 70 | def max_episode_length(self) -> int: 71 | return 12 72 | 73 | @property 74 | def max_nodes(self) -> int: 75 | return 2000 76 | 77 | @property 78 | def n_info_states(self) -> int: 79 | return len(INFO_SETS) 80 | 81 | @property 82 | def n_actions(self) -> int: 83 | return super().num_actions 84 | 85 | @property 86 | def n_players(self) -> int: 87 | return self.num_players 88 | 89 | def update_info_state( 90 | self, state: State, next_state: State, action: Int[Array, ""] 91 | ) -> InfoState: 92 | info_state = next_state.info_state 93 | assert info_state is not None 94 | private_card = next_state._cards[next_state.current_player] 95 | public_card = next_state._cards[-1] 96 | 97 | current_position = (info_state.action_sequence[state._round] != -1).sum() 98 | action_sequence = info_state.action_sequence.at[ 99 | state._round, current_position 100 | ].set(action.astype(jnp.int8)) 101 | updated_info_state = info_state._replace(action_sequence=action_sequence) 102 | 103 | info_state = jax.tree_map( 104 | lambda x, y: jnp.where(state.chance_node, x, y), 105 | info_state, 106 | updated_info_state, 107 | ) 108 | chance_round = info_state.chance_round + state.chance_node 109 | return info_state._replace( 110 | private_card=private_card, 111 | public_card=public_card, 112 | chance_round=chance_round, 113 | chance_node=next_state.chance_node, 114 | ) 115 | 116 | def info_state_to_str(self, info_state: InfoState) -> str: 117 | if info_state.chance_node: 118 | rep = f"chance{info_state.chance_round}:" 119 | 120 | else: 121 | rep = "" 122 | 123 | strings = ["c", "r", "f"] 124 | rep += f"{info_state.private_card}" 125 | 126 | for action in np.array(info_state.action_sequence)[0]: 127 | if action != -1: 128 | rep += strings[action] 129 | 130 | if info_state.public_card != -1: 131 | rep += f"{info_state.public_card}" 132 | for action in np.array(info_state.action_sequence)[1]: 133 | if action != -1: 134 | rep += strings[action] 135 | return rep 136 | 137 | def get_action_mask(self, state: State) -> jax.Array: 138 | return state.legal_action_mask 139 | 140 | def get_chance_mask(self, state: State) -> jax.Array: 141 | return state.chance_prior > 0 142 | 143 | def get_chance_probs(self, state: State) -> jax.Array: 144 | return jnp.where( 145 | (state.chance_prior != 0).any(), 146 | state.chance_prior / state.chance_prior.sum(), 147 | 0, 148 | ) 149 | 150 | def get_info_state(self, state: State) -> jax.Array: 151 | return state.info_state 152 | 153 | def info_state_idx(self, info_state: InfoState) -> Array: 154 | idx = convert_info_state_to_idx(info_state) 155 | return jnp.asarray(REVERSE_INFO_SETS_LOOKUP)[idx] 156 | 157 | def _init(self, rng: Shaped[PRNGKeyArray, "2"]) -> State: 158 | env_state = super()._init(rng) 159 | info_state = InfoState( 160 | private_card=jnp.int8(-1), 161 | public_card=jnp.int8(-1), 162 | action_sequence=jnp.ones((2, 4), dtype=jnp.int8) * -1, 163 | chance_round=jnp.int8(0), 164 | chance_node=jnp.bool_(True), 165 | ) 166 | cards = jnp.int8([-1, -1, -1]) 167 | return State( 168 | _first_player=env_state._first_player, 169 | current_player=env_state.current_player, 170 | observation=env_state.observation, 171 | rewards=env_state.rewards, 172 | terminated=env_state.terminated, 173 | truncated=env_state.truncated, 174 | _step_count=env_state._step_count, 175 | _last_action=env_state._last_action, 176 | _round=jnp.int8(0), 177 | _cards=cards, 178 | legal_action_mask=jnp.ones_like(env_state.legal_action_mask), 179 | _chips=env_state._chips, 180 | _raise_count=env_state._raise_count, 181 | info_state=info_state, 182 | chance_prior=jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS, 183 | chance_node=jnp.bool_(True), 184 | ) 185 | 186 | def _resolve_decision_node( 187 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray 188 | ) -> State: 189 | env_state = super()._step(state=state, action=action, key=random_key) 190 | 191 | is_public_card_unknown = env_state._cards[-1] == -1 192 | chance_node = ( 193 | (env_state._round > 0) & is_public_card_unknown & ~env_state.terminated 194 | ) 195 | 196 | legal_action_mask = jnp.where( 197 | chance_node, 198 | jnp.ones_like(env_state.legal_action_mask), 199 | env_state.legal_action_mask, 200 | ) 201 | 202 | state = State( 203 | _first_player=env_state._first_player, 204 | current_player=env_state.current_player, 205 | observation=env_state.observation, 206 | rewards=env_state.rewards, 207 | terminated=env_state.terminated, 208 | truncated=env_state.truncated, 209 | _step_count=env_state._step_count, 210 | _last_action=env_state._last_action, 211 | _round=env_state._round.astype(jnp.int8), 212 | _cards=env_state._cards, 213 | legal_action_mask=legal_action_mask, 214 | _chips=env_state._chips, 215 | _raise_count=env_state._raise_count, 216 | info_state=env_state.info_state, 217 | chance_node=chance_node, 218 | chance_prior=env_state.chance_prior, 219 | ) 220 | 221 | return state 222 | 223 | def _resolve_chance_node( 224 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray 225 | ) -> State: 226 | draw_player = NUM_TOTAL_CARDS - state.chance_prior.sum() 227 | action = action.astype(jnp.int8) 228 | card_rank = action % NUM_DIFFERENT_CARDS 229 | cards = state._cards.at[draw_player].set(card_rank) 230 | chance_prior = state.chance_prior.at[action].add(-1) 231 | chance_node = (cards[:2] == -1).any() 232 | 233 | legal_action_mask = jnp.where( 234 | chance_node, 235 | jnp.ones_like(state.legal_action_mask), 236 | jnp.ones_like(state.legal_action_mask).at[-1].set(False), 237 | ) 238 | return State( 239 | _first_player=state._first_player, 240 | current_player=state.current_player, 241 | observation=state.observation, 242 | rewards=state.rewards, 243 | terminated=state.terminated, 244 | truncated=state.truncated, 245 | _step_count=state._step_count, 246 | _last_action=state._last_action, 247 | _round=state._round.astype(jnp.int8), 248 | _cards=cards, 249 | legal_action_mask=legal_action_mask, 250 | _chips=state._chips, 251 | _raise_count=state._raise_count, 252 | info_state=state.info_state, 253 | chance_prior=chance_prior, 254 | chance_node=chance_node, 255 | ) 256 | 257 | def _step(self, state: State, action: Array, random_key: PRNGKeyArray) -> State: 258 | new_state = jax.lax.cond( 259 | state.chance_node, 260 | lambda: self._resolve_chance_node( 261 | state=state, action=action, random_key=random_key 262 | ), 263 | lambda: self._resolve_decision_node( 264 | state=state, action=action, random_key=random_key 265 | ), 266 | ) 267 | 268 | _round = jnp.where(new_state._cards[-1] != -1, jnp.maximum(state._round, 1), 0) 269 | new_state = new_state.replace(_round=_round) 270 | info_state = self.update_info_state( 271 | state=state, next_state=new_state, action=action 272 | ) 273 | 274 | return new_state.replace(info_state=info_state) 275 | -------------------------------------------------------------------------------- /cfrx/algorithms/mccfr/outcome_sampling.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from typing import Any, NamedTuple 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jaxtyping import Array, Float, Int, PRNGKeyArray 9 | 10 | from cfrx.envs import Env, InfoState, State 11 | from cfrx.episode import Episode 12 | from cfrx.policy import Policy, TabularPolicy 13 | from cfrx.utils import regret_matching 14 | 15 | 16 | class MCCFRState(NamedTuple): 17 | regrets: Float[Array, "*batch a"] 18 | probs: Float[Array, "... a"] 19 | avg_probs: Float[Array, "... a"] 20 | step: Int[Array, "..."] 21 | 22 | @classmethod 23 | def init(cls, n_states: int, n_actions: int) -> MCCFRState: 24 | return MCCFRState( 25 | regrets=jnp.zeros((n_states, n_actions)), 26 | probs=jnp.ones((n_states, n_actions)) 27 | / jnp.ones((n_states, n_actions)).sum(axis=-1, keepdims=True), 28 | avg_probs=jnp.zeros((n_states, n_actions)) + 1e-6, 29 | step=jnp.array(1, dtype=int), 30 | ) 31 | 32 | 33 | def compute_sampled_counterfactual_action_value( 34 | opponent_reach_prob: Float[Array, ""], 35 | outcome_sampling_prob: Float[Array, ""], 36 | outcome_prob: Float[Array, ""], 37 | utility: Float[Array, ""], 38 | ) -> Float[Array, ""]: 39 | cf_value_a = opponent_reach_prob * outcome_prob * utility / outcome_sampling_prob 40 | return cf_value_a 41 | 42 | 43 | def compute_strategy_profile( 44 | my_reach_prob: Float[Array, ""], 45 | sample_reach_prob: Float[Array, ""], 46 | strat_probs: Float[Array, "..."], 47 | ) -> Float[Array, "..."]: 48 | avg_probs = my_reach_prob / sample_reach_prob * strat_probs 49 | 50 | return avg_probs 51 | 52 | 53 | def get_regrets( 54 | step: Episode, 55 | my_prob_cumprod: Float[Array, ""], 56 | opponent_prob_cumprod: Float[Array, ""], 57 | sample_prob_cumprod: Float[Array, ""], 58 | outcome_prob: Float[Array, ""], 59 | strat_prob_distrib: Float[Array, " num_actions"], 60 | strat_prob: Float[Array, ""], 61 | utility: Float[Array, ""], 62 | outcome_sampling_prob: Float[Array, ""], 63 | update_player: Int[Array, ""], 64 | ) -> tuple[Float[Array, " num_actions"], Float[Array, " num_actions"],]: 65 | is_current_player = step.current_player == update_player 66 | cf_value_a = compute_sampled_counterfactual_action_value( 67 | opponent_reach_prob=opponent_prob_cumprod, 68 | outcome_sampling_prob=outcome_sampling_prob, 69 | outcome_prob=outcome_prob, 70 | utility=utility, 71 | ) 72 | cf_value = cf_value_a * strat_prob 73 | 74 | regrets = jnp.zeros(step.action_mask.shape[-1]) 75 | regrets = regrets.at[step.action].set(cf_value_a) 76 | regrets = (regrets - cf_value) * step.action_mask 77 | 78 | regrets = regrets * is_current_player * step.mask * (1 - step.chance_node) 79 | 80 | avg_probs = compute_strategy_profile( 81 | my_reach_prob=my_prob_cumprod, 82 | sample_reach_prob=sample_prob_cumprod, 83 | strat_probs=strat_prob_distrib * step.action_mask, 84 | ) 85 | avg_probs = avg_probs * step.mask * is_current_player * (1 - step.chance_node) 86 | return regrets, avg_probs 87 | 88 | 89 | def compute_regrets_and_strategy_profile( 90 | episode: Episode, 91 | training_state: MCCFRState, 92 | policy: TabularPolicy, 93 | update_player: Int[Array, ""], 94 | ) -> tuple[InfoState[Array, "..."], Float[Array, "..."], Float[Array, "..."]]: 95 | episode_length = episode.current_player.shape[-1] 96 | utility = episode.reward[episode.mask.sum(), update_player] 97 | is_current_player = episode.current_player == update_player 98 | prob_dist_fn = functools.partial( 99 | policy.prob_distribution, params=training_state.probs 100 | ) 101 | 102 | strat_prob_distribs = jax.vmap(prob_dist_fn)( 103 | info_state=episode.info_state, 104 | action_mask=episode.action_mask, 105 | use_behavior_policy=jnp.zeros(episode_length, dtype=bool), 106 | ) 107 | 108 | strat_probs = jnp.where( 109 | episode.mask, 110 | strat_prob_distribs[jnp.arange(episode_length), episode.action], 111 | 1.0, 112 | ) 113 | strat_probs = jnp.where(episode.chance_node, episode.behavior_prob, strat_probs) 114 | my_probs = jnp.where(is_current_player, strat_probs, 1) 115 | opponent_probs = jnp.where(is_current_player, 1, strat_probs) 116 | sample_probs = jnp.where(episode.mask, episode.behavior_prob, 1) 117 | 118 | sample_probs_cumprod = jnp.cumprod(jnp.concatenate([jnp.ones(1), sample_probs])) 119 | my_probs_cumprod = jnp.cumprod(jnp.concatenate([jnp.ones(1), my_probs])) 120 | opponent_probs_cumprod = jnp.cumprod(opponent_probs) 121 | outcome_probs = jnp.cumprod(jnp.concatenate([strat_probs, jnp.ones(1)])[::-1])[::-1] 122 | 123 | outcome_sampling_prob = jnp.prod(sample_probs) 124 | 125 | _get_regrets = functools.partial( 126 | get_regrets, 127 | utility=utility, 128 | update_player=update_player, 129 | outcome_sampling_prob=outcome_sampling_prob, 130 | ) 131 | regrets, avg_probs = jax.vmap(_get_regrets)( 132 | step=episode, 133 | my_prob_cumprod=my_probs_cumprod[:-1], 134 | opponent_prob_cumprod=opponent_probs_cumprod, 135 | sample_prob_cumprod=sample_probs_cumprod[:-1], 136 | outcome_prob=outcome_probs[1:], 137 | strat_prob_distrib=strat_prob_distribs, 138 | strat_prob=strat_probs, 139 | ) 140 | regrets = regrets[:-1] 141 | avg_probs = avg_probs[:-1] 142 | episode = jax.tree_map(lambda x: x[:-1], episode) 143 | 144 | return episode.info_state, regrets, avg_probs 145 | 146 | 147 | def unroll( 148 | init_state: State, 149 | random_key: PRNGKeyArray, 150 | training_state: MCCFRState, 151 | env: Env, 152 | policy: Policy, 153 | update_player: Int[Array, ""], 154 | n_max_steps: int, 155 | ) -> tuple[Episode, State]: 156 | """ 157 | Generates a single unroll of the game. 158 | """ 159 | 160 | def play_step( 161 | carry: tuple[State, PRNGKeyArray], unused: Any 162 | ) -> tuple[tuple[State, PRNGKeyArray], tuple[Episode, State]]: 163 | state, random_key = carry 164 | 165 | use_behavior_policy = state.current_player == update_player 166 | 167 | random_key, subkey = jax.random.split(random_key) 168 | 169 | action = policy.sample( 170 | params=training_state.probs, 171 | info_state=state.info_state, 172 | action_mask=state.legal_action_mask, 173 | random_key=subkey, 174 | use_behavior_policy=use_behavior_policy, 175 | ) 176 | probs = policy.prob_distribution( 177 | params=training_state.probs, 178 | info_state=state.info_state, 179 | action_mask=state.legal_action_mask, 180 | use_behavior_policy=use_behavior_policy, 181 | ) 182 | 183 | chance_probs = state.chance_prior / state.chance_prior.sum( 184 | axis=-1, keepdims=True 185 | ) 186 | random_key, subkey = jax.random.split(random_key) 187 | chance_action = jax.random.choice( 188 | subkey, 189 | jnp.arange(state.chance_prior.shape[0]), 190 | p=chance_probs, 191 | ) 192 | 193 | action = jnp.where(state.chance_node, chance_action, action) 194 | 195 | probs = jnp.where(state.chance_node, chance_probs[action], probs[action]) 196 | 197 | current_player = jnp.where(state.chance_node, -1, state.current_player) 198 | game_step = Episode( 199 | info_state=state.info_state, 200 | action=action, 201 | reward=state.rewards, 202 | action_mask=state.legal_action_mask, 203 | current_player=current_player, 204 | behavior_prob=probs, 205 | chance_node=state.chance_node, 206 | mask=1 - state.terminated, 207 | ) 208 | 209 | state = env.step(state, action) 210 | return (state, random_key), (game_step, state) 211 | 212 | (state, random_key), (episode, states) = jax.lax.scan( 213 | play_step, (init_state, random_key), (), length=n_max_steps 214 | ) 215 | 216 | last_game_step = Episode( 217 | info_state=state.info_state, 218 | action=jnp.array(-1), 219 | reward=state.rewards, 220 | action_mask=jnp.ones(policy._n_actions, dtype=bool), 221 | current_player=state.current_player, 222 | behavior_prob=jnp.array(1.0), 223 | chance_node=jnp.bool_(False), 224 | mask=1 - state.terminated, 225 | ) 226 | states = jax.tree_map( 227 | lambda x, y: jnp.concatenate([jnp.expand_dims(y, 0), x]), 228 | states, 229 | init_state, 230 | ) 231 | 232 | episode = jax.tree_map( 233 | lambda x, y: jnp.concatenate([x, jnp.expand_dims(y, 0)]), 234 | episode, 235 | last_game_step, 236 | ) 237 | return episode, states 238 | 239 | 240 | def do_iteration( 241 | training_state: MCCFRState, 242 | random_key: PRNGKeyArray, 243 | env: Env, 244 | policy: TabularPolicy, 245 | update_player: Int[Array, ""], 246 | ) -> tuple[Float[Array, "*batch a"], Float[Array, "*batch a"], Episode]: 247 | """ 248 | Do one iteration of MCCFR: traverse the game tree once and compute counterfactual 249 | regrets and strategy profiles. 250 | 251 | Args: 252 | training_state: The current state of the training. 253 | random_key: A random key. 254 | env: The environment. 255 | policy: The policy. 256 | update_player: The player to update. 257 | 258 | Returns: 259 | The updated regrets and average strategy profile and the episode. 260 | """ 261 | 262 | # Sample one path in the game tree 263 | random_key, subkey = jax.random.split(random_key) 264 | episode, states = unroll( 265 | init_state=env.init(subkey), 266 | training_state=training_state, 267 | random_key=subkey, 268 | update_player=update_player, 269 | env=env, 270 | policy=policy, 271 | n_max_steps=env.max_episode_length, 272 | ) 273 | 274 | # Compute counterfactual values and strategy profile 275 | ( 276 | info_states, 277 | sampled_regrets, 278 | sampled_avg_probs, 279 | ) = compute_regrets_and_strategy_profile( 280 | episode=episode, 281 | training_state=training_state, 282 | policy=policy, 283 | update_player=update_player, 284 | ) 285 | info_states_idx = jax.vmap(env.info_state_idx)(info_states) 286 | 287 | # Store regret and strategy profile values 288 | new_regrets = training_state.regrets.at[info_states_idx].add(sampled_regrets) 289 | new_avg_probs = training_state.avg_probs.at[info_states_idx].add(sampled_avg_probs) 290 | 291 | # Accumulate regrets, compute new strategy and avg strategy 292 | new_probs = regret_matching(new_regrets) 293 | new_probs /= new_probs.sum(axis=-1, keepdims=True) 294 | 295 | training_state = training_state._replace( 296 | regrets=new_regrets, 297 | probs=new_probs, 298 | avg_probs=new_avg_probs, 299 | step=training_state.step + 1, 300 | ) 301 | return training_state 302 | -------------------------------------------------------------------------------- /cfrx/tree/traverse.py: -------------------------------------------------------------------------------- 1 | import functools as ft 2 | from typing import Tuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import pgx 7 | from jaxtyping import Array, Bool, Float, Int, PyTree 8 | 9 | from cfrx.policy import Policy 10 | from cfrx.tree import Root, Tree 11 | from cfrx.utils import get_action_mask 12 | 13 | 14 | def instantiate_tree_from_root( 15 | root: Root, 16 | n_max_nodes: int, 17 | n_players: int, 18 | running_probabilities: bool = False, 19 | ) -> Tree: 20 | """Initializes tree state at search root.""" 21 | (n_actions,) = root.prior_logits.shape 22 | 23 | data_dtype = root.value.dtype 24 | 25 | def _zeros(x: Array) -> Array: 26 | return jnp.zeros((n_max_nodes,) + x.shape, dtype=x.dtype) 27 | 28 | # Create a new empty tree state and fill its root. 29 | tree = Tree( 30 | node_visits=jnp.zeros(n_max_nodes, dtype=jnp.int32), 31 | raw_values=jnp.zeros((n_max_nodes, n_players), dtype=data_dtype), 32 | node_values=jnp.zeros((n_max_nodes, n_players), dtype=data_dtype), 33 | parents=jnp.full( 34 | n_max_nodes, 35 | Tree.NO_PARENT, 36 | dtype=jnp.int32, 37 | ), 38 | action_from_parent=jnp.full( 39 | n_max_nodes, 40 | Tree.NO_PARENT, 41 | dtype=jnp.int32, 42 | ), 43 | children_index=jnp.full( 44 | (n_max_nodes, n_actions), 45 | Tree.UNVISITED, 46 | dtype=jnp.int32, 47 | ), 48 | children_prior_logits=jnp.zeros( 49 | (n_max_nodes, n_actions), dtype=root.prior_logits.dtype 50 | ), 51 | children_values=jnp.zeros((n_max_nodes, n_actions, n_players), dtype=data_dtype), 52 | children_visits=jnp.zeros((n_max_nodes, n_actions), dtype=jnp.int32), 53 | children_rewards=jnp.zeros( 54 | (n_max_nodes, n_actions, n_players), dtype=data_dtype 55 | ), 56 | to_visit=jnp.zeros(n_max_nodes, dtype=bool), 57 | states=jax.tree_util.tree_map(_zeros, root.state), 58 | depth=jnp.ones(n_max_nodes, dtype=jnp.int32) * -1, 59 | extra_data={}, 60 | ) 61 | new_tree: Tree = tree._replace( 62 | node_visits=tree.node_visits.at[Tree.ROOT_INDEX].set(1), 63 | states=jax.tree_map( 64 | lambda x, y: x.at[Tree.ROOT_INDEX].set(y), tree.states, root.state 65 | ), 66 | children_prior_logits=tree.children_prior_logits.at[Tree.ROOT_INDEX].set( 67 | root.prior_logits 68 | ), 69 | depth=tree.depth.at[Tree.ROOT_INDEX].set(0), 70 | ) 71 | 72 | if running_probabilities: 73 | new_tree = initialize_running_probabilities(new_tree) 74 | 75 | return new_tree 76 | 77 | 78 | def initialize_running_probabilities(tree: Tree) -> Tree: 79 | n_max_nodes = tree.node_visits.shape[-1] 80 | init_prob = (jnp.ones(n_max_nodes) * -1).at[Tree.ROOT_INDEX].set(1.0) 81 | running_probabilities = { 82 | "p_self": init_prob, 83 | "p_opponent": init_prob, 84 | "p_chance": init_prob, 85 | } 86 | tree = tree._replace(extra_data={**tree.extra_data, **running_probabilities}) 87 | return tree 88 | 89 | 90 | def add_children( 91 | tree: Tree, 92 | state: PyTree, 93 | env: pgx.Env, 94 | node_counter: jax.Array, 95 | parent_idx: jax.Array, 96 | ) -> tuple[Tree, jax.Array]: 97 | 98 | chance_fn = ft.partial(add_chance_children, env=env) 99 | player_fn = ft.partial(add_player_children, env=env) 100 | return jax.lax.cond( 101 | state.chance_node, chance_fn, player_fn, tree, state, node_counter, parent_idx 102 | ) 103 | 104 | 105 | def add_player_children( 106 | tree: Tree, 107 | state: PyTree, 108 | node_counter: jax.Array, 109 | parent_idx: jax.Array, 110 | env: pgx.Env, 111 | ) -> tuple[Tree, jax.Array]: 112 | 113 | action_mask = env.get_action_mask(state) & ~state.terminated 114 | n_actions = len(action_mask) 115 | n_max_nodes = len(tree.to_visit) 116 | action_idx = jnp.where(action_mask, jnp.arange(n_actions), n_max_nodes) 117 | action_idx = jnp.sort(action_idx) 118 | 119 | update_idx = jnp.arange(n_actions) + node_counter + 1 120 | update_idx = jnp.where(action_idx == n_max_nodes, n_max_nodes, update_idx) 121 | 122 | action_from_parent = tree.action_from_parent.at[update_idx].set(action_idx) 123 | parents = tree.parents.at[update_idx].set(parent_idx) 124 | to_visit = tree.to_visit.at[update_idx].set(True) 125 | 126 | tree = tree._replace( 127 | action_from_parent=action_from_parent, parents=parents, to_visit=to_visit 128 | ) 129 | 130 | return tree, action_mask.sum() 131 | 132 | 133 | def add_chance_children( 134 | tree: Tree, 135 | state: PyTree, 136 | node_counter: jax.Array, 137 | parent_idx: jax.Array, 138 | env: pgx.Env, 139 | ) -> tuple[Tree, jax.Array]: 140 | 141 | action_mask = env.get_chance_mask(state) & ~state.terminated 142 | n_actions = len(action_mask) 143 | n_max_nodes = len(tree.to_visit) 144 | action_idx = jnp.where(action_mask, jnp.arange(n_actions), n_max_nodes) 145 | action_idx = jnp.sort(action_idx) 146 | 147 | update_idx = jnp.arange(n_actions) + node_counter + 1 148 | update_idx = jnp.where(action_idx == n_max_nodes, n_max_nodes, update_idx) 149 | 150 | action_from_parent = tree.action_from_parent.at[update_idx].set(action_idx) 151 | parents = tree.parents.at[update_idx].set(parent_idx) 152 | to_visit = tree.to_visit.at[update_idx].set(True) 153 | 154 | tree = tree._replace( 155 | action_from_parent=action_from_parent, parents=parents, to_visit=to_visit 156 | ) 157 | 158 | return tree, action_mask.sum() 159 | 160 | 161 | def select_new_node_and_play( 162 | tree: Tree, env: pgx.Env 163 | ) -> tuple[PyTree, jax.Array, jax.Array, jax.Array]: 164 | 165 | child_index = jnp.argmax(tree.to_visit) 166 | parent_index = tree.parents[child_index] 167 | action = tree.action_from_parent[child_index] 168 | 169 | parent_state = jax.tree_map(lambda x: x[parent_index], tree.states) 170 | print(parent_state.legal_action_mask.shape) 171 | new_state = env.step(parent_state, action) 172 | 173 | return new_state, parent_index, child_index, action 174 | 175 | 176 | def update_running_probabilities( 177 | tree: Tree, 178 | parent_index: jax.Array, 179 | next_node_index: jax.Array, 180 | strategy: Float[Array, ""], 181 | traverser: int, 182 | ) -> Tree: 183 | parent_state = jax.tree_map(lambda x: x[parent_index], tree.states) 184 | 185 | p_self = tree.extra_data["p_self"] 186 | 187 | update_p_self_condition = ( 188 | (parent_state.current_player == traverser) 189 | & (~parent_state.terminated) 190 | & (~parent_state.chance_node) 191 | ) 192 | p_self_new_value = jnp.where( 193 | update_p_self_condition, 194 | p_self[parent_index] * strategy, 195 | p_self[parent_index], 196 | ) 197 | 198 | p_opponent = tree.extra_data["p_opponent"] 199 | 200 | update_p_opponent_condition = ( 201 | (parent_state.current_player != traverser) 202 | & (~parent_state.terminated) 203 | & (~parent_state.chance_node) 204 | ) 205 | 206 | p_opponent_new_value = jnp.where( 207 | update_p_opponent_condition, 208 | p_opponent[parent_index] * strategy, 209 | p_opponent[parent_index], 210 | ) 211 | 212 | p_chance = tree.extra_data["p_chance"] 213 | 214 | p_chance_new_value = jnp.where( 215 | parent_state.chance_node, 216 | p_chance[parent_index] * strategy, 217 | p_chance[parent_index], 218 | ) 219 | 220 | tree = tree._replace( 221 | extra_data={ 222 | **tree.extra_data, 223 | **{ 224 | "p_self": p_self.at[next_node_index].set(p_self_new_value), 225 | "p_opponent": p_opponent.at[next_node_index].set(p_opponent_new_value), 226 | "p_chance": p_chance.at[next_node_index].set(p_chance_new_value), 227 | }, 228 | } 229 | ) 230 | return tree 231 | 232 | 233 | def traverse_tree_vanilla( 234 | tree: Tree, 235 | env: pgx.Env, 236 | ) -> Tree: 237 | def cond_fn(val: Tuple) -> Bool[Array, ""]: 238 | tree, n = val 239 | n_max_nodes = len(tree.node_visits) 240 | 241 | return tree.to_visit.any() & (n < n_max_nodes) 242 | 243 | def loop_fn(val: Tuple) -> Tuple: 244 | tree, n = val 245 | 246 | new_state, parent_index, child_index, action = select_new_node_and_play( 247 | tree, env 248 | ) 249 | 250 | tree, n_added = add_children( 251 | tree=tree, state=new_state, env=env, node_counter=n, parent_idx=child_index 252 | ) 253 | 254 | tree = tree._replace( 255 | node_visits=tree.node_visits.at[child_index].set(1), 256 | node_values=tree.node_values.at[child_index].set(new_state.rewards), 257 | states=jax.tree_map( 258 | lambda x, y: x.at[child_index].set(y), tree.states, new_state 259 | ), 260 | raw_values=tree.raw_values.at[child_index].set(new_state.rewards), 261 | children_index=tree.children_index.at[parent_index, action].set(child_index), 262 | children_rewards=tree.children_rewards.at[parent_index, action].set( 263 | new_state.rewards 264 | ), 265 | children_values=tree.children_values.at[parent_index, action].set( 266 | new_state.rewards 267 | ), 268 | parents=tree.parents.at[child_index].set(parent_index), 269 | action_from_parent=tree.action_from_parent.at[child_index].set(action), 270 | depth=tree.depth.at[child_index].set(tree.depth[parent_index] + 1), 271 | to_visit=tree.to_visit.at[child_index].set(False), 272 | ) 273 | 274 | return tree, n + n_added 275 | 276 | tree, n_added = add_children( 277 | tree=tree, 278 | state=jax.tree_map(lambda x: x[0], tree.states), 279 | env=env, 280 | node_counter=0, 281 | parent_idx=0, 282 | ) 283 | 284 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, n_added)) 285 | return tree 286 | 287 | 288 | def traverse_tree_cfr( 289 | tree: Tree, 290 | policy: Policy, 291 | policy_params: Array, 292 | env: pgx.Env, 293 | traverser: int = 0, 294 | ) -> Tree: 295 | def cond_fn(val: Tuple) -> Bool[Array, ""]: 296 | tree, n = val 297 | n_max_nodes = len(tree.node_visits) 298 | 299 | return tree.to_visit.any() & (n < n_max_nodes) 300 | 301 | def loop_fn(val: Tuple) -> Tuple: 302 | tree, n = val 303 | 304 | new_state, parent_index, child_index, action = select_new_node_and_play( 305 | tree, env 306 | ) 307 | 308 | tree, n_added = add_children( 309 | tree=tree, state=new_state, env=env, node_counter=n, parent_idx=child_index 310 | ) 311 | 312 | parent_state = jax.tree_map(lambda x: x[parent_index], tree.states) 313 | 314 | strategy = policy.prob_distribution( 315 | params=policy_params, 316 | info_state=env.get_info_state(parent_state), 317 | action_mask=env.get_action_mask(state=parent_state), 318 | use_behavior_policy=jnp.bool_(False), 319 | ) 320 | 321 | chance_strategy = env.get_chance_probs(parent_state)[action] 322 | # jax.debug.breakpoint() 323 | 324 | action_prob = jnp.where( 325 | parent_state.chance_node, chance_strategy, strategy[action] 326 | ) 327 | 328 | tree = update_running_probabilities( 329 | tree=tree, 330 | parent_index=parent_index, 331 | next_node_index=child_index, 332 | strategy=action_prob, 333 | traverser=traverser, 334 | ) 335 | 336 | tree = tree._replace( 337 | node_visits=tree.node_visits.at[child_index].set(1), 338 | node_values=tree.node_values.at[child_index].set(new_state.rewards), 339 | states=jax.tree_map( 340 | lambda x, y: x.at[child_index].set(y), tree.states, new_state 341 | ), 342 | raw_values=tree.raw_values.at[child_index].set(new_state.rewards), 343 | children_index=tree.children_index.at[parent_index, action].set(child_index), 344 | children_rewards=tree.children_rewards.at[parent_index, action].set( 345 | new_state.rewards 346 | ), 347 | children_values=tree.children_values.at[parent_index, action].set( 348 | new_state.rewards 349 | ), 350 | children_prior_logits=tree.children_prior_logits.at[parent_index].set( 351 | jnp.where(parent_state.chance_node, chance_strategy, strategy) 352 | ), 353 | parents=tree.parents.at[child_index].set(parent_index), 354 | action_from_parent=tree.action_from_parent.at[child_index].set(action), 355 | depth=tree.depth.at[child_index].set(tree.depth[parent_index] + 1), 356 | to_visit=tree.to_visit.at[child_index].set(False), 357 | ) 358 | 359 | return tree, n + n_added 360 | 361 | tree, n_added = add_children( 362 | tree=tree, 363 | state=jax.tree_map(lambda x: x[0], tree.states), 364 | env=env, 365 | node_counter=0, 366 | parent_idx=0, 367 | ) 368 | 369 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, n_added)) 370 | return tree 371 | --------------------------------------------------------------------------------