├── cfrx
├── __init__.py
├── algorithms
│ ├── __init__.py
│ ├── cfr
│ │ ├── __init__.py
│ │ └── cfr.py
│ └── mccfr
│ │ ├── __init__.py
│ │ ├── test_outcome_sampling.py
│ │ └── outcome_sampling.py
├── trainers
│ ├── __init__.py
│ ├── cfr.py
│ └── mccfr.py
├── envs
│ ├── nlhe_poker
│ │ ├── __init__.py
│ │ ├── showdown.py
│ │ ├── gui.py
│ │ ├── env.py
│ │ └── test_showdown.py
│ ├── kuhn_poker
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ └── info_states.npz
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ └── env.py
│ ├── leduc_poker
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ └── info_states.npz
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ └── env.py
│ ├── __init__.py
│ └── base.py
├── tree
│ ├── __init__.py
│ ├── tree.py
│ └── traverse.py
├── metrics
│ ├── __init__.py
│ ├── exploitability.py
│ ├── exploitability_test.py
│ └── best_response.py
├── episode.py
├── utils.py
└── policy.py
├── imgs
└── bench_open_spiel.png
├── MANIFEST.in
├── mypy.ini
├── CONTRIBUTING.md
├── .pre-commit-config.yaml
├── LICENSE
├── pyproject.toml
├── .gitignore
├── README.md
└── examples
└── cfr.ipynb
/cfrx/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfrx/algorithms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfrx/trainers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfrx/envs/nlhe_poker/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfrx/envs/kuhn_poker/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cfrx/envs/leduc_poker/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/bench_open_spiel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Egiob/cfrx/HEAD/imgs/bench_open_spiel.png
--------------------------------------------------------------------------------
/cfrx/tree/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.tree.tree import Root, Tree
2 |
3 | __all__ = ["Root", "Tree"]
4 |
--------------------------------------------------------------------------------
/cfrx/envs/kuhn_poker/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.envs.kuhn_poker.env import KuhnPoker
2 |
3 | __all__ = ["KuhnPoker"]
4 |
--------------------------------------------------------------------------------
/cfrx/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.envs.base import Env, InfoState, State
2 |
3 | __all__ = ["Env", "State", "InfoState"]
4 |
--------------------------------------------------------------------------------
/cfrx/envs/leduc_poker/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.envs.leduc_poker.env import LeducPoker
2 |
3 | __all__ = ["LeducPoker"]
4 |
--------------------------------------------------------------------------------
/cfrx/envs/kuhn_poker/data/info_states.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Egiob/cfrx/HEAD/cfrx/envs/kuhn_poker/data/info_states.npz
--------------------------------------------------------------------------------
/cfrx/envs/leduc_poker/data/info_states.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Egiob/cfrx/HEAD/cfrx/envs/leduc_poker/data/info_states.npz
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | recursive-include * *.npz
3 |
4 | # remove the test specific files
5 | recursive-exclude * *_test.py
6 |
--------------------------------------------------------------------------------
/cfrx/algorithms/cfr/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.algorithms.cfr.cfr import CFRState, do_iteration
2 |
3 | __all__ = ["CFRState", "do_iteration"]
4 |
--------------------------------------------------------------------------------
/cfrx/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.metrics.best_response import compute_best_response_value
2 | from cfrx.metrics.exploitability import exploitability
3 |
--------------------------------------------------------------------------------
/cfrx/algorithms/mccfr/__init__.py:
--------------------------------------------------------------------------------
1 | from cfrx.algorithms.mccfr.outcome_sampling import MCCFRState, do_iteration
2 |
3 | __all__ = ["MCCFRState", "do_iteration"]
4 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | # Global options:
2 |
3 | [mypy]
4 | warn_return_any = True
5 | warn_unused_configs = True
6 | ignore_missing_imports = True
7 | follow_imports = skip
8 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contribution Guidelines
2 |
3 | All contributions are welcome, don't hesitate to send a PR or open an issue.
4 |
5 | ## License
6 |
7 | By contributing, you agree that your contributions will be licensed under MIT License.
8 |
--------------------------------------------------------------------------------
/cfrx/envs/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC, abstractmethod
4 |
5 | import pgx.core
6 | from jaxtyping import Array, Int, PyTree
7 |
8 | InfoState = PyTree
9 | State = PyTree
10 |
11 |
12 | class BaseEnv(ABC):
13 | @abstractmethod
14 | def action_to_string(cls, action: Int[Array, ""]) -> str:
15 | pass
16 |
17 |
18 | class Env(BaseEnv, pgx.core.Env):
19 | pass
20 |
--------------------------------------------------------------------------------
/cfrx/episode.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | from jaxtyping import Array, Bool, Float, Int
4 |
5 | from cfrx.envs.base import InfoState
6 |
7 |
8 | class Episode(NamedTuple):
9 | info_state: InfoState
10 | action: Float[Array, "..."]
11 | reward: Float[Array, "..."]
12 | action_mask: Bool[Array, "..."]
13 | current_player: Int[Array, "..."]
14 | behavior_prob: Float[Array, "..."]
15 | mask: Bool[Array, "..."]
16 | chance_node: Bool[Array, "..."]
17 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v3.2.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - id: check-added-large-files
9 | - id: requirements-txt-fixer
10 | - id: pretty-format-json
11 | exclude: \.ipynb
12 | - repo: https://github.com/psf/black
13 | rev: 23.11.0 # Use the sha / tag you want to point at
14 | hooks:
15 | - id: black
16 | - repo: https://github.com/pycqa/isort
17 | rev: 5.12.0
18 | hooks:
19 | - id: isort
20 |
21 |
22 | - repo: https://github.com/kynan/nbstripout
23 | rev: 0.7.1
24 | hooks:
25 | - id: nbstripout
26 |
--------------------------------------------------------------------------------
/cfrx/envs/leduc_poker/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | current_path = os.path.dirname(os.path.abspath(__file__))
6 |
7 | INFO_SETS = dict(
8 | np.load(
9 | os.path.join(
10 | current_path,
11 | "data",
12 | "info_states.npz",
13 | )
14 | )
15 | )
16 |
17 |
18 | # This is a bit hacky, it allows to transform an infostate into an index, and to
19 | # construct an array to efficiently lookup in Jax.
20 | multiplier = np.array([3**k for k in range(12)])
21 | tr = {k: np.sum((v + 1) * multiplier) % 1235 for k, v in INFO_SETS.items()}
22 | max_value = max(list(tr.values()))
23 | REVERSE_INFO_SETS_LOOKUP = np.zeros(max_value + 1, dtype=int)
24 | REVERSE_INFO_SETS_LOOKUP[list(tr.values())] = np.arange(len(tr))
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Raphaël Boige
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cfrx/metrics/exploitability.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import pgx
4 | from jaxtyping import Array
5 |
6 | from cfrx.metrics.best_response import compute_best_response_value
7 | from cfrx.policy import Float, Policy
8 | from cfrx.tree import Root
9 | from cfrx.tree.traverse import instantiate_tree_from_root, traverse_tree_cfr
10 |
11 |
12 | def exploitability(
13 | env: pgx.Env,
14 | policy: Policy,
15 | policy_params: Array,
16 | n_players: int,
17 | n_max_nodes: int = 100,
18 | ) -> Float[Array, ""]:
19 | random_key = jax.random.PRNGKey(0)
20 | state = jax.jit(env.init)(random_key)
21 |
22 | values_list = []
23 | root = Root(
24 | prior_logits=state.legal_action_mask / state.legal_action_mask.sum(),
25 | value=jnp.array(0.0),
26 | state=state,
27 | )
28 | for traverser in range(n_players):
29 | tree = instantiate_tree_from_root(
30 | root=root,
31 | n_max_nodes=n_max_nodes,
32 | n_players=n_players,
33 | running_probabilities=True,
34 | )
35 |
36 | tree = traverse_tree_cfr(
37 | tree,
38 | policy=policy,
39 | env=env,
40 | traverser=traverser,
41 | policy_params=policy_params,
42 | )
43 |
44 | value = compute_best_response_value(
45 | tree, br_player=traverser, info_state_fn=env.info_state_idx
46 | )
47 | values_list.append(value)
48 |
49 | values = jnp.stack(values_list)
50 | exploitability = jnp.array(values).sum() / n_players
51 | return exploitability
52 |
--------------------------------------------------------------------------------
/cfrx/envs/kuhn_poker/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import jax
4 | import numpy as np
5 |
6 | current_path = os.path.dirname(os.path.abspath(__file__))
7 |
8 | INFO_SETS = dict(
9 | np.load(
10 | os.path.join(
11 | current_path,
12 | "data",
13 | "info_states.npz",
14 | )
15 | )
16 | )
17 |
18 | INFO_SET_ACTION_MASK = {
19 | "0": [False, True, False, True],
20 | "0p": [False, True, False, True],
21 | "0b": [True, False, True, False],
22 | "0pb": [True, False, True, False],
23 | "1": [False, True, False, True],
24 | "1p": [False, True, False, True],
25 | "1b": [True, False, True, False],
26 | "1pb": [True, False, True, False],
27 | "2": [False, True, False, True],
28 | "2p": [False, True, False, True],
29 | "2b": [True, False, True, False],
30 | "2pb": [True, False, True, False],
31 | }
32 |
33 |
34 | def get_kuhn_optimal_policy(alpha: float) -> dict:
35 | assert 0 <= alpha <= 1 / 3
36 | optimal_probs = {
37 | "0": [0.0, alpha, 0.0, 1 - alpha],
38 | "0p": [0.0, 1 / 3, 0.0, 2 / 3],
39 | "0b": [0.0, 0.0, 1.0, 0.0],
40 | "0pb": [0.0, 0.0, 1.0, 0.0],
41 | "1": [0.0, 0.0, 0.0, 1.0],
42 | "1p": [0.0, 0.0, 0.0, 1.0],
43 | "1b": [1 / 3, 0.0, 2 / 3, 0.0],
44 | "1pb": [alpha + 1 / 3, 0.0, 2 / 3 - alpha, 0.0],
45 | "2": [0.0, 3 * alpha, 0, 1 - 3 * alpha],
46 | "2p": [0.0, 1.0, 0.0, 0.0],
47 | "2b": [1.0, 0.0, 0.0, 0.0],
48 | "2pb": [1.0, 0.0, 0.0, 0.0],
49 | }
50 | return optimal_probs
51 |
52 |
53 | KUHN_UNIFORM_POLICY = jax.tree_map(
54 | lambda x: x / x.sum(), {k: np.array(x) for k, x in INFO_SET_ACTION_MASK.items()}
55 | )
56 |
--------------------------------------------------------------------------------
/cfrx/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional
4 |
5 | import jax
6 | import jax.flatten_util
7 | import jax.numpy as jnp
8 | from jaxtyping import Array, Bool, Int, PyTree
9 |
10 |
11 | def get_action_mask(state: PyTree) -> Bool[Array, "..."]:
12 | chance_action_mask = state.chance_prior > 0
13 | decision_action_mask = state.legal_action_mask
14 | max_action_mask_size = max(
15 | chance_action_mask.shape[-1], decision_action_mask.shape[-1]
16 | )
17 | chance_action_mask = jnp.pad(
18 | chance_action_mask,
19 | (0, max_action_mask_size - chance_action_mask.shape[-1]),
20 | constant_values=(False,),
21 | )
22 | decision_action_mask = jnp.pad(
23 | decision_action_mask,
24 | (0, max_action_mask_size - decision_action_mask.shape[-1]),
25 | constant_values=(False,),
26 | )
27 | action_mask = jnp.where(
28 | state.chance_node[..., None],
29 | chance_action_mask,
30 | decision_action_mask,
31 | )
32 | return action_mask
33 |
34 |
35 | def reverse_array_lookup(x: Array, lookup_table: Array) -> Int[Array, ""]:
36 | return (lookup_table == x).all(axis=1).argmax()
37 |
38 |
39 | def tree_unstack(tree: PyTree) -> list[PyTree]:
40 | leaves, treedef = jax.tree_util.tree_flatten(tree)
41 | n_trees = leaves[0].shape[0]
42 | new_leaves: list = [[] for _ in range(n_trees)]
43 | for leaf in leaves:
44 | for i in range(n_trees):
45 | new_leaves[i].append(leaf[i])
46 | new_trees = [treedef.unflatten(leaves) for leaves in new_leaves]
47 | return new_trees
48 |
49 |
50 | def ravel(tree: PyTree) -> Array:
51 | return jax.flatten_util.ravel_pytree(tree)[0]
52 |
53 |
54 | def regret_matching(regrets: Array) -> Array:
55 | positive_regrets = jnp.maximum(regrets, 0)
56 | n_actions = positive_regrets.shape[-1]
57 | sum_pos_regret = positive_regrets.sum(axis=-1, keepdims=True)
58 | dist = jnp.where(
59 | sum_pos_regret == 0, 1 / n_actions, positive_regrets / sum_pos_regret
60 | )
61 | return dist
62 |
63 |
64 | def log_array(x: Array, name: Optional[str] = None) -> None:
65 | if name is None:
66 | name = "array"
67 | jax.debug.print(name + ": {x}", x=x)
68 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 | [project]
7 | name = "cfrx"
8 | description = "Counterfactual Regret Minimization in Jax"
9 | readme = "README.md"
10 | license = {file = "LICENSE"}
11 | urls = {repository = "https://github.com/Egiob/cfrx" }
12 | authors = [
13 | {name="Raphaël Boige"},
14 | ]
15 | requires-python = ">=3.9.0"
16 | version = "0.0.2"
17 | dependencies = [
18 | "jaxtyping>=0.2.19",
19 | "jax>=0.4.0",
20 | "flax>=0.7.0",
21 | "pgx= =2.0.1",
22 | "tqdm~=4.66.0"
23 | ]
24 | keywords = ["jax", "game-theory", "reinforcement-learning", "cfr", "poker"]
25 | classifiers = [
26 | "Development Status :: 3 - Alpha",
27 | "Intended Audience :: Developers",
28 | "Intended Audience :: Information Technology",
29 | "Intended Audience :: Science/Research",
30 | "License :: OSI Approved :: MIT License",
31 | "Natural Language :: English",
32 | "Programming Language :: Python :: 3",
33 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
34 | "Topic :: Scientific/Engineering :: Information Analysis",
35 | "Topic :: Scientific/Engineering :: Mathematics",
36 | ]
37 |
38 | [tool.setuptools]
39 | include-package-data = true
40 |
41 | [tool.setuptools.packages.find]
42 | where = ["."]
43 | include = ["*"]
44 | exclude = ["imgs"]
45 |
46 | [tool.isort]
47 | profile = "black"
48 |
49 | [tool.mypy]
50 | python_version = "3.10"
51 | namespace_packages = true
52 | incremental = false
53 | cache_dir = ""
54 | warn_redundant_casts = true
55 | warn_return_any = true
56 | warn_unused_configs = true
57 | warn_unused_ignores = false
58 | allow_redefinition = true
59 | disallow_untyped_calls = true
60 | disallow_untyped_defs = true
61 | disallow_incomplete_defs = true
62 | check_untyped_defs = true
63 | disallow_untyped_decorators = false
64 | strict_optional = true
65 | strict_equality = true
66 | explicit_package_bases = true
67 | follow_imports = "skip"
68 | ignore_missing_imports = true
69 |
70 | [tool.black]
71 | line-length = 89
72 |
73 | [tool.bumpver]
74 | current_version = "0.0.2"
75 | version_pattern = "MAJOR.MINOR.PATCH"
76 | commit_message = "build: bump version {old_version} -> {new_version}"
77 | commit = true
78 | tag = true
79 | push = false
80 |
81 | [tool.bumpver.file_patterns]
82 | "pyproject.toml" = ['current_version = "{version}"', 'version = "{version}"']
83 |
--------------------------------------------------------------------------------
/cfrx/metrics/exploitability_test.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import pyspiel
7 | import pytest
8 | from jaxtyping import Array
9 | from open_spiel.python.algorithms.exploitability import (
10 | exploitability as open_spiel_exploitability_fn,
11 | )
12 | from open_spiel.python.algorithms.mccfr import AveragePolicy
13 |
14 | from cfrx.envs.kuhn_poker.constants import (
15 | INFO_SETS,
16 | KUHN_UNIFORM_POLICY,
17 | get_kuhn_optimal_policy,
18 | )
19 | from cfrx.envs.kuhn_poker.env import KuhnPoker
20 | from cfrx.metrics.exploitability import exploitability as cfrx_exploitability_fn
21 | from cfrx.policy import TabularPolicy
22 |
23 |
24 | @pytest.mark.parametrize(
25 | "policy_dict",
26 | [
27 | get_kuhn_optimal_policy(alpha=0.0),
28 | KUHN_UNIFORM_POLICY,
29 | get_kuhn_optimal_policy(alpha=0.15),
30 | ],
31 | )
32 | def test_kuhn_exploitability_vs_open_spiel(policy_dict: Dict[str, Array]) -> None:
33 | info_states_dict = INFO_SETS
34 | default = np.ones_like(next(iter(policy_dict.values())))
35 | policy = np.stack([policy_dict.get(x, default) for x in info_states_dict])
36 |
37 | # cfrx
38 | policy_jax = jnp.asarray(policy)
39 | n_max_nodes = 200
40 | env = KuhnPoker()
41 | policy_obj = TabularPolicy(
42 | n_actions=policy.shape[1], info_state_idx_fn=env.info_state_idx
43 | )
44 | n_players = 2
45 | cfrx_exploitability = cfrx_exploitability_fn(
46 | env,
47 | policy_params=policy_jax,
48 | n_players=n_players,
49 | n_max_nodes=n_max_nodes,
50 | policy=policy_obj,
51 | )
52 |
53 | # open_spiel
54 | game = pyspiel.load_game("kuhn_poker")
55 | avg_probs = np.array(policy)
56 | avg_probs = np.concatenate(
57 | [
58 | avg_probs[:, 2:4].sum(axis=1, keepdims=True),
59 | avg_probs[:, :2].sum(axis=1, keepdims=True),
60 | ],
61 | axis=1,
62 | )
63 |
64 | open_spiel_info_states = [[None, x] for x in avg_probs]
65 | cfrx_to_open_spiel_info_states = dict(zip(INFO_SETS.keys(), open_spiel_info_states))
66 |
67 | average_policy = AveragePolicy(
68 | game=game, player_ids=[0, 1], infostates=cfrx_to_open_spiel_info_states
69 | )
70 | open_spiel_exploitability = open_spiel_exploitability_fn(game, average_policy)
71 |
72 | assert np.allclose(cfrx_exploitability, open_spiel_exploitability, atol=1e-5)
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cfrx: Counterfactual Regret Minimization in Jax.
2 |
3 | cfrx is an open-source library designed for efficient implementation of counterfactual regret minimization (CFR) algorithms using JAX. It focuses on computational speed and easy parallelization on hardware accelerators like GPUs and TPUs.
4 |
5 | Key Features:
6 |
7 | - **JIT Compilation for Speed:** cfrx makes the most out of JAX's just-in-time (JIT) compilation to minimize runtime overhead and maximize computational speed.
8 |
9 | - **Hardware Accelerator Support:** It supports parallelization on GPUs and TPUs, enabling efficient scaling of computations for large-scale problems.
10 |
11 | - **Python/JAX Ease of Use:** cfrx provides a Pythonic interface built on JAX, offering simplicity and accessibility compared to traditional C++ implementations or prohibitively slow pure-Python code.
12 |
13 | ## Installation
14 |
15 | pip install cfrx
16 |
17 | ## Getting started
18 |
19 | An example notebook is available [here](examples/mccfr.ipynb).
20 |
21 | Snippet for training a MCCFR-outcome sampling on the Kuhn Poker game.
22 | ```python3
23 | import jax
24 |
25 | from cfrx.envs.kuhn_poker.env import KuhnPoker
26 | from cfrx.policy import TabularPolicy
27 | from cfrx.trainers.mccfr import MCCFRTrainer
28 |
29 | env = KuhnPoker()
30 |
31 | policy = TabularPolicy(
32 | n_actions=env.n_actions,
33 | exploration_factor=0.6,
34 | info_state_idx_fn=env.info_state_idx,
35 | )
36 |
37 | random_key = jax.random.PRNGKey(0)
38 |
39 | trainer = MCCFRTrainer(env=env, policy=policy)
40 |
41 | training_state, metrics = trainer.train(
42 | random_key=random_key, n_iterations=100_000, metrics_period=5_000
43 | )
44 | ```
45 |
46 |
47 | ## Implemented features and upcoming features
48 |
49 | | Algorithms | |
50 | |---|---|
51 | | MCCFR (outcome-sampling) | :white_check_mark: |
52 | | MCCFR (other variants) | :x: |
53 | | Vanilla CFR | :white_check_mark: |
54 | | Deep CFR | :x: |
55 |
56 | | Metrics | |
57 | |---|---|
58 | | Exploitability | :white_check_mark: |
59 | | Local Best Response | :x: |
60 |
61 | | Environments | |
62 | |---|---|
63 | | Kuhn Poker | :white_check_mark: |
64 | | Leduc Poker | :white_check_mark: |
65 | | Larger games | :x: |
66 |
67 |
68 | ## Performance
69 |
70 | Below is a small benchmark against `open_spiel` for MCCFR-outcome-sampling on Kuhn Poker and Leduc Poker. Compared to the Python API of `open_spiel`, `cfrx` has faster runtime and demonstrates similar convergence.
71 |
72 | 
73 |
74 | ## See also
75 |
76 | cfrx is heavily inspired by the amazing [google-deepmind/open_spiel](https://github.com/google-deepmind/open_spiel) library as well as by many projects from the Jax ecosystem and especially [sotetsuk/pgx](https://github.com/sotetsuk/pgx) and [google-deepmind/mctx](https://github.com/google-deepmind/mctx).
77 |
78 |
79 | ## Contributing
80 |
81 | Contributions are welcome, refer to the [contributions guidelines](CONTRIBUTING.md).
82 |
--------------------------------------------------------------------------------
/cfrx/policy.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from collections.abc import Callable
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray
7 |
8 | from cfrx.envs.base import InfoState
9 |
10 |
11 | class Policy(ABC):
12 | _n_actions: int
13 |
14 | @abstractmethod
15 | def prob_distribution(
16 | self,
17 | params: Float[Array, "... a"],
18 | info_state: InfoState[Array, "..."],
19 | action_mask: Bool[Array, "... a"],
20 | use_behavior_policy: Bool[Array, "..."],
21 | ) -> Array:
22 | pass
23 |
24 | @abstractmethod
25 | def sample(
26 | self,
27 | params: Float[Array, "... a"],
28 | info_state: InfoState[Array, "..."],
29 | action_mask: Bool[Array, "... a"],
30 | random_key: PRNGKeyArray,
31 | use_behavior_policy: Bool[Array, "..."],
32 | ) -> Array:
33 | pass
34 |
35 |
36 | class TabularPolicy(Policy):
37 | def __init__(
38 | self,
39 | n_actions: int,
40 | info_state_idx_fn: Callable[[InfoState], Int[Array, ""]],
41 | exploration_factor: float = 0.6,
42 | ):
43 | self._n_actions = n_actions
44 | self._exploration_factor = exploration_factor
45 | self._info_state_idx_fn = info_state_idx_fn
46 |
47 | def prob_distribution(
48 | self,
49 | params: Float[Array, "... a"],
50 | info_state: InfoState[Array, "..."],
51 | action_mask: Bool[Array, "... a"],
52 | use_behavior_policy: Bool[Array, "..."],
53 | ) -> Array:
54 | info_state_idx = self._info_state_idx_fn(info_state)
55 | probs = params[info_state_idx]
56 |
57 | behavior_probabilities = (
58 | probs * (1 - self._exploration_factor)
59 | + self._exploration_factor * jnp.ones_like(probs) / self._n_actions
60 | )
61 |
62 | probs = jnp.where(use_behavior_policy, behavior_probabilities, probs)
63 | probs = probs * action_mask
64 | probs /= probs.sum(axis=-1, keepdims=True)
65 | return probs
66 |
67 | def sample(
68 | self,
69 | params: Float[Array, "... a"],
70 | info_state: InfoState[Array, "..."],
71 | action_mask: Bool[Array, "... a"],
72 | random_key: PRNGKeyArray,
73 | use_behavior_policy: Bool[Array, "..."],
74 | ) -> Array:
75 | info_state_idx = self._info_state_idx_fn(info_state)
76 | probs = params[info_state_idx]
77 |
78 | behavior_probabilities = (
79 | probs * (1 - self._exploration_factor)
80 | + self._exploration_factor * jnp.ones_like(probs) / self._n_actions
81 | )
82 |
83 | probs = jnp.where(use_behavior_policy, behavior_probabilities, probs)
84 |
85 | probs = probs * action_mask
86 | probs /= probs.sum(axis=-1, keepdims=True)
87 | action = jax.random.choice(random_key, jnp.arange(self._n_actions), p=probs)
88 | return action
89 |
--------------------------------------------------------------------------------
/examples/cfr.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "0",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import matplotlib.pyplot as plt\n",
11 | "\n",
12 | "from cfrx.algorithms.cfr import CFRState\n",
13 | "from cfrx.policy import TabularPolicy\n",
14 | "from cfrx.trainers.cfr import CFRTrainer"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "1",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "ENV_NAME = \"Kuhn Poker\""
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "2",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "if ENV_NAME == \"Kuhn Poker\":\n",
35 | " from cfrx.envs.kuhn_poker.env import KuhnPoker\n",
36 | "\n",
37 | " env_cls = KuhnPoker\n",
38 | "\n",
39 | "\n",
40 | "elif ENV_NAME == \"Leduc Poker\":\n",
41 | " from cfrx.envs.leduc_poker.env import LeducPoker\n",
42 | "\n",
43 | " env_cls = LeducPoker"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "id": "3",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "env = env_cls()"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "id": "4",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "training_state = CFRState.init(n_states=env.n_info_states, n_actions=env.n_actions)\n",
64 | "policy = TabularPolicy(n_actions=env.n_actions, info_state_idx_fn=env.info_state_idx)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "id": "5",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "trainer = CFRTrainer(env=env, policy=policy, device=\"cpu\")\n",
75 | "training_state, metrics = trainer.train(n_iterations=10000, metrics_period=10)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "id": "6",
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "plt.plot(metrics[\"step\"], metrics[\"exploitability\"])\n",
86 | "plt.yscale(\"log\")\n",
87 | "plt.xlabel(\"Iterations\")\n",
88 | "plt.title(f\"CFR on {ENV_NAME}\")\n",
89 | "plt.ylabel(\"Exploitability\")"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "id": "7",
96 | "metadata": {},
97 | "outputs": [],
98 | "source": []
99 | }
100 | ],
101 | "metadata": {
102 | "kernelspec": {
103 | "display_name": "Python 3 (ipykernel)",
104 | "language": "python",
105 | "name": "python3"
106 | },
107 | "language_info": {
108 | "codemirror_mode": {
109 | "name": "ipython",
110 | "version": 3
111 | },
112 | "file_extension": ".py",
113 | "mimetype": "text/x-python",
114 | "name": "python",
115 | "nbconvert_exporter": "python",
116 | "pygments_lexer": "ipython3",
117 | "version": "3.10.13"
118 | }
119 | },
120 | "nbformat": 4,
121 | "nbformat_minor": 5
122 | }
123 |
--------------------------------------------------------------------------------
/cfrx/tree/tree.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from jaxtyping import Array, Float, Int, PyTree
6 |
7 |
8 | class classproperty:
9 | def __init__(self, func):
10 | self.fget = func
11 |
12 | def __get__(self, instance, owner):
13 | return self.fget(owner)
14 |
15 |
16 | class Root(NamedTuple):
17 | """
18 | Base class to hold the root of a search tree.
19 |
20 | Args:
21 | prior_logits: `[n_actions]` the action prior logits.
22 | value: `[]` the value of the root node.
23 | state: `[...]` the state of the root node.
24 | """
25 |
26 | prior_logits: Array
27 | value: Array
28 | state: Array
29 |
30 |
31 | UNVISITED = -1
32 | NO_PARENT = -1
33 | ROOT_INDEX = 0
34 |
35 |
36 | class Tree(NamedTuple):
37 | """
38 | Adapted with minor modifications from https://github.com/google-deepmind/mctx.
39 |
40 | State of a search tree.
41 |
42 | The `Tree` dataclass is used to hold and inspect search data for a batch of
43 | inputs. In the fields below `N` represents the number of nodes in the tree,
44 | and `n_actions` is the number of discrete actions.
45 |
46 | node_visits: `[N]` the visit counts for each node.
47 | raw_values: `[N, n_players]` the raw value for each node.
48 | node_values: `[N, n_players]` the cumulative search value for each node.
49 | parents: `[N]` the node index for the parents for each node.
50 | action_from_parent: `[N]` action to take from the parent to reach each
51 | node.
52 | children_index: `[N, n_actions]` the node index of the children for each
53 | action.
54 | children_prior_logits: `[N, n_actions` the action prior logits of each
55 | node.
56 | children_visits: `[N, n_actions]` the visit counts for children for
57 | each action.
58 | children_rewards: `[N, n_actions, n_players]` the immediate reward for each
59 | action.
60 | children_values: `[N, n_actions, n_players]` the value of the next node after
61 | the action.
62 | states: `[N, ...]` the state embeddings of each node.
63 | depth: `[N]` the depth of each node in the tree.
64 | extra_data: `[...]` extra data passed to the tree.
65 |
66 | """
67 |
68 | node_visits: Int[Array, "..."]
69 | raw_values: Float[Array, "... n_players"]
70 | node_values: Float[Array, "... n_players"]
71 | parents: Int[Array, "..."]
72 | action_from_parent: Int[Array, "..."]
73 | children_index: Int[Array, "... n_actions"]
74 | children_prior_logits: Float[Array, "... n_actions"]
75 | children_visits: Int[Array, "... n_actions"]
76 | children_rewards: Float[Array, "... n_actions n_players"]
77 | children_values: Float[Array, "... n_actions n_players"]
78 | states: PyTree
79 | depth: Int[Array, "..."]
80 | extra_data: dict[str, Array]
81 | to_visit: jax.Array
82 |
83 | @classproperty
84 | def ROOT_INDEX(cls) -> Int[Array, ""]:
85 | return jnp.asarray(ROOT_INDEX)
86 |
87 | @classproperty
88 | def NO_PARENT(cls) -> Int[Array, ""]:
89 | return jnp.asarray(NO_PARENT)
90 |
91 | @classproperty
92 | def UNVISITED(cls) -> Int[Array, ""]:
93 | return jnp.asarray(UNVISITED)
94 |
95 | @property
96 | def n_actions(self) -> int:
97 | return self.children_index.shape[-1]
98 |
--------------------------------------------------------------------------------
/cfrx/trainers/cfr.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from jaxtyping import Array, Int
3 | from tqdm import tqdm
4 |
5 | from cfrx.algorithms.cfr.cfr import CFRState, do_iteration
6 | from cfrx.envs import Env
7 | from cfrx.metrics import exploitability
8 | from cfrx.policy import TabularPolicy
9 |
10 |
11 | class CFRTrainer:
12 | def __init__(self, env: Env, policy: TabularPolicy, device: str):
13 | self._env = env
14 | self._policy = policy
15 | self._device = jax.devices(device)[0]
16 | self._exploitability_fn = jax.jit(
17 | lambda policy_params: exploitability(
18 | policy_params=policy_params,
19 | env=env,
20 | n_players=env.n_players,
21 | n_max_nodes=env.max_nodes,
22 | policy=policy,
23 | ),
24 | device=self._device,
25 | )
26 |
27 | self._do_iteration_fn = jax.jit(
28 | lambda training_state, update_player: do_iteration(
29 | training_state=training_state,
30 | env=env,
31 | policy=policy,
32 | update_player=update_player,
33 | ),
34 | device=self._device,
35 | )
36 |
37 | def do_n_iterations(
38 | self,
39 | training_state: CFRState,
40 | update_player: Int[Array, ""],
41 | n: int,
42 | ) -> tuple[CFRState, Int[Array, ""]]:
43 | def _scan_fn(carry, unused):
44 | training_state, update_player = carry
45 |
46 | update_player = (update_player + 1) % 2
47 | training_state = self._do_iteration_fn(
48 | training_state,
49 | update_player=update_player,
50 | )
51 |
52 | return (training_state, update_player), None
53 |
54 | (new_training_state, last_update_player), _ = jax.lax.scan(
55 | _scan_fn,
56 | (training_state, update_player),
57 | None,
58 | length=n,
59 | )
60 |
61 | return new_training_state, last_update_player
62 |
63 | def train(
64 | self, n_iterations: int, metrics_period: int
65 | ) -> tuple[CFRState, dict[str, Array]]:
66 | training_state = CFRState.init(self._env.n_info_states, self._env.n_actions)
67 |
68 | assert n_iterations % metrics_period == 0
69 |
70 | n_loops = n_iterations // metrics_period
71 | update_player = 0
72 | _do_n_iterations = jax.jit(
73 | lambda training_state, update_player: self.do_n_iterations(
74 | training_state=training_state,
75 | update_player=update_player,
76 | n=2 * metrics_period,
77 | ),
78 | device=self._device,
79 | )
80 | metrics = []
81 |
82 | pbar = tqdm(total=n_iterations, desc="Training", unit_scale=True)
83 | for k in range(n_loops):
84 | if k == 0:
85 | current_policy = training_state.avg_probs
86 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True)
87 | exp = self._exploitability_fn(policy_params=current_policy)
88 | metrics.append({"exploitability": exp, "step": 0})
89 | pbar.set_postfix(exploitability=f"{exp:.1e}")
90 |
91 | # Do n iterations
92 | training_state, update_player = _do_n_iterations(
93 | training_state, update_player
94 | )
95 |
96 | # Evaluate exploitability
97 | current_policy = training_state.avg_probs
98 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True)
99 |
100 | exp = self._exploitability_fn(policy_params=current_policy)
101 | metrics.append({"exploitability": exp, "step": k * metrics_period})
102 | pbar.set_postfix(exploitability=f"{exp:.1e}")
103 | pbar.update(metrics_period)
104 |
105 | metrics = jax.tree_map(lambda *x: jax.numpy.stack(x), *metrics)
106 | return training_state, metrics
107 |
--------------------------------------------------------------------------------
/cfrx/trainers/mccfr.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from jaxtyping import Array, Int, PRNGKeyArray
3 | from tqdm import tqdm
4 |
5 | from cfrx.algorithms.mccfr.outcome_sampling import MCCFRState, do_iteration
6 | from cfrx.envs import Env
7 | from cfrx.metrics import exploitability
8 | from cfrx.policy import TabularPolicy
9 |
10 |
11 | class MCCFRTrainer:
12 | def __init__(self, env: Env, policy: TabularPolicy):
13 | self._env = env
14 | self._policy = policy
15 |
16 | self._exploitability_fn = jax.jit(
17 | lambda policy_params: exploitability(
18 | policy_params=policy_params,
19 | env=env,
20 | n_players=env.n_players,
21 | n_max_nodes=env.max_nodes,
22 | policy=policy,
23 | )
24 | )
25 |
26 | self._do_iteration_fn = jax.jit(
27 | lambda training_state, random_key, update_player: do_iteration(
28 | training_state=training_state,
29 | random_key=random_key,
30 | env=env,
31 | policy=policy,
32 | update_player=update_player,
33 | )
34 | )
35 |
36 | def do_n_iterations(
37 | self,
38 | training_state: MCCFRState,
39 | update_player: Int[Array, ""],
40 | random_key: PRNGKeyArray,
41 | n: int,
42 | ) -> tuple[MCCFRState, Int[Array, ""]]:
43 | def _scan_fn(carry, unused):
44 | training_state, random_key, update_player = carry
45 |
46 | random_key, subkey = jax.random.split(random_key)
47 | update_player = (update_player + 1) % 2
48 | training_state = self._do_iteration_fn(
49 | training_state,
50 | subkey,
51 | update_player=update_player,
52 | )
53 |
54 | return (training_state, random_key, update_player), None
55 |
56 | (new_training_state, _, last_update_player), _ = jax.lax.scan(
57 | _scan_fn,
58 | (training_state, random_key, update_player),
59 | None,
60 | length=n,
61 | )
62 |
63 | return new_training_state, last_update_player
64 |
65 | def train(
66 | self, n_iterations: int, metrics_period: int, random_key: PRNGKeyArray
67 | ) -> tuple[MCCFRState, dict[str, Array]]:
68 | training_state = MCCFRState.init(self._env.n_info_states, self._env.n_actions)
69 |
70 | assert n_iterations % metrics_period == 0
71 |
72 | n_loops = n_iterations // metrics_period
73 | update_player = 0
74 | _do_n_iterations = jax.jit(
75 | lambda training_state, update_player, random_key: self.do_n_iterations(
76 | training_state=training_state,
77 | update_player=update_player,
78 | random_key=random_key,
79 | n=2 * metrics_period,
80 | )
81 | )
82 |
83 | metrics = []
84 | pbar = tqdm(total=n_iterations, desc="Training", unit_scale=True)
85 | for k in range(n_loops):
86 | if k == 0:
87 | current_policy = training_state.avg_probs
88 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True)
89 | exp = self._exploitability_fn(policy_params=current_policy)
90 | metrics.append({"exploitability": exp, "step": 0})
91 | pbar.set_postfix(exploitability=f"{exp:.1e}")
92 |
93 | random_key, subkey = jax.random.split(random_key)
94 |
95 | # Do n iterations
96 | training_state, update_player = _do_n_iterations(
97 | training_state, update_player, subkey
98 | )
99 |
100 | # Evaluate exploitability
101 | current_policy = training_state.avg_probs
102 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True)
103 |
104 | exp = self._exploitability_fn(policy_params=current_policy)
105 | metrics.append({"exploitability": exp, "step": (k + 1) * metrics_period})
106 | pbar.set_postfix(exploitability=f"{exp:.1e}")
107 | pbar.update(metrics_period)
108 |
109 | metrics = jax.tree_map(lambda *x: jax.numpy.stack(x), *metrics)
110 | return training_state, metrics
111 |
--------------------------------------------------------------------------------
/cfrx/metrics/best_response.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from jaxtyping import Array, Float, Int
6 |
7 | from cfrx.tree import Tree
8 |
9 |
10 | def backward_one_infoset(
11 | tree: Tree,
12 | info_states: Array,
13 | current_infoset: Int[Array, ""],
14 | br_player: int,
15 | depth: Int[Array, ""],
16 | ) -> Tree:
17 | # Select all nodes in this infoset
18 | infoset_mask = (
19 | (tree.depth == depth)
20 | & (~tree.states.terminated)
21 | & (info_states == current_infoset)
22 | )
23 |
24 | is_br_player = (
25 | (tree.states.current_player == br_player)
26 | & (~tree.states.chance_node)
27 | & infoset_mask
28 | ).any()
29 |
30 | p_opponent = tree.extra_data["p_opponent"]
31 | p_chance = tree.extra_data["p_chance"]
32 | p_self = tree.extra_data["p_self"]
33 |
34 | # Get expected values for each of the node in the infoset
35 | cf_reach_prob = (p_opponent * p_chance * p_self)[..., None]
36 |
37 | legal_action_mask = (tree.states.legal_action_mask & infoset_mask[..., None]).any(
38 | axis=0
39 | )
40 |
41 | best_response_values = tree.children_values[..., br_player] * cf_reach_prob # (T, a)
42 | best_response_values = jnp.sum(
43 | best_response_values, axis=0, where=infoset_mask[..., None] # (a,)
44 | )
45 |
46 | best_action = jnp.where(legal_action_mask, best_response_values, -jnp.inf).argmax()
47 |
48 | br_value = tree.children_values[:, best_action, br_player]
49 |
50 | expected_current_value = (
51 | tree.children_values[..., br_player] * tree.children_prior_logits
52 | ).sum(axis=1)
53 |
54 | current_value = jnp.where(is_br_player, br_value, expected_current_value)
55 |
56 | new_node_values = jnp.where(
57 | infoset_mask, current_value, tree.node_values[..., br_player]
58 | )
59 | new_children_values = jnp.where(
60 | tree.children_index != -1, new_node_values[tree.children_index], 0
61 | )
62 |
63 | tree = tree._replace(
64 | node_values=tree.node_values.at[..., br_player].set(new_node_values),
65 | children_values=tree.children_values.at[..., br_player].set(new_children_values),
66 | )
67 |
68 | return tree
69 |
70 |
71 | def backward_one_depth_level(
72 | tree: Tree,
73 | depth: Int[Array, ""],
74 | br_player: int,
75 | info_state_fn: Callable,
76 | ) -> Tree:
77 | info_states = jax.vmap(info_state_fn)(tree.states.info_state)
78 |
79 | def cond_fn(val: tuple[Tree, Array]) -> Array:
80 | tree, visited = val
81 | nodes_to_visit_idx = (
82 | (tree.depth == depth) & (~tree.states.terminated) & (~visited)
83 | )
84 | return nodes_to_visit_idx.any()
85 |
86 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]:
87 | tree, visited = val
88 | nodes_to_visit_idx = (
89 | (tree.depth == depth) & (~tree.states.terminated) & (~visited)
90 | )
91 |
92 | # Select an infoset to resolve
93 | selected_infoset_idx = nodes_to_visit_idx.argmax()
94 | selected_infoset = info_states[selected_infoset_idx]
95 |
96 | tree = backward_one_infoset(
97 | tree=tree,
98 | depth=depth,
99 | info_states=info_states,
100 | current_infoset=selected_infoset,
101 | br_player=br_player,
102 | )
103 |
104 | visited = jnp.where(info_states == selected_infoset, True, visited)
105 | return tree, visited
106 |
107 | visited = jnp.zeros(tree.node_values.shape[0], dtype=bool)
108 | tree, visited = jax.lax.while_loop(cond_fn, loop_fn, (tree, visited))
109 |
110 | return tree
111 |
112 |
113 | def compute_best_response_value(
114 | tree: Tree,
115 | br_player: int,
116 | info_state_fn: Callable,
117 | ) -> Float[Array, " num_players"]:
118 | depth = tree.depth.max()
119 |
120 | def cond_fn(val: tuple[Tree, Array]) -> Array:
121 | _, depth = val
122 | return depth >= 0
123 |
124 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]:
125 | tree, depth = val
126 | tree = backward_one_depth_level(
127 | tree=tree, depth=depth, br_player=br_player, info_state_fn=info_state_fn
128 | )
129 | depth -= 1
130 | return tree, depth
131 |
132 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, depth))
133 | return tree.node_values[0]
134 |
--------------------------------------------------------------------------------
/cfrx/algorithms/mccfr/test_outcome_sampling.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import numpy as np
3 | import pytest
4 |
5 | from cfrx.algorithms.mccfr.outcome_sampling import (
6 | MCCFRState,
7 | compute_regrets_and_strategy_profile,
8 | unroll,
9 | )
10 | from cfrx.metrics import exploitability
11 | from cfrx.policy import TabularPolicy
12 | from cfrx.utils import regret_matching
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "env_name, num_iterations, target_exploitability",
17 | [
18 | ("Kuhn Poker", 20000, 15e-3),
19 | ("Leduc Poker", 50000, 5e-1),
20 | ],
21 | )
22 | def test_perf(env_name: str, num_iterations: int, target_exploitability: float):
23 | device = jax.devices("cpu")[0]
24 |
25 | random_key = jax.random.PRNGKey(0)
26 | if env_name == "Kuhn Poker":
27 | from cfrx.envs.kuhn_poker.constants import INFO_SETS
28 | from cfrx.envs.kuhn_poker.env import KuhnPoker
29 |
30 | env_cls = KuhnPoker
31 | EPISODE_LEN = 8
32 | NUM_MAX_NODES = 100
33 |
34 | elif env_name == "Leduc Poker":
35 | from cfrx.envs.leduc_poker.constants import INFO_SETS
36 | from cfrx.envs.leduc_poker.env import LeducPoker
37 |
38 | env_cls = LeducPoker
39 | EPISODE_LEN = 20
40 | NUM_MAX_NODES = 2000
41 |
42 | env = env_cls()
43 |
44 | n_states = len(INFO_SETS)
45 | n_actions = env.num_actions
46 |
47 | training_state = MCCFRState.init(n_states, n_actions)
48 |
49 | policy = TabularPolicy(
50 | n_actions=n_actions,
51 | exploration_factor=0.6,
52 | info_state_idx_fn=env.info_state_idx,
53 | )
54 |
55 | def do_iteration(training_state, random_key, env, policy, update_player):
56 | """
57 | Do one iteration of MCCFR: traverse the game tree once and
58 | compute counterfactual regrets and strategy profiles
59 | """
60 |
61 | random_key, subkey = jax.random.split(random_key)
62 |
63 | # Sample one path in the game tree
64 | random_key, subkey = jax.random.split(random_key)
65 | episode, states = unroll(
66 | init_state=env.init(subkey),
67 | training_state=training_state,
68 | random_key=subkey,
69 | update_player=update_player,
70 | env=env,
71 | policy=policy,
72 | n_max_steps=EPISODE_LEN,
73 | )
74 |
75 | # Compute counterfactual values and strategy profile
76 | (
77 | info_states,
78 | sampled_regrets,
79 | sampled_avg_probs,
80 | ) = compute_regrets_and_strategy_profile(
81 | episode=episode,
82 | training_state=training_state,
83 | policy=policy,
84 | update_player=update_player,
85 | )
86 | info_states_idx = jax.vmap(env.info_state_idx)(info_states)
87 |
88 | # Store regret and strategy profile values
89 | regrets = training_state.regrets.at[info_states_idx].add(sampled_regrets)
90 | avg_probs = training_state.avg_probs.at[info_states_idx].add(sampled_avg_probs)
91 |
92 | return regrets, avg_probs, episode
93 |
94 | do_iteration = jax.jit(
95 | do_iteration, static_argnames=("env", "policy"), device=device
96 | )
97 |
98 | # This function measures the exploitability of a strategy
99 | exploitability_fn = jax.jit(
100 | lambda policy_params: exploitability(
101 | policy_params=policy_params,
102 | env=env,
103 | n_players=2,
104 | n_max_nodes=NUM_MAX_NODES,
105 | policy=policy,
106 | ),
107 | device=device,
108 | )
109 |
110 | # One iteration consists in updating the policy for both players
111 | n_loops = 2 * num_iterations
112 |
113 | for k in range(n_loops):
114 | random_key, subkey = jax.random.split(random_key)
115 |
116 | # Update players alternatively
117 | update_player = k % 2
118 | new_regrets, new_avg_probs, episode = do_iteration(
119 | training_state,
120 | random_key,
121 | env=env,
122 | policy=policy,
123 | update_player=update_player,
124 | )
125 |
126 | # Accumulate regrets, compute new strategy and avg strategy
127 | new_probs = regret_matching(new_regrets)
128 | new_probs /= new_probs.sum(axis=-1, keepdims=True)
129 |
130 | training_state = training_state._replace(
131 | regrets=new_regrets,
132 | probs=new_probs,
133 | avg_probs=new_avg_probs,
134 | step=training_state.step + 1,
135 | )
136 |
137 | current_policy = training_state.avg_probs
138 | current_policy /= training_state.avg_probs.sum(axis=-1, keepdims=True)
139 | exp = exploitability_fn(policy_params=current_policy)
140 |
141 | assert np.allclose(exp, target_exploitability, rtol=1e-1, atol=1e-2)
142 |
--------------------------------------------------------------------------------
/cfrx/envs/nlhe_poker/showdown.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 |
5 | def straight_flush_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
6 | suit_counts = jnp.bincount(suits, length=4)
7 | color = suit_counts.argmax()
8 | ranks = jnp.where(suits == color, ranks, -1)
9 | append_ace = jnp.where((ranks == 0).any(), 13, -1)
10 | values = jnp.unique(jnp.sort(ranks), size=7, fill_value=append_ace)
11 | diff = jnp.diff(values, append=append_ace)
12 | conv = jnp.convolve((diff == 1), jnp.ones(4), mode="same")
13 | straigth_idx = jnp.where(conv == 4, jnp.arange(7), -1).argmax()
14 | straight_rank = values[straigth_idx]
15 | return 8_000_000 + straight_rank
16 |
17 |
18 | def four_of_a_kind_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
19 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
20 | rank_counts = jnp.bincount(ranks, length=13)
21 | active_rank = rank_counts.argmax()
22 | remaining_ranks = jnp.where(ranks != active_rank, ranks, -1)
23 | remaining_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=1)
24 | return 7_000_000 + active_rank * 13 + remaining_score
25 |
26 |
27 | def full_house_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
28 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
29 | rank_counts = jnp.bincount(ranks, length=13)
30 | three_of_a_kind_mask = rank_counts >= 3
31 | active_rank_three = jnp.where(three_of_a_kind_mask, jnp.arange(13), -1).argmax()
32 | rank_counts = rank_counts.at[active_rank_three].set(0)
33 | pair_mask = rank_counts >= 2
34 | active_rank_pair = jnp.where(pair_mask, jnp.arange(13), -1).argmax()
35 | return 6_000_000 + active_rank_three * 13 + active_rank_pair
36 |
37 |
38 | def flush_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
39 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
40 | suit_counts = jnp.bincount(suits, length=4)
41 | color = suit_counts.argmax()
42 | colored_ranks = jnp.where(suits == color, ranks, -1)
43 | colored_ranks = colored_ranks.sort()
44 | score = high_card_score(colored_ranks, suits)
45 | return 5_000_000 + score
46 |
47 |
48 | def straight_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
49 | append_ace = jnp.where((ranks == 0).any(), 13, -1)
50 | values = jnp.unique(jnp.sort(ranks), size=7, fill_value=append_ace)
51 | diff = jnp.diff(values, append=append_ace)
52 | conv = jnp.convolve((diff == 1), jnp.ones(4), mode="same")
53 | straigth_idx = jnp.where(conv == 4, jnp.arange(7), -1).argmax()
54 | straight_rank = values[straigth_idx]
55 | return 4_000_000 + straight_rank
56 |
57 |
58 | def three_of_a_kind_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
59 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
60 | rank_counts = jnp.bincount(ranks, length=13)
61 | three_of_a_kind_mask = rank_counts >= 3
62 | active_rank = jnp.where(three_of_a_kind_mask, jnp.arange(13), -1).argmax()
63 | remaining_ranks = jnp.where(ranks != active_rank, ranks, -1)
64 | kicker_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=2)
65 | return 3_000_000 + active_rank * 13**2 + kicker_score
66 |
67 |
68 | def two_pair_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
69 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
70 | rank_counts = jnp.bincount(ranks, length=13)
71 | pair_mask = rank_counts >= 2
72 | pair_ranks = jnp.where(pair_mask, jnp.arange(13), -1)
73 | pair_ranks = jnp.argsort(pair_ranks)[-2:]
74 | pair_ranks = jnp.sort(pair_ranks)[::-1]
75 | active_rank_first_pair, active_rank_second_pair = (
76 | pair_ranks[0],
77 | pair_ranks[1],
78 | )
79 | rank_counts = rank_counts.at[active_rank_first_pair].set(0)
80 | rank_counts = rank_counts.at[active_rank_second_pair].set(0)
81 | remaining_ranks = jnp.where(
82 | (ranks != active_rank_first_pair) & (ranks != active_rank_second_pair),
83 | ranks,
84 | -1,
85 | )
86 | kicker_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=1)
87 | return (
88 | 2_000_000
89 | + active_rank_first_pair * 13**2
90 | + active_rank_second_pair * 13
91 | + kicker_score
92 | )
93 |
94 |
95 | def one_pair_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
96 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
97 | rank_counts = jnp.bincount(ranks, length=13)
98 | pair_mask = rank_counts >= 2
99 | active_rank_pair = jnp.where(pair_mask, jnp.arange(13), -1).argmax()
100 | rank_counts = rank_counts.at[active_rank_pair].set(0)
101 | remaining_ranks = jnp.where(ranks != active_rank_pair, ranks, -1)
102 | kicker_score = high_card_score(remaining_ranks, suits, n_start=0, n_end=3)
103 | return 1_000_000 + active_rank_pair * 13**3 + kicker_score
104 |
105 |
106 | def high_card_hand_score(ranks: jax.Array, suits: jax.Array) -> jax.Array:
107 | ranks = jnp.where(ranks == 0, 12, ranks - 1)
108 | return high_card_score(ranks, suits, n_start=0, n_end=5)
109 |
110 |
111 | def high_card_score(
112 | ranks: jax.Array, suits: jax.Array, n_start: int = 0, n_end: int = 5
113 | ) -> jax.Array:
114 | n = n_end - n_start
115 | ranks = ranks.sort()
116 | rank_scores = jnp.array([13**k for k in range(n_start, n_end)])
117 | ranks = ranks[-n:]
118 | return (ranks * rank_scores).sum()
119 |
120 |
121 | def is_straight_fn(ranks):
122 | append_ace = jnp.where((ranks == 0).any(), 13, -1)
123 | values = jnp.unique(jnp.sort(ranks), size=7, fill_value=append_ace)
124 | diff = jnp.diff(values, append=append_ace)
125 | cond = (jnp.convolve((diff == 1), jnp.ones(4), mode="valid") == 4).any()
126 | return cond
127 |
128 |
129 | def is_straight_flush_fn(ranks, suits):
130 | suit_counts = jnp.bincount(suits, length=4)
131 | color = suit_counts.argmax()
132 | ranks = jnp.where(suits == color, ranks, -1)
133 | return is_straight_fn(ranks)
134 |
135 |
136 | def get_hand_type(ranks: jax.Array, suits: jax.Array) -> jax.Array:
137 | rank_counts = jnp.bincount(ranks, length=13)
138 | suit_counts = jnp.bincount(suits, length=4)
139 |
140 | is_straight_flush = is_straight_flush_fn(ranks, suits)
141 | higher = is_straight_flush
142 | index = 8 * is_straight_flush
143 |
144 | is_four_of_a_kind = ~higher & (rank_counts == 4).any()
145 | higher |= is_four_of_a_kind
146 | index += 7 * is_four_of_a_kind
147 |
148 | is_full = ~higher & (rank_counts == 3).any() & (rank_counts == 2).any()
149 | higher |= is_full
150 | index += 6 * is_full
151 |
152 | is_flush = ~higher & (suit_counts >= 5).any()
153 | higher |= is_flush
154 | index += 5 * is_flush
155 |
156 | is_straight = ~higher & is_straight_fn(ranks)
157 | higher |= is_straight
158 | index += 4 * is_straight
159 |
160 | is_three_of_a_kind = ~higher & (rank_counts == 3).any()
161 | higher |= is_three_of_a_kind
162 | index += 3 * is_three_of_a_kind
163 |
164 | is_two_pairs = ~higher & ((rank_counts == 2).sum() >= 2)
165 | higher |= is_two_pairs
166 | index += 2 * is_two_pairs
167 |
168 | is_one_pair = ~higher & (rank_counts == 2).any()
169 | higher |= is_one_pair
170 | index += 1 * is_one_pair
171 |
172 | return index
173 |
174 |
175 | def get_showdown_score(hand: jax.Array) -> jax.Array:
176 | hand = hand.astype(jnp.int32)
177 | ranks = hand % 13
178 | suits = hand // 13
179 |
180 | hand_type = get_hand_type(ranks, suits)
181 |
182 | score = jax.lax.switch(
183 | hand_type,
184 | [
185 | high_card_hand_score,
186 | one_pair_score,
187 | two_pair_score,
188 | three_of_a_kind_score,
189 | straight_score,
190 | flush_score,
191 | full_house_score,
192 | four_of_a_kind_score,
193 | straight_flush_score,
194 | ],
195 | ranks,
196 | suits,
197 | )
198 |
199 | return score
200 |
--------------------------------------------------------------------------------
/cfrx/envs/nlhe_poker/gui.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from nicegui import app, ui
6 |
7 | from cfrx.envs.nlhe_poker.env import State, TexasHoldem
8 |
9 | SCALE = 0.5
10 | X0 = 400
11 | X1 = 80
12 | Y1 = 700
13 | OFF = 180
14 | DEALER_COORDS = [(3 * X0, 0), (3 * X0, 2500)]
15 | STACK_COORDS = [(120, 50), (120, 400)]
16 | BET_COORDS = [(220, 140), (220, 320)]
17 | POT_COORDS = (500, 220)
18 |
19 | COLORS = ["heart", "diamond", "club", "spade"]
20 | VALUES = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "jack", "queen", "king"]
21 |
22 |
23 | def create_svg(
24 | hands: list[tuple[int, int]], board: list[int], board_mask: int, dealer: int
25 | ) -> str:
26 | svg_open = """"
34 |
35 | svg_hands = ""
36 | for i, hand in enumerate(hands):
37 | for j, card in enumerate(hand):
38 | x = X0 + j * 40
39 | y = i * Y1
40 | svg_hands += f"""
41 |
43 | """
44 |
45 | svg_board = ""
46 | for i, card in enumerate(board):
47 | x = X1 + i * OFF
48 | y = 350
49 | if i >= board_mask:
50 | svg_board += f"""
51 |
53 | """
54 | else:
55 | svg_board += f"""
56 |
58 | """
59 |
60 | dealer_coords = DEALER_COORDS[dealer]
61 | svg_dealer = f"""
62 |
64 | """
65 |
66 | return f"{svg_open}{svg_hands}{svg_dealer}{svg_board}{svg_close}"
67 |
68 |
69 | def create_stack(player: int, stack: int) -> str:
70 | stack_html = f"""
72 | Player {player}
Stack: {stack}
"""
73 |
74 | return stack_html
75 |
76 |
77 | def create_pot(pot: int) -> str:
78 | pot_html = f"""
80 | Pot: {pot}
"""
81 |
82 | return pot_html
83 |
84 |
85 | def create_bet(player: int, bet: str, current_player: int) -> str:
86 | color = "red" if player == current_player else None
87 |
88 | bet_html = f"""
93 | {bet}
"""
94 |
95 | return bet_html
96 |
97 |
98 | def create_terminated(done: bool, reward: tuple[float, float]) -> str:
99 | if done:
100 | end_html = f"""Terminated
P0: {reward[0]}
102 |
P1: {reward[1]}
103 |
"""
104 | else:
105 | end_html = ""
106 |
107 | return end_html
108 |
109 |
110 | def create_min_bet(min_bet: int, min_raise: int) -> str:
111 | min_bet_html = f"""Min bet:
{min_bet}
113 |
Min raise:
{min_raise }
114 |
"""
115 |
116 | return min_bet_html
117 |
118 |
119 | def visualize(state: State, container: Any) -> None:
120 | with container:
121 | hands = [tuple(state.hands[0]), tuple(state.hands[1])]
122 | board = state.board.tolist()
123 |
124 | svg = create_svg(
125 | hands=hands,
126 | board=board,
127 | board_mask=int(state.board_mask),
128 | dealer=int(state.dealer),
129 | )
130 | ui.html(svg)
131 |
132 | stack_0 = int(state.stacks[0])
133 |
134 | stack = create_stack(0, stack_0)
135 | ui.html(stack)
136 |
137 | stack_1 = int(state.stacks[1])
138 |
139 | stack = create_stack(1, stack_1)
140 | ui.html(stack)
141 |
142 | pot_value = int(state.bets.sum())
143 |
144 | pot = create_pot(pot_value)
145 | ui.html(pot)
146 |
147 | current_bets = state.bets[state.current_round]
148 | current_mask = state.bets_mask[state.current_round]
149 |
150 | idx_0 = current_mask[0].sum()
151 |
152 | if idx_0 == 0:
153 | current_bet_0 = ""
154 | else:
155 | current_bet_0 = str(current_bets[0, :idx_0].sum())
156 |
157 | idx_1 = current_mask[1].sum()
158 | if idx_1 == 0:
159 | current_bet_1 = ""
160 | else:
161 | current_bet_1 = str(current_bets[1, :idx_1].sum())
162 |
163 | cr = int(state.current_player)
164 |
165 | bet = create_bet(0, current_bet_0, current_player=cr)
166 | ui.html(bet)
167 |
168 | bet = create_bet(1, current_bet_1, current_player=cr)
169 | ui.html(bet)
170 |
171 | done = create_terminated(bool(state.terminated), reward=tuple(state.rewards))
172 | ui.html(done)
173 |
174 | min_bet = create_min_bet(int(state.min_bet), int(state.min_raise))
175 | ui.html(min_bet)
176 |
177 |
178 | def visualize_previous():
179 | global states
180 | global current_state_idx
181 | global container
182 |
183 | current_state_idx = max(current_state_idx - 1, 0)
184 | visualize(states[current_state_idx], container=container)
185 |
186 |
187 | def visualize_next():
188 | global states
189 | global current_state_idx
190 | global container
191 |
192 | current_state_idx = min(current_state_idx + 1, len(states) - 1)
193 | visualize(states[current_state_idx], container=container)
194 |
195 |
196 | def delete_last():
197 | global states
198 | global current_state_idx
199 | global container
200 | if len(states) > 1:
201 | states.pop(-1)
202 | current_state_idx = max(len(states) - 1, 0)
203 | visualize(states[current_state_idx], container=container)
204 |
205 |
206 | def ui_bet():
207 | global states
208 | global current_state_idx
209 | global env
210 | global slider
211 | global container
212 |
213 | state = env.step(states[-1], action=jnp.array(slider.value))
214 | states.append(state)
215 | current_state_idx = len(states) - 1
216 | visualize(states[current_state_idx], container=container)
217 |
218 |
219 | if __name__ in {"__main__", "__mp_main__"}:
220 | app.add_static_files("/static", "static")
221 | env = TexasHoldem()
222 |
223 | container = ui.card().tight()
224 |
225 | with ui.row().style("position: absolute; top:600px"):
226 | with ui.button_group():
227 | ui.button("<", on_click=visualize_previous)
228 | ui.button(">", on_click=visualize_next)
229 | ui.button("DEL", on_click=delete_last)
230 |
231 | slider = ui.slider(min=-1, max=100, value=0).style("width: 200px")
232 | ui.label().bind_text_from(slider, "value")
233 | ui.button("BET", on_click=ui_bet)
234 |
235 | state = env.init(jax.random.PRNGKey(1), initial_stacks=jnp.array([100, 50]))
236 |
237 | states = [state]
238 |
239 | state = env.step(state, action=jnp.array(1.0))
240 | states.append(state)
241 |
242 | state = env.step(state, action=jnp.array(10.0))
243 | states.append(state)
244 | visualize(state, container=container)
245 | current_state_idx = len(states) - 1
246 |
247 | ui.run(show=False)
248 |
--------------------------------------------------------------------------------
/cfrx/envs/kuhn_poker/env.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import pgx
7 | import pgx.kuhn_poker
8 | from jaxtyping import Array, Bool, Int, PRNGKeyArray
9 | from pgx._src.dwg.kuhn_poker import CARD
10 | from pgx._src.struct import dataclass
11 |
12 | import cfrx
13 | from cfrx.envs.kuhn_poker.constants import INFO_SETS
14 | from cfrx.utils import ravel, reverse_array_lookup
15 |
16 | CARD.append("?")
17 | NUM_DIFFERENT_CARDS = 3
18 | NUM_REPEAT_CARDS = 1
19 | NUM_TOTAL_CARDS = NUM_DIFFERENT_CARDS * NUM_REPEAT_CARDS
20 | INFO_SETS_VALUES = np.stack(list(INFO_SETS.values()))
21 |
22 |
23 | class InfoState(NamedTuple):
24 | private_card: Int[Array, "..."]
25 | action_sequence: Int[Array, "..."]
26 | chance_round: Int[Array, "..."]
27 | chance_node: Bool[Array, "..."]
28 |
29 |
30 | @dataclass
31 | class State(pgx.kuhn_poker.State):
32 | info_state: InfoState = InfoState(
33 | private_card=jnp.int8(-1),
34 | action_sequence=jnp.ones(2, dtype=jnp.int8) * -1,
35 | chance_round=jnp.int8(0),
36 | chance_node=jnp.bool_(True),
37 | )
38 | chance_node: Bool[Array, "..."] = jnp.bool_(False)
39 | chance_prior: Int[Array, "..."] = (
40 | jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS
41 | )
42 |
43 |
44 | class KuhnPoker(pgx.kuhn_poker.KuhnPoker, cfrx.envs.Env):
45 | @classmethod
46 | def action_to_string(cls, action: Int[Array, ""]) -> str:
47 | strings = ["b", "p"]
48 | if action != -1:
49 | a = int(action) // 2
50 | rep = strings[a]
51 | else:
52 | rep = "?"
53 | return rep
54 |
55 | @property
56 | def max_episode_length(self) -> int:
57 | return 6
58 |
59 | @property
60 | def max_nodes(self) -> int:
61 | return 60
62 |
63 | @property
64 | def n_info_states(self) -> int:
65 | return len(INFO_SETS)
66 |
67 | @property
68 | def n_actions(self) -> int:
69 | return self.num_actions
70 |
71 | @property
72 | def n_players(self) -> int:
73 | return self.num_players
74 |
75 | def update_info_state(
76 | self, state: State, next_state: State, action: Int[Array, ""]
77 | ) -> InfoState:
78 | info_state = next_state.info_state
79 |
80 | private_card = jnp.where(
81 | next_state.chance_node, -1, next_state._cards[next_state.current_player]
82 | )
83 | current_position = (info_state.action_sequence != -1).sum()
84 | action_sequence = info_state.action_sequence.at[current_position].set(
85 | jnp.int8(action)
86 | )
87 |
88 | action_sequence = jnp.where(state.chance_node, -1, action_sequence)
89 | chance_round = info_state.chance_round + state.chance_node
90 | info_state = info_state._replace(
91 | private_card=private_card,
92 | action_sequence=action_sequence,
93 | chance_round=chance_round,
94 | chance_node=next_state.chance_node,
95 | )
96 | return info_state
97 |
98 | def info_state_to_str(self, info_state: InfoState) -> str:
99 | if info_state.chance_node:
100 | rep = f"chance{info_state.chance_round}:"
101 |
102 | else:
103 | rep = ""
104 |
105 | strings = ["b", "p"]
106 | rep += f"{info_state.private_card}"
107 |
108 | for action in np.array(info_state.action_sequence):
109 | action = action // 2
110 | if action != -1:
111 | rep += strings[action]
112 |
113 | return rep
114 |
115 | def get_action_mask(self, state: State) -> jax.Array:
116 | return state.legal_action_mask
117 |
118 | def get_chance_mask(self, state: State) -> jax.Array:
119 | return state.chance_prior > 0
120 |
121 | def get_info_state(self, state: State) -> jax.Array:
122 | return state.info_state
123 |
124 | def info_state_idx(self, info_state: InfoState) -> Int[Array, ""]:
125 | info_state_ravel = ravel(info_state)
126 | return reverse_array_lookup(info_state_ravel, jnp.asarray(INFO_SETS_VALUES))
127 |
128 | def _init(self, rng: PRNGKeyArray) -> State:
129 | env_state = super()._init(rng)
130 |
131 | info_state = InfoState(
132 | private_card=jnp.int8(-1),
133 | action_sequence=jnp.ones(2, dtype=jnp.int8) * -1,
134 | chance_round=jnp.int8(0),
135 | chance_node=jnp.bool_(True),
136 | )
137 |
138 | return State(
139 | current_player=env_state.current_player.astype(jnp.int8),
140 | observation=env_state.observation,
141 | rewards=env_state.rewards,
142 | terminated=env_state.terminated,
143 | truncated=env_state.truncated,
144 | _step_count=env_state._step_count,
145 | _last_action=env_state._last_action,
146 | _cards=jnp.int8([-1, -1]),
147 | legal_action_mask=jnp.bool_([1, 1, 1, 1]),
148 | _pot=env_state._pot,
149 | info_state=info_state,
150 | chance_node=jnp.bool_(True),
151 | chance_prior=jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS,
152 | )
153 |
154 | def _resolve_chance_node(
155 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray
156 | ) -> State:
157 | draw_player = NUM_TOTAL_CARDS - state.chance_prior.sum()
158 |
159 | cards = state._cards.at[draw_player].set(action.astype(jnp.int8))
160 | chance_prior = state.chance_prior.at[action].add(-1)
161 | chance_node = (cards == -1).any()
162 |
163 | legal_action_mask = jnp.where(
164 | chance_node, state.legal_action_mask, jnp.bool_([0, 1, 0, 1])
165 | )
166 |
167 | return State(
168 | current_player=state.current_player.astype(jnp.int8),
169 | observation=state.observation,
170 | rewards=state.rewards,
171 | terminated=state.terminated,
172 | truncated=state.truncated,
173 | _step_count=state._step_count,
174 | _last_action=state._last_action,
175 | _cards=cards,
176 | legal_action_mask=legal_action_mask,
177 | _pot=state._pot,
178 | info_state=state.info_state,
179 | chance_node=chance_node,
180 | chance_prior=chance_prior,
181 | )
182 |
183 | def _resolve_decision_node(
184 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray
185 | ) -> State:
186 | env_state = super()._step(state=state, action=action, key=random_key)
187 | print(env_state.legal_action_mask.shape)
188 | return State(
189 | current_player=env_state.current_player.astype(jnp.int8),
190 | observation=env_state.observation,
191 | rewards=env_state.rewards,
192 | terminated=env_state.terminated,
193 | truncated=env_state.truncated,
194 | _step_count=env_state._step_count,
195 | _last_action=env_state._last_action,
196 | _cards=env_state._cards,
197 | legal_action_mask=env_state.legal_action_mask,
198 | _pot=env_state._pot,
199 | info_state=env_state.info_state,
200 | chance_node=jnp.bool_(False),
201 | chance_prior=env_state.chance_prior,
202 | )
203 |
204 | def _step(
205 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray
206 | ) -> State:
207 | new_state = jax.lax.cond(
208 | state.chance_node,
209 | lambda: self._resolve_chance_node(
210 | state=state, action=action, random_key=random_key
211 | ),
212 | lambda: self._resolve_decision_node(
213 | state=state, action=action, random_key=random_key
214 | ),
215 | )
216 | info_state = self.update_info_state(
217 | state=state, next_state=new_state, action=action
218 | )
219 | return new_state.replace(info_state=info_state)
220 |
--------------------------------------------------------------------------------
/cfrx/algorithms/cfr/cfr.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections.abc import Callable
4 | from typing import NamedTuple
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jaxtyping import Array, Float, Int
9 |
10 | from cfrx.envs import Env
11 | from cfrx.policy import TabularPolicy
12 | from cfrx.tree import Tree
13 | from cfrx.tree.traverse_old import instantiate_tree_from_root, traverse_tree_cfr
14 | from cfrx.tree.tree_old import Root
15 | from cfrx.utils import regret_matching
16 |
17 |
18 | class CFRState(NamedTuple):
19 | regrets: Float[Array, "*batch a"]
20 | probs: Float[Array, "... a"]
21 | avg_probs: Float[Array, "... a"]
22 | step: Int[Array, "..."]
23 |
24 | @classmethod
25 | def init(cls, n_states: int, n_actions: int) -> CFRState:
26 | return CFRState(
27 | regrets=jnp.zeros((n_states, n_actions)),
28 | probs=jnp.ones((n_states, n_actions))
29 | / jnp.ones((n_states, n_actions)).sum(axis=-1, keepdims=True),
30 | avg_probs=jnp.zeros((n_states, n_actions)) + 1e-6,
31 | step=jnp.array(1, dtype=int),
32 | )
33 |
34 |
35 | def backward_one_infoset(
36 | tree: Tree,
37 | info_states: Array,
38 | current_infoset: Int[Array, ""],
39 | cfr_player: int,
40 | depth: Int[Array, ""],
41 | ) -> Tree:
42 | # Select all nodes in this infoset
43 | infoset_mask = (
44 | (tree.depth == depth)
45 | & (~tree.states.terminated)
46 | & (info_states == current_infoset)
47 | )
48 |
49 | is_cfr_player = (
50 | (tree.states.current_player == cfr_player)
51 | & (~tree.states.chance_node)
52 | & infoset_mask
53 | ).any()
54 |
55 | p_opponent = tree.extra_data["p_opponent"]
56 | p_chance = tree.extra_data["p_chance"]
57 | p_self = tree.extra_data["p_self"]
58 |
59 | # Get expected values for each of the node in the infoset
60 | cf_reach_prob = p_opponent * p_chance
61 |
62 | legal_action_mask = (tree.states.legal_action_mask & infoset_mask[..., None]).any(
63 | axis=0
64 | )
65 |
66 | cf_state_action_values = (
67 | tree.children_values[..., cfr_player]
68 | * cf_reach_prob[..., None]
69 | * legal_action_mask
70 | ) # (T, a)
71 | cf_state_action_values = jnp.sum(
72 | cf_state_action_values, axis=0, where=infoset_mask[..., None] # (a,)
73 | )
74 |
75 | expected_values = (
76 | tree.children_values[..., cfr_player] * tree.children_prior_logits
77 | ).sum(axis=1)
78 |
79 | cf_state_values = jnp.sum(
80 | expected_values * cf_reach_prob, axis=0, where=infoset_mask
81 | ) # (,)
82 |
83 | new_node_values = jnp.where(
84 | infoset_mask, expected_values, tree.node_values[..., cfr_player]
85 | )
86 |
87 | new_children_values = jnp.where(
88 | tree.children_index != -1, new_node_values[tree.children_index], 0
89 | )
90 |
91 | regrets = (
92 | cf_state_action_values - cf_state_values[..., None]
93 | ) * legal_action_mask # (a,)
94 |
95 | strategy_profile = jnp.sum(
96 | p_self[..., None] * tree.children_prior_logits,
97 | axis=0,
98 | where=infoset_mask[..., None],
99 | ) # (a,)
100 |
101 | tree = tree._replace(
102 | node_values=tree.node_values.at[..., cfr_player].set(new_node_values),
103 | children_values=tree.children_values.at[..., cfr_player].set(
104 | new_children_values
105 | ),
106 | extra_data={
107 | **tree.extra_data,
108 | "regrets": jnp.where(
109 | infoset_mask[..., None] & is_cfr_player,
110 | tree.extra_data["regrets"] + regrets,
111 | tree.extra_data["regrets"],
112 | ),
113 | "strategy_profile": jnp.where(
114 | infoset_mask[..., None] & is_cfr_player,
115 | tree.extra_data["strategy_profile"] + strategy_profile,
116 | tree.extra_data["strategy_profile"],
117 | ),
118 | },
119 | )
120 |
121 | return tree
122 |
123 |
124 | def backward_one_depth_level(
125 | tree: Tree,
126 | depth: Int[Array, ""],
127 | cfr_player: int,
128 | info_state_fn: Callable,
129 | ) -> Tree:
130 | info_states = jax.vmap(info_state_fn)(tree.states.info_state)
131 |
132 | def cond_fn(val: tuple[Tree, Array]) -> Array:
133 | tree, visited = val
134 | nodes_to_visit_idx = (
135 | (tree.depth == depth) & (~tree.states.terminated) & (~visited)
136 | )
137 | return nodes_to_visit_idx.any()
138 |
139 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]:
140 | tree, visited = val
141 | nodes_to_visit_idx = (
142 | (tree.depth == depth) & (~tree.states.terminated) & (~visited)
143 | )
144 |
145 | # Select an infoset to resolve
146 | selected_infoset_idx = nodes_to_visit_idx.argmax()
147 | selected_infoset = info_states[selected_infoset_idx]
148 |
149 | tree = backward_one_infoset(
150 | tree=tree,
151 | depth=depth,
152 | info_states=info_states,
153 | current_infoset=selected_infoset,
154 | cfr_player=cfr_player,
155 | )
156 |
157 | visited = jnp.where(info_states == selected_infoset, True, visited)
158 | return tree, visited
159 |
160 | visited = jnp.zeros(tree.node_values.shape[0], dtype=bool)
161 | tree, visited = jax.lax.while_loop(cond_fn, loop_fn, (tree, visited))
162 |
163 | return tree
164 |
165 |
166 | def backward_cfr(
167 | tree: Tree,
168 | cfr_player: int,
169 | info_state_fn: Callable,
170 | ) -> Tree:
171 | depth = tree.depth.max()
172 |
173 | def cond_fn(val: tuple[Tree, Array]) -> Array:
174 | _, depth = val
175 | return depth >= 0
176 |
177 | def loop_fn(val: tuple[Tree, Array]) -> tuple[Tree, Array]:
178 | tree, depth = val
179 | tree = backward_one_depth_level(
180 | tree=tree, depth=depth, cfr_player=cfr_player, info_state_fn=info_state_fn
181 | )
182 | depth -= 1
183 | return tree, depth
184 |
185 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, depth))
186 | return tree
187 |
188 |
189 | def do_iteration(
190 | training_state: CFRState,
191 | update_player: Int[Array, ""],
192 | env: Env,
193 | policy: TabularPolicy,
194 | ) -> CFRState:
195 | s0 = env.init(jax.random.PRNGKey(0))
196 | root = Root(
197 | prior_logits=s0.legal_action_mask * 1.0,
198 | value=jnp.zeros(1, dtype=float),
199 | state=s0,
200 | )
201 |
202 | tree = instantiate_tree_from_root(
203 | root,
204 | n_max_nodes=env.max_nodes,
205 | n_players=env.n_players,
206 | running_probabilities=True,
207 | )
208 |
209 | tree = traverse_tree_cfr(
210 | tree,
211 | policy=policy,
212 | policy_params=training_state.probs,
213 | env=env,
214 | traverser=update_player,
215 | )
216 |
217 | infoset_idx = jax.vmap(env.info_state_idx)(tree.states.info_state)
218 | squash_idx = jnp.unique(infoset_idx, size=env.n_info_states, return_index=True)[1]
219 |
220 | tree_regrets = training_state.regrets.take(infoset_idx, axis=0)
221 | tree_strategies = training_state.avg_probs.take(infoset_idx, axis=0)
222 |
223 | tree = tree._replace(
224 | extra_data={
225 | **tree.extra_data,
226 | "regrets": tree_regrets,
227 | "strategy_profile": tree_strategies,
228 | }
229 | )
230 |
231 | tree = backward_cfr(
232 | tree=tree, cfr_player=update_player, info_state_fn=env.info_state_idx
233 | )
234 |
235 | regrets = tree.extra_data["regrets"][squash_idx]
236 | strategy_profile = tree.extra_data["strategy_profile"][squash_idx]
237 | probs = regret_matching(regrets)
238 | new_state = training_state._replace(
239 | regrets=regrets,
240 | probs=probs,
241 | avg_probs=strategy_profile,
242 | step=training_state.step + 1,
243 | )
244 | return new_state
245 |
--------------------------------------------------------------------------------
/cfrx/envs/nlhe_poker/env.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | import cfrx
7 | from cfrx.envs.nlhe_poker.showdown import get_showdown_score
8 |
9 | NUM_PLAYERS = 2
10 |
11 |
12 | class State(NamedTuple):
13 | board: jax.Array
14 | board_mask: jax.Array
15 | hands: jax.Array
16 | stacks: jax.Array
17 | bets: jax.Array
18 | bets_mask: jax.Array
19 | current_player: jax.Array
20 | min_bet: jax.Array
21 | min_raise: jax.Array
22 | current_round: jax.Array
23 | terminated: jax.Array
24 | small_blind: jax.Array
25 | showdown_result: jax.Array
26 | dealer: jax.Array
27 | rewards: jax.Array
28 |
29 |
30 | class InfoState(NamedTuple):
31 | hand: jax.Array
32 | board: jax.Array
33 | board_mask: jax.Array
34 | bets: jax.Array
35 | bets_mask: jax.Array
36 | stacks: jax.Array
37 |
38 |
39 | class TexasHoldem:
40 | def __init__(self, num_max_bets: int = 4):
41 | self.num_max_bets = num_max_bets
42 | self._num_visibles = jnp.array([0, 3, 4, 5], dtype=jnp.uint8)
43 |
44 | @classmethod
45 | def action_to_string(cls, action: jax.Array) -> str:
46 | raise NotImplementedError
47 |
48 | @property
49 | def max_episode_length(self) -> int:
50 | return 4 * self.num_max_bets
51 |
52 | @property
53 | def max_nodes(self) -> int:
54 | raise NotImplementedError
55 |
56 | @property
57 | def n_info_states(self) -> int:
58 | raise NotImplementedError
59 |
60 | def init(
61 | self,
62 | random_key: jax.Array,
63 | dealer_player: int = 0,
64 | initial_stacks: jax.Array | None = None,
65 | small_blind: float = 1.0,
66 | ):
67 | deck = jax.random.permutation(random_key, jnp.arange(52, dtype=jnp.uint8))
68 |
69 | hands = jnp.zeros((NUM_PLAYERS, 2), dtype=jnp.uint8)
70 | for i in range(NUM_PLAYERS):
71 | cards = deck[i * 2 : (i + 1) * 2]
72 | hands = hands.at[i].set(cards)
73 |
74 | board = deck[NUM_PLAYERS * 2 : NUM_PLAYERS * 2 + 5].astype(jnp.uint8)
75 |
76 | board_mask = jnp.array(0)
77 |
78 | if initial_stacks is None:
79 | stacks = jnp.ones(NUM_PLAYERS, dtype=float) * 100 * small_blind
80 | else:
81 | stacks = initial_stacks
82 |
83 | sb_player = (dealer_player + 1) % NUM_PLAYERS
84 | bb_player = (dealer_player + 2) % NUM_PLAYERS
85 | first_player = (dealer_player + 3) % NUM_PLAYERS
86 |
87 | bets = jnp.zeros((4, NUM_PLAYERS, self.num_max_bets), dtype=float)
88 |
89 | sb_bet = jnp.minimum(small_blind, stacks[sb_player])
90 | bb_bet = jnp.minimum(small_blind * 2, stacks[bb_player])
91 | bets = bets.at[0, sb_player, 0].set(sb_bet)
92 | bets = bets.at[0, bb_player, 0].set(bb_bet)
93 | stacks = stacks.at[sb_player].set(stacks[sb_player] - sb_bet)
94 | stacks = stacks.at[bb_player].set(stacks[bb_player] - bb_bet)
95 |
96 | bets_mask = bets > 0.0
97 |
98 | showdown_result = self._resolve_showdown(board, hands)
99 |
100 | state = State(
101 | board=board,
102 | board_mask=board_mask,
103 | hands=hands,
104 | stacks=stacks,
105 | bets=bets,
106 | bets_mask=bets_mask,
107 | current_player=jnp.array(first_player),
108 | min_bet=jnp.array(small_blind),
109 | min_raise=2 * jnp.array(small_blind),
110 | current_round=jnp.array(0),
111 | terminated=jnp.array(False),
112 | small_blind=jnp.array(small_blind),
113 | showdown_result=showdown_result,
114 | dealer=jnp.array(dealer_player),
115 | rewards=jnp.array([0.0, 0.0]),
116 | )
117 |
118 | return state
119 |
120 | def _is_end_round(
121 | self,
122 | current_round: jax.Array,
123 | bets_mask: jax.Array,
124 | bets: jax.Array,
125 | ) -> jax.Array:
126 | bets_mask = jnp.where(
127 | current_round == 0, bets_mask.at[0, :, 0].set(False), bets_mask
128 | )
129 | has_everyone_spoken = bets_mask[current_round].any(axis=1).all()
130 | total_bets = bets[current_round].sum(axis=-1)
131 | same_bets = (total_bets.max() == total_bets.min()).all()
132 |
133 | end_round = has_everyone_spoken & same_bets
134 | return end_round
135 |
136 | def _resolve_showdown(self, board: jax.Array, hands: jax.Array) -> jax.Array:
137 | """
138 | Return 0 if p0 wins, 1 if p1 wins, -1 if tie
139 | """
140 |
141 | p0_score = get_showdown_score(jnp.concatenate([board, hands[0]]))
142 | p1_score = get_showdown_score(jnp.concatenate([board, hands[1]]))
143 |
144 | return jnp.where(p0_score == p1_score, jnp.array(-1), p0_score < p1_score)
145 |
146 | def _current_round_to_board_mask(self, current_round: jax.Array) -> jax.Array:
147 | return self._num_visibles[current_round]
148 |
149 | def step(self, state: State, action: jax.Array) -> State:
150 | fold = action < 0.0
151 |
152 | # check = action == 0.0
153 | # bet = action > 0.0
154 |
155 | # clit bet
156 | action = jnp.clip(action, a_min=0, a_max=state.stacks[state.current_player])
157 |
158 | cp = state.current_player
159 | cr = state.current_round
160 |
161 | # store bet, update bet mask, update stack
162 |
163 | current_idx = state.bets_mask[cr, cp].sum()
164 | new_bets_mask = state.bets_mask.at[cr, cp, current_idx].set(True)
165 | new_bets = state.bets.at[cr, cp, current_idx].set(action)
166 | new_stacks = state.stacks.at[cp].set(state.stacks[cp] - action)
167 |
168 | reward = jnp.ones(NUM_PLAYERS) * new_bets[:, cp].sum()
169 | # compute hypothetic fold reward
170 | fold_reward = jnp.where(jnp.arange(NUM_PLAYERS) == cp, -reward, reward)
171 | # jax.debug.print("fold_reward: {x}", x=fold_reward)
172 |
173 | # compute hypothetic showdown reward
174 | showdown_result = state.showdown_result
175 | showdown_reward = jnp.where(
176 | jnp.arange(NUM_PLAYERS) == showdown_result, reward, -reward
177 | )
178 | showdown_reward = jnp.where(showdown_result == -1, 0, showdown_reward)
179 |
180 | end_round = self._is_end_round(
181 | current_round=state.current_round,
182 | bets_mask=new_bets_mask,
183 | bets=new_bets,
184 | )
185 |
186 | # jax.debug.print("end_round: {x}", x=end_round)
187 |
188 | is_allin = (state.stacks == 0).any() & ~fold
189 |
190 | is_showdown = (end_round & (state.current_round == 3)) | is_allin
191 |
192 | done = state.terminated | is_showdown | fold
193 |
194 | # jax.debug.print("is showdown: {x}", x=is_showdown)
195 |
196 | reward = jnp.where(is_showdown, showdown_reward, fold_reward)
197 | reward = jnp.where(done, reward, 0.0)
198 |
199 | new_cr = jnp.where(end_round, cr + 1, cr)
200 | new_cr = jnp.where(is_showdown, 4, new_cr)
201 |
202 | new_current_player = jnp.where(
203 | end_round, (state.dealer + 1) % NUM_PLAYERS, (cp + 1) % NUM_PLAYERS
204 | )
205 |
206 | bet_diff = jnp.abs(jnp.diff(new_bets[cr].sum(axis=-1)))[0]
207 | my_total_bet = new_bets[cr, cp].sum(axis=-1)
208 |
209 | min_bet = jnp.maximum(bet_diff, state.small_blind)
210 | min_bet = jnp.where(end_round, state.small_blind, min_bet)
211 |
212 | min_raise = jnp.where(end_round, min_bet, my_total_bet + min_bet)
213 |
214 | new_state = State(
215 | board=state.board,
216 | board_mask=self._num_visibles[new_cr],
217 | hands=state.hands,
218 | stacks=new_stacks,
219 | bets=new_bets,
220 | bets_mask=new_bets_mask,
221 | current_player=new_current_player,
222 | min_bet=min_bet,
223 | min_raise=min_raise,
224 | current_round=new_cr,
225 | terminated=done,
226 | small_blind=state.small_blind,
227 | showdown_result=state.showdown_result,
228 | dealer=state.dealer,
229 | rewards=reward,
230 | )
231 |
232 | return new_state
233 |
234 | def observe(self, state: State) -> InfoState:
235 | board_mask = self._num_visibles[state.current_round]
236 | board_mask = ~(
237 | jnp.zeros(5, dtype=bool).at[board_mask].set(True).cumsum().astype(bool)
238 | )
239 | infostate = InfoState(
240 | hand=state.hands[state.current_player],
241 | board=jnp.where(board_mask, state.board, -1),
242 | board_mask=board_mask,
243 | bets=state.bets,
244 | bets_mask=state.bets_mask,
245 | stacks=state.stacks,
246 | )
247 |
248 | return infostate
249 |
--------------------------------------------------------------------------------
/cfrx/envs/nlhe_poker/test_showdown.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 |
3 | from cfrx.envs.nlhe_poker.showdown import (
4 | flush_score,
5 | four_of_a_kind_score,
6 | full_house_score,
7 | get_hand_type,
8 | high_card_hand_score,
9 | one_pair_score,
10 | straight_flush_score,
11 | straight_score,
12 | three_of_a_kind_score,
13 | two_pair_score,
14 | )
15 |
16 |
17 | def test_straight_flush_order():
18 | hands = [
19 | (
20 | jnp.array([0, 1, 2, 3, 4, 8, 12]),
21 | jnp.array([0, 0, 0, 0, 0, 1, 2]),
22 | ), # 5-high straight flush
23 | (
24 | jnp.array([4, 5, 6, 7, 8, 1, 11]),
25 | jnp.array([2, 2, 2, 2, 2, 3, 0]),
26 | ), # 9-high straight flush
27 | (
28 | jnp.array([8, 9, 10, 11, 12, 0, 2]),
29 | jnp.array([1, 1, 1, 1, 1, 2, 3]),
30 | ), # King-high straight flush
31 | (
32 | jnp.array([9, 10, 11, 12, 0, 8, 7]),
33 | jnp.array([3, 3, 3, 3, 3, 0, 1]),
34 | ), # Ace-high straight flush
35 | ]
36 |
37 | scores = [straight_flush_score(ranks, suits) for ranks, suits in hands]
38 |
39 | print("straight_flush_score", scores)
40 |
41 | # Expected order of scores from lowest to highest
42 | expected_order = sorted(scores)
43 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
44 |
45 |
46 | def test_four_of_a_kind_order():
47 | hands = [
48 | jnp.array([10, 10, 10, 10, 0, 1, 2]), # Four of Jacks with kicker A
49 | jnp.array([11, 11, 11, 11, 0, 1, 2]), # Four of Queens with kicker A
50 | jnp.array([12, 12, 12, 12, 0, 2, 3]), # Four of Kings with kicker A
51 | jnp.array([0, 0, 0, 0, 1, 1, 1]), # Four of Aces with kicker 2
52 | jnp.array([0, 0, 0, 0, 1, 1, 2]), # Four of Aces with kicker 3
53 | jnp.array([0, 0, 0, 0, 1, 1, 12]), # Four of Aces with kicker King
54 | jnp.array([0, 0, 0, 0, 11, 12, 12]), # Four of Aces with kicker King
55 | ]
56 |
57 | suits = (jnp.array([3, 3, 3, 3, 0, 1, 2]),)
58 | scores = [four_of_a_kind_score(ranks, suits) for ranks in hands]
59 |
60 | print("four_of_a_kind", scores)
61 |
62 | expected_order = sorted(scores)
63 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
64 |
65 |
66 | def test_three_of_a_kind_order():
67 | hands = [
68 | (jnp.array([1, 1, 1, 2, 3, 4, 6])), # Three 2s with kickers (5,7)
69 | (jnp.array([1, 1, 1, 2, 3, 5, 6])), # Three 2s with kickers (6,7)
70 | (jnp.array([1, 1, 1, 2, 3, 5, 0])), # Three 2s with kickers (6,A)
71 | (jnp.array([1, 1, 1, 3, 3, 3, 2])), # Three 2s, three 4s with kickers (2,3)
72 | (jnp.array([1, 1, 1, 3, 3, 3, 0])), # Three 2s, three 4s with kickers (2,A)
73 | (jnp.array([10, 10, 10, 3, 3, 3, 5])), # Three Js, three 4s with kickers (4,6)
74 | (jnp.array([10, 10, 10, 3, 3, 3, 5])), # Three Js, three 4s with kickers (4,A)
75 | (jnp.array([10, 10, 10, 0, 0, 0, 5])), # Three Js, three As
76 | ]
77 |
78 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3])
79 | scores = [three_of_a_kind_score(ranks, suits) for ranks in hands]
80 |
81 | print("three_of_a_kind", scores)
82 |
83 | expected_order = sorted(scores)
84 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
85 |
86 |
87 | def test_flush_order():
88 | hands = [
89 | (
90 | jnp.array([2, 4, 5, 6, 10, 4, 5]),
91 | jnp.array([0, 0, 0, 0, 0, 1, 1]),
92 | ),
93 | (
94 | jnp.array([2, 4, 5, 8, 10, 4, 5]),
95 | jnp.array([0, 0, 0, 0, 0, 1, 1]),
96 | ),
97 | (
98 | jnp.array([2, 4, 5, 8, 10, 7, 5]),
99 | jnp.array([0, 0, 0, 0, 0, 0, 1]),
100 | ),
101 | (
102 | jnp.array([0, 4, 5, 8, 10, 7, 5]),
103 | jnp.array([0, 0, 0, 0, 0, 0, 1]),
104 | ),
105 | ]
106 |
107 | scores = [flush_score(ranks, suits) for ranks, suits in hands]
108 |
109 | print("flush", scores)
110 |
111 | expected_order = sorted(scores)
112 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
113 |
114 |
115 | def test_straight_order():
116 | hands = [
117 | (jnp.array([0, 1, 2, 3, 4, 7, 8])),
118 | (jnp.array([0, 1, 2, 3, 4, 7, 0])),
119 | (jnp.array([3, 4, 5, 6, 7, 10, 11])),
120 | (jnp.array([2, 9, 8, 9, 10, 11, 12])),
121 | (jnp.array([2, 9, 9, 10, 11, 12, 0])),
122 | ]
123 |
124 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3])
125 | scores = [straight_score(ranks, suits) for ranks in hands]
126 |
127 | print("straight", scores)
128 |
129 | expected_order = sorted(scores)
130 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
131 |
132 |
133 | def test_full_house_order():
134 | hands = [
135 | (jnp.array([1, 1, 2, 2, 2, 3, 4])),
136 | (jnp.array([4, 4, 3, 3, 3, 2, 1])),
137 | (jnp.array([4, 4, 4, 3, 3, 3, 2])),
138 | (jnp.array([2, 2, 10, 10, 10, 4, 5])),
139 | (jnp.array([12, 12, 10, 10, 10, 4, 5])),
140 | (jnp.array([12, 12, 0, 0, 0, 4, 5])),
141 | ]
142 |
143 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3])
144 | scores = [full_house_score(ranks, suits) for ranks in hands]
145 |
146 | print("full_house", scores)
147 |
148 | expected_order = sorted(scores)
149 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
150 |
151 |
152 | def test_two_pair_order():
153 | hands = [
154 | (jnp.array([1, 1, 2, 2, 3, 5, 7])),
155 | (jnp.array([1, 1, 2, 2, 4, 5, 0])),
156 | (jnp.array([1, 1, 3, 4, 5, 5, 7])),
157 | (jnp.array([1, 2, 3, 3, 5, 5, 7])),
158 | (jnp.array([1, 1, 3, 3, 5, 5, 7])),
159 | (jnp.array([11, 11, 12, 12, 2, 3, 6])),
160 | (jnp.array([12, 12, 0, 0, 2, 3, 4])),
161 | (jnp.array([12, 12, 0, 0, 11, 3, 4])),
162 | ]
163 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3])
164 | scores = [two_pair_score(ranks, suits) for ranks in hands]
165 |
166 | print("two_pair", scores)
167 |
168 | expected_order = sorted(scores)
169 |
170 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
171 |
172 |
173 | def test_one_pair_order():
174 | hands = [
175 | (jnp.array([1, 1, 2, 3, 4, 7, 8])),
176 | (jnp.array([1, 1, 2, 3, 9, 7, 8])),
177 | (jnp.array([1, 1, 4, 5, 0, 7, 8])),
178 | (jnp.array([12, 12, 11, 10, 9, 1, 2])),
179 | (jnp.array([0, 0, 2, 3, 4, 5, 7])),
180 | ]
181 |
182 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3])
183 | scores = [one_pair_score(ranks, suits) for ranks in hands]
184 |
185 | print("one_pair", scores)
186 |
187 | expected_order = sorted(scores)
188 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
189 |
190 |
191 | def test_high_card_order():
192 | hands = [
193 | (jnp.array([1, 2, 3, 4, 6, 7, 8])),
194 | (jnp.array([1, 2, 3, 4, 6, 7, 10])),
195 | (jnp.array([1, 2, 3, 4, 6, 8, 10])),
196 | (jnp.array([1, 2, 12, 4, 6, 8, 10])),
197 | (jnp.array([0, 2, 12, 4, 6, 8, 10])),
198 | ]
199 | suits = jnp.array([0, 0, 0, 0, 1, 2, 3])
200 | scores = [high_card_hand_score(ranks, suits) for ranks in hands]
201 |
202 | print("high_card", scores)
203 |
204 | expected_order = sorted(scores)
205 | assert scores == expected_order, f"Scores are not in expected order: {scores}"
206 |
207 |
208 | def test_get_hand_type():
209 | # Royal Flush (Ace-high straight flush)
210 | ranks = jnp.asarray([0, 9, 10, 11, 12, 3, 5]) # A, 10, J, Q, K, 4, 6
211 | suits = jnp.asarray([1, 1, 1, 1, 1, 2, 3])
212 | assert get_hand_type(ranks, suits) == 8
213 |
214 | # Straight Flush
215 | ranks = jnp.asarray([8, 9, 10, 11, 12, 3, 5]) # 9, 10, J, Q, K, 4, 6
216 | suits = jnp.asarray([2, 2, 2, 2, 2, 3, 1])
217 | assert get_hand_type(ranks, suits) == 8
218 |
219 | # Four of a Kind
220 | ranks = jnp.asarray([1, 1, 1, 1, 2, 3, 4]) # 2, 2, 2, 2, 3, 4, 5
221 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2])
222 | assert get_hand_type(ranks, suits) == 7
223 |
224 | # Full House
225 | ranks = jnp.asarray([2, 2, 2, 3, 3, 4, 5]) # 3, 3, 3, 4, 4, 5, 6
226 | suits = jnp.asarray([0, 1, 2, 0, 1, 2, 3])
227 | assert get_hand_type(ranks, suits) == 6
228 |
229 | # Flush
230 | ranks = jnp.asarray([1, 4, 6, 8, 10, 2, 3]) # 2, 5, 7, 9, J, 3, 4
231 | suits = jnp.asarray([1, 1, 1, 1, 1, 2, 0])
232 | assert get_hand_type(ranks, suits) == 5
233 |
234 | # Straight
235 | ranks = jnp.asarray([0, 9, 10, 11, 12, 1, 2]) # 5, 6, 7, 8, 9, 10, J
236 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2])
237 | assert get_hand_type(ranks, suits) == 4
238 |
239 | # Three of a Kind
240 | ranks = jnp.asarray([3, 3, 3, 5, 6, 7, 8]) # 4, 4, 4, 6, 7, 8, 9
241 | suits = jnp.asarray([0, 1, 2, 0, 1, 2, 3])
242 | assert get_hand_type(ranks, suits) == 3
243 |
244 | # Two Pairs
245 | ranks = jnp.asarray([4, 4, 7, 7, 10, 1, 2]) # 5, 5, 8, 8, J, 2, 3
246 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2])
247 | assert get_hand_type(ranks, suits) == 2
248 |
249 | # One Pair
250 | ranks = jnp.asarray([6, 6, 8, 9, 11, 1, 3]) # 7, 7, 9, 10, Q, 2, 4
251 | suits = jnp.asarray([0, 1, 2, 0, 1, 2, 3])
252 | assert get_hand_type(ranks, suits) == 1
253 |
254 | ranks = jnp.asarray([1, 10, 5, 7, 9, 2, 4]) # 2, 4, 6, 8, 10, 3, 5
255 | suits = jnp.asarray([0, 1, 2, 3, 0, 1, 2])
256 | assert get_hand_type(ranks, suits) == 0
257 |
--------------------------------------------------------------------------------
/cfrx/envs/leduc_poker/env.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import pgx
7 | import pgx.leduc_holdem
8 | from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, Shaped
9 | from pgx._src.dwg.leduc_holdem import CARD
10 | from pgx._src.struct import dataclass
11 |
12 | import cfrx.envs
13 | from cfrx.envs.leduc_poker.constants import INFO_SETS, REVERSE_INFO_SETS_LOOKUP
14 | from cfrx.utils import ravel
15 |
16 | CARD.append("?")
17 | INFO_SETS_VALUES = np.stack(list(INFO_SETS.values()))
18 | NUM_DIFFERENT_CARDS = 3
19 | NUM_REPEAT_CARDS = 2
20 | NUM_TOTAL_CARDS = NUM_DIFFERENT_CARDS * NUM_REPEAT_CARDS
21 |
22 |
23 | class InfoState(NamedTuple):
24 | private_card: Int[Array, ""]
25 | public_card: Int[Array, ""]
26 | action_sequence: Int[Array, "..."]
27 | chance_round: Int[Array, "..."]
28 | chance_node: Bool[Array, "..."]
29 |
30 |
31 | @dataclass
32 | class State(pgx.leduc_holdem.State):
33 | info_state: InfoState = InfoState(
34 | private_card=jnp.int8(-1),
35 | public_card=jnp.int8(-1),
36 | action_sequence=jnp.ones((2, 4), dtype=jnp.int8) * -1,
37 | chance_round=jnp.int8(0),
38 | chance_node=jnp.bool_(True),
39 | )
40 | chance_node: Bool[Array, ""] = jnp.bool_(False)
41 | chance_prior: Float[Array, "..."] = (
42 | jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS
43 | )
44 |
45 |
46 | def convert_info_state_to_idx(info_state: InfoState) -> jnp.ndarray:
47 | """
48 | This is a bit hacky, it allows to transform an infostate into an index, and to
49 | construct an array to efficiently lookup in Jax.
50 | """
51 |
52 | info_state_ravel = ravel(info_state)
53 | multiplier = jnp.array([3**k for k in range(info_state_ravel.shape[-1])])
54 | idx = jnp.sum((info_state_ravel + 1) * multiplier) % 1235
55 | return idx
56 |
57 |
58 | class LeducPoker(pgx.leduc_holdem.LeducHoldem, cfrx.envs.Env):
59 | @classmethod
60 | def action_to_string(cls, action: Int[Array, ""]) -> str:
61 | strings = ["c", "r", "f"]
62 | if action != -1:
63 | a = int(action)
64 | rep = strings[a]
65 | else:
66 | rep = "?"
67 | return rep
68 |
69 | @property
70 | def max_episode_length(self) -> int:
71 | return 12
72 |
73 | @property
74 | def max_nodes(self) -> int:
75 | return 2000
76 |
77 | @property
78 | def n_info_states(self) -> int:
79 | return len(INFO_SETS)
80 |
81 | @property
82 | def n_actions(self) -> int:
83 | return super().num_actions
84 |
85 | @property
86 | def n_players(self) -> int:
87 | return self.num_players
88 |
89 | def update_info_state(
90 | self, state: State, next_state: State, action: Int[Array, ""]
91 | ) -> InfoState:
92 | info_state = next_state.info_state
93 | assert info_state is not None
94 | private_card = next_state._cards[next_state.current_player]
95 | public_card = next_state._cards[-1]
96 |
97 | current_position = (info_state.action_sequence[state._round] != -1).sum()
98 | action_sequence = info_state.action_sequence.at[
99 | state._round, current_position
100 | ].set(action.astype(jnp.int8))
101 | updated_info_state = info_state._replace(action_sequence=action_sequence)
102 |
103 | info_state = jax.tree_map(
104 | lambda x, y: jnp.where(state.chance_node, x, y),
105 | info_state,
106 | updated_info_state,
107 | )
108 | chance_round = info_state.chance_round + state.chance_node
109 | return info_state._replace(
110 | private_card=private_card,
111 | public_card=public_card,
112 | chance_round=chance_round,
113 | chance_node=next_state.chance_node,
114 | )
115 |
116 | def info_state_to_str(self, info_state: InfoState) -> str:
117 | if info_state.chance_node:
118 | rep = f"chance{info_state.chance_round}:"
119 |
120 | else:
121 | rep = ""
122 |
123 | strings = ["c", "r", "f"]
124 | rep += f"{info_state.private_card}"
125 |
126 | for action in np.array(info_state.action_sequence)[0]:
127 | if action != -1:
128 | rep += strings[action]
129 |
130 | if info_state.public_card != -1:
131 | rep += f"{info_state.public_card}"
132 | for action in np.array(info_state.action_sequence)[1]:
133 | if action != -1:
134 | rep += strings[action]
135 | return rep
136 |
137 | def get_action_mask(self, state: State) -> jax.Array:
138 | return state.legal_action_mask
139 |
140 | def get_chance_mask(self, state: State) -> jax.Array:
141 | return state.chance_prior > 0
142 |
143 | def get_chance_probs(self, state: State) -> jax.Array:
144 | return jnp.where(
145 | (state.chance_prior != 0).any(),
146 | state.chance_prior / state.chance_prior.sum(),
147 | 0,
148 | )
149 |
150 | def get_info_state(self, state: State) -> jax.Array:
151 | return state.info_state
152 |
153 | def info_state_idx(self, info_state: InfoState) -> Array:
154 | idx = convert_info_state_to_idx(info_state)
155 | return jnp.asarray(REVERSE_INFO_SETS_LOOKUP)[idx]
156 |
157 | def _init(self, rng: Shaped[PRNGKeyArray, "2"]) -> State:
158 | env_state = super()._init(rng)
159 | info_state = InfoState(
160 | private_card=jnp.int8(-1),
161 | public_card=jnp.int8(-1),
162 | action_sequence=jnp.ones((2, 4), dtype=jnp.int8) * -1,
163 | chance_round=jnp.int8(0),
164 | chance_node=jnp.bool_(True),
165 | )
166 | cards = jnp.int8([-1, -1, -1])
167 | return State(
168 | _first_player=env_state._first_player,
169 | current_player=env_state.current_player,
170 | observation=env_state.observation,
171 | rewards=env_state.rewards,
172 | terminated=env_state.terminated,
173 | truncated=env_state.truncated,
174 | _step_count=env_state._step_count,
175 | _last_action=env_state._last_action,
176 | _round=jnp.int8(0),
177 | _cards=cards,
178 | legal_action_mask=jnp.ones_like(env_state.legal_action_mask),
179 | _chips=env_state._chips,
180 | _raise_count=env_state._raise_count,
181 | info_state=info_state,
182 | chance_prior=jnp.ones(NUM_DIFFERENT_CARDS, dtype=int) * NUM_REPEAT_CARDS,
183 | chance_node=jnp.bool_(True),
184 | )
185 |
186 | def _resolve_decision_node(
187 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray
188 | ) -> State:
189 | env_state = super()._step(state=state, action=action, key=random_key)
190 |
191 | is_public_card_unknown = env_state._cards[-1] == -1
192 | chance_node = (
193 | (env_state._round > 0) & is_public_card_unknown & ~env_state.terminated
194 | )
195 |
196 | legal_action_mask = jnp.where(
197 | chance_node,
198 | jnp.ones_like(env_state.legal_action_mask),
199 | env_state.legal_action_mask,
200 | )
201 |
202 | state = State(
203 | _first_player=env_state._first_player,
204 | current_player=env_state.current_player,
205 | observation=env_state.observation,
206 | rewards=env_state.rewards,
207 | terminated=env_state.terminated,
208 | truncated=env_state.truncated,
209 | _step_count=env_state._step_count,
210 | _last_action=env_state._last_action,
211 | _round=env_state._round.astype(jnp.int8),
212 | _cards=env_state._cards,
213 | legal_action_mask=legal_action_mask,
214 | _chips=env_state._chips,
215 | _raise_count=env_state._raise_count,
216 | info_state=env_state.info_state,
217 | chance_node=chance_node,
218 | chance_prior=env_state.chance_prior,
219 | )
220 |
221 | return state
222 |
223 | def _resolve_chance_node(
224 | self, state: State, action: Int[Array, ""], random_key: PRNGKeyArray
225 | ) -> State:
226 | draw_player = NUM_TOTAL_CARDS - state.chance_prior.sum()
227 | action = action.astype(jnp.int8)
228 | card_rank = action % NUM_DIFFERENT_CARDS
229 | cards = state._cards.at[draw_player].set(card_rank)
230 | chance_prior = state.chance_prior.at[action].add(-1)
231 | chance_node = (cards[:2] == -1).any()
232 |
233 | legal_action_mask = jnp.where(
234 | chance_node,
235 | jnp.ones_like(state.legal_action_mask),
236 | jnp.ones_like(state.legal_action_mask).at[-1].set(False),
237 | )
238 | return State(
239 | _first_player=state._first_player,
240 | current_player=state.current_player,
241 | observation=state.observation,
242 | rewards=state.rewards,
243 | terminated=state.terminated,
244 | truncated=state.truncated,
245 | _step_count=state._step_count,
246 | _last_action=state._last_action,
247 | _round=state._round.astype(jnp.int8),
248 | _cards=cards,
249 | legal_action_mask=legal_action_mask,
250 | _chips=state._chips,
251 | _raise_count=state._raise_count,
252 | info_state=state.info_state,
253 | chance_prior=chance_prior,
254 | chance_node=chance_node,
255 | )
256 |
257 | def _step(self, state: State, action: Array, random_key: PRNGKeyArray) -> State:
258 | new_state = jax.lax.cond(
259 | state.chance_node,
260 | lambda: self._resolve_chance_node(
261 | state=state, action=action, random_key=random_key
262 | ),
263 | lambda: self._resolve_decision_node(
264 | state=state, action=action, random_key=random_key
265 | ),
266 | )
267 |
268 | _round = jnp.where(new_state._cards[-1] != -1, jnp.maximum(state._round, 1), 0)
269 | new_state = new_state.replace(_round=_round)
270 | info_state = self.update_info_state(
271 | state=state, next_state=new_state, action=action
272 | )
273 |
274 | return new_state.replace(info_state=info_state)
275 |
--------------------------------------------------------------------------------
/cfrx/algorithms/mccfr/outcome_sampling.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from typing import Any, NamedTuple
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jaxtyping import Array, Float, Int, PRNGKeyArray
9 |
10 | from cfrx.envs import Env, InfoState, State
11 | from cfrx.episode import Episode
12 | from cfrx.policy import Policy, TabularPolicy
13 | from cfrx.utils import regret_matching
14 |
15 |
16 | class MCCFRState(NamedTuple):
17 | regrets: Float[Array, "*batch a"]
18 | probs: Float[Array, "... a"]
19 | avg_probs: Float[Array, "... a"]
20 | step: Int[Array, "..."]
21 |
22 | @classmethod
23 | def init(cls, n_states: int, n_actions: int) -> MCCFRState:
24 | return MCCFRState(
25 | regrets=jnp.zeros((n_states, n_actions)),
26 | probs=jnp.ones((n_states, n_actions))
27 | / jnp.ones((n_states, n_actions)).sum(axis=-1, keepdims=True),
28 | avg_probs=jnp.zeros((n_states, n_actions)) + 1e-6,
29 | step=jnp.array(1, dtype=int),
30 | )
31 |
32 |
33 | def compute_sampled_counterfactual_action_value(
34 | opponent_reach_prob: Float[Array, ""],
35 | outcome_sampling_prob: Float[Array, ""],
36 | outcome_prob: Float[Array, ""],
37 | utility: Float[Array, ""],
38 | ) -> Float[Array, ""]:
39 | cf_value_a = opponent_reach_prob * outcome_prob * utility / outcome_sampling_prob
40 | return cf_value_a
41 |
42 |
43 | def compute_strategy_profile(
44 | my_reach_prob: Float[Array, ""],
45 | sample_reach_prob: Float[Array, ""],
46 | strat_probs: Float[Array, "..."],
47 | ) -> Float[Array, "..."]:
48 | avg_probs = my_reach_prob / sample_reach_prob * strat_probs
49 |
50 | return avg_probs
51 |
52 |
53 | def get_regrets(
54 | step: Episode,
55 | my_prob_cumprod: Float[Array, ""],
56 | opponent_prob_cumprod: Float[Array, ""],
57 | sample_prob_cumprod: Float[Array, ""],
58 | outcome_prob: Float[Array, ""],
59 | strat_prob_distrib: Float[Array, " num_actions"],
60 | strat_prob: Float[Array, ""],
61 | utility: Float[Array, ""],
62 | outcome_sampling_prob: Float[Array, ""],
63 | update_player: Int[Array, ""],
64 | ) -> tuple[Float[Array, " num_actions"], Float[Array, " num_actions"],]:
65 | is_current_player = step.current_player == update_player
66 | cf_value_a = compute_sampled_counterfactual_action_value(
67 | opponent_reach_prob=opponent_prob_cumprod,
68 | outcome_sampling_prob=outcome_sampling_prob,
69 | outcome_prob=outcome_prob,
70 | utility=utility,
71 | )
72 | cf_value = cf_value_a * strat_prob
73 |
74 | regrets = jnp.zeros(step.action_mask.shape[-1])
75 | regrets = regrets.at[step.action].set(cf_value_a)
76 | regrets = (regrets - cf_value) * step.action_mask
77 |
78 | regrets = regrets * is_current_player * step.mask * (1 - step.chance_node)
79 |
80 | avg_probs = compute_strategy_profile(
81 | my_reach_prob=my_prob_cumprod,
82 | sample_reach_prob=sample_prob_cumprod,
83 | strat_probs=strat_prob_distrib * step.action_mask,
84 | )
85 | avg_probs = avg_probs * step.mask * is_current_player * (1 - step.chance_node)
86 | return regrets, avg_probs
87 |
88 |
89 | def compute_regrets_and_strategy_profile(
90 | episode: Episode,
91 | training_state: MCCFRState,
92 | policy: TabularPolicy,
93 | update_player: Int[Array, ""],
94 | ) -> tuple[InfoState[Array, "..."], Float[Array, "..."], Float[Array, "..."]]:
95 | episode_length = episode.current_player.shape[-1]
96 | utility = episode.reward[episode.mask.sum(), update_player]
97 | is_current_player = episode.current_player == update_player
98 | prob_dist_fn = functools.partial(
99 | policy.prob_distribution, params=training_state.probs
100 | )
101 |
102 | strat_prob_distribs = jax.vmap(prob_dist_fn)(
103 | info_state=episode.info_state,
104 | action_mask=episode.action_mask,
105 | use_behavior_policy=jnp.zeros(episode_length, dtype=bool),
106 | )
107 |
108 | strat_probs = jnp.where(
109 | episode.mask,
110 | strat_prob_distribs[jnp.arange(episode_length), episode.action],
111 | 1.0,
112 | )
113 | strat_probs = jnp.where(episode.chance_node, episode.behavior_prob, strat_probs)
114 | my_probs = jnp.where(is_current_player, strat_probs, 1)
115 | opponent_probs = jnp.where(is_current_player, 1, strat_probs)
116 | sample_probs = jnp.where(episode.mask, episode.behavior_prob, 1)
117 |
118 | sample_probs_cumprod = jnp.cumprod(jnp.concatenate([jnp.ones(1), sample_probs]))
119 | my_probs_cumprod = jnp.cumprod(jnp.concatenate([jnp.ones(1), my_probs]))
120 | opponent_probs_cumprod = jnp.cumprod(opponent_probs)
121 | outcome_probs = jnp.cumprod(jnp.concatenate([strat_probs, jnp.ones(1)])[::-1])[::-1]
122 |
123 | outcome_sampling_prob = jnp.prod(sample_probs)
124 |
125 | _get_regrets = functools.partial(
126 | get_regrets,
127 | utility=utility,
128 | update_player=update_player,
129 | outcome_sampling_prob=outcome_sampling_prob,
130 | )
131 | regrets, avg_probs = jax.vmap(_get_regrets)(
132 | step=episode,
133 | my_prob_cumprod=my_probs_cumprod[:-1],
134 | opponent_prob_cumprod=opponent_probs_cumprod,
135 | sample_prob_cumprod=sample_probs_cumprod[:-1],
136 | outcome_prob=outcome_probs[1:],
137 | strat_prob_distrib=strat_prob_distribs,
138 | strat_prob=strat_probs,
139 | )
140 | regrets = regrets[:-1]
141 | avg_probs = avg_probs[:-1]
142 | episode = jax.tree_map(lambda x: x[:-1], episode)
143 |
144 | return episode.info_state, regrets, avg_probs
145 |
146 |
147 | def unroll(
148 | init_state: State,
149 | random_key: PRNGKeyArray,
150 | training_state: MCCFRState,
151 | env: Env,
152 | policy: Policy,
153 | update_player: Int[Array, ""],
154 | n_max_steps: int,
155 | ) -> tuple[Episode, State]:
156 | """
157 | Generates a single unroll of the game.
158 | """
159 |
160 | def play_step(
161 | carry: tuple[State, PRNGKeyArray], unused: Any
162 | ) -> tuple[tuple[State, PRNGKeyArray], tuple[Episode, State]]:
163 | state, random_key = carry
164 |
165 | use_behavior_policy = state.current_player == update_player
166 |
167 | random_key, subkey = jax.random.split(random_key)
168 |
169 | action = policy.sample(
170 | params=training_state.probs,
171 | info_state=state.info_state,
172 | action_mask=state.legal_action_mask,
173 | random_key=subkey,
174 | use_behavior_policy=use_behavior_policy,
175 | )
176 | probs = policy.prob_distribution(
177 | params=training_state.probs,
178 | info_state=state.info_state,
179 | action_mask=state.legal_action_mask,
180 | use_behavior_policy=use_behavior_policy,
181 | )
182 |
183 | chance_probs = state.chance_prior / state.chance_prior.sum(
184 | axis=-1, keepdims=True
185 | )
186 | random_key, subkey = jax.random.split(random_key)
187 | chance_action = jax.random.choice(
188 | subkey,
189 | jnp.arange(state.chance_prior.shape[0]),
190 | p=chance_probs,
191 | )
192 |
193 | action = jnp.where(state.chance_node, chance_action, action)
194 |
195 | probs = jnp.where(state.chance_node, chance_probs[action], probs[action])
196 |
197 | current_player = jnp.where(state.chance_node, -1, state.current_player)
198 | game_step = Episode(
199 | info_state=state.info_state,
200 | action=action,
201 | reward=state.rewards,
202 | action_mask=state.legal_action_mask,
203 | current_player=current_player,
204 | behavior_prob=probs,
205 | chance_node=state.chance_node,
206 | mask=1 - state.terminated,
207 | )
208 |
209 | state = env.step(state, action)
210 | return (state, random_key), (game_step, state)
211 |
212 | (state, random_key), (episode, states) = jax.lax.scan(
213 | play_step, (init_state, random_key), (), length=n_max_steps
214 | )
215 |
216 | last_game_step = Episode(
217 | info_state=state.info_state,
218 | action=jnp.array(-1),
219 | reward=state.rewards,
220 | action_mask=jnp.ones(policy._n_actions, dtype=bool),
221 | current_player=state.current_player,
222 | behavior_prob=jnp.array(1.0),
223 | chance_node=jnp.bool_(False),
224 | mask=1 - state.terminated,
225 | )
226 | states = jax.tree_map(
227 | lambda x, y: jnp.concatenate([jnp.expand_dims(y, 0), x]),
228 | states,
229 | init_state,
230 | )
231 |
232 | episode = jax.tree_map(
233 | lambda x, y: jnp.concatenate([x, jnp.expand_dims(y, 0)]),
234 | episode,
235 | last_game_step,
236 | )
237 | return episode, states
238 |
239 |
240 | def do_iteration(
241 | training_state: MCCFRState,
242 | random_key: PRNGKeyArray,
243 | env: Env,
244 | policy: TabularPolicy,
245 | update_player: Int[Array, ""],
246 | ) -> tuple[Float[Array, "*batch a"], Float[Array, "*batch a"], Episode]:
247 | """
248 | Do one iteration of MCCFR: traverse the game tree once and compute counterfactual
249 | regrets and strategy profiles.
250 |
251 | Args:
252 | training_state: The current state of the training.
253 | random_key: A random key.
254 | env: The environment.
255 | policy: The policy.
256 | update_player: The player to update.
257 |
258 | Returns:
259 | The updated regrets and average strategy profile and the episode.
260 | """
261 |
262 | # Sample one path in the game tree
263 | random_key, subkey = jax.random.split(random_key)
264 | episode, states = unroll(
265 | init_state=env.init(subkey),
266 | training_state=training_state,
267 | random_key=subkey,
268 | update_player=update_player,
269 | env=env,
270 | policy=policy,
271 | n_max_steps=env.max_episode_length,
272 | )
273 |
274 | # Compute counterfactual values and strategy profile
275 | (
276 | info_states,
277 | sampled_regrets,
278 | sampled_avg_probs,
279 | ) = compute_regrets_and_strategy_profile(
280 | episode=episode,
281 | training_state=training_state,
282 | policy=policy,
283 | update_player=update_player,
284 | )
285 | info_states_idx = jax.vmap(env.info_state_idx)(info_states)
286 |
287 | # Store regret and strategy profile values
288 | new_regrets = training_state.regrets.at[info_states_idx].add(sampled_regrets)
289 | new_avg_probs = training_state.avg_probs.at[info_states_idx].add(sampled_avg_probs)
290 |
291 | # Accumulate regrets, compute new strategy and avg strategy
292 | new_probs = regret_matching(new_regrets)
293 | new_probs /= new_probs.sum(axis=-1, keepdims=True)
294 |
295 | training_state = training_state._replace(
296 | regrets=new_regrets,
297 | probs=new_probs,
298 | avg_probs=new_avg_probs,
299 | step=training_state.step + 1,
300 | )
301 | return training_state
302 |
--------------------------------------------------------------------------------
/cfrx/tree/traverse.py:
--------------------------------------------------------------------------------
1 | import functools as ft
2 | from typing import Tuple
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import pgx
7 | from jaxtyping import Array, Bool, Float, Int, PyTree
8 |
9 | from cfrx.policy import Policy
10 | from cfrx.tree import Root, Tree
11 | from cfrx.utils import get_action_mask
12 |
13 |
14 | def instantiate_tree_from_root(
15 | root: Root,
16 | n_max_nodes: int,
17 | n_players: int,
18 | running_probabilities: bool = False,
19 | ) -> Tree:
20 | """Initializes tree state at search root."""
21 | (n_actions,) = root.prior_logits.shape
22 |
23 | data_dtype = root.value.dtype
24 |
25 | def _zeros(x: Array) -> Array:
26 | return jnp.zeros((n_max_nodes,) + x.shape, dtype=x.dtype)
27 |
28 | # Create a new empty tree state and fill its root.
29 | tree = Tree(
30 | node_visits=jnp.zeros(n_max_nodes, dtype=jnp.int32),
31 | raw_values=jnp.zeros((n_max_nodes, n_players), dtype=data_dtype),
32 | node_values=jnp.zeros((n_max_nodes, n_players), dtype=data_dtype),
33 | parents=jnp.full(
34 | n_max_nodes,
35 | Tree.NO_PARENT,
36 | dtype=jnp.int32,
37 | ),
38 | action_from_parent=jnp.full(
39 | n_max_nodes,
40 | Tree.NO_PARENT,
41 | dtype=jnp.int32,
42 | ),
43 | children_index=jnp.full(
44 | (n_max_nodes, n_actions),
45 | Tree.UNVISITED,
46 | dtype=jnp.int32,
47 | ),
48 | children_prior_logits=jnp.zeros(
49 | (n_max_nodes, n_actions), dtype=root.prior_logits.dtype
50 | ),
51 | children_values=jnp.zeros((n_max_nodes, n_actions, n_players), dtype=data_dtype),
52 | children_visits=jnp.zeros((n_max_nodes, n_actions), dtype=jnp.int32),
53 | children_rewards=jnp.zeros(
54 | (n_max_nodes, n_actions, n_players), dtype=data_dtype
55 | ),
56 | to_visit=jnp.zeros(n_max_nodes, dtype=bool),
57 | states=jax.tree_util.tree_map(_zeros, root.state),
58 | depth=jnp.ones(n_max_nodes, dtype=jnp.int32) * -1,
59 | extra_data={},
60 | )
61 | new_tree: Tree = tree._replace(
62 | node_visits=tree.node_visits.at[Tree.ROOT_INDEX].set(1),
63 | states=jax.tree_map(
64 | lambda x, y: x.at[Tree.ROOT_INDEX].set(y), tree.states, root.state
65 | ),
66 | children_prior_logits=tree.children_prior_logits.at[Tree.ROOT_INDEX].set(
67 | root.prior_logits
68 | ),
69 | depth=tree.depth.at[Tree.ROOT_INDEX].set(0),
70 | )
71 |
72 | if running_probabilities:
73 | new_tree = initialize_running_probabilities(new_tree)
74 |
75 | return new_tree
76 |
77 |
78 | def initialize_running_probabilities(tree: Tree) -> Tree:
79 | n_max_nodes = tree.node_visits.shape[-1]
80 | init_prob = (jnp.ones(n_max_nodes) * -1).at[Tree.ROOT_INDEX].set(1.0)
81 | running_probabilities = {
82 | "p_self": init_prob,
83 | "p_opponent": init_prob,
84 | "p_chance": init_prob,
85 | }
86 | tree = tree._replace(extra_data={**tree.extra_data, **running_probabilities})
87 | return tree
88 |
89 |
90 | def add_children(
91 | tree: Tree,
92 | state: PyTree,
93 | env: pgx.Env,
94 | node_counter: jax.Array,
95 | parent_idx: jax.Array,
96 | ) -> tuple[Tree, jax.Array]:
97 |
98 | chance_fn = ft.partial(add_chance_children, env=env)
99 | player_fn = ft.partial(add_player_children, env=env)
100 | return jax.lax.cond(
101 | state.chance_node, chance_fn, player_fn, tree, state, node_counter, parent_idx
102 | )
103 |
104 |
105 | def add_player_children(
106 | tree: Tree,
107 | state: PyTree,
108 | node_counter: jax.Array,
109 | parent_idx: jax.Array,
110 | env: pgx.Env,
111 | ) -> tuple[Tree, jax.Array]:
112 |
113 | action_mask = env.get_action_mask(state) & ~state.terminated
114 | n_actions = len(action_mask)
115 | n_max_nodes = len(tree.to_visit)
116 | action_idx = jnp.where(action_mask, jnp.arange(n_actions), n_max_nodes)
117 | action_idx = jnp.sort(action_idx)
118 |
119 | update_idx = jnp.arange(n_actions) + node_counter + 1
120 | update_idx = jnp.where(action_idx == n_max_nodes, n_max_nodes, update_idx)
121 |
122 | action_from_parent = tree.action_from_parent.at[update_idx].set(action_idx)
123 | parents = tree.parents.at[update_idx].set(parent_idx)
124 | to_visit = tree.to_visit.at[update_idx].set(True)
125 |
126 | tree = tree._replace(
127 | action_from_parent=action_from_parent, parents=parents, to_visit=to_visit
128 | )
129 |
130 | return tree, action_mask.sum()
131 |
132 |
133 | def add_chance_children(
134 | tree: Tree,
135 | state: PyTree,
136 | node_counter: jax.Array,
137 | parent_idx: jax.Array,
138 | env: pgx.Env,
139 | ) -> tuple[Tree, jax.Array]:
140 |
141 | action_mask = env.get_chance_mask(state) & ~state.terminated
142 | n_actions = len(action_mask)
143 | n_max_nodes = len(tree.to_visit)
144 | action_idx = jnp.where(action_mask, jnp.arange(n_actions), n_max_nodes)
145 | action_idx = jnp.sort(action_idx)
146 |
147 | update_idx = jnp.arange(n_actions) + node_counter + 1
148 | update_idx = jnp.where(action_idx == n_max_nodes, n_max_nodes, update_idx)
149 |
150 | action_from_parent = tree.action_from_parent.at[update_idx].set(action_idx)
151 | parents = tree.parents.at[update_idx].set(parent_idx)
152 | to_visit = tree.to_visit.at[update_idx].set(True)
153 |
154 | tree = tree._replace(
155 | action_from_parent=action_from_parent, parents=parents, to_visit=to_visit
156 | )
157 |
158 | return tree, action_mask.sum()
159 |
160 |
161 | def select_new_node_and_play(
162 | tree: Tree, env: pgx.Env
163 | ) -> tuple[PyTree, jax.Array, jax.Array, jax.Array]:
164 |
165 | child_index = jnp.argmax(tree.to_visit)
166 | parent_index = tree.parents[child_index]
167 | action = tree.action_from_parent[child_index]
168 |
169 | parent_state = jax.tree_map(lambda x: x[parent_index], tree.states)
170 | print(parent_state.legal_action_mask.shape)
171 | new_state = env.step(parent_state, action)
172 |
173 | return new_state, parent_index, child_index, action
174 |
175 |
176 | def update_running_probabilities(
177 | tree: Tree,
178 | parent_index: jax.Array,
179 | next_node_index: jax.Array,
180 | strategy: Float[Array, ""],
181 | traverser: int,
182 | ) -> Tree:
183 | parent_state = jax.tree_map(lambda x: x[parent_index], tree.states)
184 |
185 | p_self = tree.extra_data["p_self"]
186 |
187 | update_p_self_condition = (
188 | (parent_state.current_player == traverser)
189 | & (~parent_state.terminated)
190 | & (~parent_state.chance_node)
191 | )
192 | p_self_new_value = jnp.where(
193 | update_p_self_condition,
194 | p_self[parent_index] * strategy,
195 | p_self[parent_index],
196 | )
197 |
198 | p_opponent = tree.extra_data["p_opponent"]
199 |
200 | update_p_opponent_condition = (
201 | (parent_state.current_player != traverser)
202 | & (~parent_state.terminated)
203 | & (~parent_state.chance_node)
204 | )
205 |
206 | p_opponent_new_value = jnp.where(
207 | update_p_opponent_condition,
208 | p_opponent[parent_index] * strategy,
209 | p_opponent[parent_index],
210 | )
211 |
212 | p_chance = tree.extra_data["p_chance"]
213 |
214 | p_chance_new_value = jnp.where(
215 | parent_state.chance_node,
216 | p_chance[parent_index] * strategy,
217 | p_chance[parent_index],
218 | )
219 |
220 | tree = tree._replace(
221 | extra_data={
222 | **tree.extra_data,
223 | **{
224 | "p_self": p_self.at[next_node_index].set(p_self_new_value),
225 | "p_opponent": p_opponent.at[next_node_index].set(p_opponent_new_value),
226 | "p_chance": p_chance.at[next_node_index].set(p_chance_new_value),
227 | },
228 | }
229 | )
230 | return tree
231 |
232 |
233 | def traverse_tree_vanilla(
234 | tree: Tree,
235 | env: pgx.Env,
236 | ) -> Tree:
237 | def cond_fn(val: Tuple) -> Bool[Array, ""]:
238 | tree, n = val
239 | n_max_nodes = len(tree.node_visits)
240 |
241 | return tree.to_visit.any() & (n < n_max_nodes)
242 |
243 | def loop_fn(val: Tuple) -> Tuple:
244 | tree, n = val
245 |
246 | new_state, parent_index, child_index, action = select_new_node_and_play(
247 | tree, env
248 | )
249 |
250 | tree, n_added = add_children(
251 | tree=tree, state=new_state, env=env, node_counter=n, parent_idx=child_index
252 | )
253 |
254 | tree = tree._replace(
255 | node_visits=tree.node_visits.at[child_index].set(1),
256 | node_values=tree.node_values.at[child_index].set(new_state.rewards),
257 | states=jax.tree_map(
258 | lambda x, y: x.at[child_index].set(y), tree.states, new_state
259 | ),
260 | raw_values=tree.raw_values.at[child_index].set(new_state.rewards),
261 | children_index=tree.children_index.at[parent_index, action].set(child_index),
262 | children_rewards=tree.children_rewards.at[parent_index, action].set(
263 | new_state.rewards
264 | ),
265 | children_values=tree.children_values.at[parent_index, action].set(
266 | new_state.rewards
267 | ),
268 | parents=tree.parents.at[child_index].set(parent_index),
269 | action_from_parent=tree.action_from_parent.at[child_index].set(action),
270 | depth=tree.depth.at[child_index].set(tree.depth[parent_index] + 1),
271 | to_visit=tree.to_visit.at[child_index].set(False),
272 | )
273 |
274 | return tree, n + n_added
275 |
276 | tree, n_added = add_children(
277 | tree=tree,
278 | state=jax.tree_map(lambda x: x[0], tree.states),
279 | env=env,
280 | node_counter=0,
281 | parent_idx=0,
282 | )
283 |
284 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, n_added))
285 | return tree
286 |
287 |
288 | def traverse_tree_cfr(
289 | tree: Tree,
290 | policy: Policy,
291 | policy_params: Array,
292 | env: pgx.Env,
293 | traverser: int = 0,
294 | ) -> Tree:
295 | def cond_fn(val: Tuple) -> Bool[Array, ""]:
296 | tree, n = val
297 | n_max_nodes = len(tree.node_visits)
298 |
299 | return tree.to_visit.any() & (n < n_max_nodes)
300 |
301 | def loop_fn(val: Tuple) -> Tuple:
302 | tree, n = val
303 |
304 | new_state, parent_index, child_index, action = select_new_node_and_play(
305 | tree, env
306 | )
307 |
308 | tree, n_added = add_children(
309 | tree=tree, state=new_state, env=env, node_counter=n, parent_idx=child_index
310 | )
311 |
312 | parent_state = jax.tree_map(lambda x: x[parent_index], tree.states)
313 |
314 | strategy = policy.prob_distribution(
315 | params=policy_params,
316 | info_state=env.get_info_state(parent_state),
317 | action_mask=env.get_action_mask(state=parent_state),
318 | use_behavior_policy=jnp.bool_(False),
319 | )
320 |
321 | chance_strategy = env.get_chance_probs(parent_state)[action]
322 | # jax.debug.breakpoint()
323 |
324 | action_prob = jnp.where(
325 | parent_state.chance_node, chance_strategy, strategy[action]
326 | )
327 |
328 | tree = update_running_probabilities(
329 | tree=tree,
330 | parent_index=parent_index,
331 | next_node_index=child_index,
332 | strategy=action_prob,
333 | traverser=traverser,
334 | )
335 |
336 | tree = tree._replace(
337 | node_visits=tree.node_visits.at[child_index].set(1),
338 | node_values=tree.node_values.at[child_index].set(new_state.rewards),
339 | states=jax.tree_map(
340 | lambda x, y: x.at[child_index].set(y), tree.states, new_state
341 | ),
342 | raw_values=tree.raw_values.at[child_index].set(new_state.rewards),
343 | children_index=tree.children_index.at[parent_index, action].set(child_index),
344 | children_rewards=tree.children_rewards.at[parent_index, action].set(
345 | new_state.rewards
346 | ),
347 | children_values=tree.children_values.at[parent_index, action].set(
348 | new_state.rewards
349 | ),
350 | children_prior_logits=tree.children_prior_logits.at[parent_index].set(
351 | jnp.where(parent_state.chance_node, chance_strategy, strategy)
352 | ),
353 | parents=tree.parents.at[child_index].set(parent_index),
354 | action_from_parent=tree.action_from_parent.at[child_index].set(action),
355 | depth=tree.depth.at[child_index].set(tree.depth[parent_index] + 1),
356 | to_visit=tree.to_visit.at[child_index].set(False),
357 | )
358 |
359 | return tree, n + n_added
360 |
361 | tree, n_added = add_children(
362 | tree=tree,
363 | state=jax.tree_map(lambda x: x[0], tree.states),
364 | env=env,
365 | node_counter=0,
366 | parent_idx=0,
367 | )
368 |
369 | tree, _ = jax.lax.while_loop(cond_fn, loop_fn, (tree, n_added))
370 | return tree
371 |
--------------------------------------------------------------------------------