├── .gitignore ├── agents ├── rainbow │ ├── third_party │ │ ├── __init__.py │ │ └── dopamine │ │ │ ├── __init__.py │ │ │ ├── iteration_statistics.py │ │ │ ├── logger.py │ │ │ ├── checkpointer.py │ │ │ ├── sum_tree.py │ │ │ └── LICENSE │ ├── configs │ │ └── hanabi_rainbow.gin │ ├── README.md │ ├── train.py │ ├── prioritized_replay_memory.py │ ├── rainbow_agent.py │ └── run_experiment.py ├── __init__.py ├── random_agent.py └── simple_agent.py ├── hanabi_lib ├── CMakeLists.txt ├── hanabi_card.cc ├── hanabi_card.h ├── observation_encoder.h ├── canonical_encoders.h ├── hanabi_move.cc ├── hanabi_history_item.cc ├── hanabi_move.h ├── util.h ├── hanabi_history_item.h ├── util.cc ├── hanabi_observation.h ├── hanabi_hand.cc ├── hanabi_game.h ├── hanabi_observation.cc ├── hanabi_hand.h ├── hanabi_state.h ├── hanabi_game.cc ├── hanabi_state.cc └── canonical_encoders.cc ├── __init__.py ├── CMakeLists.txt ├── README.md ├── clean_all.sh ├── CONTRIBUTING.md ├── rl_env_example.py ├── game_example.py ├── game_example.cc ├── pyhanabi.h └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /agents/rainbow/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hanabi_lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library (hanabi hanabi_card.cc hanabi_game.cc hanabi_hand.cc hanabi_history_item.cc hanabi_move.cc hanabi_observation.cc hanabi_state.cc util.cc canonical_encoders.cc) 2 | target_include_directories(hanabi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /agents/rainbow/third_party/dopamine/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.8.11) 2 | project (hanabi_learning_environment) 3 | 4 | set(CMAKE_C_FLAGS "-O2 -std=c++11 -fPIC") 5 | set(CMAKE_CXX_FLAGS "-O2 -std=c++11 -fPIC") 6 | 7 | add_subdirectory (hanabi_lib) 8 | 9 | add_library (pyhanabi SHARED pyhanabi.cc) 10 | target_link_libraries (pyhanabi LINK_PUBLIC hanabi) 11 | 12 | add_executable (game_example game_example.cc) 13 | target_link_libraries (game_example LINK_PUBLIC hanabi) 14 | 15 | install(TARGETS 16 | pyhanabi 17 | LIBRARY DESTINATION lib 18 | ARCHIVE DESTINATION lib 19 | RUNTIME DESTINATION lib) 20 | 21 | install(FILES pyhanabi.h DESTINATION include) 22 | install(DIRECTORY hanabi_lib/ DESTINATION include FILES_MATCHING PATTERN "*.h") 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is not an officially supported Google product. 2 | 3 | hanabi\_learning\_environment is a research platform for Hanabi experiments. The file rl\_env.py provides an RL environment using an API similar to OpenAI Gym. A lower level game interface is provided in pyhanabi.py for non-RL methods like Monte Carlo tree search. 4 | 5 | ### Getting started 6 | ``` 7 | sudo apt-get install g++ # if you don't already have a CXX compiler 8 | sudo apt-get install cmake # if you don't already have CMake 9 | sudo apt-get install python-pip # if you don't already have pip 10 | pip install cffi # if you don't already have cffi 11 | cmake . 12 | make 13 | python rl_env_example.py # Runs RL episodes 14 | python game_example.py # Plays a game using the lower level interface 15 | ``` 16 | -------------------------------------------------------------------------------- /clean_all.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Clean-up script to bring the repository back to a pre-cmake state. 16 | 17 | #!/bin/sh 18 | if [ -f Makefile ] 19 | then 20 | make clean 21 | fi 22 | 23 | rm -rf *.pyc agents/*.pyc __pycache__ agents/__pycache__ CMakeCache.txt CMakeFiles Makefile cmake_install.cmake hanabi_lib/CMakeFiles hanabi_lib/Makefile hanabi_lib/cmake_install.cmake 24 | -------------------------------------------------------------------------------- /agents/random_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Random Agent.""" 15 | 16 | import random 17 | from rl_env import Agent 18 | 19 | 20 | class RandomAgent(Agent): 21 | """Agent that takes random legal actions.""" 22 | 23 | def __init__(self, config, *args, **kwargs): 24 | """Initialize the agent.""" 25 | self.config = config 26 | 27 | def act(self, observation): 28 | """Act based on an observation.""" 29 | if observation['current_player_offset'] == 0: 30 | return random.choice(observation['legal_moves']) 31 | else: 32 | return None 33 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_card.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_card.h" 16 | 17 | #include "util.h" 18 | 19 | namespace hanabi_learning_env { 20 | 21 | bool HanabiCard::operator==(const HanabiCard& other_card) const { 22 | return other_card.Color() == Color() && other_card.Rank() == Rank(); 23 | } 24 | 25 | std::string HanabiCard::ToString() const { 26 | if (!IsValid()) { 27 | return std::string("XX"); 28 | } 29 | return std::string() + ColorIndexToChar(Color()) + RankIndexToChar(Rank()); 30 | } 31 | 32 | } // namespace hanabi_learning_env 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_card.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_CARD_H__ 16 | #define __HANABI_CARD_H__ 17 | 18 | #include 19 | 20 | namespace hanabi_learning_env { 21 | 22 | class HanabiCard { 23 | public: 24 | HanabiCard(int color, int rank) : color_(color), rank_(rank) {} 25 | HanabiCard() = default; // Create an invalid card. 26 | bool operator==(const HanabiCard& other_card) const; 27 | bool IsValid() const { return color_ >= 0 && rank_ >= 0; } 28 | std::string ToString() const; 29 | int Color() const { return color_; } 30 | int Rank() const { return rank_; } 31 | 32 | private: 33 | int color_ = -1; // 0 indexed card color. 34 | int rank_ = -1; // 0 indexed card rank. 35 | }; 36 | 37 | } // namespace hanabi_learning_env 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /agents/rainbow/configs/hanabi_rainbow.gin: -------------------------------------------------------------------------------- 1 | import dqn_agent 2 | import rainbow_agent 3 | import run_experiment 4 | 5 | # This configures the DQN Agent. 6 | AGENT_CLASS = @DQNAgent 7 | DQNAgent.gamma = 0.99 8 | DQNAgent.update_horizon = 1 9 | DQNAgent.min_replay_history = 500 # agent steps 10 | DQNAgent.target_update_period = 500 # agent steps 11 | DQNAgent.epsilon_train = 0.0 12 | DQNAgent.epsilon_eval = 0.0 13 | DQNAgent.epsilon_decay_period = 1000 # agent steps 14 | DQNAgent.tf_device = '/gpu:0' # '/cpu:*' use for non-GPU version 15 | 16 | # This configures the Rainbow agent. 17 | AGENT_CLASS_2 = @RainbowAgent 18 | RainbowAgent.gamma = 0.99 19 | RainbowAgent.update_horizon = 1 20 | RainbowAgent.num_atoms = 51 21 | RainbowAgent.min_replay_history = 500 # agent steps 22 | RainbowAgent.target_update_period = 500 # agent steps 23 | RainbowAgent.epsilon_train = 0.0 24 | RainbowAgent.epsilon_eval = 0.0 25 | RainbowAgent.epsilon_decay_period = 1000 # agent steps 26 | RainbowAgent.tf_device = '/gpu:0' # '/cpu:*' use for non-GPU version 27 | WrappedReplayMemory.replay_capacity = 50000 28 | 29 | run_experiment.training_steps = 10000 30 | run_experiment.num_iterations = 10005 31 | run_experiment.checkpoint_every_n = 50 32 | run_one_iteration.evaluate_every_n = 10 33 | 34 | # Small Hanabi. 35 | create_environment.game_type = 'Hanabi-Full-CardKnowledge' 36 | create_environment.num_players = 2 37 | 38 | create_agent.agent_type = 'Rainbow' 39 | create_obs_stacker.history_size = 1 40 | 41 | rainbow_template.layer_size=512 42 | rainbow_template.num_layers=2 43 | -------------------------------------------------------------------------------- /hanabi_lib/observation_encoder.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // A helper object to translate Hanabi observations to agent inputs 16 | // (e.g. tensors). 17 | 18 | #ifndef __OBSERVATION_ENCODER_H__ 19 | #define __OBSERVATION_ENCODER_H__ 20 | 21 | #include 22 | 23 | #include "hanabi_observation.h" 24 | 25 | namespace hanabi_learning_env { 26 | 27 | class ObservationEncoder { 28 | public: 29 | enum Type { kCanonical = 0 }; 30 | virtual ~ObservationEncoder() = default; 31 | 32 | // Returns the shape (dimension sizes of the tensor). 33 | virtual std::vector Shape() const = 0; 34 | 35 | // All of the canonical observation encodings are vectors of bits. We can 36 | // change this if we want something more general (e.g. floats or doubles). 37 | virtual std::vector Encode(const HanabiObservation& obs) const = 0; 38 | 39 | // Return the type of this encoder. 40 | virtual Type type() const = 0; 41 | }; 42 | 43 | } // namespace hanabi_learning_env 44 | 45 | #endif 46 | -------------------------------------------------------------------------------- /hanabi_lib/canonical_encoders.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // The standard Open Hanabi observation encoders. These encoders translate 16 | // HanabiObservations to input tensors that an agent can train on. 17 | 18 | #ifndef __CANONICAL_ENCODERS_H__ 19 | #define __CANONICAL_ENCODERS_H__ 20 | 21 | #include 22 | 23 | #include "hanabi_game.h" 24 | #include "hanabi_observation.h" 25 | #include "observation_encoder.h" 26 | 27 | namespace hanabi_learning_env { 28 | 29 | // This is the canonical observation encoding. 30 | class CanonicalObservationEncoder : public ObservationEncoder { 31 | public: 32 | explicit CanonicalObservationEncoder(const HanabiGame* parent_game) 33 | : parent_game_(parent_game) {} 34 | 35 | std::vector Shape() const override; 36 | std::vector Encode(const HanabiObservation& obs) const override; 37 | 38 | ObservationEncoder::Type type() const override { 39 | return ObservationEncoder::Type::kCanonical; 40 | } 41 | 42 | private: 43 | const HanabiGame* parent_game_ = nullptr; 44 | }; 45 | 46 | } // namespace hanabi_learning_env 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /agents/rainbow/third_party/dopamine/iteration_statistics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A class for storing iteration-specific metrics. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | class IterationStatistics(object): 24 | """A class for storing iteration-specific metrics. 25 | 26 | The internal format is as follows: we maintain a mapping from keys to lists. 27 | Each list contains all the values corresponding to the given key. 28 | 29 | For example, self.data_lists['train_episode_returns'] might contain the 30 | per-episode returns achieved during this iteration. 31 | 32 | Attributes: 33 | data_lists: dict mapping each metric_name (str) to a list of said metric 34 | across episodes. 35 | """ 36 | 37 | def __init__(self): 38 | self.data_lists = {} 39 | 40 | def append(self, data_pairs): 41 | """Add the given values to their corresponding key-indexed lists. 42 | 43 | Args: 44 | data_pairs: A dictionary of key-value pairs to be recorded. 45 | """ 46 | for key, value in data_pairs.items(): 47 | if key not in self.data_lists: 48 | self.data_lists[key] = [] 49 | self.data_lists[key].append(value) 50 | -------------------------------------------------------------------------------- /agents/rainbow/README.md: -------------------------------------------------------------------------------- 1 | # Rainbow agent for the Hanabi Learning Environment 2 | 3 | ## Instructions 4 | 5 | The Rainbow agent is derived from the 6 | [Dopamine framework](https://github.com/google/dopamine) which is based on 7 | Tensorflow. We recommend you consult the 8 | [Tensorflow documentation](https://www.tensorflow.org/install) 9 | for additional details. 10 | 11 | To run the agent, some dependencies need to be pre-installed. If you don't have 12 | access to a GPU, then replace `tensorflow-gpu` with `tensorflow` in the line 13 | below 14 | (see [Tensorflow instructions](https://www.tensorflow.org/install/install_linux) 15 | for details). 16 | 17 | ``` 18 | pip install absl-py gin-config tensorflow-gpu cffi 19 | ``` 20 | 21 | If you would prefer to not use the GPU, you may install tensorflow instead 22 | of tensorflow-gpu and set `RainbowAgent.tf_device = '/cpu:*'` in 23 | `configs/hanabi_rainbow.gin`. 24 | 25 | The entry point to run a Rainbow agent on the Hanabi environment is `train.py`. 26 | Assuming you are running from the agent directory `agents/rainbow`, 27 | 28 | ``` 29 | PYTHONPATH=${PYTHONPATH}:../.. 30 | python -um train \ 31 | --base_dir=/tmp/hanabi_rainbow \ 32 | --gin_files='configs/hanabi_rainbow.gin' 33 | ``` 34 | 35 | The `PYTHONPATH` fix exposes `rl_env.py`, the main entry point to the Hanabi 36 | Learning Environment. The `--base_dir` argument must be provided. 37 | 38 | To get finer-grained information about the training process, you can adjust the 39 | experiment parameters in `configs/hanabi_rainbow.gin` in particular by reducing 40 | `Runner.training_steps` and `Runner.evaluation_steps`, which together determine 41 | the total number of steps needed to complete an iteration. This is useful if you 42 | want to inspect log files or checkpoints, which are generated at the end of each 43 | iteration. 44 | 45 | More generally, most parameters are easily configured using the 46 | [gin configuration framework](https://github.com/google/gin-config). 47 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_move.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_move.h" 16 | 17 | #include "util.h" 18 | 19 | namespace hanabi_learning_env { 20 | 21 | bool HanabiMove::operator==(const HanabiMove& other_move) const { 22 | if (MoveType() != other_move.MoveType()) { 23 | return false; 24 | } 25 | switch (MoveType()) { 26 | case kPlay: 27 | case kDiscard: 28 | return CardIndex() == other_move.CardIndex(); 29 | case kRevealColor: 30 | return TargetOffset() == other_move.TargetOffset() && 31 | Color() == other_move.Color(); 32 | case kRevealRank: 33 | return TargetOffset() == other_move.TargetOffset() && 34 | Rank() == other_move.Rank(); 35 | case kDeal: 36 | return Color() == other_move.Color() && Rank() == other_move.Rank(); 37 | default: 38 | return true; 39 | } 40 | } 41 | 42 | std::string HanabiMove::ToString() const { 43 | switch (MoveType()) { 44 | case kPlay: 45 | return "(Play " + std::to_string(CardIndex()) + ")"; 46 | case kDiscard: 47 | return "(Discard " + std::to_string(CardIndex()) + ")"; 48 | case kRevealColor: 49 | return "(Reveal player +" + std::to_string(TargetOffset()) + " color " + 50 | ColorIndexToChar(Color()) + ")"; 51 | case kRevealRank: 52 | return "(Reveal player +" + std::to_string(TargetOffset()) + " rank " + 53 | RankIndexToChar(Rank()) + ")"; 54 | case kDeal: 55 | if (color_ >= 0) { 56 | return std::string("(Deal ") + ColorIndexToChar(Color()) + 57 | RankIndexToChar(Rank()) + ")"; 58 | } else { 59 | return std::string("(Deal XX)"); 60 | } 61 | default: 62 | return "(INVALID)"; 63 | } 64 | } 65 | 66 | } // namespace hanabi_learning_env 67 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_history_item.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_history_item.h" 16 | 17 | #include 18 | 19 | #include "util.h" 20 | 21 | namespace hanabi_learning_env { 22 | 23 | std::string HanabiHistoryItem::ToString() const { 24 | std::string str = "<" + move.ToString(); 25 | if (player >= 0) { 26 | str += " by player " + std::to_string(player); 27 | } 28 | if (scored) { 29 | str += " scored"; 30 | } 31 | if (information_token) { 32 | str += " info_token"; 33 | } 34 | if (color >= 0) { 35 | assert(rank >= 0); 36 | str += " "; 37 | str += ColorIndexToChar(color); 38 | str += RankIndexToChar(rank); 39 | } 40 | if (reveal_bitmask) { 41 | str += " reveal "; 42 | bool first = true; 43 | for (int i = 0; i < 8; ++i) { // 8 bits in reveal_bitmask 44 | if (reveal_bitmask & (1 << i)) { 45 | if (first) { 46 | first = false; 47 | } else { 48 | str += ","; 49 | } 50 | str += std::to_string(i); 51 | } 52 | } 53 | } 54 | str += ">"; 55 | return str; 56 | } 57 | 58 | void ChangeToObserverRelative(int observer_pid, int player_count, 59 | HanabiHistoryItem* item) { 60 | if (item->move.MoveType() == HanabiMove::kDeal) { 61 | assert(item->player < 0 && item->deal_to_player >= 0); 62 | item->deal_to_player = 63 | (item->deal_to_player - observer_pid + player_count) % player_count; 64 | if (item->deal_to_player == 0) { 65 | // Hide cards dealt to observer. 66 | item->move = HanabiMove(HanabiMove::kDeal, -1, -1, -1, -1); 67 | } 68 | } else { 69 | assert(item->player >= 0); 70 | item->player = (item->player - observer_pid + player_count) % player_count; 71 | } 72 | } 73 | 74 | } // namespace hanabi_learning_env 75 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_move.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_MOVE_H__ 16 | #define __HANABI_MOVE_H__ 17 | 18 | #include 19 | #include 20 | 21 | namespace hanabi_learning_env { 22 | 23 | // 5 types of moves: 24 | // "Play" card_index of card in player hand 25 | // "Discard" card_index of card in player hand 26 | // "RevealColor" target_offset color hints to player all cards of color 27 | // "RevealRank" target_offset rank hints to player all cards of given rank 28 | // NOTE: RevealXYZ target_offset field is an offset from the acting player 29 | // "Deal" color rank deal card with color and rank 30 | // "Invalid" move is not valid 31 | class HanabiMove { 32 | // HanabiMove is small, and intended to be passed by value. 33 | public: 34 | enum Type { kInvalid, kPlay, kDiscard, kRevealColor, kRevealRank, kDeal }; 35 | 36 | HanabiMove(Type move_type, int8_t card_index, int8_t target_offset, 37 | int8_t color, int8_t rank) 38 | : move_type_(move_type), 39 | card_index_(card_index), 40 | target_offset_(target_offset), 41 | color_(color), 42 | rank_(rank) {} 43 | // Tests whether two moves are functionally equivalent. 44 | bool operator==(const HanabiMove& other_move) const; 45 | std::string ToString() const; 46 | 47 | Type MoveType() const { return move_type_; } 48 | bool IsValid() const { return move_type_ != kInvalid; } 49 | int8_t CardIndex() const { return card_index_; } 50 | int8_t TargetOffset() const { return target_offset_; } 51 | int8_t Color() const { return color_; } 52 | int8_t Rank() const { return rank_; } 53 | 54 | private: 55 | Type move_type_ = kInvalid; 56 | int8_t card_index_ = -1; 57 | int8_t target_offset_ = -1; 58 | int8_t color_ = -1; 59 | int8_t rank_ = -1; 60 | }; 61 | 62 | } // namespace hanabi_learning_env 63 | 64 | #endif 65 | -------------------------------------------------------------------------------- /hanabi_lib/util.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __UTIL_H__ 16 | #define __UTIL_H__ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | namespace hanabi_learning_env { 24 | 25 | constexpr int kMaxNumColors = 5; 26 | constexpr int kMaxNumRanks = 5; 27 | 28 | // Returns a character representation of an integer color/rank index. 29 | char ColorIndexToChar(int color); 30 | char RankIndexToChar(int rank); 31 | 32 | // Returns string associated with key in params, parsed as template type. 33 | // If key is not in params, returns the provided default value. 34 | template 35 | T ParameterValue(const std::unordered_map& params, 36 | const std::string& key, T default_value); 37 | 38 | template <> 39 | int ParameterValue(const std::unordered_map& params, 40 | const std::string& key, int default_value); 41 | template <> 42 | double ParameterValue( 43 | const std::unordered_map& params, 44 | const std::string& key, double default_value); 45 | template <> 46 | std::string ParameterValue( 47 | const std::unordered_map& params, 48 | const std::string& key, std::string default_value); 49 | template <> 50 | bool ParameterValue(const std::unordered_map& params, 51 | const std::string& key, bool default_value); 52 | 53 | #if defined(NDEBUG) 54 | #define REQUIRE(expr) \ 55 | (expr ? (void)0 \ 56 | : (fprintf(stderr, "Input requirements failed at %s:%d in %s: %s\n", \ 57 | __FILE__, __LINE__, __func__, #expr), \ 58 | std::abort())) 59 | #else 60 | #define REQUIRE(expr) assert(expr) 61 | #endif 62 | 63 | } // namespace hanabi_learning_env 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_history_item.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_HISTORY_ITEM_H__ 16 | #define __HANABI_HISTORY_ITEM_H__ 17 | 18 | #include 19 | #include 20 | 21 | #include "hanabi_move.h" 22 | 23 | namespace hanabi_learning_env { 24 | 25 | // A move that has been made within a Hanabi game, along with the side-effects 26 | // of making that move. 27 | struct HanabiHistoryItem { 28 | explicit HanabiHistoryItem(HanabiMove move_made) : move(move_made) {} 29 | HanabiHistoryItem(const HanabiHistoryItem& past_move) = default; 30 | std::string ToString() const; 31 | 32 | // Move that was made. 33 | HanabiMove move; 34 | // Index of player who made the move. 35 | int8_t player = -1; 36 | // Indicator of whether a Play move was successful. 37 | bool scored = false; 38 | // Indicator of whether a Play/Discard move added an information token 39 | bool information_token = false; 40 | // Color of card that was played or discarded. Valid if color_ >= 0. 41 | int8_t color = -1; 42 | // Rank of card that was played or discarded. Valid if rank_ >= 0. 43 | int8_t rank = -1; 44 | // Bitmask indicating whether a card was targeted by a RevealX move. 45 | // Bit_i=1 if color/rank of card_i matches X in a RevealX move. 46 | // For example, if cards 0 and 3 had rank 2, a RevealRank 2 move 47 | // would result in a reveal_bitmask of 9 (2^0+2^3). 48 | uint8_t reveal_bitmask = 0; 49 | // Bitmask indicating whether a card was newly revealed by a RevealX move. 50 | // Bit_i=1 if color/rank of card_i was not known before RevealX move. 51 | // For example, if cards 1, 2, and 4 had color 'R', and the color of 52 | // card 1 had previously been revealed to be 'R', a RevealRank 'R' move 53 | // would result in a newly_revealed_bitmask of 20 (2^2+2^4). 54 | uint8_t newly_revealed_bitmask = 0; 55 | // Player that received a card from a Deal move. 56 | int8_t deal_to_player = -1; 57 | }; 58 | 59 | } // namespace hanabi_learning_env 60 | 61 | #endif 62 | -------------------------------------------------------------------------------- /hanabi_lib/util.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "util.h" 16 | 17 | #include 18 | 19 | namespace hanabi_learning_env { 20 | 21 | char ColorIndexToChar(int color) { 22 | if (color >= 0 && color <= kMaxNumColors) { 23 | return "RYGWB"[color]; 24 | } else { 25 | return 'X'; 26 | } 27 | } 28 | 29 | char RankIndexToChar(int rank) { 30 | if (rank >= 0 && rank <= kMaxNumRanks) { 31 | return "12345"[rank]; 32 | } else { 33 | return 'X'; 34 | } 35 | } 36 | 37 | template <> 38 | int ParameterValue( 39 | const std::unordered_map& params, 40 | const std::string& key, int default_value) { 41 | auto iter = params.find(key); 42 | if (iter == params.end()) { 43 | return default_value; 44 | } 45 | 46 | return std::stoi(iter->second); 47 | } 48 | 49 | template <> 50 | std::string ParameterValue( 51 | const std::unordered_map& params, 52 | const std::string& key, std::string default_value) { 53 | auto iter = params.find(key); 54 | if (iter == params.end()) { 55 | return default_value; 56 | } 57 | 58 | return iter->second; 59 | } 60 | 61 | template <> 62 | double ParameterValue( 63 | const std::unordered_map& params, 64 | const std::string& key, double default_value) { 65 | auto iter = params.find(key); 66 | if (iter == params.end()) { 67 | return default_value; 68 | } 69 | 70 | return std::stod(iter->second); 71 | } 72 | 73 | template <> 74 | bool ParameterValue( 75 | const std::unordered_map& params, 76 | const std::string& key, bool default_value) { 77 | auto iter = params.find(key); 78 | if (iter == params.end()) { 79 | return default_value; 80 | } 81 | 82 | return (iter->second == "1" || iter->second == "true" || 83 | iter->second == "True" 84 | ? true 85 | : false); 86 | } 87 | 88 | } // namespace hanabi_learning_env 89 | -------------------------------------------------------------------------------- /agents/simple_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Random Agent.""" 15 | 16 | from rl_env import Agent 17 | 18 | 19 | class SimpleAgent(Agent): 20 | """Agent that applies a simple heuristic.""" 21 | 22 | def __init__(self, config, *args, **kwargs): 23 | """Initialize the agent.""" 24 | self.config = config 25 | # Extract max info tokens or set default to 8. 26 | self.max_information_tokens = config.get('information_tokens', 8) 27 | 28 | @staticmethod 29 | def playable_card(card, fireworks): 30 | """A card is playable if it can be placed on the fireworks pile.""" 31 | return card['rank'] == fireworks[card['color']] 32 | 33 | def act(self, observation): 34 | """Act based on an observation.""" 35 | if observation['current_player_offset'] != 0: 36 | return None 37 | 38 | # Check if there are any pending hints and play the card corresponding to 39 | # the hint. 40 | for card_index, hint in enumerate(observation['card_knowledge'][0]): 41 | if hint['color'] is not None or hint['rank'] is not None: 42 | return {'action_type': 'PLAY', 'card_index': card_index} 43 | 44 | # Check if it's possible to hint a card to your colleagues. 45 | fireworks = observation['fireworks'] 46 | if observation['information_tokens'] > 0: 47 | # Check if there are any playable cards in the hands of the opponents. 48 | for player_offset in range(1, observation['num_players']): 49 | player_hand = observation['observed_hands'][player_offset] 50 | player_hints = observation['card_knowledge'][player_offset] 51 | # Check if the card in the hand of the opponent is playable. 52 | for card, hint in zip(player_hand, player_hints): 53 | if SimpleAgent.playable_card(card, 54 | fireworks) and hint['color'] is None: 55 | return { 56 | 'action_type': 'REVEAL_COLOR', 57 | 'color': card['color'], 58 | 'target_offset': player_offset 59 | } 60 | 61 | # If no card is hintable then discard or play. 62 | if observation['information_tokens'] < self.max_information_tokens: 63 | return {'action_type': 'DISCARD', 'card_index': 0} 64 | else: 65 | return {'action_type': 'PLAY', 'card_index': 0} 66 | -------------------------------------------------------------------------------- /rl_env_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A simple episode runner using the RL environment.""" 15 | 16 | from __future__ import print_function 17 | 18 | import sys 19 | import getopt 20 | import rl_env 21 | from agents.random_agent import RandomAgent 22 | from agents.simple_agent import SimpleAgent 23 | 24 | AGENT_CLASSES = {'SimpleAgent': SimpleAgent, 'RandomAgent': RandomAgent} 25 | 26 | 27 | class Runner(object): 28 | """Runner class.""" 29 | 30 | def __init__(self, flags): 31 | """Initialize runner.""" 32 | self.flags = flags 33 | self.agent_config = {'players': flags['players']} 34 | self.environment = rl_env.make('Hanabi-Full', num_players=flags['players']) 35 | self.agent_class = AGENT_CLASSES[flags['agent_class']] 36 | 37 | def run(self): 38 | """Run episodes.""" 39 | rewards = [] 40 | for episode in range(flags['num_episodes']): 41 | observations = self.environment.reset() 42 | agents = [self.agent_class(self.agent_config) 43 | for _ in range(self.flags['players'])] 44 | done = False 45 | episode_reward = 0 46 | while not done: 47 | for agent_id, agent in enumerate(agents): 48 | observation = observations['player_observations'][agent_id] 49 | action = agent.act(observation) 50 | if observation['current_player'] == agent_id: 51 | assert action is not None 52 | current_player_action = action 53 | else: 54 | assert action is None 55 | # Make an environment step. 56 | print('Agent: {} action: {}'.format(observation['current_player'], 57 | current_player_action)) 58 | observations, reward, done, unused_info = self.environment.step( 59 | current_player_action) 60 | episode_reward += reward 61 | rewards.append(episode_reward) 62 | print('Running episode: %d' % episode) 63 | print('Max Reward: %.3f' % max(rewards)) 64 | return rewards 65 | 66 | if __name__ == "__main__": 67 | flags = {'players': 2, 'num_episodes': 1, 'agent_class': 'SimpleAgent'} 68 | options, arguments = getopt.getopt(sys.argv[1:], '', 69 | ['players=', 70 | 'num_episodes=', 71 | 'agent_class=']) 72 | if arguments: 73 | sys.exit('usage: rl_env_example.py [options]\n' 74 | '--players number of players in the game.\n' 75 | '--num_episodes number of game episodes to run.\n' 76 | '--agent_class {}'.format(' or '.join(AGENT_CLASSES.keys()))) 77 | for flag, value in options: 78 | flag = flag[2:] # Strip leading --. 79 | flags[flag] = type(flags[flag])(value) 80 | runner = Runner(flags) 81 | runner.run() 82 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_observation.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_OBSERVATION_H__ 16 | #define __HANABI_OBSERVATION_H__ 17 | 18 | #include 19 | #include 20 | 21 | #include "hanabi_card.h" 22 | #include "hanabi_game.h" 23 | #include "hanabi_hand.h" 24 | #include "hanabi_history_item.h" 25 | #include "hanabi_move.h" 26 | #include "hanabi_state.h" 27 | 28 | namespace hanabi_learning_env { 29 | 30 | // Agent observation of a HanabiState 31 | class HanabiObservation { 32 | public: 33 | HanabiObservation(const HanabiState& state, int observing_player); 34 | 35 | std::string ToString() const; 36 | 37 | // offset of current player from observing player. 38 | int CurPlayerOffset() const { return cur_player_offset_; } 39 | // observed hands are in relative order, with index 1 being the 40 | // first player clock-wise from observing_player. hands[0][] has 41 | // invalid cards as players don't see their own cards. 42 | const std::vector& Hands() const { return hands_; } 43 | // The element at the back is the most recent discard. 44 | const std::vector& DiscardPile() const { return discard_pile_; } 45 | const std::vector& Fireworks() const { return fireworks_; } 46 | int DeckSize() const { return deck_size_; } // number of remaining cards 47 | const HanabiGame* ParentGame() const { return parent_game_; } 48 | // Moves made since observing_player's last action, most recent to oldest 49 | // (that is, last_moves[0] is the most recent move.) 50 | // Move targets are relative to observing_player not acting_player. 51 | // Note that the deal moves are included in this vector. 52 | const std::vector& LastMoves() const { 53 | return last_moves_; 54 | } 55 | int InformationTokens() const { return information_tokens_; } 56 | int LifeTokens() const { return life_tokens_; } 57 | const std::vector& LegalMoves() const { return legal_moves_; } 58 | 59 | // returns true if card with color and rank can be played on fireworks pile 60 | bool CardPlayableOnFireworks(int color, int rank) const; 61 | bool CardPlayableOnFireworks(HanabiCard card) const { 62 | return CardPlayableOnFireworks(card.Color(), card.Rank()); 63 | } 64 | 65 | private: 66 | int cur_player_offset_; // offset of current_player from observing_player 67 | std::vector hands_; // observing player is element 0 68 | std::vector discard_pile_; // back is most recent discard 69 | std::vector fireworks_; 70 | int deck_size_; 71 | std::vector last_moves_; 72 | int information_tokens_; 73 | int life_tokens_; 74 | std::vector legal_moves_; // list of legal moves 75 | const HanabiGame* parent_game_ = nullptr; 76 | }; 77 | 78 | } // namespace hanabi_learning_env 79 | 80 | #endif 81 | -------------------------------------------------------------------------------- /agents/rainbow/third_party/dopamine/logger.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A lightweight logging mechanism for dopamine agents.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import pickle 23 | import tensorflow as tf 24 | 25 | 26 | CHECKPOINT_DURATION = 4 27 | 28 | 29 | class Logger(object): 30 | """Class for maintaining a dictionary of data to log.""" 31 | 32 | def __init__(self, logging_dir): 33 | """Initializes Logger. 34 | 35 | Args: 36 | logging_dir: str, Directory to which logs are written. 37 | """ 38 | # Dict used by logger to store data. 39 | self.data = {} 40 | self._logging_enabled = True 41 | 42 | if not logging_dir: 43 | tf.logging.info('Logging directory not specified, will not log.') 44 | self._logging_enabled = False 45 | return 46 | # Try to create logging directory. 47 | try: 48 | tf.gfile.MakeDirs(logging_dir) 49 | except tf.errors.PermissionDeniedError: 50 | # If it already exists, ignore exception. 51 | pass 52 | if not tf.gfile.Exists(logging_dir): 53 | tf.logging.warning( 54 | 'Could not create directory %s, logging will be disabled.', 55 | logging_dir) 56 | self._logging_enabled = False 57 | return 58 | self._logging_dir = logging_dir 59 | 60 | def __setitem__(self, key, value): 61 | """This method will set an entry at key with value in the dictionary. 62 | 63 | It will effectively overwrite any previous data at the same key. 64 | 65 | Args: 66 | key: str, indicating key where to write the entry. 67 | value: A python object to store. 68 | """ 69 | if self._logging_enabled: 70 | self.data[key] = value 71 | 72 | def _generate_filename(self, filename_prefix, iteration_number): 73 | filename = '{}_{}'.format(filename_prefix, iteration_number) 74 | return os.path.join(self._logging_dir, filename) 75 | 76 | def log_to_file(self, filename_prefix, iteration_number): 77 | """Save the pickled dictionary to a file. 78 | 79 | Args: 80 | filename_prefix: str, name of the file to use (without iteration 81 | number). 82 | iteration_number: int, the iteration number, appended to the end of 83 | filename_prefix. 84 | """ 85 | if not self._logging_enabled: 86 | tf.logging.warning('Logging is disabled.') 87 | return 88 | log_file = self._generate_filename(filename_prefix, iteration_number) 89 | with tf.gfile.GFile(log_file, 'w') as fout: 90 | pickle.dump(self.data, fout, protocol=pickle.HIGHEST_PROTOCOL) 91 | # After writing a checkpoint file, we garbage collect the log file 92 | # that is CHECKPOINT_DURATION versions old. 93 | stale_iteration_number = iteration_number - CHECKPOINT_DURATION 94 | if stale_iteration_number >= 0: 95 | stale_file = self._generate_filename(filename_prefix, 96 | stale_iteration_number) 97 | try: 98 | tf.gfile.Remove(stale_file) 99 | except tf.errors.NotFoundError: 100 | # Ignore if file not found. 101 | pass 102 | 103 | def is_logging_enabled(self): 104 | """Return if logging is enabled.""" 105 | return self._logging_enabled 106 | -------------------------------------------------------------------------------- /agents/rainbow/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors and Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # 17 | # 18 | # This file is a fork of the original Dopamine code incorporating changes for 19 | # the multiplayer setting and the Hanabi Learning Environment. 20 | # 21 | """The entry point for running a Rainbow agent on Hanabi.""" 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from absl import app 28 | from absl import flags 29 | 30 | from third_party.dopamine import logger 31 | 32 | import run_experiment 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | flags.DEFINE_multi_string( 37 | 'gin_files', [], 38 | 'List of paths to gin configuration files (e.g.' 39 | '"configs/hanabi_rainbow.gin").') 40 | flags.DEFINE_multi_string( 41 | 'gin_bindings', [], 42 | 'Gin bindings to override the values set in the config files ' 43 | '(e.g. "DQNAgent.epsilon_train=0.1").') 44 | 45 | flags.DEFINE_string('base_dir', None, 46 | 'Base directory to host all required sub-directories.') 47 | 48 | flags.DEFINE_string('checkpoint_dir', '', 49 | 'Directory where checkpoint files should be saved. If ' 50 | 'empty, no checkpoints will be saved.') 51 | flags.DEFINE_string('checkpoint_file_prefix', 'ckpt', 52 | 'Prefix to use for the checkpoint files.') 53 | flags.DEFINE_string('logging_dir', '', 54 | 'Directory where experiment data will be saved. If empty ' 55 | 'no checkpoints will be saved.') 56 | flags.DEFINE_string('logging_file_prefix', 'log', 57 | 'Prefix to use for the log files.') 58 | 59 | 60 | def launch_experiment(): 61 | """Launches the experiment. 62 | 63 | Specifically: 64 | - Load the gin configs and bindings. 65 | - Initialize the Logger object. 66 | - Initialize the environment. 67 | - Initialize the observation stacker. 68 | - Initialize the agent. 69 | - Reload from the latest checkpoint, if available, and initialize the 70 | Checkpointer object. 71 | - Run the experiment. 72 | """ 73 | if FLAGS.base_dir == None: 74 | raise ValueError('--base_dir is None: please provide a path for ' 75 | 'logs and checkpoints.') 76 | 77 | run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) 78 | experiment_logger = logger.Logger('{}/logs'.format(FLAGS.base_dir)) 79 | 80 | environment = run_experiment.create_environment() 81 | obs_stacker = run_experiment.create_obs_stacker(environment) 82 | agent = run_experiment.create_agent(environment, obs_stacker) 83 | 84 | checkpoint_dir = '{}/checkpoints'.format(FLAGS.base_dir) 85 | start_iteration, experiment_checkpointer = ( 86 | run_experiment.initialize_checkpointing(agent, 87 | experiment_logger, 88 | checkpoint_dir, 89 | FLAGS.checkpoint_file_prefix)) 90 | 91 | run_experiment.run_experiment(agent, environment, start_iteration, 92 | obs_stacker, 93 | experiment_logger, experiment_checkpointer, 94 | checkpoint_dir, 95 | logging_file_prefix=FLAGS.logging_file_prefix) 96 | 97 | 98 | def main(unused_argv): 99 | """This main function acts as a wrapper around a gin-configurable experiment. 100 | 101 | Args: 102 | unused_argv: Arguments (unused). 103 | """ 104 | launch_experiment() 105 | 106 | if __name__ == '__main__': 107 | app.run(main) 108 | -------------------------------------------------------------------------------- /game_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example code demonstrating the Python Hanabi interface.""" 16 | 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import pyhanabi 21 | 22 | 23 | def run_game(game_parameters): 24 | """Play a game, selecting random actions.""" 25 | 26 | def print_state(state): 27 | """Print some basic information about the state.""" 28 | print("") 29 | print("Current player: {}".format(state.cur_player())) 30 | print(state) 31 | 32 | # Example of more queries to provide more about this state. For 33 | # example, bots could use these methods to to get information 34 | # about the state in order to act accordingly. 35 | print("### Information about the state retrieved separately ###") 36 | print("### Information tokens: {}".format(state.information_tokens())) 37 | print("### Life tokens: {}".format(state.life_tokens())) 38 | print("### Fireworks: {}".format(state.fireworks())) 39 | print("### Deck size: {}".format(state.deck_size())) 40 | print("### Discard pile: {}".format(str(state.discard_pile()))) 41 | print("### Player hands: {}".format(str(state.player_hands()))) 42 | print("") 43 | 44 | def print_observation(observation): 45 | """Print some basic information about an agent observation.""" 46 | print("--- Observation ---") 47 | print(observation) 48 | 49 | print("### Information about the observation retrieved separately ###") 50 | print("### Current player, relative to self: {}".format( 51 | observation.cur_player_offset())) 52 | print("### Observed hands: {}".format(observation.observed_hands())) 53 | print("### Card knowledge: {}".format(observation.card_knowledge())) 54 | print("### Discard pile: {}".format(observation.discard_pile())) 55 | print("### Fireworks: {}".format(observation.fireworks())) 56 | print("### Deck size: {}".format(observation.deck_size())) 57 | move_string = "### Last moves:" 58 | for move_tuple in observation.last_moves(): 59 | move_string += " {}".format(move_tuple) 60 | print(move_string) 61 | print("### Information tokens: {}".format(observation.information_tokens())) 62 | print("### Life tokens: {}".format(observation.life_tokens())) 63 | print("### Legal moves: {}".format(observation.legal_moves())) 64 | print("--- EndObservation ---") 65 | 66 | def print_encoded_observations(encoder, state, num_players): 67 | print("--- EncodedObservations ---") 68 | print("Observation encoding shape: {}".format(encoder.shape())) 69 | print("Current actual player: {}".format(state.cur_player())) 70 | for i in range(num_players): 71 | print("Encoded observation for player {}: {}".format( 72 | i, encoder.encode(state.observation(i)))) 73 | print("--- EndEncodedObservations ---") 74 | 75 | game = pyhanabi.HanabiGame(game_parameters) 76 | print(game.parameter_string(), end="") 77 | obs_encoder = pyhanabi.ObservationEncoder( 78 | game, enc_type=pyhanabi.ObservationEncoderType.CANONICAL) 79 | 80 | state = game.new_initial_state() 81 | while not state.is_terminal(): 82 | if state.cur_player() == pyhanabi.CHANCE_PLAYER_ID: 83 | state.deal_random_card() 84 | continue 85 | 86 | print_state(state) 87 | 88 | observation = state.observation(state.cur_player()) 89 | print_observation(observation) 90 | print_encoded_observations(obs_encoder, state, game.num_players()) 91 | 92 | legal_moves = state.legal_moves() 93 | print("") 94 | print("Number of legal moves: {}".format(len(legal_moves))) 95 | 96 | move = np.random.choice(legal_moves) 97 | print("Chose random legal move: {}".format(move)) 98 | 99 | state.apply_move(move) 100 | 101 | print("") 102 | print("Game done. Terminal state:") 103 | print("") 104 | print(state) 105 | print("") 106 | print("score: {}".format(state.score())) 107 | 108 | 109 | if __name__ == "__main__": 110 | # Check that the cdef and library were loaded from the standard paths. 111 | assert pyhanabi.cdef_loaded(), "cdef failed to load" 112 | assert pyhanabi.lib_loaded(), "lib failed to load" 113 | run_game({"players": 3, "random_start_player": True}) 114 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_hand.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_hand.h" 16 | 17 | #include 18 | #include 19 | 20 | #include "util.h" 21 | 22 | namespace hanabi_learning_env { 23 | 24 | HanabiHand::ValueKnowledge::ValueKnowledge(int value_range) 25 | : value_(-1), value_plausible_(std::max(value_range, 0), true) { 26 | assert(value_range > 0); 27 | } 28 | 29 | void HanabiHand::ValueKnowledge::ApplyIsValueHint(int value) { 30 | assert(value >= 0 && value < value_plausible_.size()); 31 | assert(value_ < 0 || value_ == value); 32 | assert(value_plausible_[value] == true); 33 | value_ = value; 34 | std::fill(value_plausible_.begin(), value_plausible_.end(), false); 35 | value_plausible_[value] = true; 36 | } 37 | 38 | void HanabiHand::ValueKnowledge::ApplyIsNotValueHint(int value) { 39 | assert(value >= 0 && value < value_plausible_.size()); 40 | assert(value_ < 0 || value_ != value); 41 | value_plausible_[value] = false; 42 | } 43 | 44 | HanabiHand::CardKnowledge::CardKnowledge(int num_colors, int num_ranks) 45 | : color_(num_colors), rank_(num_ranks) {} 46 | 47 | std::string HanabiHand::CardKnowledge::ToString() const { 48 | std::string result; 49 | result = result + (ColorHinted() ? ColorIndexToChar(Color()) : 'X') + 50 | (RankHinted() ? RankIndexToChar(Rank()) : 'X') + '|'; 51 | for (int c = 0; c < color_.Range(); ++c) { 52 | if (color_.IsPlausible(c)) { 53 | result += ColorIndexToChar(c); 54 | } 55 | } 56 | for (int r = 0; r < rank_.Range(); ++r) { 57 | if (rank_.IsPlausible(r)) { 58 | result += RankIndexToChar(r); 59 | } 60 | } 61 | return result; 62 | } 63 | 64 | HanabiHand::HanabiHand(const HanabiHand& hand, bool hide_cards, 65 | bool hide_knowledge) { 66 | if (hide_cards) { 67 | cards_.resize(hand.cards_.size(), HanabiCard()); 68 | } else { 69 | cards_ = hand.cards_; 70 | } 71 | if (hide_knowledge && !hand.cards_.empty()) { 72 | card_knowledge_.resize(hand.cards_.size(), 73 | CardKnowledge(hand.card_knowledge_[0].NumColors(), 74 | hand.card_knowledge_[0].NumRanks())); 75 | } else { 76 | card_knowledge_ = hand.card_knowledge_; 77 | } 78 | } 79 | 80 | void HanabiHand::AddCard(HanabiCard card, 81 | const CardKnowledge& initial_knowledge) { 82 | REQUIRE(card.IsValid()); 83 | cards_.push_back(card); 84 | card_knowledge_.push_back(initial_knowledge); 85 | } 86 | 87 | void HanabiHand::RemoveFromHand(int card_index, 88 | std::vector* discard_pile) { 89 | if (discard_pile != nullptr) { 90 | discard_pile->push_back(cards_[card_index]); 91 | } 92 | cards_.erase(cards_.begin() + card_index); 93 | card_knowledge_.erase(card_knowledge_.begin() + card_index); 94 | } 95 | 96 | uint8_t HanabiHand::RevealColor(const int color) { 97 | uint8_t mask = 0; 98 | assert(cards_.size() <= 8); // More than 8 cards is currently not supported. 99 | for (int i = 0; i < cards_.size(); ++i) { 100 | if (cards_[i].Color() == color) { 101 | if (!card_knowledge_[i].ColorHinted()) { 102 | mask |= static_cast(1) << i; 103 | } 104 | card_knowledge_[i].ApplyIsColorHint(color); 105 | } else { 106 | card_knowledge_[i].ApplyIsNotColorHint(color); 107 | } 108 | } 109 | return mask; 110 | } 111 | 112 | uint8_t HanabiHand::RevealRank(const int rank) { 113 | uint8_t mask = 0; 114 | assert(cards_.size() <= 8); // More than 8 cards is currently not supported. 115 | for (int i = 0; i < cards_.size(); ++i) { 116 | if (cards_[i].Rank() == rank) { 117 | if (!card_knowledge_[i].RankHinted()) { 118 | mask |= static_cast(1) << i; 119 | } 120 | card_knowledge_[i].ApplyIsRankHint(rank); 121 | } else { 122 | card_knowledge_[i].ApplyIsNotRankHint(rank); 123 | } 124 | } 125 | return mask; 126 | } 127 | 128 | std::string HanabiHand::ToString() const { 129 | std::string result; 130 | assert(cards_.size() == card_knowledge_.size()); 131 | for (int i = 0; i < cards_.size(); ++i) { 132 | result += 133 | cards_[i].ToString() + " || " + card_knowledge_[i].ToString() + '\n'; 134 | } 135 | return result; 136 | } 137 | 138 | } // namespace hanabi_learning_env 139 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_game.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_GAME_H__ 16 | #define __HANABI_GAME_H__ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "hanabi_card.h" 24 | #include "hanabi_move.h" 25 | 26 | namespace hanabi_learning_env { 27 | 28 | class HanabiGame { 29 | public: 30 | // An agent's observation of a state does include all state knowledge. 31 | // For example, observations never include an agent's own cards. 32 | // A kMinimal observation is similar to what a human sees, and does not 33 | // include any memory of past RevalColor/RevealRank hints. A CardKnowledge 34 | // observation includes per-card knowledge of past hints, as well as simple 35 | // inferred knowledge of the form "this card is not red, because it was 36 | // not revealed as red in a past move". A Seer observation 37 | // shows all cards, including the player's own cards, regardless of what 38 | // hints have been given. 39 | enum AgentObservationType { kMinimal = 0, kCardKnowledge = 1, kSeer = 2 }; 40 | 41 | explicit HanabiGame( 42 | const std::unordered_map& params); 43 | 44 | // Number of different player moves. 45 | int MaxMoves() const; 46 | // Get a HanabiMove by unique id. 47 | HanabiMove GetMove(int uid) const { return moves_[uid]; } 48 | // Get unique id for a move. Returns -1 for invalid move. 49 | int GetMoveUid(HanabiMove move) const; 50 | int GetMoveUid(HanabiMove::Type move_type, int card_index, int target_offset, 51 | int color, int rank) const; 52 | // Number of different chance outcomes. 53 | int MaxChanceOutcomes() const; 54 | // Get a chance-outcome HanabiMove by unique id. 55 | HanabiMove GetChanceOutcome(int uid) const { return chance_outcomes_[uid]; } 56 | // Get unique id for a chance-outcome move. Returns -1 for invalid move. 57 | int GetChanceOutcomeUid(HanabiMove move) const; 58 | // Randomly sample a random chance-outcome move from list of moves and 59 | // associated probability distribution. 60 | HanabiMove PickRandomChance( 61 | const std::pair, std::vector>& 62 | chance_outcomes) const; 63 | 64 | std::unordered_map Parameters() const; 65 | int MinPlayers() const { return 2; } 66 | int MaxPlayers() const { return 5; } 67 | int MinScore() const { return 0; } 68 | int MaxScore() const { return num_ranks_ * num_colors_; } 69 | std::string Name() const { return "Hanabi"; } 70 | 71 | int NumColors() const { return num_colors_; } 72 | int NumRanks() const { return num_ranks_; } 73 | int NumPlayers() const { return num_players_; } 74 | int HandSize() const { return hand_size_; } 75 | int MaxInformationTokens() const { return max_information_tokens_; } 76 | int MaxLifeTokens() const { return max_life_tokens_; } 77 | int CardsPerColor() const { return cards_per_color_; } 78 | int MaxDeckSize() const { return cards_per_color_ * num_colors_; } 79 | int NumberCardInstances(int color, int rank) const; 80 | int NumberCardInstances(HanabiCard card) const { 81 | return NumberCardInstances(card.Color(), card.Rank()); 82 | } 83 | AgentObservationType ObservationType() const { return observation_type_; } 84 | 85 | // Get the first player to act. Might be randomly generated at each call. 86 | int GetSampledStartPlayer() const; 87 | 88 | private: 89 | // Calculating max moves by move type. 90 | int MaxDiscardMoves() const { return hand_size_; } 91 | int MaxPlayMoves() const { return hand_size_; } 92 | int MaxRevealColorMoves() const { return (num_players_ - 1) * num_colors_; } 93 | int MaxRevealRankMoves() const { return (num_players_ - 1) * num_ranks_; } 94 | 95 | int HandSizeFromRules() const; 96 | HanabiMove ConstructMove(int uid) const; 97 | HanabiMove ConstructChanceOutcome(int uid) const; 98 | 99 | // Table of all possible moves in this game. 100 | std::vector moves_; 101 | // Table of all possible chance outcomes in this game. 102 | std::vector chance_outcomes_; 103 | std::unordered_map params_; 104 | int num_colors_ = -1; 105 | int num_ranks_ = -1; 106 | int num_players_ = -1; 107 | int hand_size_ = -1; 108 | int max_information_tokens_ = -1; 109 | int max_life_tokens_ = -1; 110 | int cards_per_color_ = -1; 111 | int seed_ = -1; 112 | bool random_start_player_ = false; 113 | AgentObservationType observation_type_ = kCardKnowledge; 114 | mutable std::mt19937 rng_; 115 | }; 116 | 117 | } // namespace hanabi_learning_env 118 | 119 | #endif 120 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_observation.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_observation.h" 16 | 17 | #include 18 | #include 19 | 20 | #include "util.h" 21 | 22 | namespace hanabi_learning_env { 23 | 24 | namespace { 25 | // Returns the offset of player ID pid relative to player ID observer_pid, 26 | // or pid for negative values. That is, offset such that for a non-negative 27 | // player id pid, we have (observer_pid + offset) % num_players == pid. 28 | int PlayerToOffset(int pid, int observer_pid, int num_players) { 29 | return pid >= 0 ? (pid - observer_pid + num_players) % num_players : pid; 30 | } 31 | 32 | // Switch members from absolute player indices to observer-relative offsets, 33 | // including player indices within the contained HanabiMove. 34 | void ChangeHistoryItemToObserverRelative(int observer_pid, int num_players, 35 | bool show_cards, 36 | HanabiHistoryItem* item) { 37 | if (item->move.MoveType() == HanabiMove::kDeal) { 38 | assert(item->player < 0 && item->deal_to_player >= 0); 39 | item->deal_to_player = 40 | (item->deal_to_player - observer_pid + num_players) % num_players; 41 | if (item->deal_to_player == 0 && !show_cards) { 42 | // Hide cards dealt to observer if they shouldn't be able to see them. 43 | item->move = HanabiMove(HanabiMove::kDeal, -1, -1, -1, -1); 44 | } 45 | } else { 46 | assert(item->player >= 0); 47 | item->player = (item->player - observer_pid + num_players) % num_players; 48 | } 49 | } 50 | } // namespace 51 | 52 | HanabiObservation::HanabiObservation(const HanabiState& state, 53 | int observing_player) 54 | : cur_player_offset_(PlayerToOffset(state.CurPlayer(), observing_player, 55 | state.ParentGame()->NumPlayers())), 56 | discard_pile_(state.DiscardPile()), 57 | fireworks_(state.Fireworks()), 58 | deck_size_(state.Deck().Size()), 59 | information_tokens_(state.InformationTokens()), 60 | life_tokens_(state.LifeTokens()), 61 | legal_moves_(state.LegalMoves(observing_player)), 62 | parent_game_(state.ParentGame()) { 63 | REQUIRE(observing_player >= 0 && 64 | observing_player < state.ParentGame()->NumPlayers()); 65 | hands_.reserve(state.Hands().size()); 66 | const bool hide_knowledge = 67 | state.ParentGame()->ObservationType() == HanabiGame::kMinimal; 68 | const bool show_cards = state.ParentGame()->ObservationType() == HanabiGame::kSeer; 69 | hands_.push_back( 70 | HanabiHand(state.Hands()[observing_player], !show_cards, hide_knowledge)); 71 | for (int offset = 1; offset < state.ParentGame()->NumPlayers(); ++offset) { 72 | hands_.push_back(HanabiHand(state.Hands()[(observing_player + offset) % 73 | state.ParentGame()->NumPlayers()], 74 | false, hide_knowledge)); 75 | } 76 | 77 | const auto& history = state.MoveHistory(); 78 | auto start = std::find_if(history.begin(), history.end(), 79 | [](const HanabiHistoryItem& item) { 80 | return item.player != kChancePlayerId; 81 | }); 82 | std::reverse_iterator rend(start); 83 | for (auto it = history.rbegin(); it != rend; ++it) { 84 | last_moves_.push_back(*it); 85 | ChangeHistoryItemToObserverRelative(observing_player, 86 | state.ParentGame()->NumPlayers(), 87 | show_cards, 88 | &last_moves_.back()); 89 | if (it->player == observing_player) { 90 | break; 91 | } 92 | } 93 | } 94 | 95 | std::string HanabiObservation::ToString() const { 96 | std::string result; 97 | result += "Life tokens: " + std::to_string(LifeTokens()) + "\n"; 98 | result += "Info tokens: " + std::to_string(InformationTokens()) + "\n"; 99 | result += "Fireworks: "; 100 | for (int i = 0; i < ParentGame()->NumColors(); ++i) { 101 | result += ColorIndexToChar(i); 102 | result += std::to_string(fireworks_[i]) + " "; 103 | } 104 | result += "\nHands:\n"; 105 | for (int i = 0; i < hands_.size(); ++i) { 106 | if (i > 0) { 107 | result += "-----\n"; 108 | } 109 | if (i == CurPlayerOffset()) { 110 | result += "Cur player\n"; 111 | } 112 | result += hands_[i].ToString(); 113 | } 114 | result += "Deck size: " + std::to_string(DeckSize()) + "\n"; 115 | result += "Discards:"; 116 | for (int i = 0; i < discard_pile_.size(); ++i) { 117 | result += " " + discard_pile_[i].ToString(); 118 | } 119 | return result; 120 | } 121 | 122 | bool HanabiObservation::CardPlayableOnFireworks(int color, int rank) const { 123 | if (color < 0 || color >= ParentGame()->NumColors()) { 124 | return false; 125 | } 126 | return rank == fireworks_[color]; 127 | } 128 | 129 | } // namespace hanabi_learning_env 130 | -------------------------------------------------------------------------------- /game_example.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "hanabi_game.h" 24 | #include "hanabi_state.h" 25 | 26 | struct GameResult { 27 | int score; 28 | int fireworks_played; // Number of successful Play moves. 29 | int num_steps; // Number of moves by a player. 30 | }; 31 | 32 | constexpr const char* kGameParamArgPrefix = "--config.hanabi."; 33 | 34 | GameResult SimulateGame(const hanabi_learning_env::HanabiGame& game, 35 | bool verbose, std::mt19937* rng) { 36 | hanabi_learning_env::HanabiState state(&game); 37 | GameResult result = {0, 0, 0}; 38 | while (!state.IsTerminal()) { 39 | if (state.CurPlayer() == hanabi_learning_env::kChancePlayerId) { 40 | // All of this could be replaced with state.ApplyRandomChance(). 41 | // Only done this way to demonstrate picking specific chance moves. 42 | auto chance_outcomes = state.ChanceOutcomes(); 43 | std::discrete_distribution dist( 44 | chance_outcomes.second.begin(), chance_outcomes.second.end()); 45 | auto move = chance_outcomes.first[dist(*rng)]; 46 | if (verbose) { 47 | std::cout << "Legal chance:"; 48 | for (int i = 0; i < chance_outcomes.first.size(); ++i) { 49 | std::cout << " <" << chance_outcomes.first[i].ToString() << ", " 50 | << chance_outcomes.second[i] << ">"; 51 | } 52 | std::cout << "\n"; 53 | std::cout << "Sampled move: " << move.ToString() << "\n\n"; 54 | } 55 | state.ApplyMove(move); 56 | continue; 57 | } 58 | 59 | auto legal_moves = state.LegalMoves(state.CurPlayer()); 60 | std::uniform_int_distribution dist( 61 | 0, legal_moves.size() - 1); 62 | auto move = legal_moves[dist(*rng)]; 63 | if (verbose) { 64 | std::cout << "Current player: " << state.CurPlayer() << "\n"; 65 | std::cout << state.ToString() << "\n\n"; 66 | std::cout << "Legal moves:"; 67 | for (int i = 0; i < legal_moves.size(); ++i) { 68 | std::cout << " " << legal_moves[i].ToString(); 69 | } 70 | std::cout << "\n"; 71 | std::cout << "Sampled move: " << move.ToString() << "\n\n"; 72 | } 73 | state.ApplyMove(move); 74 | ++result.num_steps; 75 | if (state.MoveHistory().back().scored) { 76 | ++result.fireworks_played; 77 | } 78 | } 79 | 80 | if (verbose) { 81 | std::cout << "Game done, terminal state:\n" << state.ToString() << "\n\n"; 82 | std::cout << "score = " << state.Score() << "\n\n"; 83 | } 84 | 85 | result.score = state.Score(); 86 | return result; 87 | } 88 | 89 | void SimulateGames( 90 | const std::unordered_map& game_params, 91 | int num_trials = 1, bool verbose = true) { 92 | std::mt19937 rng; 93 | rng.seed(std::random_device()()); 94 | 95 | hanabi_learning_env::HanabiGame game(game_params); 96 | auto params = game.Parameters(); 97 | std::cout << "Hanabi game created, with parameters:\n"; 98 | for (const auto& item : params) { 99 | std::cout << " " << item.first << "=" << item.second << "\n"; 100 | } 101 | 102 | std::vector results; 103 | results.reserve(num_trials); 104 | for (int trial = 0; trial < num_trials; ++trial) { 105 | results.push_back(SimulateGame(game, verbose, &rng)); 106 | } 107 | 108 | if (num_trials > 1) { 109 | GameResult avg_score = std::accumulate( 110 | results.begin(), results.end(), GameResult(), 111 | [](const GameResult& lhs, const GameResult& rhs) { 112 | GameResult result = {lhs.score + rhs.score, 113 | lhs.fireworks_played + rhs.fireworks_played, 114 | lhs.num_steps + rhs.num_steps}; 115 | return result; 116 | }); 117 | std::cout << "Average score: " 118 | << static_cast(avg_score.score) / results.size() 119 | << " average number of fireworks played: " 120 | << static_cast(avg_score.fireworks_played) / 121 | results.size() 122 | << " average num_steps: " 123 | << static_cast(avg_score.num_steps) / results.size() 124 | << "\n"; 125 | } 126 | } 127 | 128 | std::unordered_map ParseArguments(int argc, 129 | char** argv) { 130 | std::unordered_map game_params; 131 | const auto prefix_len = strlen(kGameParamArgPrefix); 132 | for (int i = 1; i < argc; ++i) { 133 | std::string param = argv[i]; 134 | if (param.compare(0, prefix_len, kGameParamArgPrefix) == 0 && 135 | param.size() > prefix_len) { 136 | std::string value; 137 | param = param.substr(prefix_len, std::string::npos); 138 | auto value_pos = param.find("="); 139 | if (value_pos != std::string::npos) { 140 | value = param.substr(value_pos + 1, std::string::npos); 141 | param = param.substr(0, value_pos); 142 | } 143 | game_params[param] = value; 144 | } 145 | } 146 | return game_params; 147 | } 148 | 149 | int main(int argc, char** argv) { 150 | auto game_params = ParseArguments(argc, argv); 151 | SimulateGames(game_params); 152 | return 0; 153 | } 154 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_hand.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_HAND_H__ 16 | #define __HANABI_HAND_H__ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "hanabi_card.h" 23 | 24 | namespace hanabi_learning_env { 25 | 26 | class HanabiHand { 27 | public: 28 | class ValueKnowledge { 29 | // Knowledge about an unknown integer variable in range 0 to value_range-1. 30 | // Records hints that either reveal the exact value (no longer unknown), 31 | // or reveal that the variable is not some particular value. 32 | // For example, ValueKnowledge(3) tracks a variable that can be 0, 1, or 2. 33 | // Initially, ValueHinted()=false, value()=-1, and ValueCouldBe(v)=true 34 | // for v=0, 1, and 2. 35 | // After recording that the value is not 1, we have 36 | // ValueHinted()=false, value()=-1, and ValueCouldBe(1)=false. 37 | // After recording that the value is 0, we have 38 | // ValueHinted()=true, value()=0, and ValueCouldBe(v)=false for v=1, and 2. 39 | public: 40 | explicit ValueKnowledge(int value_range); 41 | int Range() const { return value_plausible_.size(); } 42 | // Returns true if and only if the exact value was revealed. 43 | // Does not perform inference to get a known value from not-value hints. 44 | bool ValueHinted() const { return value_ >= 0; } 45 | int Value() const { return value_; } // -1 if value was not hinted. 46 | // Returns true if we have no hint saying variable is not the given value. 47 | bool IsPlausible(int value) const { return value_plausible_[value]; } 48 | // Record a hint that gives the value of the variable. 49 | void ApplyIsValueHint(int value); 50 | // Record a hint that the variable does not have the given value. 51 | void ApplyIsNotValueHint(int value); 52 | 53 | private: 54 | // Value if hint directly provided the value, or -1 with no direct hint. 55 | int value_ = -1; 56 | std::vector value_plausible_; // Knowledge from not-value hints. 57 | }; 58 | 59 | class CardKnowledge { 60 | // Hinted knowledge about color and rank of an initially unknown card. 61 | public: 62 | CardKnowledge(int num_colors, int num_ranks); 63 | // Returns number of possible colors being tracked. 64 | int NumColors() const { return color_.Range(); } 65 | // Returns true if and only if the exact color was revealed. 66 | // Does not perform inference to get a known color from not-color hints. 67 | bool ColorHinted() const { return color_.ValueHinted(); } 68 | // Color of card if it was hinted, -1 if not hinted. 69 | int Color() const { return color_.Value(); } 70 | // Returns true if we have no hint saying card is not the given color. 71 | bool ColorPlausible(int color) const { return color_.IsPlausible(color); } 72 | void ApplyIsColorHint(int color) { color_.ApplyIsValueHint(color); } 73 | void ApplyIsNotColorHint(int color) { color_.ApplyIsNotValueHint(color); } 74 | // Returns number of possible ranks being tracked. 75 | int NumRanks() const { return rank_.Range(); } 76 | // Returns true if and only if the exact rank was revealed. 77 | // Does not perform inference to get a known rank from not-rank hints. 78 | bool RankHinted() const { return rank_.ValueHinted(); } 79 | // Rank of card if it was hinted, -1 if not hinted. 80 | int Rank() const { return rank_.Value(); } 81 | // Returns true if we have no hint saying card is not the given rank. 82 | bool RankPlausible(int rank) const { return rank_.IsPlausible(rank); } 83 | void ApplyIsRankHint(int rank) { rank_.ApplyIsValueHint(rank); } 84 | void ApplyIsNotRankHint(int rank) { rank_.ApplyIsNotValueHint(rank); } 85 | std::string ToString() const; 86 | 87 | private: 88 | ValueKnowledge color_; 89 | ValueKnowledge rank_; 90 | }; 91 | 92 | HanabiHand() {} 93 | HanabiHand(const HanabiHand& hand) 94 | : cards_(hand.cards_), card_knowledge_(hand.card_knowledge_) {} 95 | // Copy hand. Hide cards (set to invalid) if hide_cards is true. 96 | // Hide card knowledge (set to unknown) if hide_knowledge is true. 97 | HanabiHand(const HanabiHand& hand, bool hide_cards, bool hide_knowledge); 98 | // Cards and corresponding card knowledge are always arranged from oldest to 99 | // newest, with the oldest card or knowledge at index 0. 100 | const std::vector& Cards() const { return cards_; } 101 | const std::vector& Knowledge() const { 102 | return card_knowledge_; 103 | } 104 | void AddCard(HanabiCard card, const CardKnowledge& initial_knowledge); 105 | // Remove card_index card from hand. Put in discard_pile if not nullptr 106 | // (pushes the card to the back of the discard_pile vector). 107 | void RemoveFromHand(int card_index, std::vector* discard_pile); 108 | // Make cards with the given rank visible. 109 | // Returns new information bitmask, bit_i set if card_i color was revealed 110 | // and was previously unknown. 111 | uint8_t RevealRank(int rank); 112 | // Make cards with the given color visible. 113 | // Returns new information bitmask, bit_i set if card_i color was revealed 114 | // and was previously unknown. 115 | uint8_t RevealColor(int color); 116 | std::string ToString() const; 117 | 118 | private: 119 | // A set of cards and knowledge about them. 120 | std::vector cards_; 121 | std::vector card_knowledge_; 122 | }; 123 | 124 | } // namespace hanabi_learning_env 125 | 126 | #endif 127 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_state.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __HANABI_STATE_H__ 16 | #define __HANABI_STATE_H__ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "hanabi_card.h" 23 | #include "hanabi_game.h" 24 | #include "hanabi_hand.h" 25 | #include "hanabi_history_item.h" 26 | #include "hanabi_move.h" 27 | 28 | namespace hanabi_learning_env { 29 | 30 | constexpr int kChancePlayerId = -1; 31 | 32 | class HanabiState { 33 | public: 34 | class HanabiDeck { 35 | public: 36 | explicit HanabiDeck(const HanabiGame& game); 37 | // DealCard returns invalid card on failure. 38 | HanabiCard DealCard(int color, int rank); 39 | HanabiCard DealCard(std::mt19937* rng); 40 | int Size() const { return total_count_; } 41 | bool Empty() const { return total_count_ == 0; } 42 | int CardCount(int color, int rank) const { 43 | return card_count_[CardToIndex(color, rank)]; 44 | } 45 | 46 | private: 47 | int CardToIndex(int color, int rank) const { 48 | return color * num_ranks_ + rank; 49 | } 50 | int IndexToColor(int index) const { return index / num_ranks_; } 51 | int IndexToRank(int index) const { return index % num_ranks_; } 52 | 53 | // Number of instances in the deck for each card. 54 | // E.g., if card_count_[CardToIndex(card)] == 2, then there are two 55 | // instances of card remaining in the deck, available to be dealt out. 56 | std::vector card_count_; 57 | int total_count_ = -1; // Total number of cards available to be dealt out. 58 | int num_ranks_ = -1; // From game.NumRanks(), used to map card to index. 59 | }; 60 | 61 | enum EndOfGameType { 62 | kNotFinished, // Not the end of game. 63 | kOutOfLifeTokens, // Players ran out of life tokens. 64 | kOutOfCards, // Players ran out of cards. 65 | kCompletedFireworks // All fireworks played. 66 | }; 67 | 68 | // Construct a HanabiState, initialised to the start of the game. 69 | // If start_player >= 0, the game-provided start player is overridden 70 | // and the first player after chance is start_player. 71 | explicit HanabiState(const HanabiGame* parent_game, int start_player = -1); 72 | // Copy constructor for recursive game traversals using copy + apply-move. 73 | HanabiState(const HanabiState& state) = default; 74 | 75 | bool MoveIsLegal(HanabiMove move) const; 76 | void ApplyMove(HanabiMove move); 77 | // Legal moves for state. Moves point into an unchanging list in parent_game. 78 | std::vector LegalMoves(int player) const; 79 | // Returns true if card with color and rank can be played on fireworks pile. 80 | bool CardPlayableOnFireworks(int color, int rank) const; 81 | bool CardPlayableOnFireworks(HanabiCard card) const { 82 | return CardPlayableOnFireworks(card.Color(), card.Rank()); 83 | } 84 | bool ChanceOutcomeIsLegal(HanabiMove move) const { return MoveIsLegal(move); } 85 | double ChanceOutcomeProb(HanabiMove move) const; 86 | void ApplyChanceOutcome(HanabiMove move) { ApplyMove(move); } 87 | void ApplyRandomChance(); 88 | // Get the valid chance moves, and associated probabilities. 89 | // Guaranteed that moves.size() == probabilities.size(). 90 | std::pair, std::vector> ChanceOutcomes() 91 | const; 92 | EndOfGameType EndOfGameStatus() const; 93 | bool IsTerminal() const { return EndOfGameStatus() != kNotFinished; } 94 | int Score() const; 95 | std::string ToString() const; 96 | 97 | int CurPlayer() const { return cur_player_; } 98 | int LifeTokens() const { return life_tokens_; } 99 | int InformationTokens() const { return information_tokens_; } 100 | const std::vector& Hands() const { return hands_; } 101 | const std::vector& Fireworks() const { return fireworks_; } 102 | const HanabiGame* ParentGame() const { return parent_game_; } 103 | const HanabiDeck& Deck() const { return deck_; } 104 | // Get the discard pile (the element at the back is the most recent discard.) 105 | const std::vector& DiscardPile() const { return discard_pile_; } 106 | // Sequence of moves from beginning of game. Stored as . 107 | const std::vector& MoveHistory() const { 108 | return move_history_; 109 | } 110 | 111 | private: 112 | // Add card to table if possible, if not lose a life token. 113 | // Returns 114 | // success is true iff card was successfully added to fireworks. 115 | // information_token_added is true iff information_tokens increase 116 | // (i.e., success=true, highest rank was added, and not at max tokens.) 117 | std::pair AddToFireworks(HanabiCard card); 118 | const HanabiHand& HandByOffset(int offset) const { 119 | return hands_[(cur_player_ + offset) % hands_.size()]; 120 | } 121 | HanabiHand* HandByOffset(int offset) { 122 | return &hands_[(cur_player_ + offset) % hands_.size()]; 123 | } 124 | void AdvanceToNextPlayer(); // Set cur_player to next player to act. 125 | bool HintingIsLegal(HanabiMove move) const; 126 | int PlayerToDeal() const; // -1 if no player needs a card. 127 | bool IncrementInformationTokens(); 128 | void DecrementInformationTokens(); 129 | void DecrementLifeTokens(); 130 | 131 | const HanabiGame* parent_game_ = nullptr; 132 | HanabiDeck deck_; 133 | // Back element of discard_pile_ is most recently discarded card. 134 | std::vector discard_pile_; 135 | std::vector hands_; 136 | std::vector move_history_; 137 | int cur_player_ = -1; 138 | int next_non_chance_player_ = -1; // Next non-chance player to act. 139 | int information_tokens_ = -1; 140 | int life_tokens_ = -1; 141 | std::vector fireworks_; 142 | int turns_to_play_ = -1; // Number of turns to play once deck is empty. 143 | }; 144 | 145 | } // namespace hanabi_learning_env 146 | 147 | #endif 148 | -------------------------------------------------------------------------------- /agents/rainbow/third_party/dopamine/checkpointer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A checkpointing mechanism for Dopamine agents. 16 | 17 | This Checkpointer expects a base directory where checkpoints for different 18 | iterations are stored. Specifically, Checkpointer.save_checkpoint() takes in 19 | as input a dictionary 'data' to be pickled to disk. At each iteration, we 20 | write a file called 'cpkt.#', where # is the iteration number. The 21 | Checkpointer also cleans up old files, maintaining up to the CHECKPOINT_DURATION 22 | most recent iterations. 23 | 24 | The Checkpointer writes a sentinel file to indicate that checkpointing was 25 | globally successful. This means that all other checkpointing activities 26 | (saving the Tensorflow graph, the replay buffer) should be performed *prior* 27 | to calling Checkpointer.save_checkpoint(). This allows the Checkpointer to 28 | detect incomplete checkpoints. 29 | 30 | #### Example 31 | 32 | After running 10 iterations (numbered 0...9) with base_directory='/checkpoint', 33 | the following files will exist: 34 | ``` 35 | /checkpoint/cpkt.6 36 | /checkpoint/cpkt.7 37 | /checkpoint/cpkt.8 38 | /checkpoint/cpkt.9 39 | /checkpoint/sentinel_checkpoint_complete.6 40 | /checkpoint/sentinel_checkpoint_complete.7 41 | /checkpoint/sentinel_checkpoint_complete.8 42 | /checkpoint/sentinel_checkpoint_complete.9 43 | ``` 44 | """ 45 | 46 | from __future__ import absolute_import 47 | from __future__ import division 48 | from __future__ import print_function 49 | 50 | import os 51 | import pickle 52 | import tensorflow as tf 53 | 54 | CHECKPOINT_DURATION = 4 55 | 56 | 57 | def get_latest_checkpoint_number(base_directory): 58 | """Returns the version number of the latest completed checkpoint. 59 | 60 | Args: 61 | base_directory: str, directory in which to look for checkpoint files. 62 | 63 | Returns: 64 | int, the iteration number of the latest checkpoint, or -1 if none was found. 65 | """ 66 | glob = os.path.join(base_directory, 'sentinel_checkpoint_complete.*') 67 | def extract_iteration(x): 68 | return int(x[x.rfind('.') + 1:]) 69 | try: 70 | checkpoint_files = tf.gfile.Glob(glob) 71 | except tf.errors.NotFoundError: 72 | return -1 73 | try: 74 | latest_iteration = max(extract_iteration(x) for x in checkpoint_files) 75 | return latest_iteration 76 | except ValueError: 77 | return -1 78 | 79 | 80 | class Checkpointer(object): 81 | """Class for managing checkpoints for Dopamine agents. 82 | """ 83 | 84 | def __init__(self, base_directory, checkpoint_file_prefix='ckpt', 85 | checkpoint_frequency=1): 86 | """Initializes Checkpointer. 87 | 88 | Args: 89 | base_directory: str, directory where all checkpoints are saved/loaded. 90 | checkpoint_file_prefix: str, prefix to use for naming checkpoint files. 91 | checkpoint_frequency: int, the frequency at which to checkpoint. 92 | 93 | Raises: 94 | ValueError: if base_directory is empty, or not creatable. 95 | """ 96 | if not base_directory: 97 | raise ValueError('No path provided to Checkpointer.') 98 | self._checkpoint_file_prefix = checkpoint_file_prefix 99 | self._checkpoint_frequency = checkpoint_frequency 100 | self._base_directory = base_directory 101 | try: 102 | tf.gfile.MakeDirs(base_directory) 103 | except tf.errors.PermissionDeniedError: 104 | # We catch the PermissionDeniedError and issue a more useful exception. 105 | raise ValueError('Unable to create checkpoint path: {}.'.format( 106 | base_directory)) 107 | 108 | def _generate_filename(self, file_prefix, iteration_number): 109 | """Returns a checkpoint filename from prefix and iteration number.""" 110 | filename = '{}.{}'.format(file_prefix, iteration_number) 111 | return os.path.join(self._base_directory, filename) 112 | 113 | def _save_data_to_file(self, data, filename): 114 | """Saves the given 'data' object to a file.""" 115 | with tf.gfile.GFile(filename, 'w') as fout: 116 | pickle.dump(data, fout) 117 | 118 | def save_checkpoint(self, iteration_number, data): 119 | """Saves a new checkpoint at the current iteration_number. 120 | 121 | Args: 122 | iteration_number: int, the current iteration number for this checkpoint. 123 | data: Any (picklable) python object containing the data to store in the 124 | checkpoint. 125 | """ 126 | if iteration_number % self._checkpoint_frequency != 0: 127 | return 128 | 129 | filename = self._generate_filename(self._checkpoint_file_prefix, 130 | iteration_number) 131 | self._save_data_to_file(data, filename) 132 | filename = self._generate_filename('sentinel_checkpoint_complete', 133 | iteration_number) 134 | with tf.gfile.GFile(filename, 'wb') as fout: 135 | fout.write('done') 136 | 137 | self._clean_up_old_checkpoints(iteration_number) 138 | 139 | def _clean_up_old_checkpoints(self, iteration_number): 140 | """Removes sufficiently old checkpoints.""" 141 | # After writing a the checkpoint and sentinel file, we garbage collect files 142 | # that are CHECKPOINT_DURATION * self._checkpoint_frequency versions old. 143 | stale_iteration_number = iteration_number - (self._checkpoint_frequency * 144 | CHECKPOINT_DURATION) 145 | 146 | if stale_iteration_number >= 0: 147 | stale_file = self._generate_filename(self._checkpoint_file_prefix, 148 | stale_iteration_number) 149 | stale_sentinel = self._generate_filename('sentinel_checkpoint_complete', 150 | stale_iteration_number) 151 | try: 152 | tf.gfile.Remove(stale_file) 153 | tf.gfile.Remove(stale_sentinel) 154 | except tf.errors.NotFoundError: 155 | # Ignore if file not found. 156 | tf.logging.info('Unable to remove {} or {}.'.format(stale_file, 157 | stale_sentinel)) 158 | 159 | def _load_data_from_file(self, filename): 160 | if not tf.gfile.Exists(filename): 161 | return None 162 | with tf.gfile.GFile(filename, 'rb') as fin: 163 | return pickle.load(fin) 164 | 165 | def load_checkpoint(self, iteration_number): 166 | """Tries to reload a checkpoint at the selected iteration number. 167 | 168 | Args: 169 | iteration_number: The checkpoint iteration number to try to load. 170 | 171 | Returns: 172 | If the checkpoint files exist, two unpickled objects that were passed in 173 | as data to save_checkpoint; returns None if the files do not exist. 174 | """ 175 | checkpoint_file = self._generate_filename(self._checkpoint_file_prefix, 176 | iteration_number) 177 | return self._load_data_from_file(checkpoint_file) 178 | -------------------------------------------------------------------------------- /agents/rainbow/third_party/dopamine/sum_tree.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A sum tree data structure. 16 | 17 | Used for prioritized experience replay. See prioritized_replay_buffer.py 18 | and Schaul et al. (2015). 19 | """ 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import math 25 | import random 26 | 27 | import numpy as np 28 | 29 | 30 | class SumTree(object): 31 | """A sum tree data structure for storing replay priorities. 32 | 33 | A sum tree is a complete binary tree whose leaves contain values called 34 | priorities. Internal nodes maintain the sum of the priorities of all leaf 35 | nodes in their subtree. 36 | 37 | For capacity = 4, the tree may look like this: 38 | 39 | +---+ 40 | |2.5| 41 | +-+-+ 42 | | 43 | +-------+--------+ 44 | | | 45 | +-+-+ +-+-+ 46 | |1.5| |1.0| 47 | +-+-+ +-+-+ 48 | | | 49 | +----+----+ +----+----+ 50 | | | | | 51 | +-+-+ +-+-+ +-+-+ +-+-+ 52 | |0.5| |1.0| |0.5| |0.5| 53 | +---+ +---+ +---+ +---+ 54 | 55 | This is stored in a list of numpy arrays: 56 | self.nodes = [ [2.5], [1.5, 1], [0.5, 1, 0.5, 0.5] ] 57 | 58 | For conciseness, we allocate arrays as powers of two, and pad the excess 59 | elements with zero values. 60 | 61 | This is similar to the usual array-based representation of a complete binary 62 | tree, but is a little more user-friendly. 63 | """ 64 | 65 | def __init__(self, capacity): 66 | """Creates the sum tree data structure for the given replay capacity. 67 | 68 | Args: 69 | capacity: int, the maximum number of elements that can be stored in this 70 | data structure. 71 | 72 | Raises: 73 | ValueError: If requested capacity is not positive. 74 | """ 75 | assert isinstance(capacity, int) 76 | if capacity <= 0: 77 | raise ValueError('Sum tree capacity should be positive. Got: {}'. 78 | format(capacity)) 79 | 80 | self.nodes = [] 81 | tree_depth = int(math.ceil(np.log2(capacity))) 82 | level_size = 1 83 | for _ in range(tree_depth + 1): 84 | nodes_at_this_depth = np.zeros(level_size) 85 | self.nodes.append(nodes_at_this_depth) 86 | 87 | level_size *= 2 88 | 89 | self.max_recorded_priority = 1.0 90 | 91 | def _total_priority(self): 92 | """Returns the sum of all priorities stored in this sum tree. 93 | 94 | Returns: 95 | float, sum of priorities stored in this sum tree. 96 | """ 97 | return self.nodes[0][0] 98 | 99 | def sample(self, query_value=None): 100 | """Samples an element from the sum tree. 101 | 102 | Each element has probability p_i / sum_j p_j of being picked, where p_i is 103 | the (positive) value associated with node i (possibly unnormalized). 104 | 105 | Args: 106 | query_value: float in [0, 1], used as the random value to select a 107 | sample. If None, will select one randomly in [0, 1). 108 | 109 | Returns: 110 | int, a random element from the sum tree. 111 | 112 | Raises: 113 | Exception: If the sum tree is empty (i.e. its node values sum to 0), or if 114 | the supplied query_value is larger than the total sum. 115 | """ 116 | if self._total_priority() == 0.0: 117 | raise Exception('Cannot sample from an empty sum tree.') 118 | 119 | if query_value and (query_value < 0. or query_value > 1.): 120 | raise ValueError('query_value must be in [0, 1].') 121 | 122 | # Sample a value in range [0, R), where R is the value stored at the root. 123 | query_value = random.random() if query_value is None else query_value 124 | query_value *= self._total_priority() 125 | 126 | # Now traverse the sum tree. 127 | node_index = 0 128 | for nodes_at_this_depth in self.nodes[1:]: 129 | # Compute children of previous depth's node. 130 | left_child = node_index * 2 131 | 132 | left_sum = nodes_at_this_depth[left_child] 133 | # Each subtree describes a range [0, a), where a is its value. 134 | if query_value < left_sum: # Recurse into left subtree. 135 | node_index = left_child 136 | else: # Recurse into right subtree. 137 | node_index = left_child + 1 138 | # Adjust query to be relative to right subtree. 139 | query_value -= left_sum 140 | 141 | return node_index 142 | 143 | def stratified_sample(self, batch_size): 144 | """Performs stratified sampling using the sum tree. 145 | 146 | Let R be the value at the root (total value of sum tree). This method will 147 | divide [0, R) into batch_size segments, pick a random number from each of 148 | those segments, and use that random number to sample from the sum_tree. This 149 | is as specified in Schaul et al. (2015). 150 | 151 | Args: 152 | batch_size: int, the number of strata to use. 153 | Returns: 154 | list of batch_size elements sampled from the sum tree. 155 | 156 | Raises: 157 | Exception: If the sum tree is empty (i.e. its node values sum to 0). 158 | """ 159 | if self._total_priority() == 0.0: 160 | raise Exception('Cannot sample from an empty sum tree.') 161 | 162 | bounds = np.linspace(0., 1., batch_size + 1) 163 | assert len(bounds) == batch_size + 1 164 | segments = [(bounds[i], bounds[i+1]) for i in range(batch_size)] 165 | query_values = [random.uniform(x[0], x[1]) for x in segments] 166 | return [self.sample(query_value=x) for x in query_values] 167 | 168 | def get(self, node_index): 169 | """Returns the value of the leaf node corresponding to the index. 170 | 171 | Args: 172 | node_index: The index of the leaf node. 173 | Returns: 174 | The value of the leaf node. 175 | """ 176 | return self.nodes[-1][node_index] 177 | 178 | def set(self, node_index, value): 179 | """Sets the value of a leaf node and updates internal nodes accordingly. 180 | 181 | This operation takes O(log(capacity)). 182 | Args: 183 | node_index: int, the index of the leaf node to be updated. 184 | value: float, the value which we assign to the node. This value must be 185 | nonnegative. Setting value = 0 will cause the element to never be 186 | sampled. 187 | 188 | Raises: 189 | ValueError: If the given value is negative. 190 | """ 191 | if value < 0.0: 192 | raise ValueError('Sum tree values should be nonnegative. Got {}'. 193 | format(value)) 194 | self.max_recorded_priority = max(value, self.max_recorded_priority) 195 | 196 | delta_value = value - self.nodes[-1][node_index] 197 | 198 | # Now traverse back the tree, adjusting all sums along the way. 199 | for nodes_at_this_depth in reversed(self.nodes): 200 | # Note: Adding a delta leads to some tolerable numerical inaccuracies. 201 | nodes_at_this_depth[node_index] += delta_value 202 | node_index //= 2 203 | 204 | assert node_index == 0, ('Sum tree traversal failed, final node index ' 205 | 'is not 0.') 206 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_game.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_game.h" 16 | 17 | #include "util.h" 18 | 19 | namespace hanabi_learning_env { 20 | 21 | namespace { 22 | // Constants. 23 | const int kDefaultPlayers = 2; 24 | const int kInformationTokens = 8; 25 | const int kLifeTokens = 3; 26 | const bool kDefaultRandomStart = false; 27 | } // namespace 28 | 29 | HanabiGame::HanabiGame( 30 | const std::unordered_map& params) { 31 | params_ = params; 32 | num_players_ = ParameterValue(params_, "players", kDefaultPlayers); 33 | REQUIRE(num_players_ >= MinPlayers() && num_players_ <= MaxPlayers()); 34 | num_colors_ = ParameterValue(params_, "colors", kMaxNumColors); 35 | REQUIRE(num_colors_ > 0 && num_colors_ <= kMaxNumColors); 36 | num_ranks_ = ParameterValue(params_, "ranks", kMaxNumRanks); 37 | REQUIRE(num_ranks_ > 0 && num_ranks_ <= kMaxNumRanks); 38 | hand_size_ = ParameterValue(params_, "hand_size", HandSizeFromRules()); 39 | max_information_tokens_ = ParameterValue( 40 | params_, "max_information_tokens", kInformationTokens); 41 | max_life_tokens_ = 42 | ParameterValue(params_, "max_life_tokens", kLifeTokens); 43 | seed_ = ParameterValue(params_, "seed", -1); 44 | random_start_player_ = 45 | ParameterValue(params_, "random_start_player", kDefaultRandomStart); 46 | observation_type_ = AgentObservationType(ParameterValue( 47 | params_, "observation_type", AgentObservationType::kCardKnowledge)); 48 | while (seed_ == -1) { 49 | seed_ = std::random_device()(); 50 | } 51 | rng_.seed(seed_); 52 | 53 | // Work out number of cards per color, and check deck size is large enough. 54 | cards_per_color_ = 0; 55 | for (int rank = 0; rank < num_ranks_; ++rank) { 56 | cards_per_color_ += NumberCardInstances(0, rank); 57 | } 58 | REQUIRE(hand_size_ * num_players_ <= cards_per_color_ * num_colors_); 59 | 60 | // Build static list of moves. 61 | for (int uid = 0; uid < MaxMoves(); ++uid) { 62 | moves_.push_back(ConstructMove(uid)); 63 | } 64 | for (int uid = 0; uid < MaxChanceOutcomes(); ++uid) { 65 | chance_outcomes_.push_back(ConstructChanceOutcome(uid)); 66 | } 67 | } 68 | 69 | int HanabiGame::MaxMoves() const { 70 | return MaxDiscardMoves() + MaxPlayMoves() + MaxRevealColorMoves() + 71 | MaxRevealRankMoves(); 72 | } 73 | 74 | int HanabiGame::GetMoveUid(HanabiMove move) const { 75 | return GetMoveUid(move.MoveType(), move.CardIndex(), move.TargetOffset(), 76 | move.Color(), move.Rank()); 77 | } 78 | 79 | int HanabiGame::GetMoveUid(HanabiMove::Type move_type, int card_index, 80 | int target_offset, int color, int rank) const { 81 | switch (move_type) { 82 | case HanabiMove::kDiscard: 83 | return card_index; 84 | case HanabiMove::kPlay: 85 | return MaxDiscardMoves() + card_index; 86 | case HanabiMove::kRevealColor: 87 | return MaxDiscardMoves() + MaxPlayMoves() + 88 | (target_offset - 1) * NumColors() + color; 89 | case HanabiMove::kRevealRank: 90 | return MaxDiscardMoves() + MaxPlayMoves() + MaxRevealColorMoves() + 91 | (target_offset - 1) * NumRanks() + rank; 92 | default: 93 | return -1; 94 | } 95 | } 96 | 97 | int HanabiGame::MaxChanceOutcomes() const { return NumColors() * NumRanks(); } 98 | 99 | int HanabiGame::GetChanceOutcomeUid(HanabiMove move) const { 100 | if (move.MoveType() != HanabiMove::kDeal) { 101 | return -1; 102 | } 103 | return (move.TargetOffset() * NumColors() + move.Color()) * NumRanks() + 104 | move.Rank(); 105 | } 106 | 107 | HanabiMove HanabiGame::PickRandomChance( 108 | const std::pair, std::vector>& 109 | chance_outcomes) const { 110 | std::discrete_distribution dist( 111 | chance_outcomes.second.begin(), chance_outcomes.second.end()); 112 | return chance_outcomes.first[dist(rng_)]; 113 | } 114 | 115 | std::unordered_map HanabiGame::Parameters() const { 116 | return {{"players", std::to_string(num_players_)}, 117 | {"colors", std::to_string(NumColors())}, 118 | {"ranks", std::to_string(NumRanks())}, 119 | {"hand_size", std::to_string(HandSize())}, 120 | {"max_information_tokens", std::to_string(MaxInformationTokens())}, 121 | {"max_life_tokens", std::to_string(MaxLifeTokens())}, 122 | {"seed", std::to_string(seed_)}, 123 | {"random_start_player", random_start_player_ ? "true" : "false"}, 124 | {"observation_type", std::to_string(observation_type_)}}; 125 | } 126 | 127 | int HanabiGame::NumberCardInstances(int color, int rank) const { 128 | if (color < 0 || color >= NumColors() || rank < 0 || rank >= NumRanks()) { 129 | return 0; 130 | } 131 | if (rank == 0) { 132 | return 3; 133 | } else if (rank == NumRanks() - 1) { 134 | return 1; 135 | } 136 | return 2; 137 | } 138 | 139 | int HanabiGame::GetSampledStartPlayer() const { 140 | if (random_start_player_) { 141 | std::uniform_int_distribution dist( 142 | 0, num_players_ - 1); 143 | return dist(rng_); 144 | } 145 | return 0; 146 | } 147 | 148 | int HanabiGame::HandSizeFromRules() const { 149 | if (num_players_ < 4) { 150 | return 5; 151 | } 152 | return 4; 153 | } 154 | 155 | // Uid mapping. h=hand_size, p=num_players, c=colors, r=ranks 156 | // 0, h-1: discard 157 | // h, 2h-1: play 158 | // 2h, 2h+(p-1)c-1: color hint 159 | // 2h+(p-1)c, 2h+(p-1)c+(p-1)r-1: rank hint 160 | HanabiMove HanabiGame::ConstructMove(int uid) const { 161 | if (uid < 0 || uid >= MaxMoves()) { 162 | return HanabiMove(HanabiMove::kInvalid, /*card_index=*/-1, 163 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1); 164 | } 165 | if (uid < MaxDiscardMoves()) { 166 | return HanabiMove(HanabiMove::kDiscard, /*card_index=*/uid, 167 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1); 168 | } 169 | uid -= MaxDiscardMoves(); 170 | if (uid < MaxPlayMoves()) { 171 | return HanabiMove(HanabiMove::kPlay, /*card_index=*/uid, 172 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1); 173 | } 174 | uid -= MaxPlayMoves(); 175 | if (uid < MaxRevealColorMoves()) { 176 | return HanabiMove(HanabiMove::kRevealColor, /*card_index=*/-1, 177 | /*target_offset=*/1 + uid / NumColors(), 178 | /*color=*/uid % NumColors(), /*rank=*/-1); 179 | } 180 | uid -= MaxRevealColorMoves(); 181 | return HanabiMove(HanabiMove::kRevealRank, /*card_index=*/-1, 182 | /*target_offset=*/1 + uid / NumRanks(), 183 | /*color=*/-1, /*rank=*/uid % NumRanks()); 184 | } 185 | 186 | HanabiMove HanabiGame::ConstructChanceOutcome(int uid) const { 187 | if (uid < 0 || uid >= MaxChanceOutcomes()) { 188 | return HanabiMove(HanabiMove::kInvalid, /*card_index=*/-1, 189 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1); 190 | } 191 | return HanabiMove(HanabiMove::kDeal, /*card_index=*/-1, 192 | /*target_offset=*/-1, 193 | /*color=*/uid / NumRanks() % NumColors(), 194 | /*rank=*/uid % NumRanks()); 195 | } 196 | 197 | } // namespace hanabi_learning_env 198 | -------------------------------------------------------------------------------- /pyhanabi.h: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef __PYHANABI_H__ 16 | #define __PYHANABI_H__ 17 | 18 | /** 19 | * This is a pure C API to the C++ code. 20 | * All the declarations are loaded in pyhanabi.py. 21 | * The set of functions below is referred to as the 'cdef' throughout the code. 22 | */ 23 | 24 | extern "C" { 25 | 26 | typedef struct PyHanabiCard { 27 | int color; 28 | int rank; 29 | } pyhanabi_card_t; 30 | 31 | typedef struct PyHanabiCardKnowledge { 32 | /* Points to a hanabi_learning_env::HanabiHand::CardKnowledge. */ 33 | const void* knowledge; 34 | } pyhanabi_card_knowledge_t; 35 | 36 | typedef struct PyHanabiMove { 37 | /* Points to a hanabi_learning_env::HanabiMove. */ 38 | void* move; 39 | } pyhanabi_move_t; 40 | 41 | typedef struct PyHanabiHistoryItem { 42 | /* Points to a hanabi_learning_env::HanabiHistoryItem. */ 43 | void* item; 44 | } pyhanabi_history_item_t; 45 | 46 | typedef struct PyHanabiState { 47 | /* Points to a hanabi_learning_env::HanabiState. */ 48 | void* state; 49 | } pyhanabi_state_t; 50 | 51 | typedef struct PyHanabiGame { 52 | /* Points to a hanabi_learning_env::HanabiGame. */ 53 | void* game; 54 | } pyhanabi_game_t; 55 | 56 | typedef struct PyHanabiObservation { 57 | /* Points to a hanabi_learning_env::HanabiObservation. */ 58 | void* observation; 59 | } pyhanabi_observation_t; 60 | 61 | typedef struct PyHanabiObservationEncoder { 62 | /* Points to a hanabi_learning_env::ObservationEncoder. */ 63 | void* encoder; 64 | } pyhanabi_observation_encoder_t; 65 | 66 | /* Utility Functions. */ 67 | void DeleteString(char* str); 68 | 69 | /* Card functions. */ 70 | int CardValid(pyhanabi_card_t* card); 71 | 72 | /* CardKnowledge functions */ 73 | char* CardKnowledgeToString(pyhanabi_card_knowledge_t* knowledge); 74 | int ColorWasHinted(pyhanabi_card_knowledge_t* knowledge); 75 | int KnownColor(pyhanabi_card_knowledge_t* knowledge); 76 | int ColorIsPlausible(pyhanabi_card_knowledge_t* knowledge, int color); 77 | int RankWasHinted(pyhanabi_card_knowledge_t* knowledge); 78 | int KnownRank(pyhanabi_card_knowledge_t* knowledge); 79 | int RankIsPlausible(pyhanabi_card_knowledge_t* knowledge, int rank); 80 | 81 | /* Move functions. */ 82 | void DeleteMoveList(void* movelist); 83 | int NumMoves(void* movelist); 84 | void GetMove(void* movelist, int index, pyhanabi_move_t* move); 85 | void DeleteMove(pyhanabi_move_t* move); 86 | char* MoveToString(pyhanabi_move_t* move); 87 | int MoveType(pyhanabi_move_t* move); 88 | int CardIndex(pyhanabi_move_t* move); 89 | int TargetOffset(pyhanabi_move_t* move); 90 | int MoveColor(pyhanabi_move_t* move); 91 | int MoveRank(pyhanabi_move_t* move); 92 | bool GetDiscardMove(int card_index, pyhanabi_move_t* move); 93 | bool GetPlayMove(int card_index, pyhanabi_move_t* move); 94 | bool GetRevealColorMove(int target_offset, int color, pyhanabi_move_t* move); 95 | bool GetRevealRankMove(int target_offset, int rank, pyhanabi_move_t* move); 96 | 97 | /* HistoryItem functions. */ 98 | void DeleteHistoryItem(pyhanabi_history_item_t* item); 99 | char* HistoryItemToString(pyhanabi_history_item_t* item); 100 | void HistoryItemMove(pyhanabi_history_item_t* item, pyhanabi_move_t* move); 101 | int HistoryItemPlayer(pyhanabi_history_item_t* item); 102 | int HistoryItemScored(pyhanabi_history_item_t* item); 103 | int HistoryItemInformationToken(pyhanabi_history_item_t* item); 104 | int HistoryItemColor(pyhanabi_history_item_t* item); 105 | int HistoryItemRank(pyhanabi_history_item_t* item); 106 | int HistoryItemRevealBitmask(pyhanabi_history_item_t* item); 107 | int HistoryItemNewlyRevealedBitmask(pyhanabi_history_item_t* item); 108 | int HistoryItemDealToPlayer(pyhanabi_history_item_t* item); 109 | 110 | /* State functions. */ 111 | void NewState(pyhanabi_game_t* game, pyhanabi_state_t* state); 112 | void CopyState(const pyhanabi_state_t* src, pyhanabi_state_t* dest); 113 | void DeleteState(pyhanabi_state_t* state); 114 | const void* StateParentGame(pyhanabi_state_t* state); 115 | void StateApplyMove(pyhanabi_state_t* state, pyhanabi_move_t* move); 116 | int StateCurPlayer(pyhanabi_state_t* state); 117 | void StateDealRandomCard(pyhanabi_state_t* state); 118 | int StateDeckSize(pyhanabi_state_t* state); 119 | int StateFireworks(pyhanabi_state_t* state, int color); 120 | int StateDiscardPileSize(pyhanabi_state_t* state); 121 | void StateGetDiscard(pyhanabi_state_t* state, int index, pyhanabi_card_t* card); 122 | int StateGetHandSize(pyhanabi_state_t* state, int pid); 123 | void StateGetHandCard(pyhanabi_state_t* state, int pid, int index, 124 | pyhanabi_card_t* card); 125 | int StateEndOfGameStatus(pyhanabi_state_t* state); 126 | int StateInformationTokens(pyhanabi_state_t* state); 127 | void* StateLegalMoves(pyhanabi_state_t* state); 128 | int StateLifeTokens(pyhanabi_state_t* state); 129 | int StateNumPlayers(pyhanabi_state_t* state); 130 | int StateScore(pyhanabi_state_t* state); 131 | char* StateToString(pyhanabi_state_t* state); 132 | bool MoveIsLegal(const pyhanabi_state_t* state, const pyhanabi_move_t* move); 133 | bool CardPlayableOnFireworks(const pyhanabi_state_t* state, int color, 134 | int rank); 135 | int StateLenMoveHistory(pyhanabi_state_t* state); 136 | void StateGetMoveHistory(pyhanabi_state_t* state, int index, 137 | pyhanabi_history_item_t* item); 138 | 139 | /* Game functions. */ 140 | void DeleteGame(pyhanabi_game_t* game); 141 | void NewDefaultGame(pyhanabi_game_t* game); 142 | void NewGame(pyhanabi_game_t* game, int list_length, const char** param_list); 143 | char* GameParamString(pyhanabi_game_t* game); 144 | int NumPlayers(pyhanabi_game_t* game); 145 | int NumColors(pyhanabi_game_t* game); 146 | int NumRanks(pyhanabi_game_t* game); 147 | int HandSize(pyhanabi_game_t* game); 148 | int MaxInformationTokens(pyhanabi_game_t* game); 149 | int MaxLifeTokens(pyhanabi_game_t* game); 150 | int ObservationType(pyhanabi_game_t* game); 151 | int NumCards(pyhanabi_game_t* game, int color, int rank); 152 | int GetMoveUid(pyhanabi_game_t* game, pyhanabi_move_t* move); 153 | void GetMoveByUid(pyhanabi_game_t* game, int move_uid, pyhanabi_move_t* move); 154 | int MaxMoves(pyhanabi_game_t* game); 155 | 156 | /* Observation functions. */ 157 | void NewObservation(pyhanabi_state_t* state, int player, 158 | pyhanabi_observation_t* observation); 159 | void DeleteObservation(pyhanabi_observation_t* observation); 160 | char* ObsToString(pyhanabi_observation_t* observation); 161 | int ObsCurPlayerOffset(pyhanabi_observation_t* observation); 162 | int ObsNumPlayers(pyhanabi_observation_t* observation); 163 | int ObsGetHandSize(pyhanabi_observation_t* observation, int pid); 164 | void ObsGetHandCard(pyhanabi_observation_t* observation, int pid, int index, 165 | pyhanabi_card_t* card); 166 | void ObsGetHandCardKnowledge(pyhanabi_observation_t* observation, int pid, 167 | int index, pyhanabi_card_knowledge_t* knowledge); 168 | int ObsDiscardPileSize(pyhanabi_observation_t* observation); 169 | void ObsGetDiscard(pyhanabi_observation_t* observation, int index, 170 | pyhanabi_card_t* card); 171 | int ObsFireworks(pyhanabi_observation_t* observation, int color); 172 | int ObsDeckSize(pyhanabi_observation_t* observation); 173 | int ObsNumLastMoves(pyhanabi_observation_t* observation); 174 | void ObsGetLastMove(pyhanabi_observation_t* observation, int index, 175 | pyhanabi_history_item_t* item); 176 | int ObsInformationTokens(pyhanabi_observation_t* observation); 177 | int ObsLifeTokens(pyhanabi_observation_t* observation); 178 | int ObsNumLegalMoves(pyhanabi_observation_t* observation); 179 | void ObsGetLegalMove(pyhanabi_observation_t* observation, int index, 180 | pyhanabi_move_t* move); 181 | bool ObsCardPlayableOnFireworks(const pyhanabi_observation_t* observation, 182 | int color, int rank); 183 | 184 | /* ObservationEncoder functions. */ 185 | void NewObservationEncoder(pyhanabi_observation_encoder_t* encoder, 186 | pyhanabi_game_t* game, int type); 187 | void DeleteObservationEncoder(pyhanabi_observation_encoder_t* encoder); 188 | char* ObservationShape(pyhanabi_observation_encoder_t* encoder); 189 | char* EncodeObservation(pyhanabi_observation_encoder_t* encoder, 190 | pyhanabi_observation_t* observation); 191 | 192 | void EncodeObs(pyhanabi_observation_encoder_t *encoder, 193 | pyhanabi_observation_t *observation, 194 | int *encoding); 195 | 196 | } /* extern "C" */ 197 | 198 | #endif 199 | -------------------------------------------------------------------------------- /agents/rainbow/prioritized_replay_memory.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors and Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # 17 | # 18 | # This file is a fork of the original Dopamine code incorporating changes for 19 | # the multiplayer setting and the Hanabi Learning Environment. 20 | # 21 | """An implementation of Prioritized Experience Replay (PER). 22 | 23 | This implementation is based on the paper "Prioritized Experience Replay" 24 | by Tom Schaul et al. (2015). Many thanks to Tom Schaul, John Quan, and Matteo 25 | Hessel for providing useful pointers on the algorithm and its implementation. 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | from third_party.dopamine import sum_tree 33 | import gin.tf 34 | import numpy as np 35 | import replay_memory 36 | import tensorflow as tf 37 | 38 | DEFAULT_PRIORITY = 100.0 39 | 40 | 41 | class OutOfGraphPrioritizedReplayMemory(replay_memory.OutOfGraphReplayMemory): 42 | """An Out of Graph Replay Memory for Prioritized Experience Replay. 43 | 44 | See replay_memory.py for details. 45 | """ 46 | 47 | def __init__(self, num_actions, observation_size, stack_size, replay_capacity, 48 | batch_size, update_horizon=1, gamma=1.0): 49 | """This data structure does the heavy lifting in the replay memory. 50 | 51 | Args: 52 | num_actions: int, number of actions. 53 | observation_size: int, size of an input observation. 54 | stack_size: int, number of frames to use in state stack. 55 | replay_capacity: int, number of transitions to keep in memory. 56 | batch_size: int, batch size. 57 | update_horizon: int, length of update ('n' in n-step update). 58 | gamma: int, the discount factor. 59 | """ 60 | super(OutOfGraphPrioritizedReplayMemory, self).__init__( 61 | num_actions=num_actions, 62 | observation_size=observation_size, stack_size=stack_size, 63 | replay_capacity=replay_capacity, batch_size=batch_size, 64 | update_horizon=update_horizon, gamma=gamma) 65 | 66 | self.sum_tree = sum_tree.SumTree(replay_capacity) 67 | 68 | def add(self, observation, action, reward, terminal, legal_actions): 69 | """Adds a transition to the replay memory. 70 | 71 | Since the next_observation in the transition will be the observation added 72 | next there is no need to pass it. 73 | 74 | If the replay memory is at capacity the oldest transition will be discarded. 75 | 76 | Compared to OutOfGraphReplayMemory.add(), this version also sets the 77 | priority of dummy frames to 0. 78 | 79 | Args: 80 | observation: `np.array` uint8, (observation_size, observation_size). 81 | action: int, indicating the action in the transition. 82 | reward: float, indicating the reward received in the transition. 83 | terminal: int, acting as a boolean indicating whether the transition 84 | was terminal (1) or not (0). 85 | legal_actions: Binary vector indicating legal actions (1 == legal). 86 | """ 87 | if self.is_empty() or self.terminals[self.cursor() - 1] == 1: 88 | dummy_observation = np.zeros((self._observation_size)) 89 | dummy_legal_actions = np.zeros((self._num_actions)) 90 | for _ in range(self._stack_size - 1): 91 | self._add(dummy_observation, 0, 0, 0, dummy_legal_actions, priority=0.0) 92 | 93 | self._add(observation, action, reward, terminal, legal_actions, 94 | priority=DEFAULT_PRIORITY) 95 | 96 | def _add(self, observation, action, reward, terminal, legal_actions, 97 | priority=DEFAULT_PRIORITY): 98 | new_element_index = self.cursor() 99 | 100 | super(OutOfGraphPrioritizedReplayMemory, self)._add( 101 | observation, action, reward, terminal, legal_actions) 102 | 103 | self.sum_tree.set(new_element_index, priority) 104 | 105 | def sample_index_batch(self, batch_size): 106 | """Returns a batch of valid indices. 107 | 108 | Args: 109 | batch_size: int, number of indices returned. 110 | 111 | Returns: 112 | List of size batch_size containing valid indices. 113 | 114 | Raises: 115 | Exception: If the batch was not constructed after maximum number of tries. 116 | """ 117 | indices = [] 118 | allowed_attempts = replay_memory.MAX_SAMPLE_ATTEMPTS 119 | 120 | while len(indices) < batch_size and allowed_attempts > 0: 121 | index = self.sum_tree.sample() 122 | 123 | if self.is_valid_transition(index): 124 | indices.append(index) 125 | else: 126 | allowed_attempts -= 1 127 | 128 | if len(indices) != batch_size: 129 | raise Exception('Could only sample {} valid transitions'.format( 130 | len(indices))) 131 | else: 132 | return indices 133 | 134 | def set_priority(self, indices, priorities): 135 | """Sets the priority of the given elements according to Schaul et al. 136 | 137 | Args: 138 | indices: `np.array` of indices in range [0, replay_capacity). 139 | priorities: list of floats, the corresponding priorities. 140 | """ 141 | assert indices.dtype == np.int32, ('Indices must be integers, ' 142 | 'given: {}'.format(indices.dtype)) 143 | for i, memory_index in enumerate(indices): 144 | self.sum_tree.set(memory_index, priorities[i]) 145 | 146 | def get_priority(self, indices, batch_size=None): 147 | """Fetches the priorities correspond to a batch of memory indices. 148 | 149 | For any memory location not yet used, the corresponding priority is 0. 150 | 151 | Args: 152 | indices: `np.array` of indices in range [0, replay_capacity). 153 | batch_size: int, requested number of items. 154 | Returns: 155 | The corresponding priorities. 156 | """ 157 | if batch_size is None: 158 | batch_size = self._batch_size 159 | if batch_size != self._state_batch.shape[0]: 160 | self.reset_state_batch_arrays(batch_size) 161 | 162 | priority_batch = np.empty((batch_size), dtype=np.float32) 163 | 164 | assert indices.dtype == np.int32, ('Indices must be integers, ' 165 | 'given: {}'.format(indices.dtype)) 166 | for i, memory_index in enumerate(indices): 167 | priority_batch[i] = self.sum_tree.get(memory_index) 168 | 169 | return priority_batch 170 | 171 | 172 | @gin.configurable(blacklist=['observation_size', 'stack_size']) 173 | class WrappedPrioritizedReplayMemory(replay_memory.WrappedReplayMemory): 174 | """In graph wrapper for the python Replay Memory. 175 | 176 | Usage: 177 | To add a transition: run the operation add_transition_op 178 | (and feed all the placeholders in add_transition_ph) 179 | 180 | To sample a batch: Construct operations that depend on any of the 181 | sampling tensors. Every sess.run using any of these 182 | tensors will sample a new transition. 183 | 184 | When using staging: Need to prefetch the next batch with each train_op by 185 | calling self.prefetch_batch. 186 | 187 | Everytime this op is called a new transition batch 188 | would be prefetched. 189 | 190 | Attributes: 191 | # The following tensors are sampled randomly each sess.run 192 | states 193 | actions 194 | rewards 195 | next_states 196 | terminals 197 | 198 | add_transition_op: tf operation to add a transition to the replay 199 | memory. All the following placeholders need to be fed. 200 | add_obs_ph 201 | add_action_ph 202 | add_reward_ph 203 | add_terminal_ph 204 | """ 205 | 206 | def __init__(self, 207 | num_actions, 208 | observation_size, 209 | stack_size, 210 | use_staging=True, 211 | replay_capacity=1000000, 212 | batch_size=32, 213 | update_horizon=1, 214 | gamma=1.0): 215 | """Initializes a graph wrapper for the python Replay Memory. 216 | 217 | Args: 218 | num_actions: int, number of possible actions. 219 | observation_size: int, size of an input observation. 220 | stack_size: int, number of frames to use in state stack. 221 | use_staging: bool, when True it would use a staging area to prefetch 222 | the next sampling batch. 223 | replay_capacity: int, number of transitions to keep in memory. 224 | batch_size: int. 225 | update_horizon: int, length of update ('n' in n-step update). 226 | gamma: int, the discount factor. 227 | 228 | Raises: 229 | ValueError: If update_horizon is not positive. 230 | ValueError: If discount factor is not in [0, 1]. 231 | """ 232 | memory = OutOfGraphPrioritizedReplayMemory(num_actions, observation_size, 233 | stack_size, replay_capacity, 234 | batch_size, update_horizon, 235 | gamma) 236 | super(WrappedPrioritizedReplayMemory, self).__init__( 237 | num_actions, 238 | observation_size, stack_size, use_staging, replay_capacity, batch_size, 239 | update_horizon, gamma, wrapped_memory=memory) 240 | 241 | def tf_set_priority(self, indices, losses): 242 | """Sets the priorities for the given indices. 243 | 244 | Args: 245 | indices: tensor of indices (int32), size k. 246 | losses: tensor of losses (float), size k. 247 | 248 | Returns: 249 | A TF op setting the priorities according to Prioritized Experience 250 | Replay. 251 | """ 252 | return tf.py_func( 253 | self.memory.set_priority, [indices, losses], 254 | [], 255 | name='prioritized_replay_set_priority_py_func') 256 | 257 | def tf_get_priority(self, indices): 258 | """Gets the priorities for the given indices. 259 | 260 | Args: 261 | indices: tensor of indices (int32), size k. 262 | 263 | Returns: 264 | A tensor (float32) of priorities. 265 | """ 266 | return tf.py_func( 267 | self.memory.get_priority, [indices], 268 | [tf.float32], 269 | name='prioritized_replay_get_priority_py_func') 270 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /agents/rainbow/third_party/dopamine/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 The Dopamine Authors. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /hanabi_lib/hanabi_state.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "hanabi_state.h" 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "util.h" 22 | 23 | namespace hanabi_learning_env { 24 | 25 | namespace { 26 | // Returns bitmask of card indices which match color. 27 | uint8_t HandColorBitmask(const HanabiHand& hand, int color) { 28 | uint8_t mask = 0; 29 | const auto& cards = hand.Cards(); 30 | assert(cards.size() <= 8); // More than 8 cards is not supported. 31 | for (int i = 0; i < cards.size(); ++i) { 32 | if (cards[i].Color() == color) { 33 | mask |= static_cast(1) << i; 34 | } 35 | } 36 | return mask; 37 | } 38 | 39 | // Returns bitmask of card indices which match color. 40 | uint8_t HandRankBitmask(const HanabiHand& hand, int rank) { 41 | uint8_t mask = 0; 42 | const auto& cards = hand.Cards(); 43 | assert(cards.size() <= 8); // More than 8 cards is not supported. 44 | for (int i = 0; i < cards.size(); ++i) { 45 | if (cards[i].Rank() == rank) { 46 | mask |= static_cast(1) << i; 47 | } 48 | } 49 | return mask; 50 | } 51 | } // namespace 52 | 53 | HanabiState::HanabiDeck::HanabiDeck(const HanabiGame& game) 54 | : card_count_(game.NumColors() * game.NumRanks(), 0), 55 | total_count_(0), 56 | num_ranks_(game.NumRanks()) { 57 | for (int color = 0; color < game.NumColors(); ++color) { 58 | for (int rank = 0; rank < game.NumRanks(); ++rank) { 59 | auto count = game.NumberCardInstances(color, rank); 60 | card_count_[CardToIndex(color, rank)] = count; 61 | total_count_ += count; 62 | } 63 | } 64 | } 65 | 66 | HanabiCard HanabiState::HanabiDeck::DealCard(std::mt19937* rng) { 67 | if (Empty()) { 68 | return HanabiCard(); 69 | } 70 | std::discrete_distribution dist( 71 | card_count_.begin(), card_count_.end()); 72 | int index = dist(*rng); 73 | assert(card_count_[index] > 0); 74 | --card_count_[index]; 75 | --total_count_; 76 | return HanabiCard(IndexToColor(index), IndexToRank(index)); 77 | } 78 | 79 | HanabiCard HanabiState::HanabiDeck::DealCard(int color, int rank) { 80 | int index = CardToIndex(color, rank); 81 | if (card_count_[index] <= 0) { 82 | return HanabiCard(); 83 | } 84 | assert(card_count_[index] > 0); 85 | --card_count_[index]; 86 | --total_count_; 87 | return HanabiCard(IndexToColor(index), IndexToRank(index)); 88 | } 89 | 90 | HanabiState::HanabiState(const HanabiGame* parent_game, int start_player) 91 | : parent_game_(parent_game), 92 | deck_(*parent_game), 93 | hands_(parent_game->NumPlayers()), 94 | cur_player_(kChancePlayerId), 95 | next_non_chance_player_(start_player >= 0 && 96 | start_player < parent_game->NumPlayers() 97 | ? start_player 98 | : parent_game->GetSampledStartPlayer()), 99 | information_tokens_(parent_game->MaxInformationTokens()), 100 | life_tokens_(parent_game->MaxLifeTokens()), 101 | fireworks_(parent_game->NumColors(), 0), 102 | turns_to_play_(parent_game->NumPlayers()) {} 103 | 104 | void HanabiState::AdvanceToNextPlayer() { 105 | if (!deck_.Empty() && PlayerToDeal() >= 0) { 106 | cur_player_ = kChancePlayerId; 107 | } else { 108 | cur_player_ = next_non_chance_player_; 109 | next_non_chance_player_ = (cur_player_ + 1) % hands_.size(); 110 | } 111 | } 112 | 113 | bool HanabiState::IncrementInformationTokens() { 114 | if (information_tokens_ < ParentGame()->MaxInformationTokens()) { 115 | ++information_tokens_; 116 | return true; 117 | } else { 118 | return false; 119 | } 120 | } 121 | 122 | void HanabiState::DecrementInformationTokens() { 123 | assert(information_tokens_ > 0); 124 | --information_tokens_; 125 | } 126 | 127 | void HanabiState::DecrementLifeTokens() { 128 | assert(life_tokens_ > 0); 129 | --life_tokens_; 130 | } 131 | 132 | std::pair HanabiState::AddToFireworks(HanabiCard card) { 133 | if (CardPlayableOnFireworks(card)) { 134 | ++fireworks_[card.Color()]; 135 | // Check if player completed a stack. 136 | if (fireworks_[card.Color()] == ParentGame()->NumRanks()) { 137 | return {true, IncrementInformationTokens()}; 138 | } 139 | return {true, false}; 140 | } else { 141 | DecrementLifeTokens(); 142 | return {false, false}; 143 | } 144 | } 145 | 146 | bool HanabiState::HintingIsLegal(HanabiMove move) const { 147 | if (InformationTokens() <= 0) { 148 | return false; 149 | } 150 | if (move.TargetOffset() < 1 || 151 | move.TargetOffset() >= ParentGame()->NumPlayers()) { 152 | return false; 153 | } 154 | return true; 155 | } 156 | 157 | int HanabiState::PlayerToDeal() const { 158 | for (int i = 0; i < hands_.size(); ++i) { 159 | if (hands_[i].Cards().size() < ParentGame()->HandSize()) { 160 | return i; 161 | } 162 | } 163 | return -1; 164 | } 165 | 166 | bool HanabiState::MoveIsLegal(HanabiMove move) const { 167 | switch (move.MoveType()) { 168 | case HanabiMove::kDeal: 169 | if (cur_player_ != kChancePlayerId) { 170 | return false; 171 | } 172 | if (deck_.CardCount(move.Color(), move.Rank()) == 0) { 173 | return false; 174 | } 175 | break; 176 | case HanabiMove::kDiscard: 177 | if (InformationTokens() >= ParentGame()->MaxInformationTokens()) { 178 | return false; 179 | } 180 | if (move.CardIndex() >= hands_[cur_player_].Cards().size()) { 181 | return false; 182 | } 183 | break; 184 | case HanabiMove::kPlay: 185 | if (move.CardIndex() >= hands_[cur_player_].Cards().size()) { 186 | return false; 187 | } 188 | break; 189 | case HanabiMove::kRevealColor: { 190 | if (!HintingIsLegal(move)) { 191 | return false; 192 | } 193 | const auto& cards = HandByOffset(move.TargetOffset()).Cards(); 194 | if (!std::any_of(cards.begin(), cards.end(), 195 | [move](const HanabiCard& card) { 196 | return card.Color() == move.Color(); 197 | })) { 198 | return false; 199 | } 200 | break; 201 | } 202 | case HanabiMove::kRevealRank: { 203 | if (!HintingIsLegal(move)) { 204 | return false; 205 | } 206 | const auto& cards = HandByOffset(move.TargetOffset()).Cards(); 207 | if (!std::any_of(cards.begin(), cards.end(), 208 | [move](const HanabiCard& card) { 209 | return card.Rank() == move.Rank(); 210 | })) { 211 | return false; 212 | } 213 | break; 214 | } 215 | default: 216 | return false; 217 | } 218 | return true; 219 | } 220 | 221 | void HanabiState::ApplyMove(HanabiMove move) { 222 | REQUIRE(MoveIsLegal(move)); 223 | if (deck_.Empty()) { 224 | --turns_to_play_; 225 | } 226 | HanabiHistoryItem history(move); 227 | history.player = cur_player_; 228 | switch (move.MoveType()) { 229 | case HanabiMove::kDeal: { 230 | history.deal_to_player = PlayerToDeal(); 231 | HanabiHand::CardKnowledge card_knowledge(ParentGame()->NumColors(), 232 | ParentGame()->NumRanks()); 233 | if (parent_game_->ObservationType() == HanabiGame::kSeer){ 234 | card_knowledge.ApplyIsColorHint(move.Color()); 235 | card_knowledge.ApplyIsRankHint(move.Rank()); 236 | } 237 | hands_[history.deal_to_player].AddCard( 238 | deck_.DealCard(move.Color(), move.Rank()), 239 | card_knowledge); 240 | } 241 | break; 242 | case HanabiMove::kDiscard: 243 | history.information_token = IncrementInformationTokens(); 244 | history.color = hands_[cur_player_].Cards()[move.CardIndex()].Color(); 245 | history.rank = hands_[cur_player_].Cards()[move.CardIndex()].Rank(); 246 | hands_[cur_player_].RemoveFromHand(move.CardIndex(), &discard_pile_); 247 | break; 248 | case HanabiMove::kPlay: 249 | history.color = hands_[cur_player_].Cards()[move.CardIndex()].Color(); 250 | history.rank = hands_[cur_player_].Cards()[move.CardIndex()].Rank(); 251 | std::tie(history.scored, history.information_token) = 252 | AddToFireworks(hands_[cur_player_].Cards()[move.CardIndex()]); 253 | hands_[cur_player_].RemoveFromHand( 254 | move.CardIndex(), history.scored ? nullptr : &discard_pile_); 255 | break; 256 | case HanabiMove::kRevealColor: 257 | DecrementInformationTokens(); 258 | history.reveal_bitmask = 259 | HandColorBitmask(*HandByOffset(move.TargetOffset()), move.Color()); 260 | history.newly_revealed_bitmask = 261 | HandByOffset(move.TargetOffset())->RevealColor(move.Color()); 262 | break; 263 | case HanabiMove::kRevealRank: 264 | DecrementInformationTokens(); 265 | history.reveal_bitmask = 266 | HandRankBitmask(*HandByOffset(move.TargetOffset()), move.Rank()); 267 | history.newly_revealed_bitmask = 268 | HandByOffset(move.TargetOffset())->RevealRank(move.Rank()); 269 | break; 270 | default: 271 | std::abort(); // Should not be possible. 272 | } 273 | move_history_.push_back(history); 274 | AdvanceToNextPlayer(); 275 | } 276 | 277 | double HanabiState::ChanceOutcomeProb(HanabiMove move) const { 278 | return static_cast(deck_.CardCount(move.Color(), move.Rank())) / 279 | static_cast(deck_.Size()); 280 | } 281 | 282 | void HanabiState::ApplyRandomChance() { 283 | auto chance_outcomes = ChanceOutcomes(); 284 | REQUIRE(!chance_outcomes.second.empty()); 285 | ApplyMove(ParentGame()->PickRandomChance(chance_outcomes)); 286 | } 287 | 288 | std::vector HanabiState::LegalMoves(int player) const { 289 | std::vector movelist; 290 | // kChancePlayer=-1 must be handled by ChanceOutcome. 291 | REQUIRE(player >= 0 && player < ParentGame()->NumPlayers()); 292 | if (player != cur_player_) { 293 | // Turn-based game. Empty move list for other players. 294 | return movelist; 295 | } 296 | int max_move_uid = ParentGame()->MaxMoves(); 297 | for (int uid = 0; uid < max_move_uid; ++uid) { 298 | HanabiMove move = ParentGame()->GetMove(uid); 299 | if (MoveIsLegal(move)) { 300 | movelist.push_back(move); 301 | } 302 | } 303 | return movelist; 304 | } 305 | 306 | bool HanabiState::CardPlayableOnFireworks(int color, int rank) const { 307 | if (color < 0 || color >= ParentGame()->NumColors()) { 308 | return false; 309 | } 310 | return rank == fireworks_[color]; 311 | } 312 | 313 | std::pair, std::vector> 314 | HanabiState::ChanceOutcomes() const { 315 | std::pair, std::vector> rv; 316 | int max_outcome_uid = ParentGame()->MaxChanceOutcomes(); 317 | for (int uid = 0; uid < max_outcome_uid; ++uid) { 318 | HanabiMove move = ParentGame()->GetChanceOutcome(uid); 319 | if (MoveIsLegal(move)) { 320 | rv.first.push_back(move); 321 | rv.second.push_back(ChanceOutcomeProb(move)); 322 | } 323 | } 324 | return rv; 325 | } 326 | 327 | // Format: :: 328 | // -....:: 329 | // -.... || -... 330 | // :.... 331 | // ::-... 332 | std::string HanabiState::ToString() const { 333 | std::string result; 334 | result += "Life tokens: " + std::to_string(LifeTokens()) + "\n"; 335 | result += "Info tokens: " + std::to_string(InformationTokens()) + "\n"; 336 | result += "Fireworks: "; 337 | for (int i = 0; i < ParentGame()->NumColors(); ++i) { 338 | result += ColorIndexToChar(i); 339 | result += std::to_string(fireworks_[i]) + " "; 340 | } 341 | result += "\nHands:\n"; 342 | for (int i = 0; i < hands_.size(); ++i) { 343 | if (i > 0) { 344 | result += "-----\n"; 345 | } 346 | if (i == CurPlayer()) { 347 | result += "Cur player\n"; 348 | } 349 | result += hands_[i].ToString(); 350 | } 351 | result += "Deck size: " + std::to_string(Deck().Size()) + "\n"; 352 | result += "Discards:"; 353 | for (int i = 0; i < discard_pile_.size(); ++i) { 354 | result += " " + discard_pile_[i].ToString(); 355 | } 356 | return result; 357 | } 358 | 359 | int HanabiState::Score() const { 360 | if (LifeTokens() <= 0) { 361 | return 0; 362 | } 363 | return std::accumulate(fireworks_.begin(), fireworks_.end(), 0); 364 | } 365 | 366 | HanabiState::EndOfGameType HanabiState::EndOfGameStatus() const { 367 | if (LifeTokens() < 1) { 368 | return kOutOfLifeTokens; 369 | } 370 | if (Score() >= ParentGame()->NumColors() * ParentGame()->NumRanks()) { 371 | return kCompletedFireworks; 372 | } 373 | if (turns_to_play_ <= 0) { 374 | return kOutOfCards; 375 | } 376 | return kNotFinished; 377 | } 378 | 379 | } // namespace hanabi_learning_env 380 | -------------------------------------------------------------------------------- /agents/rainbow/rainbow_agent.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors and Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # 17 | # 18 | # This file is a fork of the original Dopamine code incorporating changes for 19 | # the multiplayer setting and the Hanabi Learning Environment. 20 | # 21 | """Implementation of a Rainbow agent adapted to the multiplayer setting.""" 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import functools 28 | 29 | import dqn_agent 30 | import gin.tf 31 | import numpy as np 32 | import prioritized_replay_memory 33 | import tensorflow as tf 34 | 35 | 36 | slim = tf.contrib.slim 37 | 38 | 39 | @gin.configurable 40 | def rainbow_template(state, 41 | num_actions, 42 | num_atoms=51, 43 | layer_size=512, 44 | num_layers=1): 45 | r"""Builds a Rainbow Network mapping states to value distributions. 46 | 47 | Args: 48 | state: A `tf.placeholder` for the RL state. 49 | num_actions: int, number of actions that the RL agent can take. 50 | num_atoms: int, number of atoms to approximate the distribution with. 51 | layer_size: int, number of hidden units per layer. 52 | num_layers: int, number of hidden layers. 53 | 54 | Returns: 55 | net: A `tf.Graphdef` for Rainbow: 56 | `\theta : \mathcal{X}\rightarrow\mathbb{R}^{|\mathcal{A}| \times N}`, 57 | where `N` is num_atoms. 58 | """ 59 | weights_initializer = slim.variance_scaling_initializer( 60 | factor=1.0 / np.sqrt(3.0), mode='FAN_IN', uniform=True) 61 | 62 | net = tf.cast(state, tf.float32) 63 | net = tf.squeeze(net, axis=2) 64 | 65 | for _ in range(num_layers): 66 | net = slim.fully_connected(net, layer_size, 67 | activation_fn=tf.nn.relu) 68 | net = slim.fully_connected(net, num_actions * num_atoms, activation_fn=None, 69 | weights_initializer=weights_initializer) 70 | net = tf.reshape(net, [-1, num_actions, num_atoms]) 71 | return net 72 | 73 | 74 | @gin.configurable 75 | class RainbowAgent(dqn_agent.DQNAgent): 76 | """A compact implementation of the multiplayer Rainbow agent.""" 77 | 78 | @gin.configurable 79 | def __init__(self, 80 | num_actions=None, 81 | observation_size=None, 82 | num_players=None, 83 | num_atoms=51, 84 | vmax=25., 85 | gamma=0.99, 86 | update_horizon=1, 87 | min_replay_history=500, 88 | update_period=4, 89 | target_update_period=500, 90 | epsilon_train=0.0, 91 | epsilon_eval=0.0, 92 | epsilon_decay_period=1000, 93 | learning_rate=0.000025, 94 | optimizer_epsilon=0.00003125, 95 | tf_device='/cpu:*'): 96 | """Initializes the agent and constructs its graph. 97 | 98 | Args: 99 | num_actions: int, number of actions the agent can take at any state. 100 | observation_size: int, size of observation vector. 101 | num_players: int, number of players playing this game. 102 | num_atoms: Int, the number of buckets for the value function distribution. 103 | vmax: float, maximum return predicted by a value distribution. 104 | gamma: float, discount factor as commonly used in the RL literature. 105 | update_horizon: int, horizon at which updates are performed, the 'n' in 106 | n-step update. 107 | min_replay_history: int, number of stored transitions before training. 108 | update_period: int, period between DQN updates. 109 | target_update_period: int, update period for the target network. 110 | epsilon_train: float, final epsilon for training. 111 | epsilon_eval: float, epsilon during evaluation. 112 | epsilon_decay_period: int, number of steps for epsilon to decay. 113 | learning_rate: float, learning rate for the optimizer. 114 | optimizer_epsilon: float, epsilon for Adam optimizer. 115 | tf_device: str, Tensorflow device on which to run computations. 116 | """ 117 | # We need this because some tools convert round floats into ints. 118 | vmax = float(vmax) 119 | self.num_atoms = num_atoms 120 | # Using -vmax as the minimum return is is wasteful, because all rewards are 121 | # positive -- but does not unduly affect performance. 122 | self.support = tf.linspace(-vmax, vmax, num_atoms) 123 | self.learning_rate = learning_rate 124 | self.optimizer_epsilon = optimizer_epsilon 125 | 126 | graph_template = functools.partial(rainbow_template, num_atoms=num_atoms) 127 | super(RainbowAgent, self).__init__( 128 | num_actions=num_actions, 129 | observation_size=observation_size, 130 | num_players=num_players, 131 | gamma=gamma, 132 | update_horizon=update_horizon, 133 | min_replay_history=min_replay_history, 134 | update_period=update_period, 135 | target_update_period=target_update_period, 136 | epsilon_train=epsilon_train, 137 | epsilon_eval=epsilon_eval, 138 | epsilon_decay_period=epsilon_decay_period, 139 | graph_template=graph_template, 140 | tf_device=tf_device) 141 | tf.logging.info('\t learning_rate: %f', learning_rate) 142 | tf.logging.info('\t optimizer_epsilon: %f', optimizer_epsilon) 143 | 144 | def _build_replay_memory(self, use_staging): 145 | """Creates the replay memory used by the agent. 146 | 147 | Rainbow uses prioritized replay. 148 | 149 | Args: 150 | use_staging: bool, whether to use a staging area in the replay memory. 151 | 152 | Returns: 153 | A replay memory object. 154 | """ 155 | return prioritized_replay_memory.WrappedPrioritizedReplayMemory( 156 | num_actions=self.num_actions, 157 | observation_size=self.observation_size, 158 | stack_size=1, 159 | use_staging=use_staging, 160 | update_horizon=self.update_horizon, 161 | gamma=self.gamma) 162 | 163 | def _reshape_networks(self): 164 | # self._q is actually logits now, rename things. 165 | # size of _logits: 1 x num_actions x num_atoms 166 | self._logits = self._q 167 | # size of _probabilities: 1 x num_actions x num_atoms 168 | self._probabilities = tf.contrib.layers.softmax(self._q) 169 | # size of _q: 1 x num_actions 170 | self._q = tf.reduce_sum(self.support * self._probabilities, axis=2) 171 | # Recompute argmax from q values. Ignore illegal actions. 172 | self._q_argmax = tf.argmax(self._q + self.legal_actions_ph, axis=1)[0] 173 | 174 | # size of _replay_logits: 1 x num_actions x num_atoms 175 | self._replay_logits = self._replay_qs 176 | # size of _replay_next_logits: 1 x num_actions x num_atoms 177 | self._replay_next_logits = self._replay_next_qt 178 | del self._replay_qs 179 | del self._replay_next_qt 180 | 181 | def _build_target_distribution(self): 182 | self._reshape_networks() 183 | batch_size = tf.shape(self._replay.rewards)[0] 184 | # size of rewards: batch_size x 1 185 | rewards = self._replay.rewards[:, None] 186 | # size of tiled_support: batch_size x num_atoms 187 | tiled_support = tf.tile(self.support, [batch_size]) 188 | tiled_support = tf.reshape(tiled_support, [batch_size, self.num_atoms]) 189 | # size of target_support: batch_size x num_atoms 190 | 191 | is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32) 192 | # Incorporate terminal state to discount factor. 193 | # size of gamma_with_terminal: batch_size x 1 194 | gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier 195 | gamma_with_terminal = gamma_with_terminal[:, None] 196 | 197 | target_support = rewards + gamma_with_terminal * tiled_support 198 | # size of next_probabilities: batch_size x num_actions x num_atoms 199 | next_probabilities = tf.contrib.layers.softmax( 200 | self._replay_next_logits) 201 | 202 | # size of next_qt: 1 x num_actions 203 | next_qt = tf.reduce_sum(self.support * next_probabilities, 2) 204 | # size of next_qt_argmax: 1 x batch_size 205 | next_qt_argmax = tf.argmax( 206 | next_qt + self._replay.next_legal_actions, axis=1)[:, None] 207 | batch_indices = tf.range(tf.to_int64(batch_size))[:, None] 208 | # size of next_qt_argmax: batch_size x 2 209 | next_qt_argmax = tf.concat([batch_indices, next_qt_argmax], axis=1) 210 | # size of next_probabilities: batch_size x num_atoms 211 | next_probabilities = tf.gather_nd(next_probabilities, next_qt_argmax) 212 | return project_distribution(target_support, next_probabilities, 213 | self.support) 214 | 215 | def _build_train_op(self): 216 | """Builds the training op for Rainbow. 217 | 218 | Returns: 219 | train_op: An op performing one step of training. 220 | """ 221 | target_distribution = tf.stop_gradient(self._build_target_distribution()) 222 | 223 | # size of indices: batch_size x 1. 224 | indices = tf.range(tf.shape(self._replay_logits)[0])[:, None] 225 | # size of reshaped_actions: batch_size x 2. 226 | reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1) 227 | # For each element of the batch, fetch the logits for its selected action. 228 | chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions) 229 | 230 | loss = tf.nn.softmax_cross_entropy_with_logits( 231 | labels=target_distribution, 232 | logits=chosen_action_logits) 233 | 234 | optimizer = tf.train.AdamOptimizer( 235 | learning_rate=self.learning_rate, 236 | epsilon=self.optimizer_epsilon) 237 | 238 | update_priorities_op = self._replay.tf_set_priority( 239 | self._replay.indices, tf.sqrt(loss + 1e-10)) 240 | 241 | target_priorities = self._replay.tf_get_priority(self._replay.indices) 242 | target_priorities = tf.math.add(target_priorities, 1e-10) 243 | target_priorities = 1.0 / tf.sqrt(target_priorities) 244 | target_priorities /= tf.reduce_max(target_priorities) 245 | 246 | weighted_loss = target_priorities * loss 247 | 248 | with tf.control_dependencies([update_priorities_op]): 249 | return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss 250 | 251 | 252 | def project_distribution(supports, weights, target_support, 253 | validate_args=False): 254 | """Projects a batch of (support, weights) onto target_support. 255 | 256 | Based on equation (7) in (Bellemare et al., 2017): 257 | https://arxiv.org/abs/1707.06887 258 | In the rest of the comments we will refer to this equation simply as Eq7. 259 | 260 | This code is not easy to digest, so we will use a running example to clarify 261 | what is going on, with the following sample inputs: 262 | * supports = [[0, 2, 4, 6, 8], 263 | [1, 3, 4, 5, 6]] 264 | * weights = [[0.1, 0.6, 0.1, 0.1, 0.1], 265 | [0.1, 0.2, 0.5, 0.1, 0.1]] 266 | * target_support = [4, 5, 6, 7, 8] 267 | In the code below, comments preceded with 'Ex:' will be referencing the above 268 | values. 269 | 270 | Args: 271 | supports: Tensor of shape (batch_size, num_dims) defining supports for the 272 | distribution. 273 | weights: Tensor of shape (batch_size, num_dims) defining weights on the 274 | original support points. Although for the CategoricalDQN agent these 275 | weights are probabilities, it is not required that they are. 276 | target_support: Tensor of shape (num_dims) defining support of the projected 277 | distribution. The values must be monotonically increasing. Vmin and Vmax 278 | will be inferred from the first and last elements of this tensor, 279 | respectively. The values in this tensor must be equally spaced. 280 | validate_args: Whether we will verify the contents of the 281 | target_support parameter. 282 | 283 | Returns: 284 | A Tensor of shape (batch_size, num_dims) with the projection of a batch of 285 | (support, weights) onto target_support. 286 | 287 | Raises: 288 | ValueError: If target_support has no dimensions, or if shapes of supports, 289 | weights, and target_support are incompatible. 290 | """ 291 | target_support_deltas = target_support[1:] - target_support[:-1] 292 | # delta_z = `\Delta z` in Eq7. 293 | delta_z = target_support_deltas[0] 294 | validate_deps = [] 295 | supports.shape.assert_is_compatible_with(weights.shape) 296 | supports[0].shape.assert_is_compatible_with(target_support.shape) 297 | target_support.shape.assert_has_rank(1) 298 | if validate_args: 299 | # Assert that supports and weights have the same shapes. 300 | validate_deps.append( 301 | tf.Assert( 302 | tf.reduce_all(tf.equal(tf.shape(supports), tf.shape(weights))), 303 | [supports, weights])) 304 | # Assert that elements of supports and target_support have the same shape. 305 | validate_deps.append( 306 | tf.Assert( 307 | tf.reduce_all( 308 | tf.equal(tf.shape(supports)[1], tf.shape(target_support))), 309 | [supports, target_support])) 310 | # Assert that target_support has a single dimension. 311 | validate_deps.append( 312 | tf.Assert( 313 | tf.equal(tf.size(tf.shape(target_support)), 1), [target_support])) 314 | # Assert that the target_support is monotonically increasing. 315 | validate_deps.append( 316 | tf.Assert(tf.reduce_all(target_support_deltas > 0), [target_support])) 317 | # Assert that the values in target_support are equally spaced. 318 | validate_deps.append( 319 | tf.Assert( 320 | tf.reduce_all(tf.equal(target_support_deltas, delta_z)), 321 | [target_support])) 322 | 323 | with tf.control_dependencies(validate_deps): 324 | # Ex: `v_min, v_max = 4, 8`. 325 | v_min, v_max = target_support[0], target_support[-1] 326 | # Ex: `batch_size = 2`. 327 | batch_size = tf.shape(supports)[0] 328 | # `N` in Eq7. 329 | # Ex: `num_dims = 5`. 330 | num_dims = tf.shape(target_support)[0] 331 | # clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7. 332 | # Ex: `clipped_support = [[[ 4. 4. 4. 6. 8.]] 333 | # [[ 4. 4. 4. 5. 6.]]]`. 334 | clipped_support = tf.clip_by_value(supports, v_min, v_max)[:, None, :] 335 | # Ex: `tiled_support = [[[[ 4. 4. 4. 6. 8.] 336 | # [ 4. 4. 4. 6. 8.] 337 | # [ 4. 4. 4. 6. 8.] 338 | # [ 4. 4. 4. 6. 8.] 339 | # [ 4. 4. 4. 6. 8.]] 340 | # [[ 4. 4. 4. 5. 6.] 341 | # [ 4. 4. 4. 5. 6.] 342 | # [ 4. 4. 4. 5. 6.] 343 | # [ 4. 4. 4. 5. 6.] 344 | # [ 4. 4. 4. 5. 6.]]]]`. 345 | tiled_support = tf.tile([clipped_support], [1, 1, num_dims, 1]) 346 | # Ex: `reshaped_target_support = [[[ 4.] 347 | # [ 5.] 348 | # [ 6.] 349 | # [ 7.] 350 | # [ 8.]] 351 | # [[ 4.] 352 | # [ 5.] 353 | # [ 6.] 354 | # [ 7.] 355 | # [ 8.]]]`. 356 | reshaped_target_support = tf.tile(target_support[:, None], [batch_size, 1]) 357 | reshaped_target_support = tf.reshape(reshaped_target_support, 358 | [batch_size, num_dims, 1]) 359 | # numerator = `|clipped_support - z_i|` in Eq7. 360 | # Ex: `numerator = [[[[ 0. 0. 0. 2. 4.] 361 | # [ 1. 1. 1. 1. 3.] 362 | # [ 2. 2. 2. 0. 2.] 363 | # [ 3. 3. 3. 1. 1.] 364 | # [ 4. 4. 4. 2. 0.]] 365 | # [[ 0. 0. 0. 1. 2.] 366 | # [ 1. 1. 1. 0. 1.] 367 | # [ 2. 2. 2. 1. 0.] 368 | # [ 3. 3. 3. 2. 1.] 369 | # [ 4. 4. 4. 3. 2.]]]]`. 370 | numerator = tf.abs(tiled_support - reshaped_target_support) 371 | quotient = 1 - (numerator / delta_z) 372 | # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7. 373 | # Ex: `clipped_quotient = [[[[ 1. 1. 1. 0. 0.] 374 | # [ 0. 0. 0. 0. 0.] 375 | # [ 0. 0. 0. 1. 0.] 376 | # [ 0. 0. 0. 0. 0.] 377 | # [ 0. 0. 0. 0. 1.]] 378 | # [[ 1. 1. 1. 0. 0.] 379 | # [ 0. 0. 0. 1. 0.] 380 | # [ 0. 0. 0. 0. 1.] 381 | # [ 0. 0. 0. 0. 0.] 382 | # [ 0. 0. 0. 0. 0.]]]]`. 383 | clipped_quotient = tf.clip_by_value(quotient, 0, 1) 384 | # Ex: `weights = [[ 0.1 0.6 0.1 0.1 0.1] 385 | # [ 0.1 0.2 0.5 0.1 0.1]]`. 386 | weights = weights[:, None, :] 387 | # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))` 388 | # in Eq7. 389 | # Ex: `inner_prod = [[[[ 0.1 0.6 0.1 0. 0. ] 390 | # [ 0. 0. 0. 0. 0. ] 391 | # [ 0. 0. 0. 0.1 0. ] 392 | # [ 0. 0. 0. 0. 0. ] 393 | # [ 0. 0. 0. 0. 0.1]] 394 | # [[ 0.1 0.2 0.5 0. 0. ] 395 | # [ 0. 0. 0. 0.1 0. ] 396 | # [ 0. 0. 0. 0. 0.1] 397 | # [ 0. 0. 0. 0. 0. ] 398 | # [ 0. 0. 0. 0. 0. ]]]]`. 399 | inner_prod = clipped_quotient * weights 400 | # Ex: `projection = [[ 0.8 0.0 0.1 0.0 0.1] 401 | # [ 0.8 0.1 0.1 0.0 0.0]]`. 402 | projection = tf.reduce_sum(inner_prod, 3) 403 | projection = tf.reshape(projection, [batch_size, num_dims]) 404 | return projection 405 | -------------------------------------------------------------------------------- /hanabi_lib/canonical_encoders.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2018 Google LLC 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // https://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include "canonical_encoders.h" 22 | 23 | namespace hanabi_learning_env { 24 | 25 | namespace { 26 | 27 | // Computes the product of dimensions in shape, i.e. how many individual 28 | // pieces of data the encoded observation requires. 29 | int FlatLength(const std::vector& shape) { 30 | return std::accumulate(std::begin(shape), std::end(shape), 1, 31 | std::multiplies()); 32 | } 33 | 34 | const HanabiHistoryItem* GetLastNonDealMove( 35 | const std::vector& past_moves) { 36 | auto it = std::find_if( 37 | past_moves.begin(), past_moves.end(), [](const HanabiHistoryItem& item) { 38 | return item.move.MoveType() != HanabiMove::Type::kDeal; 39 | }); 40 | return it == past_moves.end() ? nullptr : &(*it); 41 | } 42 | 43 | int BitsPerCard(const HanabiGame& game) { 44 | return game.NumColors() * game.NumRanks(); 45 | } 46 | 47 | // The card's one-hot index using a color-major ordering. 48 | int CardIndex(int color, int rank, int num_ranks) { 49 | return color * num_ranks + rank; 50 | } 51 | 52 | int HandsSectionLength(const HanabiGame& game) { 53 | return (game.NumPlayers() - 1) * game.HandSize() * BitsPerCard(game) + 54 | game.NumPlayers(); 55 | } 56 | 57 | // Enocdes cards in all other player's hands (excluding our unknown hand), 58 | // and whether the hand is missing a card for all players (when deck is empty.) 59 | // Each card in a hand is encoded with a one-hot representation using 60 | // * bits (25 bits in a standard game) per card. 61 | // Returns the number of entries written to the encoding. 62 | int EncodeHands(const HanabiGame& game, const HanabiObservation& obs, 63 | int start_offset, std::vector* encoding) { 64 | int bits_per_card = BitsPerCard(game); 65 | int num_ranks = game.NumRanks(); 66 | int num_players = game.NumPlayers(); 67 | int hand_size = game.HandSize(); 68 | 69 | int offset = start_offset; 70 | const std::vector& hands = obs.Hands(); 71 | assert(hands.size() == num_players); 72 | for (int player = 1; player < num_players; ++player) { 73 | const std::vector& cards = hands[player].Cards(); 74 | int num_cards = 0; 75 | 76 | for (const HanabiCard& card : cards) { 77 | // Only a player's own cards can be invalid/unobserved. 78 | assert(card.IsValid()); 79 | assert(card.Color() < game.NumColors()); 80 | assert(card.Rank() < num_ranks); 81 | (*encoding)[offset + CardIndex(card.Color(), card.Rank(), num_ranks)] = 1; 82 | 83 | ++num_cards; 84 | offset += bits_per_card; 85 | } 86 | 87 | // A player's hand can have fewer cards than the initial hand size. 88 | // Leave the bits for the absent cards empty (adjust the offset to skip 89 | // bits for the missing cards). 90 | if (num_cards < hand_size) { 91 | offset += (hand_size - num_cards) * bits_per_card; 92 | } 93 | } 94 | 95 | // For each player, set a bit if their hand is missing a card. 96 | for (int player = 0; player < num_players; ++player) { 97 | if (hands[player].Cards().size() < game.HandSize()) { 98 | (*encoding)[offset + player] = 1; 99 | } 100 | } 101 | offset += num_players; 102 | 103 | assert(offset - start_offset == HandsSectionLength(game)); 104 | return offset - start_offset; 105 | } 106 | 107 | int BoardSectionLength(const HanabiGame& game) { 108 | return game.MaxDeckSize() - game.NumPlayers() * game.HandSize() + // deck 109 | game.NumColors() * game.NumRanks() + // fireworks 110 | game.MaxInformationTokens() + // info tokens 111 | game.MaxLifeTokens(); // life tokens 112 | } 113 | 114 | // Encode the board, including: 115 | // - remaining deck size 116 | // (max_deck_size - num_players * hand_size bits; thermometer) 117 | // - state of the fireworks ( bits per color; one-hot) 118 | // - information tokens remaining (max_information_tokens bits; thermometer) 119 | // - life tokens remaining (max_life_tokens bits; thermometer) 120 | // We note several features use a thermometer representation instead of one-hot. 121 | // For example, life tokens could be: 000 (0), 100 (1), 110 (2), 111 (3). 122 | // Returns the number of entries written to the encoding. 123 | int EncodeBoard(const HanabiGame& game, const HanabiObservation& obs, 124 | int start_offset, std::vector* encoding) { 125 | int num_colors = game.NumColors(); 126 | int num_ranks = game.NumRanks(); 127 | int num_players = game.NumPlayers(); 128 | int hand_size = game.HandSize(); 129 | int max_deck_size = game.MaxDeckSize(); 130 | 131 | int offset = start_offset; 132 | // Encode the deck size 133 | for (int i = 0; i < obs.DeckSize(); ++i) { 134 | (*encoding)[offset + i] = 1; 135 | } 136 | offset += (max_deck_size - hand_size * num_players); // 40 in normal 2P game 137 | 138 | // fireworks 139 | const std::vector& fireworks = obs.Fireworks(); 140 | for (int c = 0; c < num_colors; ++c) { 141 | // fireworks[color] is the number of successfully played cards. 142 | // If some were played, one-hot encode the highest (0-indexed) rank played 143 | if (fireworks[c] > 0) { 144 | (*encoding)[offset + fireworks[c] - 1] = 1; 145 | } 146 | offset += num_ranks; 147 | } 148 | 149 | // info tokens 150 | assert(obs.InformationTokens() >= 0); 151 | assert(obs.InformationTokens() <= game.MaxInformationTokens()); 152 | for (int i = 0; i < obs.InformationTokens(); ++i) { 153 | (*encoding)[offset + i] = 1; 154 | } 155 | offset += game.MaxInformationTokens(); 156 | 157 | // life tokens 158 | assert(obs.LifeTokens() >= 0); 159 | assert(obs.LifeTokens() <= game.MaxLifeTokens()); 160 | for (int i = 0; i < obs.LifeTokens(); ++i) { 161 | (*encoding)[offset + i] = 1; 162 | } 163 | offset += game.MaxLifeTokens(); 164 | 165 | assert(offset - start_offset == BoardSectionLength(game)); 166 | return offset - start_offset; 167 | } 168 | 169 | int DiscardSectionLength(const HanabiGame& game) { return game.MaxDeckSize(); } 170 | 171 | // Encode the discard pile. (max_deck_size bits) 172 | // Encoding is in color-major ordering, as in kColorStr ("RYGWB"), with each 173 | // color and rank using a thermometer to represent the number of cards 174 | // discarded. For example, in a standard game, there are 3 cards of lowest rank 175 | // (1), 1 card of highest rank (5), 2 of all else. So each color would be 176 | // ordered like so: 177 | // 178 | // LLL H 179 | // 1100011101 180 | // 181 | // This means for this color: 182 | // - 2 cards of the lowest rank have been discarded 183 | // - none of the second lowest rank have been discarded 184 | // - both of the third lowest rank have been discarded 185 | // - one of the second highest rank have been discarded 186 | // - the highest rank card has been discarded 187 | // Returns the number of entries written to the encoding. 188 | int EncodeDiscards(const HanabiGame& game, const HanabiObservation& obs, 189 | int start_offset, std::vector* encoding) { 190 | int num_colors = game.NumColors(); 191 | int num_ranks = game.NumRanks(); 192 | 193 | int offset = start_offset; 194 | std::vector discard_counts(num_colors * num_ranks, 0); 195 | for (const HanabiCard& card : obs.DiscardPile()) { 196 | ++discard_counts[card.Color() * num_ranks + card.Rank()]; 197 | } 198 | 199 | for (int c = 0; c < num_colors; ++c) { 200 | for (int r = 0; r < num_ranks; ++r) { 201 | int num_discarded = discard_counts[c * num_ranks + r]; 202 | for (int i = 0; i < num_discarded; ++i) { 203 | (*encoding)[offset + i] = 1; 204 | } 205 | offset += game.NumberCardInstances(c, r); 206 | } 207 | } 208 | 209 | assert(offset - start_offset == DiscardSectionLength(game)); 210 | return offset - start_offset; 211 | } 212 | 213 | int LastActionSectionLength(const HanabiGame& game) { 214 | return game.NumPlayers() + // player id 215 | 4 + // move types (play, dis, rev col, rev rank) 216 | game.NumPlayers() + // target player id (if hint action) 217 | game.NumColors() + // color (if hint action) 218 | game.NumRanks() + // rank (if hint action) 219 | game.HandSize() + // outcome (if hint action) 220 | game.HandSize() + // position (if play action) 221 | BitsPerCard(game) + // card (if play or discard action) 222 | 2; // play (successful, added information token) 223 | } 224 | 225 | // Encode the last player action (not chance's deal of cards). This encodes: 226 | // - Acting player index, relative to ourself ( bits; one-hot) 227 | // - The MoveType (4 bits; one-hot) 228 | // - Target player index, relative to acting player, if a reveal move 229 | // ( bits; one-hot) 230 | // - Color revealed, if a reveal color move ( bits; one-hot) 231 | // - Rank revealed, if a reveal rank move ( bits; one-hot) 232 | // - Reveal outcome ( bits; each bit is 1 if the card was hinted at) 233 | // - Position played/discarded ( bits; one-hot) 234 | // - Card played/discarded ( * bits; one-hot) 235 | // Returns the number of entries written to the encoding. 236 | int EncodeLastAction(const HanabiGame& game, const HanabiObservation& obs, 237 | int start_offset, std::vector* encoding) { 238 | int num_colors = game.NumColors(); 239 | int num_ranks = game.NumRanks(); 240 | int num_players = game.NumPlayers(); 241 | int hand_size = game.HandSize(); 242 | 243 | int offset = start_offset; 244 | const HanabiHistoryItem* last_move = GetLastNonDealMove(obs.LastMoves()); 245 | if (last_move == nullptr) { 246 | offset += LastActionSectionLength(game); 247 | } else { 248 | HanabiMove::Type last_move_type = last_move->move.MoveType(); 249 | 250 | // player_id 251 | // Note: no assertion here. At a terminal state, the last player could have 252 | // been me (player id 0). 253 | (*encoding)[offset + last_move->player] = 1; 254 | offset += num_players; 255 | 256 | // move type 257 | switch (last_move_type) { 258 | case HanabiMove::Type::kPlay: 259 | (*encoding)[offset] = 1; 260 | break; 261 | case HanabiMove::Type::kDiscard: 262 | (*encoding)[offset + 1] = 1; 263 | break; 264 | case HanabiMove::Type::kRevealColor: 265 | (*encoding)[offset + 2] = 1; 266 | break; 267 | case HanabiMove::Type::kRevealRank: 268 | (*encoding)[offset + 3] = 1; 269 | break; 270 | default: 271 | std::abort(); 272 | } 273 | offset += 4; 274 | 275 | // target player (if hint action) 276 | if (last_move_type == HanabiMove::Type::kRevealColor || 277 | last_move_type == HanabiMove::Type::kRevealRank) { 278 | int8_t observer_relative_target = 279 | (last_move->player + last_move->move.TargetOffset()) % num_players; 280 | (*encoding)[offset + observer_relative_target] = 1; 281 | } 282 | offset += num_players; 283 | 284 | // color (if hint action) 285 | if (last_move_type == HanabiMove::Type::kRevealColor) { 286 | (*encoding)[offset + last_move->move.Color()] = 1; 287 | } 288 | offset += num_colors; 289 | 290 | // rank (if hint action) 291 | if (last_move_type == HanabiMove::Type::kRevealRank) { 292 | (*encoding)[offset + last_move->move.Rank()] = 1; 293 | } 294 | offset += num_ranks; 295 | 296 | // outcome (if hinted action) 297 | if (last_move_type == HanabiMove::Type::kRevealColor || 298 | last_move_type == HanabiMove::Type::kRevealRank) { 299 | for (int i = 0, mask = 1; i < hand_size; ++i, mask <<= 1) { 300 | if ((last_move->reveal_bitmask & mask) > 0) { 301 | (*encoding)[offset + i] = 1; 302 | } 303 | } 304 | } 305 | offset += hand_size; 306 | 307 | // position (if play or discard action) 308 | if (last_move_type == HanabiMove::Type::kPlay || 309 | last_move_type == HanabiMove::Type::kDiscard) { 310 | (*encoding)[offset + last_move->move.CardIndex()] = 1; 311 | } 312 | offset += hand_size; 313 | 314 | // card (if play or discard action) 315 | if (last_move_type == HanabiMove::Type::kPlay || 316 | last_move_type == HanabiMove::Type::kDiscard) { 317 | assert(last_move->color >= 0); 318 | assert(last_move->rank >= 0); 319 | (*encoding)[offset + 320 | CardIndex(last_move->color, last_move->rank, num_ranks)] = 1; 321 | } 322 | offset += BitsPerCard(game); 323 | 324 | // was successful and/or added information token (if play action) 325 | if (last_move_type == HanabiMove::Type::kPlay) { 326 | if (last_move->scored) { 327 | (*encoding)[offset] = 1; 328 | } 329 | if (last_move->information_token) { 330 | (*encoding)[offset + 1] = 1; 331 | } 332 | } 333 | offset += 2; 334 | } 335 | 336 | assert(offset - start_offset == LastActionSectionLength(game)); 337 | return offset - start_offset; 338 | } 339 | 340 | int CardKnowledgeSectionLength(const HanabiGame& game) { 341 | return game.NumPlayers() * game.HandSize() * 342 | (BitsPerCard(game) + game.NumColors() + game.NumRanks()); 343 | } 344 | 345 | // Encode the common card knowledge. 346 | // For each card/position in each player's hand, including the observing player, 347 | // encode the possible cards that could be in that position and whether the 348 | // color and rank were directly revealed by a Reveal action. Possible card 349 | // values are in color-major order, using * bits per 350 | // card. For example, if you knew nothing about a card, and a player revealed 351 | // that is was green, the knowledge would be encoded as follows. 352 | // R Y G W B 353 | // 0000000000111110000000000 Only green cards are possible. 354 | // 0 0 1 0 0 Card was revealed to be green. 355 | // 00000 Card rank was not revealed. 356 | // 357 | // Similarly, if the player revealed that one of your other cards was green, you 358 | // would know that this card could not be green, resulting in: 359 | // R Y G W B 360 | // 1111111111000001111111111 Any card that is not green is possible. 361 | // 0 0 0 0 0 Card color was not revealed. 362 | // 00000 Card rank was not revealed. 363 | // Uses * * 364 | // ( * + + ) bits. 365 | // Returns the number of entries written to the encoding. 366 | int EncodeCardKnowledge(const HanabiGame& game, const HanabiObservation& obs, 367 | int start_offset, std::vector* encoding) { 368 | int bits_per_card = BitsPerCard(game); 369 | int num_colors = game.NumColors(); 370 | int num_ranks = game.NumRanks(); 371 | int num_players = game.NumPlayers(); 372 | int hand_size = game.HandSize(); 373 | 374 | int offset = start_offset; 375 | const std::vector& hands = obs.Hands(); 376 | assert(hands.size() == num_players); 377 | for (int player = 0; player < num_players; ++player) { 378 | const std::vector& knowledge = 379 | hands[player].Knowledge(); 380 | int num_cards = 0; 381 | 382 | for (const HanabiHand::CardKnowledge& card_knowledge : knowledge) { 383 | // Add bits for plausible card. 384 | for (int color = 0; color < num_colors; ++color) { 385 | if (card_knowledge.ColorPlausible(color)) { 386 | for (int rank = 0; rank < num_ranks; ++rank) { 387 | if (card_knowledge.RankPlausible(rank)) { 388 | (*encoding)[offset + CardIndex(color, rank, num_ranks)] = 1; 389 | } 390 | } 391 | } 392 | } 393 | offset += bits_per_card; 394 | 395 | // Add bits for explicitly revealed colors and ranks. 396 | if (card_knowledge.ColorHinted()) { 397 | (*encoding)[offset + card_knowledge.Color()] = 1; 398 | } 399 | offset += num_colors; 400 | if (card_knowledge.RankHinted()) { 401 | (*encoding)[offset + card_knowledge.Rank()] = 1; 402 | } 403 | offset += num_ranks; 404 | 405 | ++num_cards; 406 | } 407 | 408 | // A player's hand can have fewer cards than the initial hand size. 409 | // Leave the bits for the absent cards empty (adjust the offset to skip 410 | // bits for the missing cards). 411 | if (num_cards < hand_size) { 412 | offset += 413 | (hand_size - num_cards) * (bits_per_card + num_colors + num_ranks); 414 | } 415 | } 416 | 417 | assert(offset - start_offset == CardKnowledgeSectionLength(game)); 418 | return offset - start_offset; 419 | } 420 | 421 | } // namespace 422 | 423 | std::vector CanonicalObservationEncoder::Shape() const { 424 | return {HandsSectionLength(*parent_game_) + 425 | BoardSectionLength(*parent_game_) + 426 | DiscardSectionLength(*parent_game_) + 427 | LastActionSectionLength(*parent_game_) + 428 | (parent_game_->ObservationType() == HanabiGame::kMinimal 429 | ? 0 430 | : CardKnowledgeSectionLength(*parent_game_))}; 431 | } 432 | 433 | std::vector CanonicalObservationEncoder::Encode( 434 | const HanabiObservation& obs) const { 435 | // Make an empty bit string of the proper size. 436 | std::vector encoding(FlatLength(Shape()), 0); 437 | 438 | // This offset is an index to the start of each section of the bit vector. 439 | // It is incremented at the end of each section. 440 | int offset = 0; 441 | offset += EncodeHands(*parent_game_, obs, offset, &encoding); 442 | offset += EncodeBoard(*parent_game_, obs, offset, &encoding); 443 | offset += EncodeDiscards(*parent_game_, obs, offset, &encoding); 444 | offset += EncodeLastAction(*parent_game_, obs, offset, &encoding); 445 | if (parent_game_->ObservationType() != HanabiGame::kMinimal) { 446 | offset += EncodeCardKnowledge(*parent_game_, obs, offset, &encoding); 447 | } 448 | 449 | assert(offset == encoding.size()); 450 | return encoding; 451 | } 452 | 453 | } // namespace hanabi_learning_env 454 | -------------------------------------------------------------------------------- /agents/rainbow/run_experiment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Dopamine Authors and Google LLC. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # 17 | # 18 | # This file is a fork of the original Dopamine code incorporating changes for 19 | # the multiplayer setting and the Hanabi Learning Environment. 20 | # 21 | """Run methods for training a DQN agent on Atari. 22 | 23 | Methods in this module are usually referenced by |train.py|. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import time 31 | 32 | from third_party.dopamine import checkpointer 33 | from third_party.dopamine import iteration_statistics 34 | import dqn_agent 35 | import gin.tf 36 | import rl_env 37 | import numpy as np 38 | import rainbow_agent 39 | import tensorflow as tf 40 | 41 | LENIENT_SCORE = False 42 | 43 | 44 | class ObservationStacker(object): 45 | """Class for stacking agent observations.""" 46 | 47 | def __init__(self, history_size, observation_size, num_players): 48 | """Initializer for observation stacker. 49 | 50 | Args: 51 | history_size: int, number of time steps to stack. 52 | observation_size: int, size of observation vector on one time step. 53 | num_players: int, number of players. 54 | """ 55 | self._history_size = history_size 56 | self._observation_size = observation_size 57 | self._num_players = num_players 58 | self._obs_stacks = list() 59 | for _ in range(0, self._num_players): 60 | self._obs_stacks.append(np.zeros(self._observation_size * 61 | self._history_size)) 62 | 63 | def add_observation(self, observation, current_player): 64 | """Adds observation for the current player. 65 | 66 | Args: 67 | observation: observation vector for current player. 68 | current_player: int, current player id. 69 | """ 70 | self._obs_stacks[current_player] = np.roll(self._obs_stacks[current_player], 71 | -self._observation_size) 72 | self._obs_stacks[current_player][(self._history_size - 1) * 73 | self._observation_size:] = observation 74 | 75 | def get_observation_stack(self, current_player): 76 | """Returns the stacked observation for current player. 77 | 78 | Args: 79 | current_player: int, current player id. 80 | """ 81 | 82 | return self._obs_stacks[current_player] 83 | 84 | def reset_stack(self): 85 | """Resets the observation stacks to all zero.""" 86 | 87 | for i in range(0, self._num_players): 88 | self._obs_stacks[i].fill(0.0) 89 | 90 | @property 91 | def history_size(self): 92 | """Returns number of steps to stack.""" 93 | return self._history_size 94 | 95 | def observation_size(self): 96 | """Returns the size of the observation vector after history stacking.""" 97 | return self._observation_size * self._history_size 98 | 99 | 100 | def load_gin_configs(gin_files, gin_bindings): 101 | """Loads gin configuration files. 102 | 103 | Args: 104 | gin_files: A list of paths to the gin configuration files for this 105 | experiment. 106 | gin_bindings: List of gin parameter bindings to override the values in the 107 | config files. 108 | """ 109 | gin.parse_config_files_and_bindings(gin_files, 110 | bindings=gin_bindings, 111 | skip_unknown=False) 112 | 113 | 114 | @gin.configurable 115 | def create_environment(game_type='Hanabi-Full', num_players=2): 116 | """Creates the Hanabi environment. 117 | 118 | Args: 119 | game_type: Type of game to play. Currently the following are supported: 120 | Hanabi-Full: Regular game. 121 | Hanabi-Small: The small version of Hanabi, with 2 cards and 2 colours. 122 | num_players: Int, number of players to play this game. 123 | 124 | Returns: 125 | A Hanabi environment. 126 | """ 127 | return rl_env.make( 128 | environment_name=game_type, num_players=num_players, pyhanabi_path=None) 129 | 130 | 131 | @gin.configurable 132 | def create_obs_stacker(environment, history_size=4): 133 | """Creates an observation stacker. 134 | 135 | Args: 136 | environment: environment object. 137 | history_size: int, number of steps to stack. 138 | 139 | Returns: 140 | An observation stacker object. 141 | """ 142 | 143 | return ObservationStacker(history_size, 144 | environment.vectorized_observation_shape()[0], 145 | environment.players) 146 | 147 | 148 | @gin.configurable 149 | def create_agent(environment, obs_stacker, agent_type='DQN'): 150 | """Creates the Hanabi agent. 151 | 152 | Args: 153 | environment: The environment. 154 | obs_stacker: Observation stacker object. 155 | agent_type: str, type of agent to construct. 156 | 157 | Returns: 158 | An agent for playing Hanabi. 159 | 160 | Raises: 161 | ValueError: if an unknown agent type is requested. 162 | """ 163 | if agent_type == 'DQN': 164 | return dqn_agent.DQNAgent(observation_size=obs_stacker.observation_size(), 165 | num_actions=environment.num_moves(), 166 | num_players=environment.players) 167 | elif agent_type == 'Rainbow': 168 | return rainbow_agent.RainbowAgent( 169 | observation_size=obs_stacker.observation_size(), 170 | num_actions=environment.num_moves(), 171 | num_players=environment.players) 172 | else: 173 | raise ValueError('Expected valid agent_type, got {}'.format(agent_type)) 174 | 175 | 176 | def initialize_checkpointing(agent, experiment_logger, checkpoint_dir, 177 | checkpoint_file_prefix='ckpt'): 178 | """Reloads the latest checkpoint if it exists. 179 | 180 | The following steps will be taken: 181 | - This method will first create a Checkpointer object, which will be used in 182 | the method and then returned to the caller for later use. 183 | - It will then call checkpointer.get_latest_checkpoint_number to determine 184 | whether there is a valid checkpoint in checkpoint_dir, and what is the 185 | largest file number. 186 | - If a valid checkpoint file is found, it will load the bundled data from 187 | this file and will pass it to the agent for it to reload its data. 188 | - If the agent is able to successfully unbundle, this method will verify that 189 | the unbundled data contains the keys, 'logs' and 'current_iteration'. It 190 | will then load the Logger's data from the bundle, and will return the 191 | iteration number keyed by 'current_iteration' as one of the return values 192 | (along with the Checkpointer object). 193 | 194 | Args: 195 | agent: The agent that will unbundle the checkpoint from checkpoint_dir. 196 | experiment_logger: The Logger object that will be loaded from the 197 | checkpoint. 198 | checkpoint_dir: str, the directory containing the checkpoints. 199 | checkpoint_file_prefix: str, the checkpoint file prefix. 200 | 201 | Returns: 202 | start_iteration: int, The iteration number to start the experiment from. 203 | experiment_checkpointer: The experiment checkpointer. 204 | """ 205 | experiment_checkpointer = checkpointer.Checkpointer( 206 | checkpoint_dir, checkpoint_file_prefix) 207 | 208 | start_iteration = 0 209 | 210 | # Check if checkpoint exists. Note that the existence of checkpoint 0 means 211 | # that we have finished iteration 0 (so we will start from iteration 1). 212 | latest_checkpoint_version = checkpointer.get_latest_checkpoint_number( 213 | checkpoint_dir) 214 | if latest_checkpoint_version >= 0: 215 | dqn_dictionary = experiment_checkpointer.load_checkpoint( 216 | latest_checkpoint_version) 217 | if agent.unbundle( 218 | checkpoint_dir, latest_checkpoint_version, dqn_dictionary): 219 | assert 'logs' in dqn_dictionary 220 | assert 'current_iteration' in dqn_dictionary 221 | experiment_logger.data = dqn_dictionary['logs'] 222 | start_iteration = dqn_dictionary['current_iteration'] + 1 223 | tf.logging.info('Reloaded checkpoint and will start from iteration %d', 224 | start_iteration) 225 | 226 | return start_iteration, experiment_checkpointer 227 | 228 | 229 | def format_legal_moves(legal_moves, action_dim): 230 | """Returns formatted legal moves. 231 | 232 | This function takes a list of actions and converts it into a fixed size vector 233 | of size action_dim. If an action is legal, its position is set to 0 and -Inf 234 | otherwise. 235 | Ex: legal_moves = [0, 1, 3], action_dim = 5 236 | returns [0, 0, -Inf, 0, -Inf] 237 | 238 | Args: 239 | legal_moves: list of legal actions. 240 | action_dim: int, number of actions. 241 | 242 | Returns: 243 | a vector of size action_dim. 244 | """ 245 | new_legal_moves = np.full(action_dim, -float('inf')) 246 | if legal_moves: 247 | new_legal_moves[legal_moves] = 0 248 | return new_legal_moves 249 | 250 | 251 | def parse_observations(observations, num_actions, obs_stacker): 252 | """Deconstructs the rich observation data into relevant components. 253 | 254 | Args: 255 | observations: dict, containing full observations. 256 | num_actions: int, The number of available actions. 257 | obs_stacker: Observation stacker object. 258 | 259 | Returns: 260 | current_player: int, Whose turn it is. 261 | legal_moves: `np.array` of floats, of length num_actions, whose elements 262 | are -inf for indices corresponding to illegal moves and 0, for those 263 | corresponding to legal moves. 264 | observation_vector: Vectorized observation for the current player. 265 | """ 266 | current_player = observations['current_player'] 267 | current_player_observation = ( 268 | observations['player_observations'][current_player]) 269 | 270 | legal_moves = current_player_observation['legal_moves_as_int'] 271 | legal_moves = format_legal_moves(legal_moves, num_actions) 272 | 273 | observation_vector = current_player_observation['vectorized'] 274 | obs_stacker.add_observation(observation_vector, current_player) 275 | observation_vector = obs_stacker.get_observation_stack(current_player) 276 | 277 | return current_player, legal_moves, observation_vector 278 | 279 | 280 | def run_one_episode(agent, environment, obs_stacker): 281 | """Runs the agent on a single game of Hanabi in self-play mode. 282 | 283 | Args: 284 | agent: Agent playing Hanabi. 285 | environment: The Hanabi environment. 286 | obs_stacker: Observation stacker object. 287 | 288 | Returns: 289 | step_number: int, number of actions in this episode. 290 | total_reward: float, undiscounted return for this episode. 291 | """ 292 | obs_stacker.reset_stack() 293 | observations = environment.reset() 294 | current_player, legal_moves, observation_vector = ( 295 | parse_observations(observations, environment.num_moves(), obs_stacker)) 296 | action = agent.begin_episode(current_player, legal_moves, observation_vector) 297 | 298 | is_done = False 299 | total_reward = 0 300 | step_number = 0 301 | 302 | has_played = {current_player} 303 | 304 | # Keep track of per-player reward. 305 | reward_since_last_action = np.zeros(environment.players) 306 | 307 | while not is_done: 308 | observations, reward, is_done, _ = environment.step(action.item()) 309 | 310 | modified_reward = max(reward, 0) if LENIENT_SCORE else reward 311 | total_reward += modified_reward 312 | 313 | reward_since_last_action += modified_reward 314 | 315 | step_number += 1 316 | if is_done: 317 | break 318 | current_player, legal_moves, observation_vector = ( 319 | parse_observations(observations, environment.num_moves(), obs_stacker)) 320 | if current_player in has_played: 321 | action = agent.step(reward_since_last_action[current_player], 322 | current_player, legal_moves, observation_vector) 323 | else: 324 | # Each player begins the episode on their first turn (which may not be 325 | # the first move of the game). 326 | action = agent.begin_episode(current_player, legal_moves, 327 | observation_vector) 328 | has_played.add(current_player) 329 | 330 | # Reset this player's reward accumulator. 331 | reward_since_last_action[current_player] = 0 332 | 333 | agent.end_episode(reward_since_last_action) 334 | 335 | tf.logging.info('EPISODE: %d %g', step_number, total_reward) 336 | return step_number, total_reward 337 | 338 | 339 | def run_one_phase(agent, environment, obs_stacker, min_steps, statistics, 340 | run_mode_str): 341 | """Runs the agent/environment loop until a desired number of steps. 342 | 343 | Args: 344 | agent: Agent playing hanabi. 345 | environment: environment object. 346 | obs_stacker: Observation stacker object. 347 | min_steps: int, minimum number of steps to generate in this phase. 348 | statistics: `IterationStatistics` object which records the experimental 349 | results. 350 | run_mode_str: str, describes the run mode for this agent. 351 | 352 | Returns: 353 | The number of steps taken in this phase, the sum of returns, and the 354 | number of episodes performed. 355 | """ 356 | step_count = 0 357 | num_episodes = 0 358 | sum_returns = 0. 359 | 360 | while step_count < min_steps: 361 | episode_length, episode_return = run_one_episode(agent, environment, 362 | obs_stacker) 363 | statistics.append({ 364 | '{}_episode_lengths'.format(run_mode_str): episode_length, 365 | '{}_episode_returns'.format(run_mode_str): episode_return 366 | }) 367 | 368 | step_count += episode_length 369 | sum_returns += episode_return 370 | num_episodes += 1 371 | 372 | return step_count, sum_returns, num_episodes 373 | 374 | 375 | @gin.configurable 376 | def run_one_iteration(agent, environment, obs_stacker, 377 | iteration, training_steps, 378 | evaluate_every_n=100, 379 | num_evaluation_games=100): 380 | """Runs one iteration of agent/environment interaction. 381 | 382 | An iteration involves running several episodes until a certain number of 383 | steps are obtained. 384 | 385 | Args: 386 | agent: Agent playing hanabi. 387 | environment: The Hanabi environment. 388 | obs_stacker: Observation stacker object. 389 | iteration: int, current iteration number, used as a global_step. 390 | training_steps: int, the number of training steps to perform. 391 | evaluate_every_n: int, frequency of evaluation. 392 | num_evaluation_games: int, number of games per evaluation. 393 | 394 | Returns: 395 | A dict containing summary statistics for this iteration. 396 | """ 397 | start_time = time.time() 398 | 399 | statistics = iteration_statistics.IterationStatistics() 400 | 401 | # First perform the training phase, during which the agent learns. 402 | agent.eval_mode = False 403 | number_steps, sum_returns, num_episodes = ( 404 | run_one_phase(agent, environment, obs_stacker, training_steps, statistics, 405 | 'train')) 406 | time_delta = time.time() - start_time 407 | tf.logging.info('Average training steps per second: %.2f', 408 | number_steps / time_delta) 409 | 410 | average_return = sum_returns / num_episodes 411 | tf.logging.info('Average per episode return: %.2f', average_return) 412 | statistics.append({'average_return': average_return}) 413 | 414 | # Also run an evaluation phase if desired. 415 | if evaluate_every_n is not None and iteration % evaluate_every_n == 0: 416 | episode_data = [] 417 | agent.eval_mode = True 418 | # Collect episode data for all games. 419 | for _ in range(num_evaluation_games): 420 | episode_data.append(run_one_episode(agent, environment, obs_stacker)) 421 | 422 | eval_episode_length, eval_episode_return = map(np.mean, zip(*episode_data)) 423 | 424 | statistics.append({ 425 | 'eval_episode_lengths': eval_episode_length, 426 | 'eval_episode_returns': eval_episode_return 427 | }) 428 | tf.logging.info('Average eval. episode length: %.2f Return: %.2f', 429 | eval_episode_length, eval_episode_return) 430 | else: 431 | statistics.append({ 432 | 'eval_episode_lengths': -1, 433 | 'eval_episode_returns': -1 434 | }) 435 | 436 | return statistics.data_lists 437 | 438 | 439 | def log_experiment(experiment_logger, iteration, statistics, 440 | logging_file_prefix='log', log_every_n=1): 441 | """Records the results of the current iteration. 442 | 443 | Args: 444 | experiment_logger: A `Logger` object. 445 | iteration: int, iteration number. 446 | statistics: Object containing statistics to log. 447 | logging_file_prefix: str, prefix to use for the log files. 448 | log_every_n: int, specifies logging frequency. 449 | """ 450 | if iteration % log_every_n == 0: 451 | experiment_logger['iter{:d}'.format(iteration)] = statistics 452 | experiment_logger.log_to_file(logging_file_prefix, iteration) 453 | 454 | 455 | def checkpoint_experiment(experiment_checkpointer, agent, experiment_logger, 456 | iteration, checkpoint_dir, checkpoint_every_n): 457 | """Checkpoint experiment data. 458 | 459 | Args: 460 | experiment_checkpointer: A `Checkpointer` object. 461 | agent: An RL agent. 462 | experiment_logger: a Logger object, to include its data in the checkpoint. 463 | iteration: int, iteration number for checkpointing. 464 | checkpoint_dir: str, the directory where to save checkpoints. 465 | checkpoint_every_n: int, the frequency for writing checkpoints. 466 | """ 467 | if iteration % checkpoint_every_n == 0: 468 | agent_dictionary = agent.bundle_and_checkpoint(checkpoint_dir, iteration) 469 | if agent_dictionary: 470 | agent_dictionary['current_iteration'] = iteration 471 | agent_dictionary['logs'] = experiment_logger.data 472 | experiment_checkpointer.save_checkpoint(iteration, agent_dictionary) 473 | 474 | 475 | @gin.configurable 476 | def run_experiment(agent, 477 | environment, 478 | start_iteration, 479 | obs_stacker, 480 | experiment_logger, 481 | experiment_checkpointer, 482 | checkpoint_dir, 483 | num_iterations=200, 484 | training_steps=5000, 485 | logging_file_prefix='log', 486 | log_every_n=1, 487 | checkpoint_every_n=1): 488 | """Runs a full experiment, spread over multiple iterations.""" 489 | tf.logging.info('Beginning training...') 490 | if num_iterations <= start_iteration: 491 | tf.logging.warning('num_iterations (%d) < start_iteration(%d)', 492 | num_iterations, start_iteration) 493 | return 494 | 495 | for iteration in range(start_iteration, num_iterations): 496 | start_time = time.time() 497 | statistics = run_one_iteration(agent, environment, obs_stacker, iteration, 498 | training_steps) 499 | tf.logging.info('Iteration %d took %d seconds', iteration, 500 | time.time() - start_time) 501 | start_time = time.time() 502 | log_experiment(experiment_logger, iteration, statistics, 503 | logging_file_prefix, log_every_n) 504 | tf.logging.info('Logging iteration %d took %d seconds', iteration, 505 | time.time() - start_time) 506 | start_time = time.time() 507 | checkpoint_experiment(experiment_checkpointer, agent, experiment_logger, 508 | iteration, checkpoint_dir, checkpoint_every_n) 509 | tf.logging.info('Checkpointing iteration %d took %d seconds', iteration, 510 | time.time() - start_time) 511 | --------------------------------------------------------------------------------