├── MuZeroExample.ipynb └── README.md /MuZeroExample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# environment:\n", 10 | "# pip3 install torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "# Implementation of simple game: Tic-Tac-Toe\n", 20 | "# You can change this to another two-player game.\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "\n", 24 | "BLACK, WHITE = 1, -1 # first turn or second turn player\n", 25 | "\n", 26 | "class State:\n", 27 | " '''Board implementation of Tic-Tac-Toe'''\n", 28 | " X, Y = 'ABC', '123'\n", 29 | " C = {0: '_', BLACK: 'O', WHITE: 'X'}\n", 30 | "\n", 31 | " def __init__(self):\n", 32 | " self.board = np.zeros((3, 3)) # (x, y)\n", 33 | " self.color = 1\n", 34 | " self.win_color = 0\n", 35 | " self.record = []\n", 36 | "\n", 37 | " def action2str(self, a):\n", 38 | " return self.X[a // 3] + self.Y[a % 3]\n", 39 | "\n", 40 | " def str2action(self, s):\n", 41 | " return self.X.find(s[0]) * 3 + self.Y.find(s[1])\n", 42 | "\n", 43 | " def record_string(self):\n", 44 | " return ' '.join([self.action2str(a) for a in self.record])\n", 45 | "\n", 46 | " def __str__(self):\n", 47 | " # output board.\n", 48 | " s = ' ' + ' '.join(self.Y) + '\\n'\n", 49 | " for i in range(3):\n", 50 | " s += self.X[i] + ' ' + ' '.join([self.C[self.board[i, j]] for j in range(3)]) + '\\n'\n", 51 | " s += 'record = ' + self.record_string()\n", 52 | " return s\n", 53 | "\n", 54 | " def play(self, action):\n", 55 | " # state transition function\n", 56 | " # action is position inerger (0~8) or string representation of action sequence\n", 57 | " if isinstance(action, str):\n", 58 | " for astr in action.split():\n", 59 | " self.play(self.str2action(astr))\n", 60 | " return self\n", 61 | "\n", 62 | " x, y = action // 3, action % 3\n", 63 | " self.board[x, y] = self.color\n", 64 | "\n", 65 | " # check whether 3 stones are on the line\n", 66 | " if self.board[x, :].sum() == 3 * self.color \\\n", 67 | " or self.board[:, y].sum() == 3 * self.color \\\n", 68 | " or (x == y and np.diag(self.board, k=0).sum() == 3 * self.color) \\\n", 69 | " or (x == 2 - y and np.diag(self.board[::-1,:], k=0).sum() == 3 * self.color):\n", 70 | " self.win_color = self.color\n", 71 | "\n", 72 | " self.color = -self.color\n", 73 | " self.record.append(action)\n", 74 | " return self\n", 75 | "\n", 76 | " def terminal(self):\n", 77 | " # terminal state check\n", 78 | " return self.win_color != 0 or len(self.record) == 3 * 3\n", 79 | "\n", 80 | " def terminal_reward(self):\n", 81 | " # terminal reward \n", 82 | " return self.win_color if self.color == BLACK else -self.win_color\n", 83 | "\n", 84 | " def legal_actions(self):\n", 85 | " # list of legal actions on each state\n", 86 | " return [a for a in range(3 * 3) if self.board[a // 3, a % 3] == 0]\n", 87 | "\n", 88 | " def feature(self):\n", 89 | " # input tensor for neural net (state)\n", 90 | " return np.stack([self.board == self.color, self.board == -self.color]).astype(np.float32)\n", 91 | "\n", 92 | " def action_feature(self, action):\n", 93 | " # input tensor for neural net (action)\n", 94 | " a = np.zeros((1, 3, 3), dtype=np.float32)\n", 95 | " a[0, action // 3, action % 3] = 1\n", 96 | " return a\n", 97 | "\n", 98 | "state = State().play('B1')\n", 99 | "print(state)\n", 100 | "print('input feature')\n", 101 | "print(state.feature())\n", 102 | "state = State().play('B2 A1 C2')\n", 103 | "print('input feature')\n", 104 | "print(state.feature())" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Small neural nets with PyTorch\n", 114 | "\n", 115 | "import torch\n", 116 | "import torch.nn as nn\n", 117 | "import torch.nn.functional as F\n", 118 | "\n", 119 | "class Conv(nn.Module):\n", 120 | " def __init__(self, filters0, filters1, kernel_size, bn=False):\n", 121 | " super().__init__()\n", 122 | " self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size//2, bias=False)\n", 123 | " self.bn = None\n", 124 | " if bn:\n", 125 | " self.bn = nn.BatchNorm2d(filters1)\n", 126 | "\n", 127 | " def forward(self, x):\n", 128 | " h = self.conv(x)\n", 129 | " if self.bn is not None:\n", 130 | " h = self.bn(h)\n", 131 | " return h\n", 132 | "\n", 133 | "class ResidualBlock(nn.Module):\n", 134 | " def __init__(self, filters):\n", 135 | " super().__init__()\n", 136 | " self.conv = Conv(filters, filters, 3, True)\n", 137 | "\n", 138 | " def forward(self, x):\n", 139 | " return F.relu(x + (self.conv(x)))" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "num_filters = 16\n", 149 | "num_blocks = 4\n", 150 | "\n", 151 | "class Representation(nn.Module):\n", 152 | " ''' Conversion from observation to inner abstract state '''\n", 153 | " def __init__(self, input_shape):\n", 154 | " super().__init__()\n", 155 | " self.input_shape = input_shape\n", 156 | " self.board_size = self.input_shape[1] * self.input_shape[2]\n", 157 | "\n", 158 | " self.layer0 = Conv(self.input_shape[0], num_filters, 3, bn=True)\n", 159 | " self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])\n", 160 | "\n", 161 | " def forward(self, x):\n", 162 | " h = F.relu(self.layer0(x))\n", 163 | " for block in self.blocks:\n", 164 | " h = block(h)\n", 165 | " return h\n", 166 | "\n", 167 | " def inference(self, x):\n", 168 | " self.eval()\n", 169 | " with torch.no_grad():\n", 170 | " rp = self(torch.from_numpy(x).unsqueeze(0))\n", 171 | " return rp.cpu().numpy()[0]\n", 172 | "\n", 173 | "class Prediction(nn.Module):\n", 174 | " ''' Policy and value prediction from inner abstract state '''\n", 175 | " def __init__(self, action_shape):\n", 176 | " super().__init__()\n", 177 | " self.board_size = np.prod(action_shape[1:])\n", 178 | " self.action_size = action_shape[0] * self.board_size\n", 179 | "\n", 180 | " self.conv_p1 = Conv(num_filters, 4, 1, bn=True)\n", 181 | " self.conv_p2 = Conv(4, 1, 1)\n", 182 | "\n", 183 | " self.conv_v = Conv(num_filters, 4, 1, bn=True)\n", 184 | " self.fc_v = nn.Linear(self.board_size * 4, 1, bias=False)\n", 185 | "\n", 186 | " def forward(self, rp):\n", 187 | " h_p = F.relu(self.conv_p1(rp))\n", 188 | " h_p = self.conv_p2(h_p).view(-1, self.action_size)\n", 189 | "\n", 190 | " h_v = F.relu(self.conv_v(rp))\n", 191 | " h_v = self.fc_v(h_v.view(-1, self.board_size * 4))\n", 192 | "\n", 193 | " # range of value is -1 ~ 1\n", 194 | " return F.softmax(h_p, dim=-1), torch.tanh(h_v)\n", 195 | "\n", 196 | " def inference(self, rp):\n", 197 | " self.eval()\n", 198 | " with torch.no_grad():\n", 199 | " p, v = self(torch.from_numpy(rp).unsqueeze(0))\n", 200 | " return p.cpu().numpy()[0], v.cpu().numpy()[0][0]\n", 201 | "\n", 202 | "class Dynamics(nn.Module):\n", 203 | " '''Abstract state transition'''\n", 204 | " def __init__(self, rp_shape, act_shape):\n", 205 | " super().__init__()\n", 206 | " self.rp_shape = rp_shape\n", 207 | " self.layer0 = Conv(rp_shape[0] + act_shape[0], num_filters, 3, bn=True)\n", 208 | " self.blocks = nn.ModuleList([ResidualBlock(num_filters) for _ in range(num_blocks)])\n", 209 | "\n", 210 | " def forward(self, rp, a):\n", 211 | " h = torch.cat([rp, a], dim=1)\n", 212 | " h = self.layer0(h)\n", 213 | " for block in self.blocks:\n", 214 | " h = block(h)\n", 215 | " return h\n", 216 | "\n", 217 | " def inference(self, rp, a):\n", 218 | " self.eval()\n", 219 | " with torch.no_grad():\n", 220 | " rp = self(torch.from_numpy(rp).unsqueeze(0), torch.from_numpy(a).unsqueeze(0))\n", 221 | " return rp.cpu().numpy()[0]\n", 222 | "\n", 223 | "class Net(nn.Module):\n", 224 | " '''Whole net'''\n", 225 | " def __init__(self):\n", 226 | " super().__init__()\n", 227 | " state = State()\n", 228 | " input_shape = state.feature().shape\n", 229 | " action_shape = state.action_feature(0).shape\n", 230 | " rp_shape = (num_filters, *input_shape[1:])\n", 231 | "\n", 232 | " self.representation = Representation(input_shape)\n", 233 | " self.prediction = Prediction(action_shape)\n", 234 | " self.dynamics = Dynamics(rp_shape, action_shape)\n", 235 | "\n", 236 | " def predict(self, state0, path):\n", 237 | " '''Predict p and v from original state and path'''\n", 238 | " outputs = []\n", 239 | " x = state0.feature()\n", 240 | " rp = self.representation.inference(x)\n", 241 | " outputs.append(self.prediction.inference(rp))\n", 242 | " for action in path:\n", 243 | " a = state0.action_feature(action)\n", 244 | " rp = self.dynamics.inference(rp, a)\n", 245 | " outputs.append(self.prediction.inference(rp))\n", 246 | " return outputs" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "def show_net(net, state):\n", 256 | " '''Display policy (p) and value (v)'''\n", 257 | " print(state)\n", 258 | " p, v = net.predict(state, [])[-1]\n", 259 | " print('p = ')\n", 260 | " print((p * 1000).astype(int).reshape((-1, *net.representation.input_shape[1:3])))\n", 261 | " print('v = ', v)\n", 262 | " print()\n", 263 | "\n", 264 | "# Outputs before training\n", 265 | "show_net(Net(), State())" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "# Implementation of Monte Carlo Tree Search\n", 275 | "\n", 276 | "class Node:\n", 277 | " '''Search result of one abstract (or root) state'''\n", 278 | " def __init__(self, p, v):\n", 279 | " self.p, self.v = p, v\n", 280 | " self.n, self.q_sum = np.zeros_like(p), np.zeros_like(p)\n", 281 | " self.n_all, self.q_sum_all = 1, v / 2 # prior\n", 282 | "\n", 283 | " def update(self, action, q_new):\n", 284 | " # Update\n", 285 | " self.n[action] += 1\n", 286 | " self.q_sum[action] += q_new\n", 287 | "\n", 288 | " # Update overall stats\n", 289 | " self.n_all += 1\n", 290 | " self.q_sum_all += q_new" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "import time\n", 300 | "import copy\n", 301 | "\n", 302 | "class Tree:\n", 303 | " '''Monte Carlo Tree'''\n", 304 | " def __init__(self, net):\n", 305 | " self.net = net\n", 306 | " self.nodes = {}\n", 307 | "\n", 308 | " def search(self, state, path, rp, depth):\n", 309 | " # Return predicted value from new state\n", 310 | " key = state.record_string()\n", 311 | " if len(path) > 0:\n", 312 | " key += '|' + ' '.join(map(state.action2str, path))\n", 313 | " if key not in self.nodes:\n", 314 | " p, v = self.net.prediction.inference(rp)\n", 315 | " self.nodes[key] = Node(p, v)\n", 316 | " return v\n", 317 | "\n", 318 | " # State transition by an action selected from bandit\n", 319 | " node = self.nodes[key]\n", 320 | " p = node.p\n", 321 | " mask = np.zeros_like(p)\n", 322 | " if depth == 0:\n", 323 | " # Add noise to policy on the root node\n", 324 | " p = 0.75 * p + 0.25 * np.random.dirichlet([0.15] * len(p))\n", 325 | " # On the root node, we choose action only from legal actions\n", 326 | " mask[state.legal_actions()] = 1\n", 327 | " p *= mask\n", 328 | " p /= p.sum() + 1e-16\n", 329 | "\n", 330 | " n, q_sum = 1 + node.n, node.q_sum_all / node.n_all + node.q_sum\n", 331 | " ucb = q_sum / n + 2.0 * np.sqrt(node.n_all) * p / n + mask * 4 # PUCB formula\n", 332 | " best_action = np.argmax(ucb)\n", 333 | "\n", 334 | " # Search next state by recursively calling this function\n", 335 | " rp_next = self.net.dynamics.inference(rp, state.action_feature(best_action))\n", 336 | " path.append(best_action)\n", 337 | " q_new = -self.search(state, path, rp_next, depth + 1) # With the assumption of changing player by turn\n", 338 | " node.update(best_action, q_new)\n", 339 | "\n", 340 | " return q_new\n", 341 | "\n", 342 | " def think(self, state, num_simulations, temperature = 0, show=False):\n", 343 | " # End point of MCTS\n", 344 | " if show:\n", 345 | " print(state)\n", 346 | " start, prev_time = time.time(), 0\n", 347 | " for _ in range(num_simulations):\n", 348 | " self.search(state, [], self.net.representation.inference(state.feature()), depth=0)\n", 349 | "\n", 350 | " # Display search result on every second\n", 351 | " if show:\n", 352 | " tmp_time = time.time() - start\n", 353 | " if int(tmp_time) > int(prev_time):\n", 354 | " prev_time = tmp_time\n", 355 | " root, pv = self.nodes[state.record_string()], self.pv(state)\n", 356 | " print('%.2f sec. best %s. q = %.4f. n = %d / %d. pv = %s'\n", 357 | " % (tmp_time, state.action2str(pv[0]), root.q_sum[pv[0]] / root.n[pv[0]],\n", 358 | " root.n[pv[0]], root.n_all, ' '.join([state.action2str(a) for a in pv])))\n", 359 | "\n", 360 | " # Return probability distribution weighted by the number of simulations\n", 361 | " root = self.nodes[state.record_string()]\n", 362 | " n = root.n + 1\n", 363 | " n = (n / np.max(n)) ** (1 / (temperature + 1e-8))\n", 364 | " return n / n.sum()\n", 365 | "\n", 366 | " def pv(self, state):\n", 367 | " # Return principal variation (action sequence which is considered as the best)\n", 368 | " s, pv_seq = copy.deepcopy(state), []\n", 369 | " while True:\n", 370 | " key = s.record_string()\n", 371 | " if key not in self.nodes or self.nodes[key].n.sum() == 0:\n", 372 | " break\n", 373 | " best_action = sorted([(a, self.nodes[key].n[a]) for a in s.legal_actions()], key=lambda x: -x[1])[0][0]\n", 374 | " pv_seq.append(best_action)\n", 375 | " s.play(best_action)\n", 376 | " return pv_seq" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": { 383 | "scrolled": true 384 | }, 385 | "outputs": [], 386 | "source": [ 387 | "# Search with initialized net\n", 388 | "\n", 389 | "tree = Tree(Net())\n", 390 | "tree.think(State(), 100, show=True)\n", 391 | "\n", 392 | "tree = Tree(Net())\n", 393 | "tree.think(State().play('A1 C1 A2 C2'), 200, show=True)\n", 394 | "\n", 395 | "tree = Tree(Net())\n", 396 | "tree.think(State().play('B2 A2 A3 C1 B3'), 200, show=True)\n", 397 | "\n", 398 | "tree = Tree(Net())\n", 399 | "tree.think(State().play('B2 A2 A3 C1'), 200, show=True)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "# Training of neural net\n", 409 | "\n", 410 | "import torch.optim as optim\n", 411 | "\n", 412 | "batch_size = 32\n", 413 | "num_steps = 100\n", 414 | "\n", 415 | "def gen_target(ep, k):\n", 416 | " '''Generate inputs and targets for training'''\n", 417 | " # path, reward, observation, action, policy\n", 418 | " turn_idx = np.random.randint(len(ep[0]))\n", 419 | " ps, vs, ax = [], [], []\n", 420 | " for t in range(turn_idx, turn_idx + k + 1):\n", 421 | " if t < len(ep[0]):\n", 422 | " p = ep[4][t]\n", 423 | " a = ep[3][t]\n", 424 | " else: # state after finishing game\n", 425 | " # p is 0 (loss is 0)\n", 426 | " p = np.zeros_like(ep[4][-1])\n", 427 | " # random action selection\n", 428 | " a = np.zeros(np.prod(ep[3][-1].shape), dtype=np.float32)\n", 429 | " a[np.random.randint(len(a))] = 1\n", 430 | " a = a.reshape(ep[3][-1].shape)\n", 431 | " vs.append([ep[1] if t % 2 == 0 else -ep[1]])\n", 432 | " ps.append(p)\n", 433 | " ax.append(a)\n", 434 | " \n", 435 | " return ep[2][turn_idx], ax, ps, vs\n", 436 | "\n", 437 | "def train(episodes, net, opt):\n", 438 | " '''Train neural net'''\n", 439 | " p_loss_sum, v_loss_sum = 0, 0\n", 440 | " net.train()\n", 441 | " k = 4\n", 442 | " for _ in range(num_steps):\n", 443 | " x, ax, p_target, v_target = zip(*[gen_target(episodes[np.random.randint(len(episodes))], k) for j in range(batch_size)])\n", 444 | " x = torch.from_numpy(np.array(x))\n", 445 | " ax = torch.from_numpy(np.array(ax))\n", 446 | " p_target = torch.from_numpy(np.array(p_target))\n", 447 | " v_target = torch.FloatTensor(np.array(v_target))\n", 448 | "\n", 449 | " # Change the order of axis as [time step, batch, ...]\n", 450 | " ax = torch.transpose(ax, 0, 1)\n", 451 | " p_target = torch.transpose(p_target, 0, 1)\n", 452 | " v_target = torch.transpose(v_target, 0, 1)\n", 453 | "\n", 454 | " # Compute losses for k (+ current) steps\n", 455 | " p_loss, v_loss = 0, 0\n", 456 | " for t in range(k + 1):\n", 457 | " rp = net.representation(x) if t == 0 else net.dynamics(rp, ax[t - 1])\n", 458 | " p, v = net.prediction(rp)\n", 459 | " p_loss += F.kl_div(torch.log(p), p_target[t], reduction='sum')\n", 460 | " v_loss += torch.sum(((v_target[t] - v) ** 2) / 2)\n", 461 | "\n", 462 | " p_loss_sum += p_loss.item()\n", 463 | " v_loss_sum += v_loss.item()\n", 464 | "\n", 465 | " optimizer.zero_grad()\n", 466 | " (p_loss + v_loss).backward()\n", 467 | " optimizer.step()\n", 468 | "\n", 469 | " num_train_datum = num_steps * batch_size\n", 470 | " print('p_loss %f v_loss %f' % (p_loss_sum / num_train_datum, v_loss_sum / num_train_datum))\n", 471 | " return net" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "metadata": {}, 478 | "outputs": [], 479 | "source": [ 480 | "# Battle against random agents\n", 481 | "\n", 482 | "def vs_random(net, n=100):\n", 483 | " results = {}\n", 484 | " for i in range(n):\n", 485 | " first_turn = i % 2 == 0\n", 486 | " turn = first_turn\n", 487 | " state = State()\n", 488 | " while not state.terminal():\n", 489 | " if turn:\n", 490 | " p, _ = net.predict(state, [])[-1]\n", 491 | " action = sorted([(a, p[a]) for a in state.legal_actions()], key=lambda x:-x[1])[0][0]\n", 492 | " else:\n", 493 | " action = np.random.choice(state.legal_actions())\n", 494 | " state.play(action)\n", 495 | " turn = not turn\n", 496 | " r = state.terminal_reward() if turn else -state.terminal_reward()\n", 497 | " results[r] = results.get(r, 0) + 1\n", 498 | " return results" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": { 505 | "scrolled": true 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "# Main algorithm of MuZero\n", 510 | "\n", 511 | "num_games = 50000\n", 512 | "num_games_one_epoch = 20\n", 513 | "num_simulations = 40\n", 514 | "\n", 515 | "net = Net()\n", 516 | "optimizer = optim.SGD(net.parameters(), lr=3e-4, weight_decay=3e-5, momentum=0.8)\n", 517 | "\n", 518 | "# Display battle results as {-1: lose 0: draw 1: win} (for episode generated for training, 1 means that the first player won)\n", 519 | "vs_random_sum = vs_random(net)\n", 520 | "print('vs_random = ', sorted(vs_random_sum.items()))\n", 521 | "\n", 522 | "episodes = []\n", 523 | "result_distribution = {1: 0, 0: 0, -1: 0}\n", 524 | "\n", 525 | "for g in range(num_games):\n", 526 | " # Generate one episode\n", 527 | " record, p_targets, features, action_features = [], [], [], []\n", 528 | " state = State()\n", 529 | " # temperature using to make policy targets from search results\n", 530 | " temperature = 0.7\n", 531 | "\n", 532 | " while not state.terminal():\n", 533 | " tree = Tree(net)\n", 534 | " p_target = tree.think(state, num_simulations, temperature)\n", 535 | " p_targets.append(p_target)\n", 536 | " features.append(state.feature())\n", 537 | "\n", 538 | " # Select action with generated distribution, and then make a transition by that action\n", 539 | " action = np.random.choice(np.arange(len(p_target)), p=p_target)\n", 540 | " record.append(action)\n", 541 | " action_features.append(state.action_feature(action))\n", 542 | " state.play(action)\n", 543 | " temperature *= 0.8\n", 544 | "\n", 545 | " # reward seen from the first turn player\n", 546 | " reward = state.terminal_reward() * (1 if len(record) % 2 == 0 else -1)\n", 547 | " result_distribution[reward] += 1\n", 548 | " episodes.append((record, reward, features, action_features, p_targets))\n", 549 | "\n", 550 | " if g % num_games_one_epoch == 0:\n", 551 | " print('game ', end='')\n", 552 | " print(g, ' ', end='')\n", 553 | "\n", 554 | " # Training of neural net\n", 555 | " if (g + 1) % num_games_one_epoch == 0:\n", 556 | " # Show the result distributiuon of generated episodes\n", 557 | " print('generated = ', sorted(result_distribution.items()))\n", 558 | " net = train(episodes, net, optimizer)\n", 559 | " vs_random_once = vs_random(net)\n", 560 | " print('vs_random = ', sorted(vs_random_once.items()), end='')\n", 561 | " for r, n in vs_random_once.items():\n", 562 | " vs_random_sum[r] += n\n", 563 | " print(' sum = ', sorted(vs_random_sum.items()))\n", 564 | " #show_net(net, State())\n", 565 | " #show_net(net, State().play('A1 C1 A2 C2'))\n", 566 | " #show_net(net, State().play('A1 B2 C3 B3 C1'))\n", 567 | " #show_net(net, State().play('B2 A2 A3 C1 B3'))\n", 568 | " #show_net(net, State().play('B2 A2 A3 C1'))\n", 569 | "print('finished')" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "metadata": {}, 576 | "outputs": [], 577 | "source": [ 578 | "# Show outputs from trained net\n", 579 | "\n", 580 | "print('initial state')\n", 581 | "show_net(net, State())\n", 582 | "\n", 583 | "print('WIN by put')\n", 584 | "show_net(net, State().play('A1 C1 A2 C2'))\n", 585 | "\n", 586 | "print('LOSE by opponent\\'s double')\n", 587 | "show_net(net, State().play('B2 A2 A3 C1 B3'))\n", 588 | "\n", 589 | "print('WIN through double')\n", 590 | "show_net(net, State().play('B2 A2 A3 C1'))\n", 591 | "\n", 592 | "# hard case: putting on A1 will cause double\n", 593 | "print('strategic WIN by following double')\n", 594 | "show_net(net, State().play('B1 A3'))" 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": null, 600 | "metadata": {}, 601 | "outputs": [], 602 | "source": [ 603 | "# Search with trained net\n", 604 | "\n", 605 | "tree = Tree(net)\n", 606 | "tree.think(State(), 100000, show=True)" 607 | ] 608 | } 609 | ], 610 | "metadata": { 611 | "kernelspec": { 612 | "display_name": "Python 3 (ipykernel)", 613 | "language": "python", 614 | "name": "python3" 615 | }, 616 | "language_info": { 617 | "codemirror_mode": { 618 | "name": "ipython", 619 | "version": 3 620 | }, 621 | "file_extension": ".py", 622 | "mimetype": "text/x-python", 623 | "name": "python", 624 | "nbconvert_exporter": "python", 625 | "pygments_lexer": "ipython3", 626 | "version": "3.9.7" 627 | } 628 | }, 629 | "nbformat": 4, 630 | "nbformat_minor": 2 631 | } 632 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MuZeroJupyterExample 2 | 3 | This is a simple reimplementation of MuZero algorithm for two-player board games. 4 | 5 | https://arxiv.org/abs/1911.08265 6 | 7 | 8 | 2020/7/18 Update 9 | - After fixing fatal bug in tree search, training is going well. Please try again. 10 | 11 | 12 | AlphaZero version is here: https://github.com/YuriCat/AlphaZeroJupyterExample 13 | --------------------------------------------------------------------------------