├── .gitignore ├── README.rst └── muzero ├── __init__.py ├── config.py ├── game ├── __init__.py ├── cartpole.py ├── game.py └── gym_wrappers.py ├── muzero.py ├── networks ├── __init__.py ├── cartpole_network.py ├── network.py └── shared_storage.py ├── self_play ├── __init__.py ├── mcts.py ├── self_play.py └── utils.py └── training ├── __init__.py ├── replay_buffer.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Pythond 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # Pycharm 127 | .idea/ 128 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. |copy| unicode:: 0xA9 2 | .. |---| unicode:: U+02014 3 | 4 | ====== 5 | MuZero 6 | ====== 7 | 8 | This repository is a Python implementation of the MuZero algorithm. 9 | It is based upon the `pre-print paper`__ and the `pseudocode`__ describing the Muzero framework. 10 | Neural computations are implemented with Tensorflow. 11 | 12 | You can easily train your own MuZero, more specifically for one player and non-image based environments (such as `CartPole`__). 13 | If you wish to train Muzero on other kinds of environments, this codebase can be used with slight modifications. 14 | 15 | __ https://arxiv.org/abs/1911.08265 16 | __ https://arxiv.org/src/1911.08265v1/anc/pseudocode.py 17 | __ https://gym.openai.com/envs/CartPole-v1/ 18 | 19 | 20 | **DISCLAIMER**: this code is early research code. What this means is: 21 | 22 | - Silent bugs may exist. 23 | - It may not work reliably on other environments or with other hyper-parameters. 24 | - The code quality and documentation are quite lacking, and much of the code might still feel "in-progress". 25 | - The training and testing pipeline is not very advanced. 26 | 27 | Dependencies 28 | ============ 29 | 30 | We run this code using: 31 | 32 | - Conda **4.7.12** 33 | - Python **3.7** 34 | - Tensorflow **2.0.0** 35 | - Numpy **1.17.3** 36 | 37 | Training your MuZero 38 | ==================== 39 | 40 | This code must be run from the main function in ``muzero.py`` (don't forget to first configure your conda environment). 41 | 42 | Training a Cartpole-v1 bot 43 | -------------------------- 44 | 45 | To train a model, please follow these steps: 46 | 47 | 1) Create or modify an existing configuration of Muzero in ``config.py``. 48 | 49 | 2) Call the right configuration inside the main of ``muzero.py``. 50 | 51 | 3) Run the main function: ``python muzero.py``. 52 | 53 | Training on an other environment 54 | -------------------------------- 55 | 56 | To train on a different environment than Cartpole-v1, please follow these additional steps: 57 | 58 | 1) Create a class that extends ``AbstractGame``, this class should implement the behavior of your environment. 59 | For instance, the ``CartPole`` class extends ``AbstractGame`` and works as a wrapper upon `gym CartPole-v1`__. 60 | You can use the ``CartPole`` class as a template for any gym environment. 61 | 62 | __ https://gym.openai.com/envs/CartPole-v1/ 63 | 64 | 2) **This step is optional** (only if you want to use a different kind of network architecture or value/reward transform). 65 | Create a class that extends ``BaseNetwork``, this class should implement the different networks (representation, value, policy, reward and dynamic) and value/reward transforms. 66 | For instance, the ``CartPoleNetwork`` class extends ``BaseNetwork`` and implements fully connected networks. 67 | 68 | 3) **This step is optional** (only if you use a different value/reward transform). 69 | You should implement the corresponding inverse value/reward transform by modifying the ``loss_value`` and ``loss_reward`` function inside ``training.py``. 70 | 71 | Differences from the paper 72 | ========================== 73 | 74 | This implementation differ from the original paper in the following manners: 75 | 76 | - We use fully connected layers instead of convolutional ones. This is due to the nature of our environment (Cartpole-v1) which as no spatial correlation in the observation vector. 77 | - We don't scale the hidden state between 0 and 1 using min-max normalization. Instead we use a tanh function that maps any values in a range between -1 and 1. 78 | - We do use a slightly simple invertible transform for the value prediction by removing the linear term. 79 | - During training, samples are drawn from a uniform distribution instead of using prioritized replay. 80 | - We also scale the loss of each head by 1/K (with K the number of unrolled steps). But, instead we consider that K is always constant (even if it is not always true). 81 | -------------------------------------------------------------------------------- /muzero/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johan-gras/MuZero/4f53f0c3e6b853990500b0b041306d051fce3951/muzero/__init__.py -------------------------------------------------------------------------------- /muzero/config.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Optional, Dict 3 | 4 | import tensorflow_core as tf 5 | 6 | from game.cartpole import CartPole 7 | from game.game import AbstractGame 8 | from networks.cartpole_network import CartPoleNetwork 9 | from networks.network import BaseNetwork, UniformNetwork 10 | 11 | KnownBounds = collections.namedtuple('KnownBounds', ['min', 'max']) 12 | 13 | 14 | class MuZeroConfig(object): 15 | 16 | def __init__(self, 17 | game, 18 | nb_training_loop: int, 19 | nb_episodes: int, 20 | nb_epochs: int, 21 | network_args: Dict, 22 | network, 23 | action_space_size: int, 24 | max_moves: int, 25 | discount: float, 26 | dirichlet_alpha: float, 27 | num_simulations: int, 28 | batch_size: int, 29 | td_steps: int, 30 | visit_softmax_temperature_fn, 31 | lr: float, 32 | known_bounds: Optional[KnownBounds] = None): 33 | ### Environment 34 | self.game = game 35 | 36 | ### Self-Play 37 | self.action_space_size = action_space_size 38 | # self.num_actors = num_actors 39 | 40 | self.visit_softmax_temperature_fn = visit_softmax_temperature_fn 41 | self.max_moves = max_moves 42 | self.num_simulations = num_simulations 43 | self.discount = discount 44 | 45 | # Root prior exploration noise. 46 | self.root_dirichlet_alpha = dirichlet_alpha 47 | self.root_exploration_fraction = 0.25 48 | 49 | # UCB formula 50 | self.pb_c_base = 19652 51 | self.pb_c_init = 1.25 52 | 53 | # If we already have some information about which values occur in the 54 | # environment, we can use them to initialize the rescaling. 55 | # This is not strictly necessary, but establishes identical behaviour to 56 | # AlphaZero in board games. 57 | self.known_bounds = known_bounds 58 | 59 | ### Training 60 | self.nb_training_loop = nb_training_loop 61 | self.nb_episodes = nb_episodes # Nb of episodes per training loop 62 | self.nb_epochs = nb_epochs # Nb of epochs per training loop 63 | 64 | # self.training_steps = int(1000e3) 65 | # self.checkpoint_interval = int(1e3) 66 | self.window_size = int(1e6) 67 | self.batch_size = batch_size 68 | self.num_unroll_steps = 5 69 | self.td_steps = td_steps 70 | 71 | self.weight_decay = 1e-4 72 | self.momentum = 0.9 73 | 74 | self.network_args = network_args 75 | self.network = network 76 | self.lr = lr 77 | # Exponential learning rate schedule 78 | # self.lr_init = lr_init 79 | # self.lr_decay_rate = 0.1 80 | # self.lr_decay_steps = lr_decay_steps 81 | 82 | def new_game(self) -> AbstractGame: 83 | return self.game(self.discount) 84 | 85 | def new_network(self) -> BaseNetwork: 86 | return self.network(**self.network_args) 87 | 88 | def uniform_network(self) -> UniformNetwork: 89 | return UniformNetwork(self.action_space_size) 90 | 91 | def new_optimizer(self) -> tf.keras.optimizers: 92 | return tf.keras.optimizers.SGD(learning_rate=self.lr, momentum=self.momentum) 93 | 94 | 95 | def make_cartpole_config() -> MuZeroConfig: 96 | def visit_softmax_temperature(num_moves, training_steps): 97 | return 1.0 98 | 99 | return MuZeroConfig( 100 | game=CartPole, 101 | nb_training_loop=50, 102 | nb_episodes=20, 103 | nb_epochs=20, 104 | network_args={'action_size': 2, 105 | 'state_size': 4, 106 | 'representation_size': 4, 107 | 'max_value': 500}, 108 | network=CartPoleNetwork, 109 | action_space_size=2, 110 | max_moves=1000, 111 | discount=0.99, 112 | dirichlet_alpha=0.25, 113 | num_simulations=11, # Odd number perform better in eval mode 114 | batch_size=512, 115 | td_steps=10, 116 | visit_softmax_temperature_fn=visit_softmax_temperature, 117 | lr=0.05) 118 | 119 | 120 | """ 121 | Legacy configs from the DeepMind's pseudocode. 122 | 123 | def make_board_game_config(action_space_size: int, max_moves: int, 124 | dirichlet_alpha: float, 125 | lr_init: float) -> MuZeroConfig: 126 | def visit_softmax_temperature(num_moves, training_steps): 127 | if num_moves < 30: 128 | return 1.0 129 | else: 130 | return 0.0 # Play according to the max. 131 | 132 | return MuZeroConfig( 133 | action_space_size=action_space_size, 134 | max_moves=max_moves, 135 | discount=1.0, 136 | dirichlet_alpha=dirichlet_alpha, 137 | num_simulations=800, 138 | batch_size=2048, 139 | td_steps=max_moves, # Always use Monte Carlo return. 140 | num_actors=3000, 141 | lr_init=lr_init, 142 | lr_decay_steps=400e3, 143 | visit_softmax_temperature_fn=visit_softmax_temperature, 144 | known_bounds=KnownBounds(-1, 1)) 145 | 146 | 147 | def make_go_config() -> MuZeroConfig: 148 | return make_board_game_config( 149 | action_space_size=362, max_moves=722, dirichlet_alpha=0.03, lr_init=0.01) 150 | 151 | 152 | def make_chess_config() -> MuZeroConfig: 153 | return make_board_game_config( 154 | action_space_size=4672, max_moves=512, dirichlet_alpha=0.3, lr_init=0.1) 155 | 156 | 157 | def make_shogi_config() -> MuZeroConfig: 158 | return make_board_game_config( 159 | action_space_size=11259, max_moves=512, dirichlet_alpha=0.15, lr_init=0.1) 160 | 161 | 162 | def make_atari_config() -> MuZeroConfig: 163 | def visit_softmax_temperature(num_moves, training_steps): 164 | if training_steps < 500e3: 165 | return 1.0 166 | elif training_steps < 750e3: 167 | return 0.5 168 | else: 169 | return 0.25 170 | 171 | return MuZeroConfig( 172 | action_space_size=18, 173 | max_moves=27000, # Half an hour at action repeat 4. 174 | discount=0.997, 175 | dirichlet_alpha=0.25, 176 | num_simulations=50, 177 | batch_size=1024, 178 | td_steps=10, 179 | num_actors=350, 180 | lr_init=0.05, 181 | lr_decay_steps=350e3, 182 | visit_softmax_temperature_fn=visit_softmax_temperature) 183 | """ 184 | -------------------------------------------------------------------------------- /muzero/game/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johan-gras/MuZero/4f53f0c3e6b853990500b0b041306d051fce3951/muzero/game/__init__.py -------------------------------------------------------------------------------- /muzero/game/cartpole.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import gym 4 | 5 | from game.game import Action, AbstractGame 6 | from game.gym_wrappers import ScalingObservationWrapper 7 | 8 | 9 | class CartPole(AbstractGame): 10 | """The Gym CartPole environment""" 11 | 12 | def __init__(self, discount: float): 13 | super().__init__(discount) 14 | self.env = gym.make('CartPole-v1') 15 | self.env = ScalingObservationWrapper(self.env, low=[-2.4, -2.0, -0.42, -3.5], high=[2.4, 2.0, 0.42, 3.5]) 16 | self.actions = list(map(lambda i: Action(i), range(self.env.action_space.n))) 17 | self.observations = [self.env.reset()] 18 | self.done = False 19 | 20 | @property 21 | def action_space_size(self) -> int: 22 | """Return the size of the action space.""" 23 | return len(self.actions) 24 | 25 | def step(self, action) -> int: 26 | """Execute one step of the game conditioned by the given action.""" 27 | 28 | observation, reward, done, _ = self.env.step(action.index) 29 | self.observations += [observation] 30 | self.done = done 31 | return reward 32 | 33 | def terminal(self) -> bool: 34 | """Is the game is finished?""" 35 | return self.done 36 | 37 | def legal_actions(self) -> List[Action]: 38 | """Return the legal actions available at this instant.""" 39 | return self.actions 40 | 41 | def make_image(self, state_index: int): 42 | """Compute the state of the game.""" 43 | return self.observations[state_index] 44 | -------------------------------------------------------------------------------- /muzero/game/game.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from typing import List 3 | 4 | from self_play.utils import Node 5 | 6 | 7 | class Action(object): 8 | """ Class that represent an action of a game.""" 9 | 10 | def __init__(self, index: int): 11 | self.index = index 12 | 13 | def __hash__(self): 14 | return self.index 15 | 16 | def __eq__(self, other): 17 | return self.index == other.index 18 | 19 | def __gt__(self, other): 20 | return self.index > other.index 21 | 22 | 23 | class Player(object): 24 | """ 25 | A one player class. 26 | This class is useless, it's here for legacy purpose and for potential adaptations for a two players MuZero. 27 | """ 28 | 29 | def __eq__(self, other): 30 | return True 31 | 32 | 33 | class ActionHistory(object): 34 | """ 35 | Simple history container used inside the search. 36 | Only used to keep track of the actions executed. 37 | """ 38 | 39 | def __init__(self, history: List[Action], action_space_size: int): 40 | self.history = list(history) 41 | self.action_space_size = action_space_size 42 | 43 | def clone(self): 44 | return ActionHistory(self.history, self.action_space_size) 45 | 46 | def add_action(self, action: Action): 47 | self.history.append(action) 48 | 49 | def last_action(self) -> Action: 50 | return self.history[-1] 51 | 52 | def action_space(self) -> List[Action]: 53 | return [Action(i) for i in range(self.action_space_size)] 54 | 55 | def to_play(self) -> Player: 56 | return Player() 57 | 58 | 59 | class AbstractGame(ABC): 60 | """ 61 | Abstract class that allows to implement a game. 62 | One instance represent a single episode of interaction with the environment. 63 | """ 64 | 65 | def __init__(self, discount: float): 66 | self.history = [] 67 | self.rewards = [] 68 | self.child_visits = [] 69 | self.root_values = [] 70 | self.discount = discount 71 | 72 | def apply(self, action: Action): 73 | """Apply an action onto the environment.""" 74 | 75 | reward = self.step(action) 76 | self.rewards.append(reward) 77 | self.history.append(action) 78 | 79 | def store_search_statistics(self, root: Node): 80 | """After each MCTS run, store the statistics generated by the search.""" 81 | 82 | sum_visits = sum(child.visit_count for child in root.children.values()) 83 | action_space = (Action(index) for index in range(self.action_space_size)) 84 | self.child_visits.append([ 85 | root.children[a].visit_count / sum_visits if a in root.children else 0 86 | for a in action_space 87 | ]) 88 | self.root_values.append(root.value()) 89 | 90 | def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int, to_play: Player): 91 | """Generate targets to learn from during the network training.""" 92 | 93 | # The value target is the discounted root value of the search tree N steps 94 | # into the future, plus the discounted sum of all rewards until then. 95 | targets = [] 96 | for current_index in range(state_index, state_index + num_unroll_steps + 1): 97 | bootstrap_index = current_index + td_steps 98 | if bootstrap_index < len(self.root_values): 99 | value = self.root_values[bootstrap_index] * self.discount ** td_steps 100 | else: 101 | value = 0 102 | 103 | for i, reward in enumerate(self.rewards[current_index:bootstrap_index]): 104 | value += reward * self.discount ** i 105 | 106 | if current_index < len(self.root_values): 107 | targets.append((value, self.rewards[current_index], self.child_visits[current_index])) 108 | else: 109 | # States past the end of games are treated as absorbing states. 110 | targets.append((0, 0, [])) 111 | return targets 112 | 113 | def to_play(self) -> Player: 114 | """Return the current player.""" 115 | return Player() 116 | 117 | def action_history(self) -> ActionHistory: 118 | """Return the actions executed inside the search.""" 119 | return ActionHistory(self.history, self.action_space_size) 120 | 121 | # Methods to be implemented by the children class 122 | @property 123 | @abstractmethod 124 | def action_space_size(self) -> int: 125 | """Return the size of the action space.""" 126 | pass 127 | 128 | @abstractmethod 129 | def step(self, action) -> int: 130 | """Execute one step of the game conditioned by the given action.""" 131 | pass 132 | 133 | @abstractmethod 134 | def terminal(self) -> bool: 135 | """Is the game is finished?""" 136 | pass 137 | 138 | @abstractmethod 139 | def legal_actions(self) -> List[Action]: 140 | """Return the legal actions available at this instant.""" 141 | pass 142 | 143 | @abstractmethod 144 | def make_image(self, state_index: int): 145 | """Compute the state of the game.""" 146 | pass 147 | -------------------------------------------------------------------------------- /muzero/game/gym_wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class ScalingObservationWrapper(gym.ObservationWrapper): 6 | """ 7 | Wrapper that apply a min-max scaling of observations. 8 | """ 9 | 10 | def __init__(self, env, low=None, high=None): 11 | super().__init__(env) 12 | assert isinstance(env.observation_space, gym.spaces.Box) 13 | 14 | low = np.array(self.observation_space.low if low is None else low) 15 | high = np.array(self.observation_space.high if high is None else high) 16 | 17 | self.mean = (high + low) / 2 18 | self.max = high - self.mean 19 | 20 | def observation(self, observation): 21 | return (observation - self.mean) / self.max 22 | -------------------------------------------------------------------------------- /muzero/muzero.py: -------------------------------------------------------------------------------- 1 | from config import MuZeroConfig, make_cartpole_config 2 | from networks.shared_storage import SharedStorage 3 | from self_play.self_play import run_selfplay, run_eval 4 | from training.replay_buffer import ReplayBuffer 5 | from training.training import train_network 6 | 7 | 8 | def muzero(config: MuZeroConfig): 9 | """ 10 | MuZero training is split into two independent parts: Network training and 11 | self-play data generation. 12 | These two parts only communicate by transferring the latest networks checkpoint 13 | from the training to the self-play, and the finished games from the self-play 14 | to the training. 15 | In contrast to the original MuZero algorithm this version doesn't works with 16 | multiple threads, therefore the training and self-play is done alternately. 17 | """ 18 | storage = SharedStorage(config.new_network(), config.uniform_network(), config.new_optimizer()) 19 | replay_buffer = ReplayBuffer(config) 20 | 21 | for loop in range(config.nb_training_loop): 22 | print("Training loop", loop) 23 | score_train = run_selfplay(config, storage, replay_buffer, config.nb_episodes) 24 | train_network(config, storage, replay_buffer, config.nb_epochs) 25 | 26 | print("Train score:", score_train) 27 | print("Eval score:", run_eval(config, storage, 50)) 28 | print(f"MuZero played {config.nb_episodes * (loop + 1)} " 29 | f"episodes and trained for {config.nb_epochs * (loop + 1)} epochs.\n") 30 | 31 | return storage.latest_network() 32 | 33 | 34 | if __name__ == '__main__': 35 | config = make_cartpole_config() 36 | muzero(config) 37 | -------------------------------------------------------------------------------- /muzero/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johan-gras/MuZero/4f53f0c3e6b853990500b0b041306d051fce3951/muzero/networks/__init__.py -------------------------------------------------------------------------------- /muzero/networks/cartpole_network.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from tensorflow_core.python.keras import regularizers 5 | from tensorflow_core.python.keras.layers.core import Dense 6 | from tensorflow_core.python.keras.models import Sequential 7 | 8 | from game.game import Action 9 | from networks.network import BaseNetwork 10 | 11 | 12 | class CartPoleNetwork(BaseNetwork): 13 | 14 | def __init__(self, 15 | state_size: int, 16 | action_size: int, 17 | representation_size: int, 18 | max_value: int, 19 | hidden_neurons: int = 64, 20 | weight_decay: float = 1e-4, 21 | representation_activation: str = 'tanh'): 22 | self.state_size = state_size 23 | self.action_size = action_size 24 | self.value_support_size = math.ceil(math.sqrt(max_value)) + 1 25 | 26 | regularizer = regularizers.l2(weight_decay) 27 | representation_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer), 28 | Dense(representation_size, activation=representation_activation, 29 | kernel_regularizer=regularizer)]) 30 | value_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer), 31 | Dense(self.value_support_size, kernel_regularizer=regularizer)]) 32 | policy_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer), 33 | Dense(action_size, kernel_regularizer=regularizer)]) 34 | dynamic_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer), 35 | Dense(representation_size, activation=representation_activation, 36 | kernel_regularizer=regularizer)]) 37 | reward_network = Sequential([Dense(16, activation='relu', kernel_regularizer=regularizer), 38 | Dense(1, kernel_regularizer=regularizer)]) 39 | 40 | super().__init__(representation_network, value_network, policy_network, dynamic_network, reward_network) 41 | 42 | def _value_transform(self, value_support: np.array) -> float: 43 | """ 44 | The value is obtained by first computing the expected value from the discrete support. 45 | Second, the inverse transform is then apply (the square function). 46 | """ 47 | 48 | value = self._softmax(value_support) 49 | value = np.dot(value, range(self.value_support_size)) 50 | value = np.asscalar(value) ** 2 51 | return value 52 | 53 | def _reward_transform(self, reward: np.array) -> float: 54 | return np.asscalar(reward) 55 | 56 | def _conditioned_hidden_state(self, hidden_state: np.array, action: Action) -> np.array: 57 | conditioned_hidden = np.concatenate((hidden_state, np.eye(self.action_size)[action.index])) 58 | return np.expand_dims(conditioned_hidden, axis=0) 59 | 60 | def _softmax(self, values): 61 | """Compute softmax using numerical stability tricks.""" 62 | values_exp = np.exp(values - np.max(values)) 63 | return values_exp / np.sum(values_exp) 64 | -------------------------------------------------------------------------------- /muzero/networks/network.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Callable 4 | 5 | import numpy as np 6 | from tensorflow_core.python.keras.models import Model 7 | 8 | from game.game import Action 9 | 10 | 11 | class NetworkOutput(typing.NamedTuple): 12 | value: float 13 | reward: float 14 | policy_logits: Dict[Action, float] 15 | hidden_state: typing.Optional[List[float]] 16 | 17 | @staticmethod 18 | def build_policy_logits(policy_logits): 19 | return {Action(i): logit for i, logit in enumerate(policy_logits[0])} 20 | 21 | 22 | class AbstractNetwork(ABC): 23 | 24 | def __init__(self): 25 | self.training_steps = 0 26 | 27 | @abstractmethod 28 | def initial_inference(self, image) -> NetworkOutput: 29 | pass 30 | 31 | @abstractmethod 32 | def recurrent_inference(self, hidden_state, action) -> NetworkOutput: 33 | pass 34 | 35 | 36 | class UniformNetwork(AbstractNetwork): 37 | """policy -> uniform, value -> 0, reward -> 0""" 38 | 39 | def __init__(self, action_size: int): 40 | super().__init__() 41 | self.action_size = action_size 42 | 43 | def initial_inference(self, image) -> NetworkOutput: 44 | return NetworkOutput(0, 0, {Action(i): 1 / self.action_size for i in range(self.action_size)}, None) 45 | 46 | def recurrent_inference(self, hidden_state, action) -> NetworkOutput: 47 | return NetworkOutput(0, 0, {Action(i): 1 / self.action_size for i in range(self.action_size)}, None) 48 | 49 | 50 | class InitialModel(Model): 51 | """Model that combine the representation and prediction (value+policy) network.""" 52 | 53 | def __init__(self, representation_network: Model, value_network: Model, policy_network: Model): 54 | super(InitialModel, self).__init__() 55 | self.representation_network = representation_network 56 | self.value_network = value_network 57 | self.policy_network = policy_network 58 | 59 | def call(self, image): 60 | hidden_representation = self.representation_network(image) 61 | value = self.value_network(hidden_representation) 62 | policy_logits = self.policy_network(hidden_representation) 63 | return hidden_representation, value, policy_logits 64 | 65 | 66 | class RecurrentModel(Model): 67 | """Model that combine the dynamic, reward and prediction (value+policy) network.""" 68 | 69 | def __init__(self, dynamic_network: Model, reward_network: Model, value_network: Model, policy_network: Model): 70 | super(RecurrentModel, self).__init__() 71 | self.dynamic_network = dynamic_network 72 | self.reward_network = reward_network 73 | self.value_network = value_network 74 | self.policy_network = policy_network 75 | 76 | def call(self, conditioned_hidden): 77 | hidden_representation = self.dynamic_network(conditioned_hidden) 78 | reward = self.reward_network(conditioned_hidden) 79 | value = self.value_network(hidden_representation) 80 | policy_logits = self.policy_network(hidden_representation) 81 | return hidden_representation, reward, value, policy_logits 82 | 83 | 84 | class BaseNetwork(AbstractNetwork): 85 | """Base class that contains all the networks and models of MuZero.""" 86 | 87 | def __init__(self, representation_network: Model, value_network: Model, policy_network: Model, 88 | dynamic_network: Model, reward_network: Model): 89 | super().__init__() 90 | # Networks blocks 91 | self.representation_network = representation_network 92 | self.value_network = value_network 93 | self.policy_network = policy_network 94 | self.dynamic_network = dynamic_network 95 | self.reward_network = reward_network 96 | 97 | # Models for inference and training 98 | self.initial_model = InitialModel(self.representation_network, self.value_network, self.policy_network) 99 | self.recurrent_model = RecurrentModel(self.dynamic_network, self.reward_network, self.value_network, 100 | self.policy_network) 101 | 102 | def initial_inference(self, image: np.array) -> NetworkOutput: 103 | """representation + prediction function""" 104 | 105 | hidden_representation, value, policy_logits = self.initial_model.predict(np.expand_dims(image, 0)) 106 | output = NetworkOutput(value=self._value_transform(value), 107 | reward=0., 108 | policy_logits=NetworkOutput.build_policy_logits(policy_logits), 109 | hidden_state=hidden_representation[0]) 110 | return output 111 | 112 | def recurrent_inference(self, hidden_state: np.array, action: Action) -> NetworkOutput: 113 | """dynamics + prediction function""" 114 | 115 | conditioned_hidden = self._conditioned_hidden_state(hidden_state, action) 116 | hidden_representation, reward, value, policy_logits = self.recurrent_model.predict(conditioned_hidden) 117 | output = NetworkOutput(value=self._value_transform(value), 118 | reward=self._reward_transform(reward), 119 | policy_logits=NetworkOutput.build_policy_logits(policy_logits), 120 | hidden_state=hidden_representation[0]) 121 | return output 122 | 123 | @abstractmethod 124 | def _value_transform(self, value: np.array) -> float: 125 | pass 126 | 127 | @abstractmethod 128 | def _reward_transform(self, reward: np.array) -> float: 129 | pass 130 | 131 | @abstractmethod 132 | def _conditioned_hidden_state(self, hidden_state: np.array, action: Action) -> np.array: 133 | pass 134 | 135 | def cb_get_variables(self) -> Callable: 136 | """Return a callback that return the trainable variables of the network.""" 137 | 138 | def get_variables(): 139 | networks = (self.representation_network, self.value_network, self.policy_network, 140 | self.dynamic_network, self.reward_network) 141 | return [variables 142 | for variables_list in map(lambda n: n.weights, networks) 143 | for variables in variables_list] 144 | 145 | return get_variables 146 | -------------------------------------------------------------------------------- /muzero/networks/shared_storage.py: -------------------------------------------------------------------------------- 1 | import tensorflow_core as tf 2 | 3 | from networks.network import BaseNetwork, UniformNetwork, AbstractNetwork 4 | 5 | 6 | class SharedStorage(object): 7 | """Save the different versions of the network.""" 8 | 9 | def __init__(self, network: BaseNetwork, uniform_network: UniformNetwork, optimizer: tf.keras.optimizers): 10 | self._networks = {} 11 | self.current_network = network 12 | self.uniform_network = uniform_network 13 | self.optimizer = optimizer 14 | 15 | def latest_network(self) -> AbstractNetwork: 16 | if self._networks: 17 | return self._networks[max(self._networks.keys())] 18 | else: 19 | # policy -> uniform, value -> 0, reward -> 0 20 | return self.uniform_network 21 | 22 | def save_network(self, step: int, network: BaseNetwork): 23 | self._networks[step] = network 24 | -------------------------------------------------------------------------------- /muzero/self_play/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johan-gras/MuZero/4f53f0c3e6b853990500b0b041306d051fce3951/muzero/self_play/__init__.py -------------------------------------------------------------------------------- /muzero/self_play/mcts.py: -------------------------------------------------------------------------------- 1 | """MCTS module: where MuZero thinks inside the tree.""" 2 | 3 | import math 4 | import random 5 | from typing import List 6 | 7 | import numpy 8 | 9 | from config import MuZeroConfig 10 | from game.game import Player, Action, ActionHistory 11 | from networks.network import NetworkOutput, BaseNetwork 12 | from self_play.utils import MinMaxStats, Node, softmax_sample 13 | 14 | 15 | def add_exploration_noise(config: MuZeroConfig, node: Node): 16 | """ 17 | At the start of each search, we add dirichlet noise to the prior of the root 18 | to encourage the search to explore new actions. 19 | """ 20 | actions = list(node.children.keys()) 21 | noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions)) 22 | frac = config.root_exploration_fraction 23 | for a, n in zip(actions, noise): 24 | node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac 25 | 26 | 27 | def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory, network: BaseNetwork): 28 | """ 29 | Core Monte Carlo Tree Search algorithm. 30 | To decide on an action, we run N simulations, always starting at the root of 31 | the search tree and traversing the tree according to the UCB formula until we 32 | reach a leaf node. 33 | """ 34 | min_max_stats = MinMaxStats(config.known_bounds) 35 | 36 | for _ in range(config.num_simulations): 37 | history = action_history.clone() 38 | node = root 39 | search_path = [node] 40 | 41 | while node.expanded(): 42 | action, node = select_child(config, node, min_max_stats) 43 | history.add_action(action) 44 | search_path.append(node) 45 | 46 | # Inside the search tree we use the dynamics function to obtain the next 47 | # hidden state given an action and the previous hidden state. 48 | parent = search_path[-2] 49 | network_output = network.recurrent_inference(parent.hidden_state, history.last_action()) 50 | expand_node(node, history.to_play(), history.action_space(), network_output) 51 | 52 | backpropagate(search_path, network_output.value, history.to_play(), config.discount, min_max_stats) 53 | 54 | 55 | def select_child(config: MuZeroConfig, node: Node, min_max_stats: MinMaxStats): 56 | """ 57 | Select the child with the highest UCB score. 58 | """ 59 | # When the parent visit count is zero, all ucb scores are zeros, therefore we return a random child 60 | if node.visit_count == 0: 61 | return random.sample(node.children.items(), 1)[0] 62 | 63 | _, action, child = max( 64 | (ucb_score(config, node, child, min_max_stats), action, 65 | child) for action, child in node.children.items()) 66 | return action, child 67 | 68 | 69 | def ucb_score(config: MuZeroConfig, parent: Node, child: Node, 70 | min_max_stats: MinMaxStats) -> float: 71 | """ 72 | The score for a node is based on its value, plus an exploration bonus based on 73 | the prior. 74 | """ 75 | pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base) + config.pb_c_init 76 | pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) 77 | 78 | prior_score = pb_c * child.prior 79 | value_score = min_max_stats.normalize(child.value()) 80 | return prior_score + value_score 81 | 82 | 83 | def expand_node(node: Node, to_play: Player, actions: List[Action], 84 | network_output: NetworkOutput): 85 | """ 86 | We expand a node using the value, reward and policy prediction obtained from 87 | the neural networks. 88 | """ 89 | node.to_play = to_play 90 | node.hidden_state = network_output.hidden_state 91 | node.reward = network_output.reward 92 | policy = {a: math.exp(network_output.policy_logits[a]) for a in actions} 93 | policy_sum = sum(policy.values()) 94 | for action, p in policy.items(): 95 | node.children[action] = Node(p / policy_sum) 96 | 97 | 98 | def backpropagate(search_path: List[Node], value: float, to_play: Player, 99 | discount: float, min_max_stats: MinMaxStats): 100 | """ 101 | At the end of a simulation, we propagate the evaluation all the way up the 102 | tree to the root. 103 | """ 104 | for node in search_path[::-1]: 105 | node.value_sum += value if node.to_play == to_play else -value 106 | node.visit_count += 1 107 | min_max_stats.update(node.value()) 108 | 109 | value = node.reward + discount * value 110 | 111 | 112 | def select_action(config: MuZeroConfig, num_moves: int, node: Node, network: BaseNetwork, mode: str = 'softmax'): 113 | """ 114 | After running simulations inside in MCTS, we select an action based on the root's children visit counts. 115 | During training we use a softmax sample for exploration. 116 | During evaluation we select the most visited child. 117 | """ 118 | visit_counts = [child.visit_count for child in node.children.values()] 119 | actions = [action for action in node.children.keys()] 120 | action = None 121 | if mode == 'softmax': 122 | t = config.visit_softmax_temperature_fn( 123 | num_moves=num_moves, training_steps=network.training_steps) 124 | action = softmax_sample(visit_counts, actions, t) 125 | elif mode == 'max': 126 | action, _ = max(node.children.items(), key=lambda item: item[1].visit_count) 127 | return action 128 | -------------------------------------------------------------------------------- /muzero/self_play/self_play.py: -------------------------------------------------------------------------------- 1 | """Self-Play module: where the games are played.""" 2 | 3 | from config import MuZeroConfig 4 | from game.game import AbstractGame 5 | from networks.network import AbstractNetwork 6 | from networks.shared_storage import SharedStorage 7 | from self_play.mcts import run_mcts, select_action, expand_node, add_exploration_noise 8 | from self_play.utils import Node 9 | from training.replay_buffer import ReplayBuffer 10 | 11 | 12 | def run_selfplay(config: MuZeroConfig, storage: SharedStorage, replay_buffer: ReplayBuffer, train_episodes: int): 13 | """Take the latest network, produces multiple games and save them in the shared replay buffer""" 14 | network = storage.latest_network() 15 | returns = [] 16 | for _ in range(train_episodes): 17 | game = play_game(config, network) 18 | replay_buffer.save_game(game) 19 | returns.append(sum(game.rewards)) 20 | return sum(returns) / train_episodes 21 | 22 | 23 | def run_eval(config: MuZeroConfig, storage: SharedStorage, eval_episodes: int): 24 | """Evaluate MuZero without noise added to the prior of the root and without softmax action selection""" 25 | network = storage.latest_network() 26 | returns = [] 27 | for _ in range(eval_episodes): 28 | game = play_game(config, network, train=False) 29 | returns.append(sum(game.rewards)) 30 | return sum(returns) / eval_episodes if eval_episodes else 0 31 | 32 | 33 | def play_game(config: MuZeroConfig, network: AbstractNetwork, train: bool = True) -> AbstractGame: 34 | """ 35 | Each game is produced by starting at the initial board position, then 36 | repeatedly executing a Monte Carlo Tree Search to generate moves until the end 37 | of the game is reached. 38 | """ 39 | game = config.new_game() 40 | mode_action_select = 'softmax' if train else 'max' 41 | 42 | while not game.terminal() and len(game.history) < config.max_moves: 43 | # At the root of the search tree we use the representation function to 44 | # obtain a hidden state given the current observation. 45 | root = Node(0) 46 | current_observation = game.make_image(-1) 47 | expand_node(root, game.to_play(), game.legal_actions(), network.initial_inference(current_observation)) 48 | if train: 49 | add_exploration_noise(config, root) 50 | 51 | # We then run a Monte Carlo Tree Search using only action sequences and the 52 | # model learned by the networks. 53 | run_mcts(config, root, game.action_history(), network) 54 | action = select_action(config, len(game.history), root, network, mode=mode_action_select) 55 | game.apply(action) 56 | game.store_search_statistics(root) 57 | return game 58 | -------------------------------------------------------------------------------- /muzero/self_play/utils.py: -------------------------------------------------------------------------------- 1 | """Helpers for the MCTS""" 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | MAXIMUM_FLOAT_VALUE = float('inf') 7 | 8 | 9 | class MinMaxStats(object): 10 | """A class that holds the min-max values of the tree.""" 11 | 12 | def __init__(self, known_bounds): 13 | self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE 14 | self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE 15 | 16 | def update(self, value: float): 17 | if value is None: 18 | raise ValueError 19 | 20 | self.maximum = max(self.maximum, value) 21 | self.minimum = min(self.minimum, value) 22 | 23 | def normalize(self, value: float) -> float: 24 | # If the value is unknow, by default we set it to the minimum possible value 25 | if value is None: 26 | return 0.0 27 | 28 | if self.maximum > self.minimum: 29 | # We normalize only when we have set the maximum and minimum values. 30 | return (value - self.minimum) / (self.maximum - self.minimum) 31 | return value 32 | 33 | 34 | class Node(object): 35 | """A class that represent nodes inside the MCTS tree""" 36 | 37 | def __init__(self, prior: float): 38 | self.visit_count = 0 39 | self.to_play = -1 40 | self.prior = prior 41 | self.value_sum = 0 42 | self.children = {} 43 | self.hidden_state = None 44 | self.reward = 0 45 | 46 | def expanded(self) -> bool: 47 | return len(self.children) > 0 48 | 49 | def value(self) -> Optional[float]: 50 | if self.visit_count == 0: 51 | return None 52 | return self.value_sum / self.visit_count 53 | 54 | 55 | def softmax_sample(visit_counts, actions, t): 56 | counts_exp = np.exp(visit_counts) * (1 / t) 57 | probs = counts_exp / np.sum(counts_exp, axis=0) 58 | action_idx = np.random.choice(len(actions), p=probs) 59 | return actions[action_idx] 60 | -------------------------------------------------------------------------------- /muzero/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johan-gras/MuZero/4f53f0c3e6b853990500b0b041306d051fce3951/muzero/training/__init__.py -------------------------------------------------------------------------------- /muzero/training/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from itertools import zip_longest 3 | from typing import List 4 | 5 | from config import MuZeroConfig 6 | from game.game import AbstractGame 7 | 8 | 9 | class ReplayBuffer(object): 10 | 11 | def __init__(self, config: MuZeroConfig): 12 | self.window_size = config.window_size 13 | self.batch_size = config.batch_size 14 | self.buffer = [] 15 | 16 | def save_game(self, game): 17 | if len(self.buffer) > self.window_size: 18 | self.buffer.pop(0) 19 | self.buffer.append(game) 20 | 21 | def sample_batch(self, num_unroll_steps: int, td_steps: int): 22 | # Generate some sample of data to train on 23 | games = self.sample_games() 24 | game_pos = [(g, self.sample_position(g)) for g in games] 25 | game_data = [(g.make_image(i), g.history[i:i + num_unroll_steps], 26 | g.make_target(i, num_unroll_steps, td_steps, g.to_play())) 27 | for (g, i) in game_pos] 28 | 29 | # Pre-process the batch 30 | image_batch, actions_time_batch, targets_batch = zip(*game_data) 31 | targets_init_batch, *targets_time_batch = zip(*targets_batch) 32 | actions_time_batch = list(zip_longest(*actions_time_batch, fillvalue=None)) 33 | 34 | # Building batch of valid actions and a dynamic mask for hidden representations during BPTT 35 | mask_time_batch = [] 36 | dynamic_mask_time_batch = [] 37 | last_mask = [True] * len(image_batch) 38 | for i, actions_batch in enumerate(actions_time_batch): 39 | mask = list(map(lambda a: bool(a), actions_batch)) 40 | dynamic_mask = [now for last, now in zip(last_mask, mask) if last] 41 | mask_time_batch.append(mask) 42 | dynamic_mask_time_batch.append(dynamic_mask) 43 | last_mask = mask 44 | actions_time_batch[i] = [action.index for action in actions_batch if action] 45 | 46 | batch = image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch 47 | return batch 48 | 49 | def sample_games(self) -> List[AbstractGame]: 50 | # Sample game from buffer either uniformly or according to some priority. 51 | return random.choices(self.buffer, k=self.batch_size) 52 | 53 | def sample_position(self, game: AbstractGame) -> int: 54 | # Sample position from game either uniformly or according to some priority. 55 | return random.randint(0, len(game.history)) 56 | -------------------------------------------------------------------------------- /muzero/training/training.py: -------------------------------------------------------------------------------- 1 | """Training module: this is where MuZero neurons are trained.""" 2 | 3 | import numpy as np 4 | import tensorflow_core as tf 5 | from tensorflow_core.python.keras.losses import MSE 6 | 7 | from config import MuZeroConfig 8 | from networks.network import BaseNetwork 9 | from networks.shared_storage import SharedStorage 10 | from training.replay_buffer import ReplayBuffer 11 | 12 | 13 | def train_network(config: MuZeroConfig, storage: SharedStorage, replay_buffer: ReplayBuffer, epochs: int): 14 | network = storage.current_network 15 | optimizer = storage.optimizer 16 | 17 | for _ in range(epochs): 18 | batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps) 19 | update_weights(optimizer, network, batch) 20 | storage.save_network(network.training_steps, network) 21 | 22 | 23 | def update_weights(optimizer: tf.keras.optimizers, network: BaseNetwork, batch): 24 | def scale_gradient(tensor, scale: float): 25 | """Trick function to scale the gradient in tensorflow""" 26 | return (1. - scale) * tf.stop_gradient(tensor) + scale * tensor 27 | 28 | def loss(): 29 | loss = 0 30 | image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch = batch 31 | 32 | # Initial step, from the real observation: representation + prediction networks 33 | representation_batch, value_batch, policy_batch = network.initial_model(np.array(image_batch)) 34 | 35 | # Only update the element with a policy target 36 | target_value_batch, _, target_policy_batch = zip(*targets_init_batch) 37 | mask_policy = list(map(lambda l: bool(l), target_policy_batch)) 38 | target_policy_batch = list(filter(lambda l: bool(l), target_policy_batch)) 39 | policy_batch = tf.boolean_mask(policy_batch, mask_policy) 40 | 41 | # Compute the loss of the first pass 42 | loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size)) 43 | loss += tf.math.reduce_mean( 44 | tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch)) 45 | 46 | # Recurrent steps, from action and previous hidden state. 47 | for actions_batch, targets_batch, mask, dynamic_mask in zip(actions_time_batch, targets_time_batch, 48 | mask_time_batch, dynamic_mask_time_batch): 49 | target_value_batch, target_reward_batch, target_policy_batch = zip(*targets_batch) 50 | 51 | # Only execute BPTT for elements with an action 52 | representation_batch = tf.boolean_mask(representation_batch, dynamic_mask) 53 | target_value_batch = tf.boolean_mask(target_value_batch, mask) 54 | target_reward_batch = tf.boolean_mask(target_reward_batch, mask) 55 | # Creating conditioned_representation: concatenate representations with actions batch 56 | actions_batch = tf.one_hot(actions_batch, network.action_size) 57 | 58 | # Recurrent step from conditioned representation: recurrent + prediction networks 59 | conditioned_representation_batch = tf.concat((representation_batch, actions_batch), axis=1) 60 | representation_batch, reward_batch, value_batch, policy_batch = network.recurrent_model( 61 | conditioned_representation_batch) 62 | 63 | # Only execute BPTT for elements with a policy target 64 | target_policy_batch = [policy for policy, b in zip(target_policy_batch, mask) if b] 65 | mask_policy = list(map(lambda l: bool(l), target_policy_batch)) 66 | target_policy_batch = tf.convert_to_tensor([policy for policy in target_policy_batch if policy]) 67 | policy_batch = tf.boolean_mask(policy_batch, mask_policy) 68 | 69 | # Compute the partial loss 70 | l = (tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size)) + 71 | MSE(target_reward_batch, tf.squeeze(reward_batch)) + 72 | tf.math.reduce_mean( 73 | tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch))) 74 | 75 | # Scale the gradient of the loss by the average number of actions unrolled 76 | gradient_scale = 1. / len(actions_time_batch) 77 | loss += scale_gradient(l, gradient_scale) 78 | 79 | # Half the gradient of the representation 80 | representation_batch = scale_gradient(representation_batch, 0.5) 81 | 82 | return loss 83 | 84 | optimizer.minimize(loss=loss, var_list=network.cb_get_variables()) 85 | network.training_steps += 1 86 | 87 | 88 | def loss_value(target_value_batch, value_batch, value_support_size: int): 89 | batch_size = len(target_value_batch) 90 | targets = np.zeros((batch_size, value_support_size)) 91 | sqrt_value = np.sqrt(target_value_batch) 92 | floor_value = np.floor(sqrt_value).astype(int) 93 | rest = sqrt_value - floor_value 94 | targets[range(batch_size), floor_value.astype(int)] = 1 - rest 95 | targets[range(batch_size), floor_value.astype(int) + 1] = rest 96 | 97 | return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets) 98 | --------------------------------------------------------------------------------