├── .gitignore
├── agents
├── rainbow
│ ├── third_party
│ │ ├── __init__.py
│ │ └── dopamine
│ │ │ ├── __init__.py
│ │ │ ├── iteration_statistics.py
│ │ │ ├── logger.py
│ │ │ ├── checkpointer.py
│ │ │ ├── sum_tree.py
│ │ │ └── LICENSE
│ ├── configs
│ │ └── hanabi_rainbow.gin
│ ├── README.md
│ ├── train.py
│ ├── prioritized_replay_memory.py
│ ├── rainbow_agent.py
│ └── run_experiment.py
├── __init__.py
├── random_agent.py
└── simple_agent.py
├── hanabi_lib
├── CMakeLists.txt
├── hanabi_card.cc
├── hanabi_card.h
├── observation_encoder.h
├── canonical_encoders.h
├── hanabi_move.cc
├── hanabi_history_item.cc
├── hanabi_move.h
├── util.h
├── hanabi_history_item.h
├── util.cc
├── hanabi_observation.h
├── hanabi_hand.cc
├── hanabi_game.h
├── hanabi_observation.cc
├── hanabi_hand.h
├── hanabi_state.h
├── hanabi_game.cc
├── hanabi_state.cc
└── canonical_encoders.cc
├── __init__.py
├── CMakeLists.txt
├── README.md
├── clean_all.sh
├── CONTRIBUTING.md
├── rl_env_example.py
├── game_example.py
├── game_example.cc
├── pyhanabi.h
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
--------------------------------------------------------------------------------
/agents/rainbow/third_party/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/hanabi_lib/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_library (hanabi hanabi_card.cc hanabi_game.cc hanabi_hand.cc hanabi_history_item.cc hanabi_move.cc hanabi_observation.cc hanabi_state.cc util.cc canonical_encoders.cc)
2 | target_include_directories(hanabi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
3 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/agents/rainbow/third_party/dopamine/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required (VERSION 2.8.11)
2 | project (hanabi_learning_environment)
3 |
4 | set(CMAKE_C_FLAGS "-O2 -std=c++11 -fPIC")
5 | set(CMAKE_CXX_FLAGS "-O2 -std=c++11 -fPIC")
6 |
7 | add_subdirectory (hanabi_lib)
8 |
9 | add_library (pyhanabi SHARED pyhanabi.cc)
10 | target_link_libraries (pyhanabi LINK_PUBLIC hanabi)
11 |
12 | add_executable (game_example game_example.cc)
13 | target_link_libraries (game_example LINK_PUBLIC hanabi)
14 |
15 | install(TARGETS
16 | pyhanabi
17 | LIBRARY DESTINATION lib
18 | ARCHIVE DESTINATION lib
19 | RUNTIME DESTINATION lib)
20 |
21 | install(FILES pyhanabi.h DESTINATION include)
22 | install(DIRECTORY hanabi_lib/ DESTINATION include FILES_MATCHING PATTERN "*.h")
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is not an officially supported Google product.
2 |
3 | hanabi\_learning\_environment is a research platform for Hanabi experiments. The file rl\_env.py provides an RL environment using an API similar to OpenAI Gym. A lower level game interface is provided in pyhanabi.py for non-RL methods like Monte Carlo tree search.
4 |
5 | ### Getting started
6 | ```
7 | sudo apt-get install g++ # if you don't already have a CXX compiler
8 | sudo apt-get install cmake # if you don't already have CMake
9 | sudo apt-get install python-pip # if you don't already have pip
10 | pip install cffi # if you don't already have cffi
11 | cmake .
12 | make
13 | python rl_env_example.py # Runs RL episodes
14 | python game_example.py # Plays a game using the lower level interface
15 | ```
16 |
--------------------------------------------------------------------------------
/clean_all.sh:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Clean-up script to bring the repository back to a pre-cmake state.
16 |
17 | #!/bin/sh
18 | if [ -f Makefile ]
19 | then
20 | make clean
21 | fi
22 |
23 | rm -rf *.pyc agents/*.pyc __pycache__ agents/__pycache__ CMakeCache.txt CMakeFiles Makefile cmake_install.cmake hanabi_lib/CMakeFiles hanabi_lib/Makefile hanabi_lib/cmake_install.cmake
24 |
--------------------------------------------------------------------------------
/agents/random_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Random Agent."""
15 |
16 | import random
17 | from rl_env import Agent
18 |
19 |
20 | class RandomAgent(Agent):
21 | """Agent that takes random legal actions."""
22 |
23 | def __init__(self, config, *args, **kwargs):
24 | """Initialize the agent."""
25 | self.config = config
26 |
27 | def act(self, observation):
28 | """Act based on an observation."""
29 | if observation['current_player_offset'] == 0:
30 | return random.choice(observation['legal_moves'])
31 | else:
32 | return None
33 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_card.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_card.h"
16 |
17 | #include "util.h"
18 |
19 | namespace hanabi_learning_env {
20 |
21 | bool HanabiCard::operator==(const HanabiCard& other_card) const {
22 | return other_card.Color() == Color() && other_card.Rank() == Rank();
23 | }
24 |
25 | std::string HanabiCard::ToString() const {
26 | if (!IsValid()) {
27 | return std::string("XX");
28 | }
29 | return std::string() + ColorIndexToChar(Color()) + RankIndexToChar(Rank());
30 | }
31 |
32 | } // namespace hanabi_learning_env
33 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_card.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_CARD_H__
16 | #define __HANABI_CARD_H__
17 |
18 | #include
19 |
20 | namespace hanabi_learning_env {
21 |
22 | class HanabiCard {
23 | public:
24 | HanabiCard(int color, int rank) : color_(color), rank_(rank) {}
25 | HanabiCard() = default; // Create an invalid card.
26 | bool operator==(const HanabiCard& other_card) const;
27 | bool IsValid() const { return color_ >= 0 && rank_ >= 0; }
28 | std::string ToString() const;
29 | int Color() const { return color_; }
30 | int Rank() const { return rank_; }
31 |
32 | private:
33 | int color_ = -1; // 0 indexed card color.
34 | int rank_ = -1; // 0 indexed card rank.
35 | };
36 |
37 | } // namespace hanabi_learning_env
38 |
39 | #endif
40 |
--------------------------------------------------------------------------------
/agents/rainbow/configs/hanabi_rainbow.gin:
--------------------------------------------------------------------------------
1 | import dqn_agent
2 | import rainbow_agent
3 | import run_experiment
4 |
5 | # This configures the DQN Agent.
6 | AGENT_CLASS = @DQNAgent
7 | DQNAgent.gamma = 0.99
8 | DQNAgent.update_horizon = 1
9 | DQNAgent.min_replay_history = 500 # agent steps
10 | DQNAgent.target_update_period = 500 # agent steps
11 | DQNAgent.epsilon_train = 0.0
12 | DQNAgent.epsilon_eval = 0.0
13 | DQNAgent.epsilon_decay_period = 1000 # agent steps
14 | DQNAgent.tf_device = '/gpu:0' # '/cpu:*' use for non-GPU version
15 |
16 | # This configures the Rainbow agent.
17 | AGENT_CLASS_2 = @RainbowAgent
18 | RainbowAgent.gamma = 0.99
19 | RainbowAgent.update_horizon = 1
20 | RainbowAgent.num_atoms = 51
21 | RainbowAgent.min_replay_history = 500 # agent steps
22 | RainbowAgent.target_update_period = 500 # agent steps
23 | RainbowAgent.epsilon_train = 0.0
24 | RainbowAgent.epsilon_eval = 0.0
25 | RainbowAgent.epsilon_decay_period = 1000 # agent steps
26 | RainbowAgent.tf_device = '/gpu:0' # '/cpu:*' use for non-GPU version
27 | WrappedReplayMemory.replay_capacity = 50000
28 |
29 | run_experiment.training_steps = 10000
30 | run_experiment.num_iterations = 10005
31 | run_experiment.checkpoint_every_n = 50
32 | run_one_iteration.evaluate_every_n = 10
33 |
34 | # Small Hanabi.
35 | create_environment.game_type = 'Hanabi-Full-CardKnowledge'
36 | create_environment.num_players = 2
37 |
38 | create_agent.agent_type = 'Rainbow'
39 | create_obs_stacker.history_size = 1
40 |
41 | rainbow_template.layer_size=512
42 | rainbow_template.num_layers=2
43 |
--------------------------------------------------------------------------------
/hanabi_lib/observation_encoder.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | // A helper object to translate Hanabi observations to agent inputs
16 | // (e.g. tensors).
17 |
18 | #ifndef __OBSERVATION_ENCODER_H__
19 | #define __OBSERVATION_ENCODER_H__
20 |
21 | #include
22 |
23 | #include "hanabi_observation.h"
24 |
25 | namespace hanabi_learning_env {
26 |
27 | class ObservationEncoder {
28 | public:
29 | enum Type { kCanonical = 0 };
30 | virtual ~ObservationEncoder() = default;
31 |
32 | // Returns the shape (dimension sizes of the tensor).
33 | virtual std::vector Shape() const = 0;
34 |
35 | // All of the canonical observation encodings are vectors of bits. We can
36 | // change this if we want something more general (e.g. floats or doubles).
37 | virtual std::vector Encode(const HanabiObservation& obs) const = 0;
38 |
39 | // Return the type of this encoder.
40 | virtual Type type() const = 0;
41 | };
42 |
43 | } // namespace hanabi_learning_env
44 |
45 | #endif
46 |
--------------------------------------------------------------------------------
/hanabi_lib/canonical_encoders.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | // The standard Open Hanabi observation encoders. These encoders translate
16 | // HanabiObservations to input tensors that an agent can train on.
17 |
18 | #ifndef __CANONICAL_ENCODERS_H__
19 | #define __CANONICAL_ENCODERS_H__
20 |
21 | #include
22 |
23 | #include "hanabi_game.h"
24 | #include "hanabi_observation.h"
25 | #include "observation_encoder.h"
26 |
27 | namespace hanabi_learning_env {
28 |
29 | // This is the canonical observation encoding.
30 | class CanonicalObservationEncoder : public ObservationEncoder {
31 | public:
32 | explicit CanonicalObservationEncoder(const HanabiGame* parent_game)
33 | : parent_game_(parent_game) {}
34 |
35 | std::vector Shape() const override;
36 | std::vector Encode(const HanabiObservation& obs) const override;
37 |
38 | ObservationEncoder::Type type() const override {
39 | return ObservationEncoder::Type::kCanonical;
40 | }
41 |
42 | private:
43 | const HanabiGame* parent_game_ = nullptr;
44 | };
45 |
46 | } // namespace hanabi_learning_env
47 |
48 | #endif
49 |
--------------------------------------------------------------------------------
/agents/rainbow/third_party/dopamine/iteration_statistics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A class for storing iteration-specific metrics.
16 | """
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 |
23 | class IterationStatistics(object):
24 | """A class for storing iteration-specific metrics.
25 |
26 | The internal format is as follows: we maintain a mapping from keys to lists.
27 | Each list contains all the values corresponding to the given key.
28 |
29 | For example, self.data_lists['train_episode_returns'] might contain the
30 | per-episode returns achieved during this iteration.
31 |
32 | Attributes:
33 | data_lists: dict mapping each metric_name (str) to a list of said metric
34 | across episodes.
35 | """
36 |
37 | def __init__(self):
38 | self.data_lists = {}
39 |
40 | def append(self, data_pairs):
41 | """Add the given values to their corresponding key-indexed lists.
42 |
43 | Args:
44 | data_pairs: A dictionary of key-value pairs to be recorded.
45 | """
46 | for key, value in data_pairs.items():
47 | if key not in self.data_lists:
48 | self.data_lists[key] = []
49 | self.data_lists[key].append(value)
50 |
--------------------------------------------------------------------------------
/agents/rainbow/README.md:
--------------------------------------------------------------------------------
1 | # Rainbow agent for the Hanabi Learning Environment
2 |
3 | ## Instructions
4 |
5 | The Rainbow agent is derived from the
6 | [Dopamine framework](https://github.com/google/dopamine) which is based on
7 | Tensorflow. We recommend you consult the
8 | [Tensorflow documentation](https://www.tensorflow.org/install)
9 | for additional details.
10 |
11 | To run the agent, some dependencies need to be pre-installed. If you don't have
12 | access to a GPU, then replace `tensorflow-gpu` with `tensorflow` in the line
13 | below
14 | (see [Tensorflow instructions](https://www.tensorflow.org/install/install_linux)
15 | for details).
16 |
17 | ```
18 | pip install absl-py gin-config tensorflow-gpu cffi
19 | ```
20 |
21 | If you would prefer to not use the GPU, you may install tensorflow instead
22 | of tensorflow-gpu and set `RainbowAgent.tf_device = '/cpu:*'` in
23 | `configs/hanabi_rainbow.gin`.
24 |
25 | The entry point to run a Rainbow agent on the Hanabi environment is `train.py`.
26 | Assuming you are running from the agent directory `agents/rainbow`,
27 |
28 | ```
29 | PYTHONPATH=${PYTHONPATH}:../..
30 | python -um train \
31 | --base_dir=/tmp/hanabi_rainbow \
32 | --gin_files='configs/hanabi_rainbow.gin'
33 | ```
34 |
35 | The `PYTHONPATH` fix exposes `rl_env.py`, the main entry point to the Hanabi
36 | Learning Environment. The `--base_dir` argument must be provided.
37 |
38 | To get finer-grained information about the training process, you can adjust the
39 | experiment parameters in `configs/hanabi_rainbow.gin` in particular by reducing
40 | `Runner.training_steps` and `Runner.evaluation_steps`, which together determine
41 | the total number of steps needed to complete an iteration. This is useful if you
42 | want to inspect log files or checkpoints, which are generated at the end of each
43 | iteration.
44 |
45 | More generally, most parameters are easily configured using the
46 | [gin configuration framework](https://github.com/google/gin-config).
47 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_move.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_move.h"
16 |
17 | #include "util.h"
18 |
19 | namespace hanabi_learning_env {
20 |
21 | bool HanabiMove::operator==(const HanabiMove& other_move) const {
22 | if (MoveType() != other_move.MoveType()) {
23 | return false;
24 | }
25 | switch (MoveType()) {
26 | case kPlay:
27 | case kDiscard:
28 | return CardIndex() == other_move.CardIndex();
29 | case kRevealColor:
30 | return TargetOffset() == other_move.TargetOffset() &&
31 | Color() == other_move.Color();
32 | case kRevealRank:
33 | return TargetOffset() == other_move.TargetOffset() &&
34 | Rank() == other_move.Rank();
35 | case kDeal:
36 | return Color() == other_move.Color() && Rank() == other_move.Rank();
37 | default:
38 | return true;
39 | }
40 | }
41 |
42 | std::string HanabiMove::ToString() const {
43 | switch (MoveType()) {
44 | case kPlay:
45 | return "(Play " + std::to_string(CardIndex()) + ")";
46 | case kDiscard:
47 | return "(Discard " + std::to_string(CardIndex()) + ")";
48 | case kRevealColor:
49 | return "(Reveal player +" + std::to_string(TargetOffset()) + " color " +
50 | ColorIndexToChar(Color()) + ")";
51 | case kRevealRank:
52 | return "(Reveal player +" + std::to_string(TargetOffset()) + " rank " +
53 | RankIndexToChar(Rank()) + ")";
54 | case kDeal:
55 | if (color_ >= 0) {
56 | return std::string("(Deal ") + ColorIndexToChar(Color()) +
57 | RankIndexToChar(Rank()) + ")";
58 | } else {
59 | return std::string("(Deal XX)");
60 | }
61 | default:
62 | return "(INVALID)";
63 | }
64 | }
65 |
66 | } // namespace hanabi_learning_env
67 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_history_item.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_history_item.h"
16 |
17 | #include
18 |
19 | #include "util.h"
20 |
21 | namespace hanabi_learning_env {
22 |
23 | std::string HanabiHistoryItem::ToString() const {
24 | std::string str = "<" + move.ToString();
25 | if (player >= 0) {
26 | str += " by player " + std::to_string(player);
27 | }
28 | if (scored) {
29 | str += " scored";
30 | }
31 | if (information_token) {
32 | str += " info_token";
33 | }
34 | if (color >= 0) {
35 | assert(rank >= 0);
36 | str += " ";
37 | str += ColorIndexToChar(color);
38 | str += RankIndexToChar(rank);
39 | }
40 | if (reveal_bitmask) {
41 | str += " reveal ";
42 | bool first = true;
43 | for (int i = 0; i < 8; ++i) { // 8 bits in reveal_bitmask
44 | if (reveal_bitmask & (1 << i)) {
45 | if (first) {
46 | first = false;
47 | } else {
48 | str += ",";
49 | }
50 | str += std::to_string(i);
51 | }
52 | }
53 | }
54 | str += ">";
55 | return str;
56 | }
57 |
58 | void ChangeToObserverRelative(int observer_pid, int player_count,
59 | HanabiHistoryItem* item) {
60 | if (item->move.MoveType() == HanabiMove::kDeal) {
61 | assert(item->player < 0 && item->deal_to_player >= 0);
62 | item->deal_to_player =
63 | (item->deal_to_player - observer_pid + player_count) % player_count;
64 | if (item->deal_to_player == 0) {
65 | // Hide cards dealt to observer.
66 | item->move = HanabiMove(HanabiMove::kDeal, -1, -1, -1, -1);
67 | }
68 | } else {
69 | assert(item->player >= 0);
70 | item->player = (item->player - observer_pid + player_count) % player_count;
71 | }
72 | }
73 |
74 | } // namespace hanabi_learning_env
75 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_move.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_MOVE_H__
16 | #define __HANABI_MOVE_H__
17 |
18 | #include
19 | #include
20 |
21 | namespace hanabi_learning_env {
22 |
23 | // 5 types of moves:
24 | // "Play" card_index of card in player hand
25 | // "Discard" card_index of card in player hand
26 | // "RevealColor" target_offset color hints to player all cards of color
27 | // "RevealRank" target_offset rank hints to player all cards of given rank
28 | // NOTE: RevealXYZ target_offset field is an offset from the acting player
29 | // "Deal" color rank deal card with color and rank
30 | // "Invalid" move is not valid
31 | class HanabiMove {
32 | // HanabiMove is small, and intended to be passed by value.
33 | public:
34 | enum Type { kInvalid, kPlay, kDiscard, kRevealColor, kRevealRank, kDeal };
35 |
36 | HanabiMove(Type move_type, int8_t card_index, int8_t target_offset,
37 | int8_t color, int8_t rank)
38 | : move_type_(move_type),
39 | card_index_(card_index),
40 | target_offset_(target_offset),
41 | color_(color),
42 | rank_(rank) {}
43 | // Tests whether two moves are functionally equivalent.
44 | bool operator==(const HanabiMove& other_move) const;
45 | std::string ToString() const;
46 |
47 | Type MoveType() const { return move_type_; }
48 | bool IsValid() const { return move_type_ != kInvalid; }
49 | int8_t CardIndex() const { return card_index_; }
50 | int8_t TargetOffset() const { return target_offset_; }
51 | int8_t Color() const { return color_; }
52 | int8_t Rank() const { return rank_; }
53 |
54 | private:
55 | Type move_type_ = kInvalid;
56 | int8_t card_index_ = -1;
57 | int8_t target_offset_ = -1;
58 | int8_t color_ = -1;
59 | int8_t rank_ = -1;
60 | };
61 |
62 | } // namespace hanabi_learning_env
63 |
64 | #endif
65 |
--------------------------------------------------------------------------------
/hanabi_lib/util.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __UTIL_H__
16 | #define __UTIL_H__
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | namespace hanabi_learning_env {
24 |
25 | constexpr int kMaxNumColors = 5;
26 | constexpr int kMaxNumRanks = 5;
27 |
28 | // Returns a character representation of an integer color/rank index.
29 | char ColorIndexToChar(int color);
30 | char RankIndexToChar(int rank);
31 |
32 | // Returns string associated with key in params, parsed as template type.
33 | // If key is not in params, returns the provided default value.
34 | template
35 | T ParameterValue(const std::unordered_map& params,
36 | const std::string& key, T default_value);
37 |
38 | template <>
39 | int ParameterValue(const std::unordered_map& params,
40 | const std::string& key, int default_value);
41 | template <>
42 | double ParameterValue(
43 | const std::unordered_map& params,
44 | const std::string& key, double default_value);
45 | template <>
46 | std::string ParameterValue(
47 | const std::unordered_map& params,
48 | const std::string& key, std::string default_value);
49 | template <>
50 | bool ParameterValue(const std::unordered_map& params,
51 | const std::string& key, bool default_value);
52 |
53 | #if defined(NDEBUG)
54 | #define REQUIRE(expr) \
55 | (expr ? (void)0 \
56 | : (fprintf(stderr, "Input requirements failed at %s:%d in %s: %s\n", \
57 | __FILE__, __LINE__, __func__, #expr), \
58 | std::abort()))
59 | #else
60 | #define REQUIRE(expr) assert(expr)
61 | #endif
62 |
63 | } // namespace hanabi_learning_env
64 |
65 | #endif
66 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_history_item.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_HISTORY_ITEM_H__
16 | #define __HANABI_HISTORY_ITEM_H__
17 |
18 | #include
19 | #include
20 |
21 | #include "hanabi_move.h"
22 |
23 | namespace hanabi_learning_env {
24 |
25 | // A move that has been made within a Hanabi game, along with the side-effects
26 | // of making that move.
27 | struct HanabiHistoryItem {
28 | explicit HanabiHistoryItem(HanabiMove move_made) : move(move_made) {}
29 | HanabiHistoryItem(const HanabiHistoryItem& past_move) = default;
30 | std::string ToString() const;
31 |
32 | // Move that was made.
33 | HanabiMove move;
34 | // Index of player who made the move.
35 | int8_t player = -1;
36 | // Indicator of whether a Play move was successful.
37 | bool scored = false;
38 | // Indicator of whether a Play/Discard move added an information token
39 | bool information_token = false;
40 | // Color of card that was played or discarded. Valid if color_ >= 0.
41 | int8_t color = -1;
42 | // Rank of card that was played or discarded. Valid if rank_ >= 0.
43 | int8_t rank = -1;
44 | // Bitmask indicating whether a card was targeted by a RevealX move.
45 | // Bit_i=1 if color/rank of card_i matches X in a RevealX move.
46 | // For example, if cards 0 and 3 had rank 2, a RevealRank 2 move
47 | // would result in a reveal_bitmask of 9 (2^0+2^3).
48 | uint8_t reveal_bitmask = 0;
49 | // Bitmask indicating whether a card was newly revealed by a RevealX move.
50 | // Bit_i=1 if color/rank of card_i was not known before RevealX move.
51 | // For example, if cards 1, 2, and 4 had color 'R', and the color of
52 | // card 1 had previously been revealed to be 'R', a RevealRank 'R' move
53 | // would result in a newly_revealed_bitmask of 20 (2^2+2^4).
54 | uint8_t newly_revealed_bitmask = 0;
55 | // Player that received a card from a Deal move.
56 | int8_t deal_to_player = -1;
57 | };
58 |
59 | } // namespace hanabi_learning_env
60 |
61 | #endif
62 |
--------------------------------------------------------------------------------
/hanabi_lib/util.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "util.h"
16 |
17 | #include
18 |
19 | namespace hanabi_learning_env {
20 |
21 | char ColorIndexToChar(int color) {
22 | if (color >= 0 && color <= kMaxNumColors) {
23 | return "RYGWB"[color];
24 | } else {
25 | return 'X';
26 | }
27 | }
28 |
29 | char RankIndexToChar(int rank) {
30 | if (rank >= 0 && rank <= kMaxNumRanks) {
31 | return "12345"[rank];
32 | } else {
33 | return 'X';
34 | }
35 | }
36 |
37 | template <>
38 | int ParameterValue(
39 | const std::unordered_map& params,
40 | const std::string& key, int default_value) {
41 | auto iter = params.find(key);
42 | if (iter == params.end()) {
43 | return default_value;
44 | }
45 |
46 | return std::stoi(iter->second);
47 | }
48 |
49 | template <>
50 | std::string ParameterValue(
51 | const std::unordered_map& params,
52 | const std::string& key, std::string default_value) {
53 | auto iter = params.find(key);
54 | if (iter == params.end()) {
55 | return default_value;
56 | }
57 |
58 | return iter->second;
59 | }
60 |
61 | template <>
62 | double ParameterValue(
63 | const std::unordered_map& params,
64 | const std::string& key, double default_value) {
65 | auto iter = params.find(key);
66 | if (iter == params.end()) {
67 | return default_value;
68 | }
69 |
70 | return std::stod(iter->second);
71 | }
72 |
73 | template <>
74 | bool ParameterValue(
75 | const std::unordered_map& params,
76 | const std::string& key, bool default_value) {
77 | auto iter = params.find(key);
78 | if (iter == params.end()) {
79 | return default_value;
80 | }
81 |
82 | return (iter->second == "1" || iter->second == "true" ||
83 | iter->second == "True"
84 | ? true
85 | : false);
86 | }
87 |
88 | } // namespace hanabi_learning_env
89 |
--------------------------------------------------------------------------------
/agents/simple_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Random Agent."""
15 |
16 | from rl_env import Agent
17 |
18 |
19 | class SimpleAgent(Agent):
20 | """Agent that applies a simple heuristic."""
21 |
22 | def __init__(self, config, *args, **kwargs):
23 | """Initialize the agent."""
24 | self.config = config
25 | # Extract max info tokens or set default to 8.
26 | self.max_information_tokens = config.get('information_tokens', 8)
27 |
28 | @staticmethod
29 | def playable_card(card, fireworks):
30 | """A card is playable if it can be placed on the fireworks pile."""
31 | return card['rank'] == fireworks[card['color']]
32 |
33 | def act(self, observation):
34 | """Act based on an observation."""
35 | if observation['current_player_offset'] != 0:
36 | return None
37 |
38 | # Check if there are any pending hints and play the card corresponding to
39 | # the hint.
40 | for card_index, hint in enumerate(observation['card_knowledge'][0]):
41 | if hint['color'] is not None or hint['rank'] is not None:
42 | return {'action_type': 'PLAY', 'card_index': card_index}
43 |
44 | # Check if it's possible to hint a card to your colleagues.
45 | fireworks = observation['fireworks']
46 | if observation['information_tokens'] > 0:
47 | # Check if there are any playable cards in the hands of the opponents.
48 | for player_offset in range(1, observation['num_players']):
49 | player_hand = observation['observed_hands'][player_offset]
50 | player_hints = observation['card_knowledge'][player_offset]
51 | # Check if the card in the hand of the opponent is playable.
52 | for card, hint in zip(player_hand, player_hints):
53 | if SimpleAgent.playable_card(card,
54 | fireworks) and hint['color'] is None:
55 | return {
56 | 'action_type': 'REVEAL_COLOR',
57 | 'color': card['color'],
58 | 'target_offset': player_offset
59 | }
60 |
61 | # If no card is hintable then discard or play.
62 | if observation['information_tokens'] < self.max_information_tokens:
63 | return {'action_type': 'DISCARD', 'card_index': 0}
64 | else:
65 | return {'action_type': 'PLAY', 'card_index': 0}
66 |
--------------------------------------------------------------------------------
/rl_env_example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A simple episode runner using the RL environment."""
15 |
16 | from __future__ import print_function
17 |
18 | import sys
19 | import getopt
20 | import rl_env
21 | from agents.random_agent import RandomAgent
22 | from agents.simple_agent import SimpleAgent
23 |
24 | AGENT_CLASSES = {'SimpleAgent': SimpleAgent, 'RandomAgent': RandomAgent}
25 |
26 |
27 | class Runner(object):
28 | """Runner class."""
29 |
30 | def __init__(self, flags):
31 | """Initialize runner."""
32 | self.flags = flags
33 | self.agent_config = {'players': flags['players']}
34 | self.environment = rl_env.make('Hanabi-Full', num_players=flags['players'])
35 | self.agent_class = AGENT_CLASSES[flags['agent_class']]
36 |
37 | def run(self):
38 | """Run episodes."""
39 | rewards = []
40 | for episode in range(flags['num_episodes']):
41 | observations = self.environment.reset()
42 | agents = [self.agent_class(self.agent_config)
43 | for _ in range(self.flags['players'])]
44 | done = False
45 | episode_reward = 0
46 | while not done:
47 | for agent_id, agent in enumerate(agents):
48 | observation = observations['player_observations'][agent_id]
49 | action = agent.act(observation)
50 | if observation['current_player'] == agent_id:
51 | assert action is not None
52 | current_player_action = action
53 | else:
54 | assert action is None
55 | # Make an environment step.
56 | print('Agent: {} action: {}'.format(observation['current_player'],
57 | current_player_action))
58 | observations, reward, done, unused_info = self.environment.step(
59 | current_player_action)
60 | episode_reward += reward
61 | rewards.append(episode_reward)
62 | print('Running episode: %d' % episode)
63 | print('Max Reward: %.3f' % max(rewards))
64 | return rewards
65 |
66 | if __name__ == "__main__":
67 | flags = {'players': 2, 'num_episodes': 1, 'agent_class': 'SimpleAgent'}
68 | options, arguments = getopt.getopt(sys.argv[1:], '',
69 | ['players=',
70 | 'num_episodes=',
71 | 'agent_class='])
72 | if arguments:
73 | sys.exit('usage: rl_env_example.py [options]\n'
74 | '--players number of players in the game.\n'
75 | '--num_episodes number of game episodes to run.\n'
76 | '--agent_class {}'.format(' or '.join(AGENT_CLASSES.keys())))
77 | for flag, value in options:
78 | flag = flag[2:] # Strip leading --.
79 | flags[flag] = type(flags[flag])(value)
80 | runner = Runner(flags)
81 | runner.run()
82 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_observation.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_OBSERVATION_H__
16 | #define __HANABI_OBSERVATION_H__
17 |
18 | #include
19 | #include
20 |
21 | #include "hanabi_card.h"
22 | #include "hanabi_game.h"
23 | #include "hanabi_hand.h"
24 | #include "hanabi_history_item.h"
25 | #include "hanabi_move.h"
26 | #include "hanabi_state.h"
27 |
28 | namespace hanabi_learning_env {
29 |
30 | // Agent observation of a HanabiState
31 | class HanabiObservation {
32 | public:
33 | HanabiObservation(const HanabiState& state, int observing_player);
34 |
35 | std::string ToString() const;
36 |
37 | // offset of current player from observing player.
38 | int CurPlayerOffset() const { return cur_player_offset_; }
39 | // observed hands are in relative order, with index 1 being the
40 | // first player clock-wise from observing_player. hands[0][] has
41 | // invalid cards as players don't see their own cards.
42 | const std::vector& Hands() const { return hands_; }
43 | // The element at the back is the most recent discard.
44 | const std::vector& DiscardPile() const { return discard_pile_; }
45 | const std::vector& Fireworks() const { return fireworks_; }
46 | int DeckSize() const { return deck_size_; } // number of remaining cards
47 | const HanabiGame* ParentGame() const { return parent_game_; }
48 | // Moves made since observing_player's last action, most recent to oldest
49 | // (that is, last_moves[0] is the most recent move.)
50 | // Move targets are relative to observing_player not acting_player.
51 | // Note that the deal moves are included in this vector.
52 | const std::vector& LastMoves() const {
53 | return last_moves_;
54 | }
55 | int InformationTokens() const { return information_tokens_; }
56 | int LifeTokens() const { return life_tokens_; }
57 | const std::vector& LegalMoves() const { return legal_moves_; }
58 |
59 | // returns true if card with color and rank can be played on fireworks pile
60 | bool CardPlayableOnFireworks(int color, int rank) const;
61 | bool CardPlayableOnFireworks(HanabiCard card) const {
62 | return CardPlayableOnFireworks(card.Color(), card.Rank());
63 | }
64 |
65 | private:
66 | int cur_player_offset_; // offset of current_player from observing_player
67 | std::vector hands_; // observing player is element 0
68 | std::vector discard_pile_; // back is most recent discard
69 | std::vector fireworks_;
70 | int deck_size_;
71 | std::vector last_moves_;
72 | int information_tokens_;
73 | int life_tokens_;
74 | std::vector legal_moves_; // list of legal moves
75 | const HanabiGame* parent_game_ = nullptr;
76 | };
77 |
78 | } // namespace hanabi_learning_env
79 |
80 | #endif
81 |
--------------------------------------------------------------------------------
/agents/rainbow/third_party/dopamine/logger.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A lightweight logging mechanism for dopamine agents."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import pickle
23 | import tensorflow as tf
24 |
25 |
26 | CHECKPOINT_DURATION = 4
27 |
28 |
29 | class Logger(object):
30 | """Class for maintaining a dictionary of data to log."""
31 |
32 | def __init__(self, logging_dir):
33 | """Initializes Logger.
34 |
35 | Args:
36 | logging_dir: str, Directory to which logs are written.
37 | """
38 | # Dict used by logger to store data.
39 | self.data = {}
40 | self._logging_enabled = True
41 |
42 | if not logging_dir:
43 | tf.logging.info('Logging directory not specified, will not log.')
44 | self._logging_enabled = False
45 | return
46 | # Try to create logging directory.
47 | try:
48 | tf.gfile.MakeDirs(logging_dir)
49 | except tf.errors.PermissionDeniedError:
50 | # If it already exists, ignore exception.
51 | pass
52 | if not tf.gfile.Exists(logging_dir):
53 | tf.logging.warning(
54 | 'Could not create directory %s, logging will be disabled.',
55 | logging_dir)
56 | self._logging_enabled = False
57 | return
58 | self._logging_dir = logging_dir
59 |
60 | def __setitem__(self, key, value):
61 | """This method will set an entry at key with value in the dictionary.
62 |
63 | It will effectively overwrite any previous data at the same key.
64 |
65 | Args:
66 | key: str, indicating key where to write the entry.
67 | value: A python object to store.
68 | """
69 | if self._logging_enabled:
70 | self.data[key] = value
71 |
72 | def _generate_filename(self, filename_prefix, iteration_number):
73 | filename = '{}_{}'.format(filename_prefix, iteration_number)
74 | return os.path.join(self._logging_dir, filename)
75 |
76 | def log_to_file(self, filename_prefix, iteration_number):
77 | """Save the pickled dictionary to a file.
78 |
79 | Args:
80 | filename_prefix: str, name of the file to use (without iteration
81 | number).
82 | iteration_number: int, the iteration number, appended to the end of
83 | filename_prefix.
84 | """
85 | if not self._logging_enabled:
86 | tf.logging.warning('Logging is disabled.')
87 | return
88 | log_file = self._generate_filename(filename_prefix, iteration_number)
89 | with tf.gfile.GFile(log_file, 'w') as fout:
90 | pickle.dump(self.data, fout, protocol=pickle.HIGHEST_PROTOCOL)
91 | # After writing a checkpoint file, we garbage collect the log file
92 | # that is CHECKPOINT_DURATION versions old.
93 | stale_iteration_number = iteration_number - CHECKPOINT_DURATION
94 | if stale_iteration_number >= 0:
95 | stale_file = self._generate_filename(filename_prefix,
96 | stale_iteration_number)
97 | try:
98 | tf.gfile.Remove(stale_file)
99 | except tf.errors.NotFoundError:
100 | # Ignore if file not found.
101 | pass
102 |
103 | def is_logging_enabled(self):
104 | """Return if logging is enabled."""
105 | return self._logging_enabled
106 |
--------------------------------------------------------------------------------
/agents/rainbow/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors and Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | #
17 | #
18 | # This file is a fork of the original Dopamine code incorporating changes for
19 | # the multiplayer setting and the Hanabi Learning Environment.
20 | #
21 | """The entry point for running a Rainbow agent on Hanabi."""
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from absl import app
28 | from absl import flags
29 |
30 | from third_party.dopamine import logger
31 |
32 | import run_experiment
33 |
34 | FLAGS = flags.FLAGS
35 |
36 | flags.DEFINE_multi_string(
37 | 'gin_files', [],
38 | 'List of paths to gin configuration files (e.g.'
39 | '"configs/hanabi_rainbow.gin").')
40 | flags.DEFINE_multi_string(
41 | 'gin_bindings', [],
42 | 'Gin bindings to override the values set in the config files '
43 | '(e.g. "DQNAgent.epsilon_train=0.1").')
44 |
45 | flags.DEFINE_string('base_dir', None,
46 | 'Base directory to host all required sub-directories.')
47 |
48 | flags.DEFINE_string('checkpoint_dir', '',
49 | 'Directory where checkpoint files should be saved. If '
50 | 'empty, no checkpoints will be saved.')
51 | flags.DEFINE_string('checkpoint_file_prefix', 'ckpt',
52 | 'Prefix to use for the checkpoint files.')
53 | flags.DEFINE_string('logging_dir', '',
54 | 'Directory where experiment data will be saved. If empty '
55 | 'no checkpoints will be saved.')
56 | flags.DEFINE_string('logging_file_prefix', 'log',
57 | 'Prefix to use for the log files.')
58 |
59 |
60 | def launch_experiment():
61 | """Launches the experiment.
62 |
63 | Specifically:
64 | - Load the gin configs and bindings.
65 | - Initialize the Logger object.
66 | - Initialize the environment.
67 | - Initialize the observation stacker.
68 | - Initialize the agent.
69 | - Reload from the latest checkpoint, if available, and initialize the
70 | Checkpointer object.
71 | - Run the experiment.
72 | """
73 | if FLAGS.base_dir == None:
74 | raise ValueError('--base_dir is None: please provide a path for '
75 | 'logs and checkpoints.')
76 |
77 | run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
78 | experiment_logger = logger.Logger('{}/logs'.format(FLAGS.base_dir))
79 |
80 | environment = run_experiment.create_environment()
81 | obs_stacker = run_experiment.create_obs_stacker(environment)
82 | agent = run_experiment.create_agent(environment, obs_stacker)
83 |
84 | checkpoint_dir = '{}/checkpoints'.format(FLAGS.base_dir)
85 | start_iteration, experiment_checkpointer = (
86 | run_experiment.initialize_checkpointing(agent,
87 | experiment_logger,
88 | checkpoint_dir,
89 | FLAGS.checkpoint_file_prefix))
90 |
91 | run_experiment.run_experiment(agent, environment, start_iteration,
92 | obs_stacker,
93 | experiment_logger, experiment_checkpointer,
94 | checkpoint_dir,
95 | logging_file_prefix=FLAGS.logging_file_prefix)
96 |
97 |
98 | def main(unused_argv):
99 | """This main function acts as a wrapper around a gin-configurable experiment.
100 |
101 | Args:
102 | unused_argv: Arguments (unused).
103 | """
104 | launch_experiment()
105 |
106 | if __name__ == '__main__':
107 | app.run(main)
108 |
--------------------------------------------------------------------------------
/game_example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Example code demonstrating the Python Hanabi interface."""
16 |
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import pyhanabi
21 |
22 |
23 | def run_game(game_parameters):
24 | """Play a game, selecting random actions."""
25 |
26 | def print_state(state):
27 | """Print some basic information about the state."""
28 | print("")
29 | print("Current player: {}".format(state.cur_player()))
30 | print(state)
31 |
32 | # Example of more queries to provide more about this state. For
33 | # example, bots could use these methods to to get information
34 | # about the state in order to act accordingly.
35 | print("### Information about the state retrieved separately ###")
36 | print("### Information tokens: {}".format(state.information_tokens()))
37 | print("### Life tokens: {}".format(state.life_tokens()))
38 | print("### Fireworks: {}".format(state.fireworks()))
39 | print("### Deck size: {}".format(state.deck_size()))
40 | print("### Discard pile: {}".format(str(state.discard_pile())))
41 | print("### Player hands: {}".format(str(state.player_hands())))
42 | print("")
43 |
44 | def print_observation(observation):
45 | """Print some basic information about an agent observation."""
46 | print("--- Observation ---")
47 | print(observation)
48 |
49 | print("### Information about the observation retrieved separately ###")
50 | print("### Current player, relative to self: {}".format(
51 | observation.cur_player_offset()))
52 | print("### Observed hands: {}".format(observation.observed_hands()))
53 | print("### Card knowledge: {}".format(observation.card_knowledge()))
54 | print("### Discard pile: {}".format(observation.discard_pile()))
55 | print("### Fireworks: {}".format(observation.fireworks()))
56 | print("### Deck size: {}".format(observation.deck_size()))
57 | move_string = "### Last moves:"
58 | for move_tuple in observation.last_moves():
59 | move_string += " {}".format(move_tuple)
60 | print(move_string)
61 | print("### Information tokens: {}".format(observation.information_tokens()))
62 | print("### Life tokens: {}".format(observation.life_tokens()))
63 | print("### Legal moves: {}".format(observation.legal_moves()))
64 | print("--- EndObservation ---")
65 |
66 | def print_encoded_observations(encoder, state, num_players):
67 | print("--- EncodedObservations ---")
68 | print("Observation encoding shape: {}".format(encoder.shape()))
69 | print("Current actual player: {}".format(state.cur_player()))
70 | for i in range(num_players):
71 | print("Encoded observation for player {}: {}".format(
72 | i, encoder.encode(state.observation(i))))
73 | print("--- EndEncodedObservations ---")
74 |
75 | game = pyhanabi.HanabiGame(game_parameters)
76 | print(game.parameter_string(), end="")
77 | obs_encoder = pyhanabi.ObservationEncoder(
78 | game, enc_type=pyhanabi.ObservationEncoderType.CANONICAL)
79 |
80 | state = game.new_initial_state()
81 | while not state.is_terminal():
82 | if state.cur_player() == pyhanabi.CHANCE_PLAYER_ID:
83 | state.deal_random_card()
84 | continue
85 |
86 | print_state(state)
87 |
88 | observation = state.observation(state.cur_player())
89 | print_observation(observation)
90 | print_encoded_observations(obs_encoder, state, game.num_players())
91 |
92 | legal_moves = state.legal_moves()
93 | print("")
94 | print("Number of legal moves: {}".format(len(legal_moves)))
95 |
96 | move = np.random.choice(legal_moves)
97 | print("Chose random legal move: {}".format(move))
98 |
99 | state.apply_move(move)
100 |
101 | print("")
102 | print("Game done. Terminal state:")
103 | print("")
104 | print(state)
105 | print("")
106 | print("score: {}".format(state.score()))
107 |
108 |
109 | if __name__ == "__main__":
110 | # Check that the cdef and library were loaded from the standard paths.
111 | assert pyhanabi.cdef_loaded(), "cdef failed to load"
112 | assert pyhanabi.lib_loaded(), "lib failed to load"
113 | run_game({"players": 3, "random_start_player": True})
114 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_hand.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_hand.h"
16 |
17 | #include
18 | #include
19 |
20 | #include "util.h"
21 |
22 | namespace hanabi_learning_env {
23 |
24 | HanabiHand::ValueKnowledge::ValueKnowledge(int value_range)
25 | : value_(-1), value_plausible_(std::max(value_range, 0), true) {
26 | assert(value_range > 0);
27 | }
28 |
29 | void HanabiHand::ValueKnowledge::ApplyIsValueHint(int value) {
30 | assert(value >= 0 && value < value_plausible_.size());
31 | assert(value_ < 0 || value_ == value);
32 | assert(value_plausible_[value] == true);
33 | value_ = value;
34 | std::fill(value_plausible_.begin(), value_plausible_.end(), false);
35 | value_plausible_[value] = true;
36 | }
37 |
38 | void HanabiHand::ValueKnowledge::ApplyIsNotValueHint(int value) {
39 | assert(value >= 0 && value < value_plausible_.size());
40 | assert(value_ < 0 || value_ != value);
41 | value_plausible_[value] = false;
42 | }
43 |
44 | HanabiHand::CardKnowledge::CardKnowledge(int num_colors, int num_ranks)
45 | : color_(num_colors), rank_(num_ranks) {}
46 |
47 | std::string HanabiHand::CardKnowledge::ToString() const {
48 | std::string result;
49 | result = result + (ColorHinted() ? ColorIndexToChar(Color()) : 'X') +
50 | (RankHinted() ? RankIndexToChar(Rank()) : 'X') + '|';
51 | for (int c = 0; c < color_.Range(); ++c) {
52 | if (color_.IsPlausible(c)) {
53 | result += ColorIndexToChar(c);
54 | }
55 | }
56 | for (int r = 0; r < rank_.Range(); ++r) {
57 | if (rank_.IsPlausible(r)) {
58 | result += RankIndexToChar(r);
59 | }
60 | }
61 | return result;
62 | }
63 |
64 | HanabiHand::HanabiHand(const HanabiHand& hand, bool hide_cards,
65 | bool hide_knowledge) {
66 | if (hide_cards) {
67 | cards_.resize(hand.cards_.size(), HanabiCard());
68 | } else {
69 | cards_ = hand.cards_;
70 | }
71 | if (hide_knowledge && !hand.cards_.empty()) {
72 | card_knowledge_.resize(hand.cards_.size(),
73 | CardKnowledge(hand.card_knowledge_[0].NumColors(),
74 | hand.card_knowledge_[0].NumRanks()));
75 | } else {
76 | card_knowledge_ = hand.card_knowledge_;
77 | }
78 | }
79 |
80 | void HanabiHand::AddCard(HanabiCard card,
81 | const CardKnowledge& initial_knowledge) {
82 | REQUIRE(card.IsValid());
83 | cards_.push_back(card);
84 | card_knowledge_.push_back(initial_knowledge);
85 | }
86 |
87 | void HanabiHand::RemoveFromHand(int card_index,
88 | std::vector* discard_pile) {
89 | if (discard_pile != nullptr) {
90 | discard_pile->push_back(cards_[card_index]);
91 | }
92 | cards_.erase(cards_.begin() + card_index);
93 | card_knowledge_.erase(card_knowledge_.begin() + card_index);
94 | }
95 |
96 | uint8_t HanabiHand::RevealColor(const int color) {
97 | uint8_t mask = 0;
98 | assert(cards_.size() <= 8); // More than 8 cards is currently not supported.
99 | for (int i = 0; i < cards_.size(); ++i) {
100 | if (cards_[i].Color() == color) {
101 | if (!card_knowledge_[i].ColorHinted()) {
102 | mask |= static_cast(1) << i;
103 | }
104 | card_knowledge_[i].ApplyIsColorHint(color);
105 | } else {
106 | card_knowledge_[i].ApplyIsNotColorHint(color);
107 | }
108 | }
109 | return mask;
110 | }
111 |
112 | uint8_t HanabiHand::RevealRank(const int rank) {
113 | uint8_t mask = 0;
114 | assert(cards_.size() <= 8); // More than 8 cards is currently not supported.
115 | for (int i = 0; i < cards_.size(); ++i) {
116 | if (cards_[i].Rank() == rank) {
117 | if (!card_knowledge_[i].RankHinted()) {
118 | mask |= static_cast(1) << i;
119 | }
120 | card_knowledge_[i].ApplyIsRankHint(rank);
121 | } else {
122 | card_knowledge_[i].ApplyIsNotRankHint(rank);
123 | }
124 | }
125 | return mask;
126 | }
127 |
128 | std::string HanabiHand::ToString() const {
129 | std::string result;
130 | assert(cards_.size() == card_knowledge_.size());
131 | for (int i = 0; i < cards_.size(); ++i) {
132 | result +=
133 | cards_[i].ToString() + " || " + card_knowledge_[i].ToString() + '\n';
134 | }
135 | return result;
136 | }
137 |
138 | } // namespace hanabi_learning_env
139 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_game.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_GAME_H__
16 | #define __HANABI_GAME_H__
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "hanabi_card.h"
24 | #include "hanabi_move.h"
25 |
26 | namespace hanabi_learning_env {
27 |
28 | class HanabiGame {
29 | public:
30 | // An agent's observation of a state does include all state knowledge.
31 | // For example, observations never include an agent's own cards.
32 | // A kMinimal observation is similar to what a human sees, and does not
33 | // include any memory of past RevalColor/RevealRank hints. A CardKnowledge
34 | // observation includes per-card knowledge of past hints, as well as simple
35 | // inferred knowledge of the form "this card is not red, because it was
36 | // not revealed as red in a past move". A Seer observation
37 | // shows all cards, including the player's own cards, regardless of what
38 | // hints have been given.
39 | enum AgentObservationType { kMinimal = 0, kCardKnowledge = 1, kSeer = 2 };
40 |
41 | explicit HanabiGame(
42 | const std::unordered_map& params);
43 |
44 | // Number of different player moves.
45 | int MaxMoves() const;
46 | // Get a HanabiMove by unique id.
47 | HanabiMove GetMove(int uid) const { return moves_[uid]; }
48 | // Get unique id for a move. Returns -1 for invalid move.
49 | int GetMoveUid(HanabiMove move) const;
50 | int GetMoveUid(HanabiMove::Type move_type, int card_index, int target_offset,
51 | int color, int rank) const;
52 | // Number of different chance outcomes.
53 | int MaxChanceOutcomes() const;
54 | // Get a chance-outcome HanabiMove by unique id.
55 | HanabiMove GetChanceOutcome(int uid) const { return chance_outcomes_[uid]; }
56 | // Get unique id for a chance-outcome move. Returns -1 for invalid move.
57 | int GetChanceOutcomeUid(HanabiMove move) const;
58 | // Randomly sample a random chance-outcome move from list of moves and
59 | // associated probability distribution.
60 | HanabiMove PickRandomChance(
61 | const std::pair, std::vector>&
62 | chance_outcomes) const;
63 |
64 | std::unordered_map Parameters() const;
65 | int MinPlayers() const { return 2; }
66 | int MaxPlayers() const { return 5; }
67 | int MinScore() const { return 0; }
68 | int MaxScore() const { return num_ranks_ * num_colors_; }
69 | std::string Name() const { return "Hanabi"; }
70 |
71 | int NumColors() const { return num_colors_; }
72 | int NumRanks() const { return num_ranks_; }
73 | int NumPlayers() const { return num_players_; }
74 | int HandSize() const { return hand_size_; }
75 | int MaxInformationTokens() const { return max_information_tokens_; }
76 | int MaxLifeTokens() const { return max_life_tokens_; }
77 | int CardsPerColor() const { return cards_per_color_; }
78 | int MaxDeckSize() const { return cards_per_color_ * num_colors_; }
79 | int NumberCardInstances(int color, int rank) const;
80 | int NumberCardInstances(HanabiCard card) const {
81 | return NumberCardInstances(card.Color(), card.Rank());
82 | }
83 | AgentObservationType ObservationType() const { return observation_type_; }
84 |
85 | // Get the first player to act. Might be randomly generated at each call.
86 | int GetSampledStartPlayer() const;
87 |
88 | private:
89 | // Calculating max moves by move type.
90 | int MaxDiscardMoves() const { return hand_size_; }
91 | int MaxPlayMoves() const { return hand_size_; }
92 | int MaxRevealColorMoves() const { return (num_players_ - 1) * num_colors_; }
93 | int MaxRevealRankMoves() const { return (num_players_ - 1) * num_ranks_; }
94 |
95 | int HandSizeFromRules() const;
96 | HanabiMove ConstructMove(int uid) const;
97 | HanabiMove ConstructChanceOutcome(int uid) const;
98 |
99 | // Table of all possible moves in this game.
100 | std::vector moves_;
101 | // Table of all possible chance outcomes in this game.
102 | std::vector chance_outcomes_;
103 | std::unordered_map params_;
104 | int num_colors_ = -1;
105 | int num_ranks_ = -1;
106 | int num_players_ = -1;
107 | int hand_size_ = -1;
108 | int max_information_tokens_ = -1;
109 | int max_life_tokens_ = -1;
110 | int cards_per_color_ = -1;
111 | int seed_ = -1;
112 | bool random_start_player_ = false;
113 | AgentObservationType observation_type_ = kCardKnowledge;
114 | mutable std::mt19937 rng_;
115 | };
116 |
117 | } // namespace hanabi_learning_env
118 |
119 | #endif
120 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_observation.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_observation.h"
16 |
17 | #include
18 | #include
19 |
20 | #include "util.h"
21 |
22 | namespace hanabi_learning_env {
23 |
24 | namespace {
25 | // Returns the offset of player ID pid relative to player ID observer_pid,
26 | // or pid for negative values. That is, offset such that for a non-negative
27 | // player id pid, we have (observer_pid + offset) % num_players == pid.
28 | int PlayerToOffset(int pid, int observer_pid, int num_players) {
29 | return pid >= 0 ? (pid - observer_pid + num_players) % num_players : pid;
30 | }
31 |
32 | // Switch members from absolute player indices to observer-relative offsets,
33 | // including player indices within the contained HanabiMove.
34 | void ChangeHistoryItemToObserverRelative(int observer_pid, int num_players,
35 | bool show_cards,
36 | HanabiHistoryItem* item) {
37 | if (item->move.MoveType() == HanabiMove::kDeal) {
38 | assert(item->player < 0 && item->deal_to_player >= 0);
39 | item->deal_to_player =
40 | (item->deal_to_player - observer_pid + num_players) % num_players;
41 | if (item->deal_to_player == 0 && !show_cards) {
42 | // Hide cards dealt to observer if they shouldn't be able to see them.
43 | item->move = HanabiMove(HanabiMove::kDeal, -1, -1, -1, -1);
44 | }
45 | } else {
46 | assert(item->player >= 0);
47 | item->player = (item->player - observer_pid + num_players) % num_players;
48 | }
49 | }
50 | } // namespace
51 |
52 | HanabiObservation::HanabiObservation(const HanabiState& state,
53 | int observing_player)
54 | : cur_player_offset_(PlayerToOffset(state.CurPlayer(), observing_player,
55 | state.ParentGame()->NumPlayers())),
56 | discard_pile_(state.DiscardPile()),
57 | fireworks_(state.Fireworks()),
58 | deck_size_(state.Deck().Size()),
59 | information_tokens_(state.InformationTokens()),
60 | life_tokens_(state.LifeTokens()),
61 | legal_moves_(state.LegalMoves(observing_player)),
62 | parent_game_(state.ParentGame()) {
63 | REQUIRE(observing_player >= 0 &&
64 | observing_player < state.ParentGame()->NumPlayers());
65 | hands_.reserve(state.Hands().size());
66 | const bool hide_knowledge =
67 | state.ParentGame()->ObservationType() == HanabiGame::kMinimal;
68 | const bool show_cards = state.ParentGame()->ObservationType() == HanabiGame::kSeer;
69 | hands_.push_back(
70 | HanabiHand(state.Hands()[observing_player], !show_cards, hide_knowledge));
71 | for (int offset = 1; offset < state.ParentGame()->NumPlayers(); ++offset) {
72 | hands_.push_back(HanabiHand(state.Hands()[(observing_player + offset) %
73 | state.ParentGame()->NumPlayers()],
74 | false, hide_knowledge));
75 | }
76 |
77 | const auto& history = state.MoveHistory();
78 | auto start = std::find_if(history.begin(), history.end(),
79 | [](const HanabiHistoryItem& item) {
80 | return item.player != kChancePlayerId;
81 | });
82 | std::reverse_iterator rend(start);
83 | for (auto it = history.rbegin(); it != rend; ++it) {
84 | last_moves_.push_back(*it);
85 | ChangeHistoryItemToObserverRelative(observing_player,
86 | state.ParentGame()->NumPlayers(),
87 | show_cards,
88 | &last_moves_.back());
89 | if (it->player == observing_player) {
90 | break;
91 | }
92 | }
93 | }
94 |
95 | std::string HanabiObservation::ToString() const {
96 | std::string result;
97 | result += "Life tokens: " + std::to_string(LifeTokens()) + "\n";
98 | result += "Info tokens: " + std::to_string(InformationTokens()) + "\n";
99 | result += "Fireworks: ";
100 | for (int i = 0; i < ParentGame()->NumColors(); ++i) {
101 | result += ColorIndexToChar(i);
102 | result += std::to_string(fireworks_[i]) + " ";
103 | }
104 | result += "\nHands:\n";
105 | for (int i = 0; i < hands_.size(); ++i) {
106 | if (i > 0) {
107 | result += "-----\n";
108 | }
109 | if (i == CurPlayerOffset()) {
110 | result += "Cur player\n";
111 | }
112 | result += hands_[i].ToString();
113 | }
114 | result += "Deck size: " + std::to_string(DeckSize()) + "\n";
115 | result += "Discards:";
116 | for (int i = 0; i < discard_pile_.size(); ++i) {
117 | result += " " + discard_pile_[i].ToString();
118 | }
119 | return result;
120 | }
121 |
122 | bool HanabiObservation::CardPlayableOnFireworks(int color, int rank) const {
123 | if (color < 0 || color >= ParentGame()->NumColors()) {
124 | return false;
125 | }
126 | return rank == fireworks_[color];
127 | }
128 |
129 | } // namespace hanabi_learning_env
130 |
--------------------------------------------------------------------------------
/game_example.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "hanabi_game.h"
24 | #include "hanabi_state.h"
25 |
26 | struct GameResult {
27 | int score;
28 | int fireworks_played; // Number of successful Play moves.
29 | int num_steps; // Number of moves by a player.
30 | };
31 |
32 | constexpr const char* kGameParamArgPrefix = "--config.hanabi.";
33 |
34 | GameResult SimulateGame(const hanabi_learning_env::HanabiGame& game,
35 | bool verbose, std::mt19937* rng) {
36 | hanabi_learning_env::HanabiState state(&game);
37 | GameResult result = {0, 0, 0};
38 | while (!state.IsTerminal()) {
39 | if (state.CurPlayer() == hanabi_learning_env::kChancePlayerId) {
40 | // All of this could be replaced with state.ApplyRandomChance().
41 | // Only done this way to demonstrate picking specific chance moves.
42 | auto chance_outcomes = state.ChanceOutcomes();
43 | std::discrete_distribution dist(
44 | chance_outcomes.second.begin(), chance_outcomes.second.end());
45 | auto move = chance_outcomes.first[dist(*rng)];
46 | if (verbose) {
47 | std::cout << "Legal chance:";
48 | for (int i = 0; i < chance_outcomes.first.size(); ++i) {
49 | std::cout << " <" << chance_outcomes.first[i].ToString() << ", "
50 | << chance_outcomes.second[i] << ">";
51 | }
52 | std::cout << "\n";
53 | std::cout << "Sampled move: " << move.ToString() << "\n\n";
54 | }
55 | state.ApplyMove(move);
56 | continue;
57 | }
58 |
59 | auto legal_moves = state.LegalMoves(state.CurPlayer());
60 | std::uniform_int_distribution dist(
61 | 0, legal_moves.size() - 1);
62 | auto move = legal_moves[dist(*rng)];
63 | if (verbose) {
64 | std::cout << "Current player: " << state.CurPlayer() << "\n";
65 | std::cout << state.ToString() << "\n\n";
66 | std::cout << "Legal moves:";
67 | for (int i = 0; i < legal_moves.size(); ++i) {
68 | std::cout << " " << legal_moves[i].ToString();
69 | }
70 | std::cout << "\n";
71 | std::cout << "Sampled move: " << move.ToString() << "\n\n";
72 | }
73 | state.ApplyMove(move);
74 | ++result.num_steps;
75 | if (state.MoveHistory().back().scored) {
76 | ++result.fireworks_played;
77 | }
78 | }
79 |
80 | if (verbose) {
81 | std::cout << "Game done, terminal state:\n" << state.ToString() << "\n\n";
82 | std::cout << "score = " << state.Score() << "\n\n";
83 | }
84 |
85 | result.score = state.Score();
86 | return result;
87 | }
88 |
89 | void SimulateGames(
90 | const std::unordered_map& game_params,
91 | int num_trials = 1, bool verbose = true) {
92 | std::mt19937 rng;
93 | rng.seed(std::random_device()());
94 |
95 | hanabi_learning_env::HanabiGame game(game_params);
96 | auto params = game.Parameters();
97 | std::cout << "Hanabi game created, with parameters:\n";
98 | for (const auto& item : params) {
99 | std::cout << " " << item.first << "=" << item.second << "\n";
100 | }
101 |
102 | std::vector results;
103 | results.reserve(num_trials);
104 | for (int trial = 0; trial < num_trials; ++trial) {
105 | results.push_back(SimulateGame(game, verbose, &rng));
106 | }
107 |
108 | if (num_trials > 1) {
109 | GameResult avg_score = std::accumulate(
110 | results.begin(), results.end(), GameResult(),
111 | [](const GameResult& lhs, const GameResult& rhs) {
112 | GameResult result = {lhs.score + rhs.score,
113 | lhs.fireworks_played + rhs.fireworks_played,
114 | lhs.num_steps + rhs.num_steps};
115 | return result;
116 | });
117 | std::cout << "Average score: "
118 | << static_cast(avg_score.score) / results.size()
119 | << " average number of fireworks played: "
120 | << static_cast(avg_score.fireworks_played) /
121 | results.size()
122 | << " average num_steps: "
123 | << static_cast(avg_score.num_steps) / results.size()
124 | << "\n";
125 | }
126 | }
127 |
128 | std::unordered_map ParseArguments(int argc,
129 | char** argv) {
130 | std::unordered_map game_params;
131 | const auto prefix_len = strlen(kGameParamArgPrefix);
132 | for (int i = 1; i < argc; ++i) {
133 | std::string param = argv[i];
134 | if (param.compare(0, prefix_len, kGameParamArgPrefix) == 0 &&
135 | param.size() > prefix_len) {
136 | std::string value;
137 | param = param.substr(prefix_len, std::string::npos);
138 | auto value_pos = param.find("=");
139 | if (value_pos != std::string::npos) {
140 | value = param.substr(value_pos + 1, std::string::npos);
141 | param = param.substr(0, value_pos);
142 | }
143 | game_params[param] = value;
144 | }
145 | }
146 | return game_params;
147 | }
148 |
149 | int main(int argc, char** argv) {
150 | auto game_params = ParseArguments(argc, argv);
151 | SimulateGames(game_params);
152 | return 0;
153 | }
154 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_hand.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_HAND_H__
16 | #define __HANABI_HAND_H__
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "hanabi_card.h"
23 |
24 | namespace hanabi_learning_env {
25 |
26 | class HanabiHand {
27 | public:
28 | class ValueKnowledge {
29 | // Knowledge about an unknown integer variable in range 0 to value_range-1.
30 | // Records hints that either reveal the exact value (no longer unknown),
31 | // or reveal that the variable is not some particular value.
32 | // For example, ValueKnowledge(3) tracks a variable that can be 0, 1, or 2.
33 | // Initially, ValueHinted()=false, value()=-1, and ValueCouldBe(v)=true
34 | // for v=0, 1, and 2.
35 | // After recording that the value is not 1, we have
36 | // ValueHinted()=false, value()=-1, and ValueCouldBe(1)=false.
37 | // After recording that the value is 0, we have
38 | // ValueHinted()=true, value()=0, and ValueCouldBe(v)=false for v=1, and 2.
39 | public:
40 | explicit ValueKnowledge(int value_range);
41 | int Range() const { return value_plausible_.size(); }
42 | // Returns true if and only if the exact value was revealed.
43 | // Does not perform inference to get a known value from not-value hints.
44 | bool ValueHinted() const { return value_ >= 0; }
45 | int Value() const { return value_; } // -1 if value was not hinted.
46 | // Returns true if we have no hint saying variable is not the given value.
47 | bool IsPlausible(int value) const { return value_plausible_[value]; }
48 | // Record a hint that gives the value of the variable.
49 | void ApplyIsValueHint(int value);
50 | // Record a hint that the variable does not have the given value.
51 | void ApplyIsNotValueHint(int value);
52 |
53 | private:
54 | // Value if hint directly provided the value, or -1 with no direct hint.
55 | int value_ = -1;
56 | std::vector value_plausible_; // Knowledge from not-value hints.
57 | };
58 |
59 | class CardKnowledge {
60 | // Hinted knowledge about color and rank of an initially unknown card.
61 | public:
62 | CardKnowledge(int num_colors, int num_ranks);
63 | // Returns number of possible colors being tracked.
64 | int NumColors() const { return color_.Range(); }
65 | // Returns true if and only if the exact color was revealed.
66 | // Does not perform inference to get a known color from not-color hints.
67 | bool ColorHinted() const { return color_.ValueHinted(); }
68 | // Color of card if it was hinted, -1 if not hinted.
69 | int Color() const { return color_.Value(); }
70 | // Returns true if we have no hint saying card is not the given color.
71 | bool ColorPlausible(int color) const { return color_.IsPlausible(color); }
72 | void ApplyIsColorHint(int color) { color_.ApplyIsValueHint(color); }
73 | void ApplyIsNotColorHint(int color) { color_.ApplyIsNotValueHint(color); }
74 | // Returns number of possible ranks being tracked.
75 | int NumRanks() const { return rank_.Range(); }
76 | // Returns true if and only if the exact rank was revealed.
77 | // Does not perform inference to get a known rank from not-rank hints.
78 | bool RankHinted() const { return rank_.ValueHinted(); }
79 | // Rank of card if it was hinted, -1 if not hinted.
80 | int Rank() const { return rank_.Value(); }
81 | // Returns true if we have no hint saying card is not the given rank.
82 | bool RankPlausible(int rank) const { return rank_.IsPlausible(rank); }
83 | void ApplyIsRankHint(int rank) { rank_.ApplyIsValueHint(rank); }
84 | void ApplyIsNotRankHint(int rank) { rank_.ApplyIsNotValueHint(rank); }
85 | std::string ToString() const;
86 |
87 | private:
88 | ValueKnowledge color_;
89 | ValueKnowledge rank_;
90 | };
91 |
92 | HanabiHand() {}
93 | HanabiHand(const HanabiHand& hand)
94 | : cards_(hand.cards_), card_knowledge_(hand.card_knowledge_) {}
95 | // Copy hand. Hide cards (set to invalid) if hide_cards is true.
96 | // Hide card knowledge (set to unknown) if hide_knowledge is true.
97 | HanabiHand(const HanabiHand& hand, bool hide_cards, bool hide_knowledge);
98 | // Cards and corresponding card knowledge are always arranged from oldest to
99 | // newest, with the oldest card or knowledge at index 0.
100 | const std::vector& Cards() const { return cards_; }
101 | const std::vector& Knowledge() const {
102 | return card_knowledge_;
103 | }
104 | void AddCard(HanabiCard card, const CardKnowledge& initial_knowledge);
105 | // Remove card_index card from hand. Put in discard_pile if not nullptr
106 | // (pushes the card to the back of the discard_pile vector).
107 | void RemoveFromHand(int card_index, std::vector* discard_pile);
108 | // Make cards with the given rank visible.
109 | // Returns new information bitmask, bit_i set if card_i color was revealed
110 | // and was previously unknown.
111 | uint8_t RevealRank(int rank);
112 | // Make cards with the given color visible.
113 | // Returns new information bitmask, bit_i set if card_i color was revealed
114 | // and was previously unknown.
115 | uint8_t RevealColor(int color);
116 | std::string ToString() const;
117 |
118 | private:
119 | // A set of cards and knowledge about them.
120 | std::vector cards_;
121 | std::vector card_knowledge_;
122 | };
123 |
124 | } // namespace hanabi_learning_env
125 |
126 | #endif
127 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_state.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __HANABI_STATE_H__
16 | #define __HANABI_STATE_H__
17 |
18 | #include
19 | #include
20 | #include
21 |
22 | #include "hanabi_card.h"
23 | #include "hanabi_game.h"
24 | #include "hanabi_hand.h"
25 | #include "hanabi_history_item.h"
26 | #include "hanabi_move.h"
27 |
28 | namespace hanabi_learning_env {
29 |
30 | constexpr int kChancePlayerId = -1;
31 |
32 | class HanabiState {
33 | public:
34 | class HanabiDeck {
35 | public:
36 | explicit HanabiDeck(const HanabiGame& game);
37 | // DealCard returns invalid card on failure.
38 | HanabiCard DealCard(int color, int rank);
39 | HanabiCard DealCard(std::mt19937* rng);
40 | int Size() const { return total_count_; }
41 | bool Empty() const { return total_count_ == 0; }
42 | int CardCount(int color, int rank) const {
43 | return card_count_[CardToIndex(color, rank)];
44 | }
45 |
46 | private:
47 | int CardToIndex(int color, int rank) const {
48 | return color * num_ranks_ + rank;
49 | }
50 | int IndexToColor(int index) const { return index / num_ranks_; }
51 | int IndexToRank(int index) const { return index % num_ranks_; }
52 |
53 | // Number of instances in the deck for each card.
54 | // E.g., if card_count_[CardToIndex(card)] == 2, then there are two
55 | // instances of card remaining in the deck, available to be dealt out.
56 | std::vector card_count_;
57 | int total_count_ = -1; // Total number of cards available to be dealt out.
58 | int num_ranks_ = -1; // From game.NumRanks(), used to map card to index.
59 | };
60 |
61 | enum EndOfGameType {
62 | kNotFinished, // Not the end of game.
63 | kOutOfLifeTokens, // Players ran out of life tokens.
64 | kOutOfCards, // Players ran out of cards.
65 | kCompletedFireworks // All fireworks played.
66 | };
67 |
68 | // Construct a HanabiState, initialised to the start of the game.
69 | // If start_player >= 0, the game-provided start player is overridden
70 | // and the first player after chance is start_player.
71 | explicit HanabiState(const HanabiGame* parent_game, int start_player = -1);
72 | // Copy constructor for recursive game traversals using copy + apply-move.
73 | HanabiState(const HanabiState& state) = default;
74 |
75 | bool MoveIsLegal(HanabiMove move) const;
76 | void ApplyMove(HanabiMove move);
77 | // Legal moves for state. Moves point into an unchanging list in parent_game.
78 | std::vector LegalMoves(int player) const;
79 | // Returns true if card with color and rank can be played on fireworks pile.
80 | bool CardPlayableOnFireworks(int color, int rank) const;
81 | bool CardPlayableOnFireworks(HanabiCard card) const {
82 | return CardPlayableOnFireworks(card.Color(), card.Rank());
83 | }
84 | bool ChanceOutcomeIsLegal(HanabiMove move) const { return MoveIsLegal(move); }
85 | double ChanceOutcomeProb(HanabiMove move) const;
86 | void ApplyChanceOutcome(HanabiMove move) { ApplyMove(move); }
87 | void ApplyRandomChance();
88 | // Get the valid chance moves, and associated probabilities.
89 | // Guaranteed that moves.size() == probabilities.size().
90 | std::pair, std::vector> ChanceOutcomes()
91 | const;
92 | EndOfGameType EndOfGameStatus() const;
93 | bool IsTerminal() const { return EndOfGameStatus() != kNotFinished; }
94 | int Score() const;
95 | std::string ToString() const;
96 |
97 | int CurPlayer() const { return cur_player_; }
98 | int LifeTokens() const { return life_tokens_; }
99 | int InformationTokens() const { return information_tokens_; }
100 | const std::vector& Hands() const { return hands_; }
101 | const std::vector& Fireworks() const { return fireworks_; }
102 | const HanabiGame* ParentGame() const { return parent_game_; }
103 | const HanabiDeck& Deck() const { return deck_; }
104 | // Get the discard pile (the element at the back is the most recent discard.)
105 | const std::vector& DiscardPile() const { return discard_pile_; }
106 | // Sequence of moves from beginning of game. Stored as .
107 | const std::vector& MoveHistory() const {
108 | return move_history_;
109 | }
110 |
111 | private:
112 | // Add card to table if possible, if not lose a life token.
113 | // Returns
114 | // success is true iff card was successfully added to fireworks.
115 | // information_token_added is true iff information_tokens increase
116 | // (i.e., success=true, highest rank was added, and not at max tokens.)
117 | std::pair AddToFireworks(HanabiCard card);
118 | const HanabiHand& HandByOffset(int offset) const {
119 | return hands_[(cur_player_ + offset) % hands_.size()];
120 | }
121 | HanabiHand* HandByOffset(int offset) {
122 | return &hands_[(cur_player_ + offset) % hands_.size()];
123 | }
124 | void AdvanceToNextPlayer(); // Set cur_player to next player to act.
125 | bool HintingIsLegal(HanabiMove move) const;
126 | int PlayerToDeal() const; // -1 if no player needs a card.
127 | bool IncrementInformationTokens();
128 | void DecrementInformationTokens();
129 | void DecrementLifeTokens();
130 |
131 | const HanabiGame* parent_game_ = nullptr;
132 | HanabiDeck deck_;
133 | // Back element of discard_pile_ is most recently discarded card.
134 | std::vector discard_pile_;
135 | std::vector hands_;
136 | std::vector move_history_;
137 | int cur_player_ = -1;
138 | int next_non_chance_player_ = -1; // Next non-chance player to act.
139 | int information_tokens_ = -1;
140 | int life_tokens_ = -1;
141 | std::vector fireworks_;
142 | int turns_to_play_ = -1; // Number of turns to play once deck is empty.
143 | };
144 |
145 | } // namespace hanabi_learning_env
146 |
147 | #endif
148 |
--------------------------------------------------------------------------------
/agents/rainbow/third_party/dopamine/checkpointer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A checkpointing mechanism for Dopamine agents.
16 |
17 | This Checkpointer expects a base directory where checkpoints for different
18 | iterations are stored. Specifically, Checkpointer.save_checkpoint() takes in
19 | as input a dictionary 'data' to be pickled to disk. At each iteration, we
20 | write a file called 'cpkt.#', where # is the iteration number. The
21 | Checkpointer also cleans up old files, maintaining up to the CHECKPOINT_DURATION
22 | most recent iterations.
23 |
24 | The Checkpointer writes a sentinel file to indicate that checkpointing was
25 | globally successful. This means that all other checkpointing activities
26 | (saving the Tensorflow graph, the replay buffer) should be performed *prior*
27 | to calling Checkpointer.save_checkpoint(). This allows the Checkpointer to
28 | detect incomplete checkpoints.
29 |
30 | #### Example
31 |
32 | After running 10 iterations (numbered 0...9) with base_directory='/checkpoint',
33 | the following files will exist:
34 | ```
35 | /checkpoint/cpkt.6
36 | /checkpoint/cpkt.7
37 | /checkpoint/cpkt.8
38 | /checkpoint/cpkt.9
39 | /checkpoint/sentinel_checkpoint_complete.6
40 | /checkpoint/sentinel_checkpoint_complete.7
41 | /checkpoint/sentinel_checkpoint_complete.8
42 | /checkpoint/sentinel_checkpoint_complete.9
43 | ```
44 | """
45 |
46 | from __future__ import absolute_import
47 | from __future__ import division
48 | from __future__ import print_function
49 |
50 | import os
51 | import pickle
52 | import tensorflow as tf
53 |
54 | CHECKPOINT_DURATION = 4
55 |
56 |
57 | def get_latest_checkpoint_number(base_directory):
58 | """Returns the version number of the latest completed checkpoint.
59 |
60 | Args:
61 | base_directory: str, directory in which to look for checkpoint files.
62 |
63 | Returns:
64 | int, the iteration number of the latest checkpoint, or -1 if none was found.
65 | """
66 | glob = os.path.join(base_directory, 'sentinel_checkpoint_complete.*')
67 | def extract_iteration(x):
68 | return int(x[x.rfind('.') + 1:])
69 | try:
70 | checkpoint_files = tf.gfile.Glob(glob)
71 | except tf.errors.NotFoundError:
72 | return -1
73 | try:
74 | latest_iteration = max(extract_iteration(x) for x in checkpoint_files)
75 | return latest_iteration
76 | except ValueError:
77 | return -1
78 |
79 |
80 | class Checkpointer(object):
81 | """Class for managing checkpoints for Dopamine agents.
82 | """
83 |
84 | def __init__(self, base_directory, checkpoint_file_prefix='ckpt',
85 | checkpoint_frequency=1):
86 | """Initializes Checkpointer.
87 |
88 | Args:
89 | base_directory: str, directory where all checkpoints are saved/loaded.
90 | checkpoint_file_prefix: str, prefix to use for naming checkpoint files.
91 | checkpoint_frequency: int, the frequency at which to checkpoint.
92 |
93 | Raises:
94 | ValueError: if base_directory is empty, or not creatable.
95 | """
96 | if not base_directory:
97 | raise ValueError('No path provided to Checkpointer.')
98 | self._checkpoint_file_prefix = checkpoint_file_prefix
99 | self._checkpoint_frequency = checkpoint_frequency
100 | self._base_directory = base_directory
101 | try:
102 | tf.gfile.MakeDirs(base_directory)
103 | except tf.errors.PermissionDeniedError:
104 | # We catch the PermissionDeniedError and issue a more useful exception.
105 | raise ValueError('Unable to create checkpoint path: {}.'.format(
106 | base_directory))
107 |
108 | def _generate_filename(self, file_prefix, iteration_number):
109 | """Returns a checkpoint filename from prefix and iteration number."""
110 | filename = '{}.{}'.format(file_prefix, iteration_number)
111 | return os.path.join(self._base_directory, filename)
112 |
113 | def _save_data_to_file(self, data, filename):
114 | """Saves the given 'data' object to a file."""
115 | with tf.gfile.GFile(filename, 'w') as fout:
116 | pickle.dump(data, fout)
117 |
118 | def save_checkpoint(self, iteration_number, data):
119 | """Saves a new checkpoint at the current iteration_number.
120 |
121 | Args:
122 | iteration_number: int, the current iteration number for this checkpoint.
123 | data: Any (picklable) python object containing the data to store in the
124 | checkpoint.
125 | """
126 | if iteration_number % self._checkpoint_frequency != 0:
127 | return
128 |
129 | filename = self._generate_filename(self._checkpoint_file_prefix,
130 | iteration_number)
131 | self._save_data_to_file(data, filename)
132 | filename = self._generate_filename('sentinel_checkpoint_complete',
133 | iteration_number)
134 | with tf.gfile.GFile(filename, 'wb') as fout:
135 | fout.write('done')
136 |
137 | self._clean_up_old_checkpoints(iteration_number)
138 |
139 | def _clean_up_old_checkpoints(self, iteration_number):
140 | """Removes sufficiently old checkpoints."""
141 | # After writing a the checkpoint and sentinel file, we garbage collect files
142 | # that are CHECKPOINT_DURATION * self._checkpoint_frequency versions old.
143 | stale_iteration_number = iteration_number - (self._checkpoint_frequency *
144 | CHECKPOINT_DURATION)
145 |
146 | if stale_iteration_number >= 0:
147 | stale_file = self._generate_filename(self._checkpoint_file_prefix,
148 | stale_iteration_number)
149 | stale_sentinel = self._generate_filename('sentinel_checkpoint_complete',
150 | stale_iteration_number)
151 | try:
152 | tf.gfile.Remove(stale_file)
153 | tf.gfile.Remove(stale_sentinel)
154 | except tf.errors.NotFoundError:
155 | # Ignore if file not found.
156 | tf.logging.info('Unable to remove {} or {}.'.format(stale_file,
157 | stale_sentinel))
158 |
159 | def _load_data_from_file(self, filename):
160 | if not tf.gfile.Exists(filename):
161 | return None
162 | with tf.gfile.GFile(filename, 'rb') as fin:
163 | return pickle.load(fin)
164 |
165 | def load_checkpoint(self, iteration_number):
166 | """Tries to reload a checkpoint at the selected iteration number.
167 |
168 | Args:
169 | iteration_number: The checkpoint iteration number to try to load.
170 |
171 | Returns:
172 | If the checkpoint files exist, two unpickled objects that were passed in
173 | as data to save_checkpoint; returns None if the files do not exist.
174 | """
175 | checkpoint_file = self._generate_filename(self._checkpoint_file_prefix,
176 | iteration_number)
177 | return self._load_data_from_file(checkpoint_file)
178 |
--------------------------------------------------------------------------------
/agents/rainbow/third_party/dopamine/sum_tree.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A sum tree data structure.
16 |
17 | Used for prioritized experience replay. See prioritized_replay_buffer.py
18 | and Schaul et al. (2015).
19 | """
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import math
25 | import random
26 |
27 | import numpy as np
28 |
29 |
30 | class SumTree(object):
31 | """A sum tree data structure for storing replay priorities.
32 |
33 | A sum tree is a complete binary tree whose leaves contain values called
34 | priorities. Internal nodes maintain the sum of the priorities of all leaf
35 | nodes in their subtree.
36 |
37 | For capacity = 4, the tree may look like this:
38 |
39 | +---+
40 | |2.5|
41 | +-+-+
42 | |
43 | +-------+--------+
44 | | |
45 | +-+-+ +-+-+
46 | |1.5| |1.0|
47 | +-+-+ +-+-+
48 | | |
49 | +----+----+ +----+----+
50 | | | | |
51 | +-+-+ +-+-+ +-+-+ +-+-+
52 | |0.5| |1.0| |0.5| |0.5|
53 | +---+ +---+ +---+ +---+
54 |
55 | This is stored in a list of numpy arrays:
56 | self.nodes = [ [2.5], [1.5, 1], [0.5, 1, 0.5, 0.5] ]
57 |
58 | For conciseness, we allocate arrays as powers of two, and pad the excess
59 | elements with zero values.
60 |
61 | This is similar to the usual array-based representation of a complete binary
62 | tree, but is a little more user-friendly.
63 | """
64 |
65 | def __init__(self, capacity):
66 | """Creates the sum tree data structure for the given replay capacity.
67 |
68 | Args:
69 | capacity: int, the maximum number of elements that can be stored in this
70 | data structure.
71 |
72 | Raises:
73 | ValueError: If requested capacity is not positive.
74 | """
75 | assert isinstance(capacity, int)
76 | if capacity <= 0:
77 | raise ValueError('Sum tree capacity should be positive. Got: {}'.
78 | format(capacity))
79 |
80 | self.nodes = []
81 | tree_depth = int(math.ceil(np.log2(capacity)))
82 | level_size = 1
83 | for _ in range(tree_depth + 1):
84 | nodes_at_this_depth = np.zeros(level_size)
85 | self.nodes.append(nodes_at_this_depth)
86 |
87 | level_size *= 2
88 |
89 | self.max_recorded_priority = 1.0
90 |
91 | def _total_priority(self):
92 | """Returns the sum of all priorities stored in this sum tree.
93 |
94 | Returns:
95 | float, sum of priorities stored in this sum tree.
96 | """
97 | return self.nodes[0][0]
98 |
99 | def sample(self, query_value=None):
100 | """Samples an element from the sum tree.
101 |
102 | Each element has probability p_i / sum_j p_j of being picked, where p_i is
103 | the (positive) value associated with node i (possibly unnormalized).
104 |
105 | Args:
106 | query_value: float in [0, 1], used as the random value to select a
107 | sample. If None, will select one randomly in [0, 1).
108 |
109 | Returns:
110 | int, a random element from the sum tree.
111 |
112 | Raises:
113 | Exception: If the sum tree is empty (i.e. its node values sum to 0), or if
114 | the supplied query_value is larger than the total sum.
115 | """
116 | if self._total_priority() == 0.0:
117 | raise Exception('Cannot sample from an empty sum tree.')
118 |
119 | if query_value and (query_value < 0. or query_value > 1.):
120 | raise ValueError('query_value must be in [0, 1].')
121 |
122 | # Sample a value in range [0, R), where R is the value stored at the root.
123 | query_value = random.random() if query_value is None else query_value
124 | query_value *= self._total_priority()
125 |
126 | # Now traverse the sum tree.
127 | node_index = 0
128 | for nodes_at_this_depth in self.nodes[1:]:
129 | # Compute children of previous depth's node.
130 | left_child = node_index * 2
131 |
132 | left_sum = nodes_at_this_depth[left_child]
133 | # Each subtree describes a range [0, a), where a is its value.
134 | if query_value < left_sum: # Recurse into left subtree.
135 | node_index = left_child
136 | else: # Recurse into right subtree.
137 | node_index = left_child + 1
138 | # Adjust query to be relative to right subtree.
139 | query_value -= left_sum
140 |
141 | return node_index
142 |
143 | def stratified_sample(self, batch_size):
144 | """Performs stratified sampling using the sum tree.
145 |
146 | Let R be the value at the root (total value of sum tree). This method will
147 | divide [0, R) into batch_size segments, pick a random number from each of
148 | those segments, and use that random number to sample from the sum_tree. This
149 | is as specified in Schaul et al. (2015).
150 |
151 | Args:
152 | batch_size: int, the number of strata to use.
153 | Returns:
154 | list of batch_size elements sampled from the sum tree.
155 |
156 | Raises:
157 | Exception: If the sum tree is empty (i.e. its node values sum to 0).
158 | """
159 | if self._total_priority() == 0.0:
160 | raise Exception('Cannot sample from an empty sum tree.')
161 |
162 | bounds = np.linspace(0., 1., batch_size + 1)
163 | assert len(bounds) == batch_size + 1
164 | segments = [(bounds[i], bounds[i+1]) for i in range(batch_size)]
165 | query_values = [random.uniform(x[0], x[1]) for x in segments]
166 | return [self.sample(query_value=x) for x in query_values]
167 |
168 | def get(self, node_index):
169 | """Returns the value of the leaf node corresponding to the index.
170 |
171 | Args:
172 | node_index: The index of the leaf node.
173 | Returns:
174 | The value of the leaf node.
175 | """
176 | return self.nodes[-1][node_index]
177 |
178 | def set(self, node_index, value):
179 | """Sets the value of a leaf node and updates internal nodes accordingly.
180 |
181 | This operation takes O(log(capacity)).
182 | Args:
183 | node_index: int, the index of the leaf node to be updated.
184 | value: float, the value which we assign to the node. This value must be
185 | nonnegative. Setting value = 0 will cause the element to never be
186 | sampled.
187 |
188 | Raises:
189 | ValueError: If the given value is negative.
190 | """
191 | if value < 0.0:
192 | raise ValueError('Sum tree values should be nonnegative. Got {}'.
193 | format(value))
194 | self.max_recorded_priority = max(value, self.max_recorded_priority)
195 |
196 | delta_value = value - self.nodes[-1][node_index]
197 |
198 | # Now traverse back the tree, adjusting all sums along the way.
199 | for nodes_at_this_depth in reversed(self.nodes):
200 | # Note: Adding a delta leads to some tolerable numerical inaccuracies.
201 | nodes_at_this_depth[node_index] += delta_value
202 | node_index //= 2
203 |
204 | assert node_index == 0, ('Sum tree traversal failed, final node index '
205 | 'is not 0.')
206 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_game.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_game.h"
16 |
17 | #include "util.h"
18 |
19 | namespace hanabi_learning_env {
20 |
21 | namespace {
22 | // Constants.
23 | const int kDefaultPlayers = 2;
24 | const int kInformationTokens = 8;
25 | const int kLifeTokens = 3;
26 | const bool kDefaultRandomStart = false;
27 | } // namespace
28 |
29 | HanabiGame::HanabiGame(
30 | const std::unordered_map& params) {
31 | params_ = params;
32 | num_players_ = ParameterValue(params_, "players", kDefaultPlayers);
33 | REQUIRE(num_players_ >= MinPlayers() && num_players_ <= MaxPlayers());
34 | num_colors_ = ParameterValue(params_, "colors", kMaxNumColors);
35 | REQUIRE(num_colors_ > 0 && num_colors_ <= kMaxNumColors);
36 | num_ranks_ = ParameterValue(params_, "ranks", kMaxNumRanks);
37 | REQUIRE(num_ranks_ > 0 && num_ranks_ <= kMaxNumRanks);
38 | hand_size_ = ParameterValue(params_, "hand_size", HandSizeFromRules());
39 | max_information_tokens_ = ParameterValue(
40 | params_, "max_information_tokens", kInformationTokens);
41 | max_life_tokens_ =
42 | ParameterValue(params_, "max_life_tokens", kLifeTokens);
43 | seed_ = ParameterValue(params_, "seed", -1);
44 | random_start_player_ =
45 | ParameterValue(params_, "random_start_player", kDefaultRandomStart);
46 | observation_type_ = AgentObservationType(ParameterValue(
47 | params_, "observation_type", AgentObservationType::kCardKnowledge));
48 | while (seed_ == -1) {
49 | seed_ = std::random_device()();
50 | }
51 | rng_.seed(seed_);
52 |
53 | // Work out number of cards per color, and check deck size is large enough.
54 | cards_per_color_ = 0;
55 | for (int rank = 0; rank < num_ranks_; ++rank) {
56 | cards_per_color_ += NumberCardInstances(0, rank);
57 | }
58 | REQUIRE(hand_size_ * num_players_ <= cards_per_color_ * num_colors_);
59 |
60 | // Build static list of moves.
61 | for (int uid = 0; uid < MaxMoves(); ++uid) {
62 | moves_.push_back(ConstructMove(uid));
63 | }
64 | for (int uid = 0; uid < MaxChanceOutcomes(); ++uid) {
65 | chance_outcomes_.push_back(ConstructChanceOutcome(uid));
66 | }
67 | }
68 |
69 | int HanabiGame::MaxMoves() const {
70 | return MaxDiscardMoves() + MaxPlayMoves() + MaxRevealColorMoves() +
71 | MaxRevealRankMoves();
72 | }
73 |
74 | int HanabiGame::GetMoveUid(HanabiMove move) const {
75 | return GetMoveUid(move.MoveType(), move.CardIndex(), move.TargetOffset(),
76 | move.Color(), move.Rank());
77 | }
78 |
79 | int HanabiGame::GetMoveUid(HanabiMove::Type move_type, int card_index,
80 | int target_offset, int color, int rank) const {
81 | switch (move_type) {
82 | case HanabiMove::kDiscard:
83 | return card_index;
84 | case HanabiMove::kPlay:
85 | return MaxDiscardMoves() + card_index;
86 | case HanabiMove::kRevealColor:
87 | return MaxDiscardMoves() + MaxPlayMoves() +
88 | (target_offset - 1) * NumColors() + color;
89 | case HanabiMove::kRevealRank:
90 | return MaxDiscardMoves() + MaxPlayMoves() + MaxRevealColorMoves() +
91 | (target_offset - 1) * NumRanks() + rank;
92 | default:
93 | return -1;
94 | }
95 | }
96 |
97 | int HanabiGame::MaxChanceOutcomes() const { return NumColors() * NumRanks(); }
98 |
99 | int HanabiGame::GetChanceOutcomeUid(HanabiMove move) const {
100 | if (move.MoveType() != HanabiMove::kDeal) {
101 | return -1;
102 | }
103 | return (move.TargetOffset() * NumColors() + move.Color()) * NumRanks() +
104 | move.Rank();
105 | }
106 |
107 | HanabiMove HanabiGame::PickRandomChance(
108 | const std::pair, std::vector>&
109 | chance_outcomes) const {
110 | std::discrete_distribution dist(
111 | chance_outcomes.second.begin(), chance_outcomes.second.end());
112 | return chance_outcomes.first[dist(rng_)];
113 | }
114 |
115 | std::unordered_map HanabiGame::Parameters() const {
116 | return {{"players", std::to_string(num_players_)},
117 | {"colors", std::to_string(NumColors())},
118 | {"ranks", std::to_string(NumRanks())},
119 | {"hand_size", std::to_string(HandSize())},
120 | {"max_information_tokens", std::to_string(MaxInformationTokens())},
121 | {"max_life_tokens", std::to_string(MaxLifeTokens())},
122 | {"seed", std::to_string(seed_)},
123 | {"random_start_player", random_start_player_ ? "true" : "false"},
124 | {"observation_type", std::to_string(observation_type_)}};
125 | }
126 |
127 | int HanabiGame::NumberCardInstances(int color, int rank) const {
128 | if (color < 0 || color >= NumColors() || rank < 0 || rank >= NumRanks()) {
129 | return 0;
130 | }
131 | if (rank == 0) {
132 | return 3;
133 | } else if (rank == NumRanks() - 1) {
134 | return 1;
135 | }
136 | return 2;
137 | }
138 |
139 | int HanabiGame::GetSampledStartPlayer() const {
140 | if (random_start_player_) {
141 | std::uniform_int_distribution dist(
142 | 0, num_players_ - 1);
143 | return dist(rng_);
144 | }
145 | return 0;
146 | }
147 |
148 | int HanabiGame::HandSizeFromRules() const {
149 | if (num_players_ < 4) {
150 | return 5;
151 | }
152 | return 4;
153 | }
154 |
155 | // Uid mapping. h=hand_size, p=num_players, c=colors, r=ranks
156 | // 0, h-1: discard
157 | // h, 2h-1: play
158 | // 2h, 2h+(p-1)c-1: color hint
159 | // 2h+(p-1)c, 2h+(p-1)c+(p-1)r-1: rank hint
160 | HanabiMove HanabiGame::ConstructMove(int uid) const {
161 | if (uid < 0 || uid >= MaxMoves()) {
162 | return HanabiMove(HanabiMove::kInvalid, /*card_index=*/-1,
163 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1);
164 | }
165 | if (uid < MaxDiscardMoves()) {
166 | return HanabiMove(HanabiMove::kDiscard, /*card_index=*/uid,
167 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1);
168 | }
169 | uid -= MaxDiscardMoves();
170 | if (uid < MaxPlayMoves()) {
171 | return HanabiMove(HanabiMove::kPlay, /*card_index=*/uid,
172 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1);
173 | }
174 | uid -= MaxPlayMoves();
175 | if (uid < MaxRevealColorMoves()) {
176 | return HanabiMove(HanabiMove::kRevealColor, /*card_index=*/-1,
177 | /*target_offset=*/1 + uid / NumColors(),
178 | /*color=*/uid % NumColors(), /*rank=*/-1);
179 | }
180 | uid -= MaxRevealColorMoves();
181 | return HanabiMove(HanabiMove::kRevealRank, /*card_index=*/-1,
182 | /*target_offset=*/1 + uid / NumRanks(),
183 | /*color=*/-1, /*rank=*/uid % NumRanks());
184 | }
185 |
186 | HanabiMove HanabiGame::ConstructChanceOutcome(int uid) const {
187 | if (uid < 0 || uid >= MaxChanceOutcomes()) {
188 | return HanabiMove(HanabiMove::kInvalid, /*card_index=*/-1,
189 | /*target_offset=*/-1, /*color=*/-1, /*rank=*/-1);
190 | }
191 | return HanabiMove(HanabiMove::kDeal, /*card_index=*/-1,
192 | /*target_offset=*/-1,
193 | /*color=*/uid / NumRanks() % NumColors(),
194 | /*rank=*/uid % NumRanks());
195 | }
196 |
197 | } // namespace hanabi_learning_env
198 |
--------------------------------------------------------------------------------
/pyhanabi.h:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef __PYHANABI_H__
16 | #define __PYHANABI_H__
17 |
18 | /**
19 | * This is a pure C API to the C++ code.
20 | * All the declarations are loaded in pyhanabi.py.
21 | * The set of functions below is referred to as the 'cdef' throughout the code.
22 | */
23 |
24 | extern "C" {
25 |
26 | typedef struct PyHanabiCard {
27 | int color;
28 | int rank;
29 | } pyhanabi_card_t;
30 |
31 | typedef struct PyHanabiCardKnowledge {
32 | /* Points to a hanabi_learning_env::HanabiHand::CardKnowledge. */
33 | const void* knowledge;
34 | } pyhanabi_card_knowledge_t;
35 |
36 | typedef struct PyHanabiMove {
37 | /* Points to a hanabi_learning_env::HanabiMove. */
38 | void* move;
39 | } pyhanabi_move_t;
40 |
41 | typedef struct PyHanabiHistoryItem {
42 | /* Points to a hanabi_learning_env::HanabiHistoryItem. */
43 | void* item;
44 | } pyhanabi_history_item_t;
45 |
46 | typedef struct PyHanabiState {
47 | /* Points to a hanabi_learning_env::HanabiState. */
48 | void* state;
49 | } pyhanabi_state_t;
50 |
51 | typedef struct PyHanabiGame {
52 | /* Points to a hanabi_learning_env::HanabiGame. */
53 | void* game;
54 | } pyhanabi_game_t;
55 |
56 | typedef struct PyHanabiObservation {
57 | /* Points to a hanabi_learning_env::HanabiObservation. */
58 | void* observation;
59 | } pyhanabi_observation_t;
60 |
61 | typedef struct PyHanabiObservationEncoder {
62 | /* Points to a hanabi_learning_env::ObservationEncoder. */
63 | void* encoder;
64 | } pyhanabi_observation_encoder_t;
65 |
66 | /* Utility Functions. */
67 | void DeleteString(char* str);
68 |
69 | /* Card functions. */
70 | int CardValid(pyhanabi_card_t* card);
71 |
72 | /* CardKnowledge functions */
73 | char* CardKnowledgeToString(pyhanabi_card_knowledge_t* knowledge);
74 | int ColorWasHinted(pyhanabi_card_knowledge_t* knowledge);
75 | int KnownColor(pyhanabi_card_knowledge_t* knowledge);
76 | int ColorIsPlausible(pyhanabi_card_knowledge_t* knowledge, int color);
77 | int RankWasHinted(pyhanabi_card_knowledge_t* knowledge);
78 | int KnownRank(pyhanabi_card_knowledge_t* knowledge);
79 | int RankIsPlausible(pyhanabi_card_knowledge_t* knowledge, int rank);
80 |
81 | /* Move functions. */
82 | void DeleteMoveList(void* movelist);
83 | int NumMoves(void* movelist);
84 | void GetMove(void* movelist, int index, pyhanabi_move_t* move);
85 | void DeleteMove(pyhanabi_move_t* move);
86 | char* MoveToString(pyhanabi_move_t* move);
87 | int MoveType(pyhanabi_move_t* move);
88 | int CardIndex(pyhanabi_move_t* move);
89 | int TargetOffset(pyhanabi_move_t* move);
90 | int MoveColor(pyhanabi_move_t* move);
91 | int MoveRank(pyhanabi_move_t* move);
92 | bool GetDiscardMove(int card_index, pyhanabi_move_t* move);
93 | bool GetPlayMove(int card_index, pyhanabi_move_t* move);
94 | bool GetRevealColorMove(int target_offset, int color, pyhanabi_move_t* move);
95 | bool GetRevealRankMove(int target_offset, int rank, pyhanabi_move_t* move);
96 |
97 | /* HistoryItem functions. */
98 | void DeleteHistoryItem(pyhanabi_history_item_t* item);
99 | char* HistoryItemToString(pyhanabi_history_item_t* item);
100 | void HistoryItemMove(pyhanabi_history_item_t* item, pyhanabi_move_t* move);
101 | int HistoryItemPlayer(pyhanabi_history_item_t* item);
102 | int HistoryItemScored(pyhanabi_history_item_t* item);
103 | int HistoryItemInformationToken(pyhanabi_history_item_t* item);
104 | int HistoryItemColor(pyhanabi_history_item_t* item);
105 | int HistoryItemRank(pyhanabi_history_item_t* item);
106 | int HistoryItemRevealBitmask(pyhanabi_history_item_t* item);
107 | int HistoryItemNewlyRevealedBitmask(pyhanabi_history_item_t* item);
108 | int HistoryItemDealToPlayer(pyhanabi_history_item_t* item);
109 |
110 | /* State functions. */
111 | void NewState(pyhanabi_game_t* game, pyhanabi_state_t* state);
112 | void CopyState(const pyhanabi_state_t* src, pyhanabi_state_t* dest);
113 | void DeleteState(pyhanabi_state_t* state);
114 | const void* StateParentGame(pyhanabi_state_t* state);
115 | void StateApplyMove(pyhanabi_state_t* state, pyhanabi_move_t* move);
116 | int StateCurPlayer(pyhanabi_state_t* state);
117 | void StateDealRandomCard(pyhanabi_state_t* state);
118 | int StateDeckSize(pyhanabi_state_t* state);
119 | int StateFireworks(pyhanabi_state_t* state, int color);
120 | int StateDiscardPileSize(pyhanabi_state_t* state);
121 | void StateGetDiscard(pyhanabi_state_t* state, int index, pyhanabi_card_t* card);
122 | int StateGetHandSize(pyhanabi_state_t* state, int pid);
123 | void StateGetHandCard(pyhanabi_state_t* state, int pid, int index,
124 | pyhanabi_card_t* card);
125 | int StateEndOfGameStatus(pyhanabi_state_t* state);
126 | int StateInformationTokens(pyhanabi_state_t* state);
127 | void* StateLegalMoves(pyhanabi_state_t* state);
128 | int StateLifeTokens(pyhanabi_state_t* state);
129 | int StateNumPlayers(pyhanabi_state_t* state);
130 | int StateScore(pyhanabi_state_t* state);
131 | char* StateToString(pyhanabi_state_t* state);
132 | bool MoveIsLegal(const pyhanabi_state_t* state, const pyhanabi_move_t* move);
133 | bool CardPlayableOnFireworks(const pyhanabi_state_t* state, int color,
134 | int rank);
135 | int StateLenMoveHistory(pyhanabi_state_t* state);
136 | void StateGetMoveHistory(pyhanabi_state_t* state, int index,
137 | pyhanabi_history_item_t* item);
138 |
139 | /* Game functions. */
140 | void DeleteGame(pyhanabi_game_t* game);
141 | void NewDefaultGame(pyhanabi_game_t* game);
142 | void NewGame(pyhanabi_game_t* game, int list_length, const char** param_list);
143 | char* GameParamString(pyhanabi_game_t* game);
144 | int NumPlayers(pyhanabi_game_t* game);
145 | int NumColors(pyhanabi_game_t* game);
146 | int NumRanks(pyhanabi_game_t* game);
147 | int HandSize(pyhanabi_game_t* game);
148 | int MaxInformationTokens(pyhanabi_game_t* game);
149 | int MaxLifeTokens(pyhanabi_game_t* game);
150 | int ObservationType(pyhanabi_game_t* game);
151 | int NumCards(pyhanabi_game_t* game, int color, int rank);
152 | int GetMoveUid(pyhanabi_game_t* game, pyhanabi_move_t* move);
153 | void GetMoveByUid(pyhanabi_game_t* game, int move_uid, pyhanabi_move_t* move);
154 | int MaxMoves(pyhanabi_game_t* game);
155 |
156 | /* Observation functions. */
157 | void NewObservation(pyhanabi_state_t* state, int player,
158 | pyhanabi_observation_t* observation);
159 | void DeleteObservation(pyhanabi_observation_t* observation);
160 | char* ObsToString(pyhanabi_observation_t* observation);
161 | int ObsCurPlayerOffset(pyhanabi_observation_t* observation);
162 | int ObsNumPlayers(pyhanabi_observation_t* observation);
163 | int ObsGetHandSize(pyhanabi_observation_t* observation, int pid);
164 | void ObsGetHandCard(pyhanabi_observation_t* observation, int pid, int index,
165 | pyhanabi_card_t* card);
166 | void ObsGetHandCardKnowledge(pyhanabi_observation_t* observation, int pid,
167 | int index, pyhanabi_card_knowledge_t* knowledge);
168 | int ObsDiscardPileSize(pyhanabi_observation_t* observation);
169 | void ObsGetDiscard(pyhanabi_observation_t* observation, int index,
170 | pyhanabi_card_t* card);
171 | int ObsFireworks(pyhanabi_observation_t* observation, int color);
172 | int ObsDeckSize(pyhanabi_observation_t* observation);
173 | int ObsNumLastMoves(pyhanabi_observation_t* observation);
174 | void ObsGetLastMove(pyhanabi_observation_t* observation, int index,
175 | pyhanabi_history_item_t* item);
176 | int ObsInformationTokens(pyhanabi_observation_t* observation);
177 | int ObsLifeTokens(pyhanabi_observation_t* observation);
178 | int ObsNumLegalMoves(pyhanabi_observation_t* observation);
179 | void ObsGetLegalMove(pyhanabi_observation_t* observation, int index,
180 | pyhanabi_move_t* move);
181 | bool ObsCardPlayableOnFireworks(const pyhanabi_observation_t* observation,
182 | int color, int rank);
183 |
184 | /* ObservationEncoder functions. */
185 | void NewObservationEncoder(pyhanabi_observation_encoder_t* encoder,
186 | pyhanabi_game_t* game, int type);
187 | void DeleteObservationEncoder(pyhanabi_observation_encoder_t* encoder);
188 | char* ObservationShape(pyhanabi_observation_encoder_t* encoder);
189 | char* EncodeObservation(pyhanabi_observation_encoder_t* encoder,
190 | pyhanabi_observation_t* observation);
191 |
192 | void EncodeObs(pyhanabi_observation_encoder_t *encoder,
193 | pyhanabi_observation_t *observation,
194 | int *encoding);
195 |
196 | } /* extern "C" */
197 |
198 | #endif
199 |
--------------------------------------------------------------------------------
/agents/rainbow/prioritized_replay_memory.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors and Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | #
17 | #
18 | # This file is a fork of the original Dopamine code incorporating changes for
19 | # the multiplayer setting and the Hanabi Learning Environment.
20 | #
21 | """An implementation of Prioritized Experience Replay (PER).
22 |
23 | This implementation is based on the paper "Prioritized Experience Replay"
24 | by Tom Schaul et al. (2015). Many thanks to Tom Schaul, John Quan, and Matteo
25 | Hessel for providing useful pointers on the algorithm and its implementation.
26 | """
27 |
28 | from __future__ import absolute_import
29 | from __future__ import division
30 | from __future__ import print_function
31 |
32 | from third_party.dopamine import sum_tree
33 | import gin.tf
34 | import numpy as np
35 | import replay_memory
36 | import tensorflow as tf
37 |
38 | DEFAULT_PRIORITY = 100.0
39 |
40 |
41 | class OutOfGraphPrioritizedReplayMemory(replay_memory.OutOfGraphReplayMemory):
42 | """An Out of Graph Replay Memory for Prioritized Experience Replay.
43 |
44 | See replay_memory.py for details.
45 | """
46 |
47 | def __init__(self, num_actions, observation_size, stack_size, replay_capacity,
48 | batch_size, update_horizon=1, gamma=1.0):
49 | """This data structure does the heavy lifting in the replay memory.
50 |
51 | Args:
52 | num_actions: int, number of actions.
53 | observation_size: int, size of an input observation.
54 | stack_size: int, number of frames to use in state stack.
55 | replay_capacity: int, number of transitions to keep in memory.
56 | batch_size: int, batch size.
57 | update_horizon: int, length of update ('n' in n-step update).
58 | gamma: int, the discount factor.
59 | """
60 | super(OutOfGraphPrioritizedReplayMemory, self).__init__(
61 | num_actions=num_actions,
62 | observation_size=observation_size, stack_size=stack_size,
63 | replay_capacity=replay_capacity, batch_size=batch_size,
64 | update_horizon=update_horizon, gamma=gamma)
65 |
66 | self.sum_tree = sum_tree.SumTree(replay_capacity)
67 |
68 | def add(self, observation, action, reward, terminal, legal_actions):
69 | """Adds a transition to the replay memory.
70 |
71 | Since the next_observation in the transition will be the observation added
72 | next there is no need to pass it.
73 |
74 | If the replay memory is at capacity the oldest transition will be discarded.
75 |
76 | Compared to OutOfGraphReplayMemory.add(), this version also sets the
77 | priority of dummy frames to 0.
78 |
79 | Args:
80 | observation: `np.array` uint8, (observation_size, observation_size).
81 | action: int, indicating the action in the transition.
82 | reward: float, indicating the reward received in the transition.
83 | terminal: int, acting as a boolean indicating whether the transition
84 | was terminal (1) or not (0).
85 | legal_actions: Binary vector indicating legal actions (1 == legal).
86 | """
87 | if self.is_empty() or self.terminals[self.cursor() - 1] == 1:
88 | dummy_observation = np.zeros((self._observation_size))
89 | dummy_legal_actions = np.zeros((self._num_actions))
90 | for _ in range(self._stack_size - 1):
91 | self._add(dummy_observation, 0, 0, 0, dummy_legal_actions, priority=0.0)
92 |
93 | self._add(observation, action, reward, terminal, legal_actions,
94 | priority=DEFAULT_PRIORITY)
95 |
96 | def _add(self, observation, action, reward, terminal, legal_actions,
97 | priority=DEFAULT_PRIORITY):
98 | new_element_index = self.cursor()
99 |
100 | super(OutOfGraphPrioritizedReplayMemory, self)._add(
101 | observation, action, reward, terminal, legal_actions)
102 |
103 | self.sum_tree.set(new_element_index, priority)
104 |
105 | def sample_index_batch(self, batch_size):
106 | """Returns a batch of valid indices.
107 |
108 | Args:
109 | batch_size: int, number of indices returned.
110 |
111 | Returns:
112 | List of size batch_size containing valid indices.
113 |
114 | Raises:
115 | Exception: If the batch was not constructed after maximum number of tries.
116 | """
117 | indices = []
118 | allowed_attempts = replay_memory.MAX_SAMPLE_ATTEMPTS
119 |
120 | while len(indices) < batch_size and allowed_attempts > 0:
121 | index = self.sum_tree.sample()
122 |
123 | if self.is_valid_transition(index):
124 | indices.append(index)
125 | else:
126 | allowed_attempts -= 1
127 |
128 | if len(indices) != batch_size:
129 | raise Exception('Could only sample {} valid transitions'.format(
130 | len(indices)))
131 | else:
132 | return indices
133 |
134 | def set_priority(self, indices, priorities):
135 | """Sets the priority of the given elements according to Schaul et al.
136 |
137 | Args:
138 | indices: `np.array` of indices in range [0, replay_capacity).
139 | priorities: list of floats, the corresponding priorities.
140 | """
141 | assert indices.dtype == np.int32, ('Indices must be integers, '
142 | 'given: {}'.format(indices.dtype))
143 | for i, memory_index in enumerate(indices):
144 | self.sum_tree.set(memory_index, priorities[i])
145 |
146 | def get_priority(self, indices, batch_size=None):
147 | """Fetches the priorities correspond to a batch of memory indices.
148 |
149 | For any memory location not yet used, the corresponding priority is 0.
150 |
151 | Args:
152 | indices: `np.array` of indices in range [0, replay_capacity).
153 | batch_size: int, requested number of items.
154 | Returns:
155 | The corresponding priorities.
156 | """
157 | if batch_size is None:
158 | batch_size = self._batch_size
159 | if batch_size != self._state_batch.shape[0]:
160 | self.reset_state_batch_arrays(batch_size)
161 |
162 | priority_batch = np.empty((batch_size), dtype=np.float32)
163 |
164 | assert indices.dtype == np.int32, ('Indices must be integers, '
165 | 'given: {}'.format(indices.dtype))
166 | for i, memory_index in enumerate(indices):
167 | priority_batch[i] = self.sum_tree.get(memory_index)
168 |
169 | return priority_batch
170 |
171 |
172 | @gin.configurable(blacklist=['observation_size', 'stack_size'])
173 | class WrappedPrioritizedReplayMemory(replay_memory.WrappedReplayMemory):
174 | """In graph wrapper for the python Replay Memory.
175 |
176 | Usage:
177 | To add a transition: run the operation add_transition_op
178 | (and feed all the placeholders in add_transition_ph)
179 |
180 | To sample a batch: Construct operations that depend on any of the
181 | sampling tensors. Every sess.run using any of these
182 | tensors will sample a new transition.
183 |
184 | When using staging: Need to prefetch the next batch with each train_op by
185 | calling self.prefetch_batch.
186 |
187 | Everytime this op is called a new transition batch
188 | would be prefetched.
189 |
190 | Attributes:
191 | # The following tensors are sampled randomly each sess.run
192 | states
193 | actions
194 | rewards
195 | next_states
196 | terminals
197 |
198 | add_transition_op: tf operation to add a transition to the replay
199 | memory. All the following placeholders need to be fed.
200 | add_obs_ph
201 | add_action_ph
202 | add_reward_ph
203 | add_terminal_ph
204 | """
205 |
206 | def __init__(self,
207 | num_actions,
208 | observation_size,
209 | stack_size,
210 | use_staging=True,
211 | replay_capacity=1000000,
212 | batch_size=32,
213 | update_horizon=1,
214 | gamma=1.0):
215 | """Initializes a graph wrapper for the python Replay Memory.
216 |
217 | Args:
218 | num_actions: int, number of possible actions.
219 | observation_size: int, size of an input observation.
220 | stack_size: int, number of frames to use in state stack.
221 | use_staging: bool, when True it would use a staging area to prefetch
222 | the next sampling batch.
223 | replay_capacity: int, number of transitions to keep in memory.
224 | batch_size: int.
225 | update_horizon: int, length of update ('n' in n-step update).
226 | gamma: int, the discount factor.
227 |
228 | Raises:
229 | ValueError: If update_horizon is not positive.
230 | ValueError: If discount factor is not in [0, 1].
231 | """
232 | memory = OutOfGraphPrioritizedReplayMemory(num_actions, observation_size,
233 | stack_size, replay_capacity,
234 | batch_size, update_horizon,
235 | gamma)
236 | super(WrappedPrioritizedReplayMemory, self).__init__(
237 | num_actions,
238 | observation_size, stack_size, use_staging, replay_capacity, batch_size,
239 | update_horizon, gamma, wrapped_memory=memory)
240 |
241 | def tf_set_priority(self, indices, losses):
242 | """Sets the priorities for the given indices.
243 |
244 | Args:
245 | indices: tensor of indices (int32), size k.
246 | losses: tensor of losses (float), size k.
247 |
248 | Returns:
249 | A TF op setting the priorities according to Prioritized Experience
250 | Replay.
251 | """
252 | return tf.py_func(
253 | self.memory.set_priority, [indices, losses],
254 | [],
255 | name='prioritized_replay_set_priority_py_func')
256 |
257 | def tf_get_priority(self, indices):
258 | """Gets the priorities for the given indices.
259 |
260 | Args:
261 | indices: tensor of indices (int32), size k.
262 |
263 | Returns:
264 | A tensor (float32) of priorities.
265 | """
266 | return tf.py_func(
267 | self.memory.get_priority, [indices],
268 | [tf.float32],
269 | name='prioritized_replay_get_priority_py_func')
270 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/agents/rainbow/third_party/dopamine/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018 The Dopamine Authors. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright [yyyy] [name of copyright owner]
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
--------------------------------------------------------------------------------
/hanabi_lib/hanabi_state.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include "hanabi_state.h"
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | #include "util.h"
22 |
23 | namespace hanabi_learning_env {
24 |
25 | namespace {
26 | // Returns bitmask of card indices which match color.
27 | uint8_t HandColorBitmask(const HanabiHand& hand, int color) {
28 | uint8_t mask = 0;
29 | const auto& cards = hand.Cards();
30 | assert(cards.size() <= 8); // More than 8 cards is not supported.
31 | for (int i = 0; i < cards.size(); ++i) {
32 | if (cards[i].Color() == color) {
33 | mask |= static_cast(1) << i;
34 | }
35 | }
36 | return mask;
37 | }
38 |
39 | // Returns bitmask of card indices which match color.
40 | uint8_t HandRankBitmask(const HanabiHand& hand, int rank) {
41 | uint8_t mask = 0;
42 | const auto& cards = hand.Cards();
43 | assert(cards.size() <= 8); // More than 8 cards is not supported.
44 | for (int i = 0; i < cards.size(); ++i) {
45 | if (cards[i].Rank() == rank) {
46 | mask |= static_cast(1) << i;
47 | }
48 | }
49 | return mask;
50 | }
51 | } // namespace
52 |
53 | HanabiState::HanabiDeck::HanabiDeck(const HanabiGame& game)
54 | : card_count_(game.NumColors() * game.NumRanks(), 0),
55 | total_count_(0),
56 | num_ranks_(game.NumRanks()) {
57 | for (int color = 0; color < game.NumColors(); ++color) {
58 | for (int rank = 0; rank < game.NumRanks(); ++rank) {
59 | auto count = game.NumberCardInstances(color, rank);
60 | card_count_[CardToIndex(color, rank)] = count;
61 | total_count_ += count;
62 | }
63 | }
64 | }
65 |
66 | HanabiCard HanabiState::HanabiDeck::DealCard(std::mt19937* rng) {
67 | if (Empty()) {
68 | return HanabiCard();
69 | }
70 | std::discrete_distribution dist(
71 | card_count_.begin(), card_count_.end());
72 | int index = dist(*rng);
73 | assert(card_count_[index] > 0);
74 | --card_count_[index];
75 | --total_count_;
76 | return HanabiCard(IndexToColor(index), IndexToRank(index));
77 | }
78 |
79 | HanabiCard HanabiState::HanabiDeck::DealCard(int color, int rank) {
80 | int index = CardToIndex(color, rank);
81 | if (card_count_[index] <= 0) {
82 | return HanabiCard();
83 | }
84 | assert(card_count_[index] > 0);
85 | --card_count_[index];
86 | --total_count_;
87 | return HanabiCard(IndexToColor(index), IndexToRank(index));
88 | }
89 |
90 | HanabiState::HanabiState(const HanabiGame* parent_game, int start_player)
91 | : parent_game_(parent_game),
92 | deck_(*parent_game),
93 | hands_(parent_game->NumPlayers()),
94 | cur_player_(kChancePlayerId),
95 | next_non_chance_player_(start_player >= 0 &&
96 | start_player < parent_game->NumPlayers()
97 | ? start_player
98 | : parent_game->GetSampledStartPlayer()),
99 | information_tokens_(parent_game->MaxInformationTokens()),
100 | life_tokens_(parent_game->MaxLifeTokens()),
101 | fireworks_(parent_game->NumColors(), 0),
102 | turns_to_play_(parent_game->NumPlayers()) {}
103 |
104 | void HanabiState::AdvanceToNextPlayer() {
105 | if (!deck_.Empty() && PlayerToDeal() >= 0) {
106 | cur_player_ = kChancePlayerId;
107 | } else {
108 | cur_player_ = next_non_chance_player_;
109 | next_non_chance_player_ = (cur_player_ + 1) % hands_.size();
110 | }
111 | }
112 |
113 | bool HanabiState::IncrementInformationTokens() {
114 | if (information_tokens_ < ParentGame()->MaxInformationTokens()) {
115 | ++information_tokens_;
116 | return true;
117 | } else {
118 | return false;
119 | }
120 | }
121 |
122 | void HanabiState::DecrementInformationTokens() {
123 | assert(information_tokens_ > 0);
124 | --information_tokens_;
125 | }
126 |
127 | void HanabiState::DecrementLifeTokens() {
128 | assert(life_tokens_ > 0);
129 | --life_tokens_;
130 | }
131 |
132 | std::pair HanabiState::AddToFireworks(HanabiCard card) {
133 | if (CardPlayableOnFireworks(card)) {
134 | ++fireworks_[card.Color()];
135 | // Check if player completed a stack.
136 | if (fireworks_[card.Color()] == ParentGame()->NumRanks()) {
137 | return {true, IncrementInformationTokens()};
138 | }
139 | return {true, false};
140 | } else {
141 | DecrementLifeTokens();
142 | return {false, false};
143 | }
144 | }
145 |
146 | bool HanabiState::HintingIsLegal(HanabiMove move) const {
147 | if (InformationTokens() <= 0) {
148 | return false;
149 | }
150 | if (move.TargetOffset() < 1 ||
151 | move.TargetOffset() >= ParentGame()->NumPlayers()) {
152 | return false;
153 | }
154 | return true;
155 | }
156 |
157 | int HanabiState::PlayerToDeal() const {
158 | for (int i = 0; i < hands_.size(); ++i) {
159 | if (hands_[i].Cards().size() < ParentGame()->HandSize()) {
160 | return i;
161 | }
162 | }
163 | return -1;
164 | }
165 |
166 | bool HanabiState::MoveIsLegal(HanabiMove move) const {
167 | switch (move.MoveType()) {
168 | case HanabiMove::kDeal:
169 | if (cur_player_ != kChancePlayerId) {
170 | return false;
171 | }
172 | if (deck_.CardCount(move.Color(), move.Rank()) == 0) {
173 | return false;
174 | }
175 | break;
176 | case HanabiMove::kDiscard:
177 | if (InformationTokens() >= ParentGame()->MaxInformationTokens()) {
178 | return false;
179 | }
180 | if (move.CardIndex() >= hands_[cur_player_].Cards().size()) {
181 | return false;
182 | }
183 | break;
184 | case HanabiMove::kPlay:
185 | if (move.CardIndex() >= hands_[cur_player_].Cards().size()) {
186 | return false;
187 | }
188 | break;
189 | case HanabiMove::kRevealColor: {
190 | if (!HintingIsLegal(move)) {
191 | return false;
192 | }
193 | const auto& cards = HandByOffset(move.TargetOffset()).Cards();
194 | if (!std::any_of(cards.begin(), cards.end(),
195 | [move](const HanabiCard& card) {
196 | return card.Color() == move.Color();
197 | })) {
198 | return false;
199 | }
200 | break;
201 | }
202 | case HanabiMove::kRevealRank: {
203 | if (!HintingIsLegal(move)) {
204 | return false;
205 | }
206 | const auto& cards = HandByOffset(move.TargetOffset()).Cards();
207 | if (!std::any_of(cards.begin(), cards.end(),
208 | [move](const HanabiCard& card) {
209 | return card.Rank() == move.Rank();
210 | })) {
211 | return false;
212 | }
213 | break;
214 | }
215 | default:
216 | return false;
217 | }
218 | return true;
219 | }
220 |
221 | void HanabiState::ApplyMove(HanabiMove move) {
222 | REQUIRE(MoveIsLegal(move));
223 | if (deck_.Empty()) {
224 | --turns_to_play_;
225 | }
226 | HanabiHistoryItem history(move);
227 | history.player = cur_player_;
228 | switch (move.MoveType()) {
229 | case HanabiMove::kDeal: {
230 | history.deal_to_player = PlayerToDeal();
231 | HanabiHand::CardKnowledge card_knowledge(ParentGame()->NumColors(),
232 | ParentGame()->NumRanks());
233 | if (parent_game_->ObservationType() == HanabiGame::kSeer){
234 | card_knowledge.ApplyIsColorHint(move.Color());
235 | card_knowledge.ApplyIsRankHint(move.Rank());
236 | }
237 | hands_[history.deal_to_player].AddCard(
238 | deck_.DealCard(move.Color(), move.Rank()),
239 | card_knowledge);
240 | }
241 | break;
242 | case HanabiMove::kDiscard:
243 | history.information_token = IncrementInformationTokens();
244 | history.color = hands_[cur_player_].Cards()[move.CardIndex()].Color();
245 | history.rank = hands_[cur_player_].Cards()[move.CardIndex()].Rank();
246 | hands_[cur_player_].RemoveFromHand(move.CardIndex(), &discard_pile_);
247 | break;
248 | case HanabiMove::kPlay:
249 | history.color = hands_[cur_player_].Cards()[move.CardIndex()].Color();
250 | history.rank = hands_[cur_player_].Cards()[move.CardIndex()].Rank();
251 | std::tie(history.scored, history.information_token) =
252 | AddToFireworks(hands_[cur_player_].Cards()[move.CardIndex()]);
253 | hands_[cur_player_].RemoveFromHand(
254 | move.CardIndex(), history.scored ? nullptr : &discard_pile_);
255 | break;
256 | case HanabiMove::kRevealColor:
257 | DecrementInformationTokens();
258 | history.reveal_bitmask =
259 | HandColorBitmask(*HandByOffset(move.TargetOffset()), move.Color());
260 | history.newly_revealed_bitmask =
261 | HandByOffset(move.TargetOffset())->RevealColor(move.Color());
262 | break;
263 | case HanabiMove::kRevealRank:
264 | DecrementInformationTokens();
265 | history.reveal_bitmask =
266 | HandRankBitmask(*HandByOffset(move.TargetOffset()), move.Rank());
267 | history.newly_revealed_bitmask =
268 | HandByOffset(move.TargetOffset())->RevealRank(move.Rank());
269 | break;
270 | default:
271 | std::abort(); // Should not be possible.
272 | }
273 | move_history_.push_back(history);
274 | AdvanceToNextPlayer();
275 | }
276 |
277 | double HanabiState::ChanceOutcomeProb(HanabiMove move) const {
278 | return static_cast(deck_.CardCount(move.Color(), move.Rank())) /
279 | static_cast(deck_.Size());
280 | }
281 |
282 | void HanabiState::ApplyRandomChance() {
283 | auto chance_outcomes = ChanceOutcomes();
284 | REQUIRE(!chance_outcomes.second.empty());
285 | ApplyMove(ParentGame()->PickRandomChance(chance_outcomes));
286 | }
287 |
288 | std::vector HanabiState::LegalMoves(int player) const {
289 | std::vector movelist;
290 | // kChancePlayer=-1 must be handled by ChanceOutcome.
291 | REQUIRE(player >= 0 && player < ParentGame()->NumPlayers());
292 | if (player != cur_player_) {
293 | // Turn-based game. Empty move list for other players.
294 | return movelist;
295 | }
296 | int max_move_uid = ParentGame()->MaxMoves();
297 | for (int uid = 0; uid < max_move_uid; ++uid) {
298 | HanabiMove move = ParentGame()->GetMove(uid);
299 | if (MoveIsLegal(move)) {
300 | movelist.push_back(move);
301 | }
302 | }
303 | return movelist;
304 | }
305 |
306 | bool HanabiState::CardPlayableOnFireworks(int color, int rank) const {
307 | if (color < 0 || color >= ParentGame()->NumColors()) {
308 | return false;
309 | }
310 | return rank == fireworks_[color];
311 | }
312 |
313 | std::pair, std::vector>
314 | HanabiState::ChanceOutcomes() const {
315 | std::pair, std::vector> rv;
316 | int max_outcome_uid = ParentGame()->MaxChanceOutcomes();
317 | for (int uid = 0; uid < max_outcome_uid; ++uid) {
318 | HanabiMove move = ParentGame()->GetChanceOutcome(uid);
319 | if (MoveIsLegal(move)) {
320 | rv.first.push_back(move);
321 | rv.second.push_back(ChanceOutcomeProb(move));
322 | }
323 | }
324 | return rv;
325 | }
326 |
327 | // Format: ::
328 | // -....::
329 | // -.... || -...
330 | // :....
331 | // ::-...
332 | std::string HanabiState::ToString() const {
333 | std::string result;
334 | result += "Life tokens: " + std::to_string(LifeTokens()) + "\n";
335 | result += "Info tokens: " + std::to_string(InformationTokens()) + "\n";
336 | result += "Fireworks: ";
337 | for (int i = 0; i < ParentGame()->NumColors(); ++i) {
338 | result += ColorIndexToChar(i);
339 | result += std::to_string(fireworks_[i]) + " ";
340 | }
341 | result += "\nHands:\n";
342 | for (int i = 0; i < hands_.size(); ++i) {
343 | if (i > 0) {
344 | result += "-----\n";
345 | }
346 | if (i == CurPlayer()) {
347 | result += "Cur player\n";
348 | }
349 | result += hands_[i].ToString();
350 | }
351 | result += "Deck size: " + std::to_string(Deck().Size()) + "\n";
352 | result += "Discards:";
353 | for (int i = 0; i < discard_pile_.size(); ++i) {
354 | result += " " + discard_pile_[i].ToString();
355 | }
356 | return result;
357 | }
358 |
359 | int HanabiState::Score() const {
360 | if (LifeTokens() <= 0) {
361 | return 0;
362 | }
363 | return std::accumulate(fireworks_.begin(), fireworks_.end(), 0);
364 | }
365 |
366 | HanabiState::EndOfGameType HanabiState::EndOfGameStatus() const {
367 | if (LifeTokens() < 1) {
368 | return kOutOfLifeTokens;
369 | }
370 | if (Score() >= ParentGame()->NumColors() * ParentGame()->NumRanks()) {
371 | return kCompletedFireworks;
372 | }
373 | if (turns_to_play_ <= 0) {
374 | return kOutOfCards;
375 | }
376 | return kNotFinished;
377 | }
378 |
379 | } // namespace hanabi_learning_env
380 |
--------------------------------------------------------------------------------
/agents/rainbow/rainbow_agent.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors and Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | #
17 | #
18 | # This file is a fork of the original Dopamine code incorporating changes for
19 | # the multiplayer setting and the Hanabi Learning Environment.
20 | #
21 | """Implementation of a Rainbow agent adapted to the multiplayer setting."""
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import functools
28 |
29 | import dqn_agent
30 | import gin.tf
31 | import numpy as np
32 | import prioritized_replay_memory
33 | import tensorflow as tf
34 |
35 |
36 | slim = tf.contrib.slim
37 |
38 |
39 | @gin.configurable
40 | def rainbow_template(state,
41 | num_actions,
42 | num_atoms=51,
43 | layer_size=512,
44 | num_layers=1):
45 | r"""Builds a Rainbow Network mapping states to value distributions.
46 |
47 | Args:
48 | state: A `tf.placeholder` for the RL state.
49 | num_actions: int, number of actions that the RL agent can take.
50 | num_atoms: int, number of atoms to approximate the distribution with.
51 | layer_size: int, number of hidden units per layer.
52 | num_layers: int, number of hidden layers.
53 |
54 | Returns:
55 | net: A `tf.Graphdef` for Rainbow:
56 | `\theta : \mathcal{X}\rightarrow\mathbb{R}^{|\mathcal{A}| \times N}`,
57 | where `N` is num_atoms.
58 | """
59 | weights_initializer = slim.variance_scaling_initializer(
60 | factor=1.0 / np.sqrt(3.0), mode='FAN_IN', uniform=True)
61 |
62 | net = tf.cast(state, tf.float32)
63 | net = tf.squeeze(net, axis=2)
64 |
65 | for _ in range(num_layers):
66 | net = slim.fully_connected(net, layer_size,
67 | activation_fn=tf.nn.relu)
68 | net = slim.fully_connected(net, num_actions * num_atoms, activation_fn=None,
69 | weights_initializer=weights_initializer)
70 | net = tf.reshape(net, [-1, num_actions, num_atoms])
71 | return net
72 |
73 |
74 | @gin.configurable
75 | class RainbowAgent(dqn_agent.DQNAgent):
76 | """A compact implementation of the multiplayer Rainbow agent."""
77 |
78 | @gin.configurable
79 | def __init__(self,
80 | num_actions=None,
81 | observation_size=None,
82 | num_players=None,
83 | num_atoms=51,
84 | vmax=25.,
85 | gamma=0.99,
86 | update_horizon=1,
87 | min_replay_history=500,
88 | update_period=4,
89 | target_update_period=500,
90 | epsilon_train=0.0,
91 | epsilon_eval=0.0,
92 | epsilon_decay_period=1000,
93 | learning_rate=0.000025,
94 | optimizer_epsilon=0.00003125,
95 | tf_device='/cpu:*'):
96 | """Initializes the agent and constructs its graph.
97 |
98 | Args:
99 | num_actions: int, number of actions the agent can take at any state.
100 | observation_size: int, size of observation vector.
101 | num_players: int, number of players playing this game.
102 | num_atoms: Int, the number of buckets for the value function distribution.
103 | vmax: float, maximum return predicted by a value distribution.
104 | gamma: float, discount factor as commonly used in the RL literature.
105 | update_horizon: int, horizon at which updates are performed, the 'n' in
106 | n-step update.
107 | min_replay_history: int, number of stored transitions before training.
108 | update_period: int, period between DQN updates.
109 | target_update_period: int, update period for the target network.
110 | epsilon_train: float, final epsilon for training.
111 | epsilon_eval: float, epsilon during evaluation.
112 | epsilon_decay_period: int, number of steps for epsilon to decay.
113 | learning_rate: float, learning rate for the optimizer.
114 | optimizer_epsilon: float, epsilon for Adam optimizer.
115 | tf_device: str, Tensorflow device on which to run computations.
116 | """
117 | # We need this because some tools convert round floats into ints.
118 | vmax = float(vmax)
119 | self.num_atoms = num_atoms
120 | # Using -vmax as the minimum return is is wasteful, because all rewards are
121 | # positive -- but does not unduly affect performance.
122 | self.support = tf.linspace(-vmax, vmax, num_atoms)
123 | self.learning_rate = learning_rate
124 | self.optimizer_epsilon = optimizer_epsilon
125 |
126 | graph_template = functools.partial(rainbow_template, num_atoms=num_atoms)
127 | super(RainbowAgent, self).__init__(
128 | num_actions=num_actions,
129 | observation_size=observation_size,
130 | num_players=num_players,
131 | gamma=gamma,
132 | update_horizon=update_horizon,
133 | min_replay_history=min_replay_history,
134 | update_period=update_period,
135 | target_update_period=target_update_period,
136 | epsilon_train=epsilon_train,
137 | epsilon_eval=epsilon_eval,
138 | epsilon_decay_period=epsilon_decay_period,
139 | graph_template=graph_template,
140 | tf_device=tf_device)
141 | tf.logging.info('\t learning_rate: %f', learning_rate)
142 | tf.logging.info('\t optimizer_epsilon: %f', optimizer_epsilon)
143 |
144 | def _build_replay_memory(self, use_staging):
145 | """Creates the replay memory used by the agent.
146 |
147 | Rainbow uses prioritized replay.
148 |
149 | Args:
150 | use_staging: bool, whether to use a staging area in the replay memory.
151 |
152 | Returns:
153 | A replay memory object.
154 | """
155 | return prioritized_replay_memory.WrappedPrioritizedReplayMemory(
156 | num_actions=self.num_actions,
157 | observation_size=self.observation_size,
158 | stack_size=1,
159 | use_staging=use_staging,
160 | update_horizon=self.update_horizon,
161 | gamma=self.gamma)
162 |
163 | def _reshape_networks(self):
164 | # self._q is actually logits now, rename things.
165 | # size of _logits: 1 x num_actions x num_atoms
166 | self._logits = self._q
167 | # size of _probabilities: 1 x num_actions x num_atoms
168 | self._probabilities = tf.contrib.layers.softmax(self._q)
169 | # size of _q: 1 x num_actions
170 | self._q = tf.reduce_sum(self.support * self._probabilities, axis=2)
171 | # Recompute argmax from q values. Ignore illegal actions.
172 | self._q_argmax = tf.argmax(self._q + self.legal_actions_ph, axis=1)[0]
173 |
174 | # size of _replay_logits: 1 x num_actions x num_atoms
175 | self._replay_logits = self._replay_qs
176 | # size of _replay_next_logits: 1 x num_actions x num_atoms
177 | self._replay_next_logits = self._replay_next_qt
178 | del self._replay_qs
179 | del self._replay_next_qt
180 |
181 | def _build_target_distribution(self):
182 | self._reshape_networks()
183 | batch_size = tf.shape(self._replay.rewards)[0]
184 | # size of rewards: batch_size x 1
185 | rewards = self._replay.rewards[:, None]
186 | # size of tiled_support: batch_size x num_atoms
187 | tiled_support = tf.tile(self.support, [batch_size])
188 | tiled_support = tf.reshape(tiled_support, [batch_size, self.num_atoms])
189 | # size of target_support: batch_size x num_atoms
190 |
191 | is_terminal_multiplier = 1. - tf.cast(self._replay.terminals, tf.float32)
192 | # Incorporate terminal state to discount factor.
193 | # size of gamma_with_terminal: batch_size x 1
194 | gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
195 | gamma_with_terminal = gamma_with_terminal[:, None]
196 |
197 | target_support = rewards + gamma_with_terminal * tiled_support
198 | # size of next_probabilities: batch_size x num_actions x num_atoms
199 | next_probabilities = tf.contrib.layers.softmax(
200 | self._replay_next_logits)
201 |
202 | # size of next_qt: 1 x num_actions
203 | next_qt = tf.reduce_sum(self.support * next_probabilities, 2)
204 | # size of next_qt_argmax: 1 x batch_size
205 | next_qt_argmax = tf.argmax(
206 | next_qt + self._replay.next_legal_actions, axis=1)[:, None]
207 | batch_indices = tf.range(tf.to_int64(batch_size))[:, None]
208 | # size of next_qt_argmax: batch_size x 2
209 | next_qt_argmax = tf.concat([batch_indices, next_qt_argmax], axis=1)
210 | # size of next_probabilities: batch_size x num_atoms
211 | next_probabilities = tf.gather_nd(next_probabilities, next_qt_argmax)
212 | return project_distribution(target_support, next_probabilities,
213 | self.support)
214 |
215 | def _build_train_op(self):
216 | """Builds the training op for Rainbow.
217 |
218 | Returns:
219 | train_op: An op performing one step of training.
220 | """
221 | target_distribution = tf.stop_gradient(self._build_target_distribution())
222 |
223 | # size of indices: batch_size x 1.
224 | indices = tf.range(tf.shape(self._replay_logits)[0])[:, None]
225 | # size of reshaped_actions: batch_size x 2.
226 | reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1)
227 | # For each element of the batch, fetch the logits for its selected action.
228 | chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions)
229 |
230 | loss = tf.nn.softmax_cross_entropy_with_logits(
231 | labels=target_distribution,
232 | logits=chosen_action_logits)
233 |
234 | optimizer = tf.train.AdamOptimizer(
235 | learning_rate=self.learning_rate,
236 | epsilon=self.optimizer_epsilon)
237 |
238 | update_priorities_op = self._replay.tf_set_priority(
239 | self._replay.indices, tf.sqrt(loss + 1e-10))
240 |
241 | target_priorities = self._replay.tf_get_priority(self._replay.indices)
242 | target_priorities = tf.math.add(target_priorities, 1e-10)
243 | target_priorities = 1.0 / tf.sqrt(target_priorities)
244 | target_priorities /= tf.reduce_max(target_priorities)
245 |
246 | weighted_loss = target_priorities * loss
247 |
248 | with tf.control_dependencies([update_priorities_op]):
249 | return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
250 |
251 |
252 | def project_distribution(supports, weights, target_support,
253 | validate_args=False):
254 | """Projects a batch of (support, weights) onto target_support.
255 |
256 | Based on equation (7) in (Bellemare et al., 2017):
257 | https://arxiv.org/abs/1707.06887
258 | In the rest of the comments we will refer to this equation simply as Eq7.
259 |
260 | This code is not easy to digest, so we will use a running example to clarify
261 | what is going on, with the following sample inputs:
262 | * supports = [[0, 2, 4, 6, 8],
263 | [1, 3, 4, 5, 6]]
264 | * weights = [[0.1, 0.6, 0.1, 0.1, 0.1],
265 | [0.1, 0.2, 0.5, 0.1, 0.1]]
266 | * target_support = [4, 5, 6, 7, 8]
267 | In the code below, comments preceded with 'Ex:' will be referencing the above
268 | values.
269 |
270 | Args:
271 | supports: Tensor of shape (batch_size, num_dims) defining supports for the
272 | distribution.
273 | weights: Tensor of shape (batch_size, num_dims) defining weights on the
274 | original support points. Although for the CategoricalDQN agent these
275 | weights are probabilities, it is not required that they are.
276 | target_support: Tensor of shape (num_dims) defining support of the projected
277 | distribution. The values must be monotonically increasing. Vmin and Vmax
278 | will be inferred from the first and last elements of this tensor,
279 | respectively. The values in this tensor must be equally spaced.
280 | validate_args: Whether we will verify the contents of the
281 | target_support parameter.
282 |
283 | Returns:
284 | A Tensor of shape (batch_size, num_dims) with the projection of a batch of
285 | (support, weights) onto target_support.
286 |
287 | Raises:
288 | ValueError: If target_support has no dimensions, or if shapes of supports,
289 | weights, and target_support are incompatible.
290 | """
291 | target_support_deltas = target_support[1:] - target_support[:-1]
292 | # delta_z = `\Delta z` in Eq7.
293 | delta_z = target_support_deltas[0]
294 | validate_deps = []
295 | supports.shape.assert_is_compatible_with(weights.shape)
296 | supports[0].shape.assert_is_compatible_with(target_support.shape)
297 | target_support.shape.assert_has_rank(1)
298 | if validate_args:
299 | # Assert that supports and weights have the same shapes.
300 | validate_deps.append(
301 | tf.Assert(
302 | tf.reduce_all(tf.equal(tf.shape(supports), tf.shape(weights))),
303 | [supports, weights]))
304 | # Assert that elements of supports and target_support have the same shape.
305 | validate_deps.append(
306 | tf.Assert(
307 | tf.reduce_all(
308 | tf.equal(tf.shape(supports)[1], tf.shape(target_support))),
309 | [supports, target_support]))
310 | # Assert that target_support has a single dimension.
311 | validate_deps.append(
312 | tf.Assert(
313 | tf.equal(tf.size(tf.shape(target_support)), 1), [target_support]))
314 | # Assert that the target_support is monotonically increasing.
315 | validate_deps.append(
316 | tf.Assert(tf.reduce_all(target_support_deltas > 0), [target_support]))
317 | # Assert that the values in target_support are equally spaced.
318 | validate_deps.append(
319 | tf.Assert(
320 | tf.reduce_all(tf.equal(target_support_deltas, delta_z)),
321 | [target_support]))
322 |
323 | with tf.control_dependencies(validate_deps):
324 | # Ex: `v_min, v_max = 4, 8`.
325 | v_min, v_max = target_support[0], target_support[-1]
326 | # Ex: `batch_size = 2`.
327 | batch_size = tf.shape(supports)[0]
328 | # `N` in Eq7.
329 | # Ex: `num_dims = 5`.
330 | num_dims = tf.shape(target_support)[0]
331 | # clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7.
332 | # Ex: `clipped_support = [[[ 4. 4. 4. 6. 8.]]
333 | # [[ 4. 4. 4. 5. 6.]]]`.
334 | clipped_support = tf.clip_by_value(supports, v_min, v_max)[:, None, :]
335 | # Ex: `tiled_support = [[[[ 4. 4. 4. 6. 8.]
336 | # [ 4. 4. 4. 6. 8.]
337 | # [ 4. 4. 4. 6. 8.]
338 | # [ 4. 4. 4. 6. 8.]
339 | # [ 4. 4. 4. 6. 8.]]
340 | # [[ 4. 4. 4. 5. 6.]
341 | # [ 4. 4. 4. 5. 6.]
342 | # [ 4. 4. 4. 5. 6.]
343 | # [ 4. 4. 4. 5. 6.]
344 | # [ 4. 4. 4. 5. 6.]]]]`.
345 | tiled_support = tf.tile([clipped_support], [1, 1, num_dims, 1])
346 | # Ex: `reshaped_target_support = [[[ 4.]
347 | # [ 5.]
348 | # [ 6.]
349 | # [ 7.]
350 | # [ 8.]]
351 | # [[ 4.]
352 | # [ 5.]
353 | # [ 6.]
354 | # [ 7.]
355 | # [ 8.]]]`.
356 | reshaped_target_support = tf.tile(target_support[:, None], [batch_size, 1])
357 | reshaped_target_support = tf.reshape(reshaped_target_support,
358 | [batch_size, num_dims, 1])
359 | # numerator = `|clipped_support - z_i|` in Eq7.
360 | # Ex: `numerator = [[[[ 0. 0. 0. 2. 4.]
361 | # [ 1. 1. 1. 1. 3.]
362 | # [ 2. 2. 2. 0. 2.]
363 | # [ 3. 3. 3. 1. 1.]
364 | # [ 4. 4. 4. 2. 0.]]
365 | # [[ 0. 0. 0. 1. 2.]
366 | # [ 1. 1. 1. 0. 1.]
367 | # [ 2. 2. 2. 1. 0.]
368 | # [ 3. 3. 3. 2. 1.]
369 | # [ 4. 4. 4. 3. 2.]]]]`.
370 | numerator = tf.abs(tiled_support - reshaped_target_support)
371 | quotient = 1 - (numerator / delta_z)
372 | # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7.
373 | # Ex: `clipped_quotient = [[[[ 1. 1. 1. 0. 0.]
374 | # [ 0. 0. 0. 0. 0.]
375 | # [ 0. 0. 0. 1. 0.]
376 | # [ 0. 0. 0. 0. 0.]
377 | # [ 0. 0. 0. 0. 1.]]
378 | # [[ 1. 1. 1. 0. 0.]
379 | # [ 0. 0. 0. 1. 0.]
380 | # [ 0. 0. 0. 0. 1.]
381 | # [ 0. 0. 0. 0. 0.]
382 | # [ 0. 0. 0. 0. 0.]]]]`.
383 | clipped_quotient = tf.clip_by_value(quotient, 0, 1)
384 | # Ex: `weights = [[ 0.1 0.6 0.1 0.1 0.1]
385 | # [ 0.1 0.2 0.5 0.1 0.1]]`.
386 | weights = weights[:, None, :]
387 | # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))`
388 | # in Eq7.
389 | # Ex: `inner_prod = [[[[ 0.1 0.6 0.1 0. 0. ]
390 | # [ 0. 0. 0. 0. 0. ]
391 | # [ 0. 0. 0. 0.1 0. ]
392 | # [ 0. 0. 0. 0. 0. ]
393 | # [ 0. 0. 0. 0. 0.1]]
394 | # [[ 0.1 0.2 0.5 0. 0. ]
395 | # [ 0. 0. 0. 0.1 0. ]
396 | # [ 0. 0. 0. 0. 0.1]
397 | # [ 0. 0. 0. 0. 0. ]
398 | # [ 0. 0. 0. 0. 0. ]]]]`.
399 | inner_prod = clipped_quotient * weights
400 | # Ex: `projection = [[ 0.8 0.0 0.1 0.0 0.1]
401 | # [ 0.8 0.1 0.1 0.0 0.0]]`.
402 | projection = tf.reduce_sum(inner_prod, 3)
403 | projection = tf.reshape(projection, [batch_size, num_dims])
404 | return projection
405 |
--------------------------------------------------------------------------------
/hanabi_lib/canonical_encoders.cc:
--------------------------------------------------------------------------------
1 | // Copyright 2018 Google LLC
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // https://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 | #include "canonical_encoders.h"
22 |
23 | namespace hanabi_learning_env {
24 |
25 | namespace {
26 |
27 | // Computes the product of dimensions in shape, i.e. how many individual
28 | // pieces of data the encoded observation requires.
29 | int FlatLength(const std::vector& shape) {
30 | return std::accumulate(std::begin(shape), std::end(shape), 1,
31 | std::multiplies());
32 | }
33 |
34 | const HanabiHistoryItem* GetLastNonDealMove(
35 | const std::vector& past_moves) {
36 | auto it = std::find_if(
37 | past_moves.begin(), past_moves.end(), [](const HanabiHistoryItem& item) {
38 | return item.move.MoveType() != HanabiMove::Type::kDeal;
39 | });
40 | return it == past_moves.end() ? nullptr : &(*it);
41 | }
42 |
43 | int BitsPerCard(const HanabiGame& game) {
44 | return game.NumColors() * game.NumRanks();
45 | }
46 |
47 | // The card's one-hot index using a color-major ordering.
48 | int CardIndex(int color, int rank, int num_ranks) {
49 | return color * num_ranks + rank;
50 | }
51 |
52 | int HandsSectionLength(const HanabiGame& game) {
53 | return (game.NumPlayers() - 1) * game.HandSize() * BitsPerCard(game) +
54 | game.NumPlayers();
55 | }
56 |
57 | // Enocdes cards in all other player's hands (excluding our unknown hand),
58 | // and whether the hand is missing a card for all players (when deck is empty.)
59 | // Each card in a hand is encoded with a one-hot representation using
60 | // * bits (25 bits in a standard game) per card.
61 | // Returns the number of entries written to the encoding.
62 | int EncodeHands(const HanabiGame& game, const HanabiObservation& obs,
63 | int start_offset, std::vector* encoding) {
64 | int bits_per_card = BitsPerCard(game);
65 | int num_ranks = game.NumRanks();
66 | int num_players = game.NumPlayers();
67 | int hand_size = game.HandSize();
68 |
69 | int offset = start_offset;
70 | const std::vector& hands = obs.Hands();
71 | assert(hands.size() == num_players);
72 | for (int player = 1; player < num_players; ++player) {
73 | const std::vector& cards = hands[player].Cards();
74 | int num_cards = 0;
75 |
76 | for (const HanabiCard& card : cards) {
77 | // Only a player's own cards can be invalid/unobserved.
78 | assert(card.IsValid());
79 | assert(card.Color() < game.NumColors());
80 | assert(card.Rank() < num_ranks);
81 | (*encoding)[offset + CardIndex(card.Color(), card.Rank(), num_ranks)] = 1;
82 |
83 | ++num_cards;
84 | offset += bits_per_card;
85 | }
86 |
87 | // A player's hand can have fewer cards than the initial hand size.
88 | // Leave the bits for the absent cards empty (adjust the offset to skip
89 | // bits for the missing cards).
90 | if (num_cards < hand_size) {
91 | offset += (hand_size - num_cards) * bits_per_card;
92 | }
93 | }
94 |
95 | // For each player, set a bit if their hand is missing a card.
96 | for (int player = 0; player < num_players; ++player) {
97 | if (hands[player].Cards().size() < game.HandSize()) {
98 | (*encoding)[offset + player] = 1;
99 | }
100 | }
101 | offset += num_players;
102 |
103 | assert(offset - start_offset == HandsSectionLength(game));
104 | return offset - start_offset;
105 | }
106 |
107 | int BoardSectionLength(const HanabiGame& game) {
108 | return game.MaxDeckSize() - game.NumPlayers() * game.HandSize() + // deck
109 | game.NumColors() * game.NumRanks() + // fireworks
110 | game.MaxInformationTokens() + // info tokens
111 | game.MaxLifeTokens(); // life tokens
112 | }
113 |
114 | // Encode the board, including:
115 | // - remaining deck size
116 | // (max_deck_size - num_players * hand_size bits; thermometer)
117 | // - state of the fireworks ( bits per color; one-hot)
118 | // - information tokens remaining (max_information_tokens bits; thermometer)
119 | // - life tokens remaining (max_life_tokens bits; thermometer)
120 | // We note several features use a thermometer representation instead of one-hot.
121 | // For example, life tokens could be: 000 (0), 100 (1), 110 (2), 111 (3).
122 | // Returns the number of entries written to the encoding.
123 | int EncodeBoard(const HanabiGame& game, const HanabiObservation& obs,
124 | int start_offset, std::vector* encoding) {
125 | int num_colors = game.NumColors();
126 | int num_ranks = game.NumRanks();
127 | int num_players = game.NumPlayers();
128 | int hand_size = game.HandSize();
129 | int max_deck_size = game.MaxDeckSize();
130 |
131 | int offset = start_offset;
132 | // Encode the deck size
133 | for (int i = 0; i < obs.DeckSize(); ++i) {
134 | (*encoding)[offset + i] = 1;
135 | }
136 | offset += (max_deck_size - hand_size * num_players); // 40 in normal 2P game
137 |
138 | // fireworks
139 | const std::vector& fireworks = obs.Fireworks();
140 | for (int c = 0; c < num_colors; ++c) {
141 | // fireworks[color] is the number of successfully played cards.
142 | // If some were played, one-hot encode the highest (0-indexed) rank played
143 | if (fireworks[c] > 0) {
144 | (*encoding)[offset + fireworks[c] - 1] = 1;
145 | }
146 | offset += num_ranks;
147 | }
148 |
149 | // info tokens
150 | assert(obs.InformationTokens() >= 0);
151 | assert(obs.InformationTokens() <= game.MaxInformationTokens());
152 | for (int i = 0; i < obs.InformationTokens(); ++i) {
153 | (*encoding)[offset + i] = 1;
154 | }
155 | offset += game.MaxInformationTokens();
156 |
157 | // life tokens
158 | assert(obs.LifeTokens() >= 0);
159 | assert(obs.LifeTokens() <= game.MaxLifeTokens());
160 | for (int i = 0; i < obs.LifeTokens(); ++i) {
161 | (*encoding)[offset + i] = 1;
162 | }
163 | offset += game.MaxLifeTokens();
164 |
165 | assert(offset - start_offset == BoardSectionLength(game));
166 | return offset - start_offset;
167 | }
168 |
169 | int DiscardSectionLength(const HanabiGame& game) { return game.MaxDeckSize(); }
170 |
171 | // Encode the discard pile. (max_deck_size bits)
172 | // Encoding is in color-major ordering, as in kColorStr ("RYGWB"), with each
173 | // color and rank using a thermometer to represent the number of cards
174 | // discarded. For example, in a standard game, there are 3 cards of lowest rank
175 | // (1), 1 card of highest rank (5), 2 of all else. So each color would be
176 | // ordered like so:
177 | //
178 | // LLL H
179 | // 1100011101
180 | //
181 | // This means for this color:
182 | // - 2 cards of the lowest rank have been discarded
183 | // - none of the second lowest rank have been discarded
184 | // - both of the third lowest rank have been discarded
185 | // - one of the second highest rank have been discarded
186 | // - the highest rank card has been discarded
187 | // Returns the number of entries written to the encoding.
188 | int EncodeDiscards(const HanabiGame& game, const HanabiObservation& obs,
189 | int start_offset, std::vector* encoding) {
190 | int num_colors = game.NumColors();
191 | int num_ranks = game.NumRanks();
192 |
193 | int offset = start_offset;
194 | std::vector discard_counts(num_colors * num_ranks, 0);
195 | for (const HanabiCard& card : obs.DiscardPile()) {
196 | ++discard_counts[card.Color() * num_ranks + card.Rank()];
197 | }
198 |
199 | for (int c = 0; c < num_colors; ++c) {
200 | for (int r = 0; r < num_ranks; ++r) {
201 | int num_discarded = discard_counts[c * num_ranks + r];
202 | for (int i = 0; i < num_discarded; ++i) {
203 | (*encoding)[offset + i] = 1;
204 | }
205 | offset += game.NumberCardInstances(c, r);
206 | }
207 | }
208 |
209 | assert(offset - start_offset == DiscardSectionLength(game));
210 | return offset - start_offset;
211 | }
212 |
213 | int LastActionSectionLength(const HanabiGame& game) {
214 | return game.NumPlayers() + // player id
215 | 4 + // move types (play, dis, rev col, rev rank)
216 | game.NumPlayers() + // target player id (if hint action)
217 | game.NumColors() + // color (if hint action)
218 | game.NumRanks() + // rank (if hint action)
219 | game.HandSize() + // outcome (if hint action)
220 | game.HandSize() + // position (if play action)
221 | BitsPerCard(game) + // card (if play or discard action)
222 | 2; // play (successful, added information token)
223 | }
224 |
225 | // Encode the last player action (not chance's deal of cards). This encodes:
226 | // - Acting player index, relative to ourself ( bits; one-hot)
227 | // - The MoveType (4 bits; one-hot)
228 | // - Target player index, relative to acting player, if a reveal move
229 | // ( bits; one-hot)
230 | // - Color revealed, if a reveal color move ( bits; one-hot)
231 | // - Rank revealed, if a reveal rank move ( bits; one-hot)
232 | // - Reveal outcome ( bits; each bit is 1 if the card was hinted at)
233 | // - Position played/discarded ( bits; one-hot)
234 | // - Card played/discarded ( * bits; one-hot)
235 | // Returns the number of entries written to the encoding.
236 | int EncodeLastAction(const HanabiGame& game, const HanabiObservation& obs,
237 | int start_offset, std::vector* encoding) {
238 | int num_colors = game.NumColors();
239 | int num_ranks = game.NumRanks();
240 | int num_players = game.NumPlayers();
241 | int hand_size = game.HandSize();
242 |
243 | int offset = start_offset;
244 | const HanabiHistoryItem* last_move = GetLastNonDealMove(obs.LastMoves());
245 | if (last_move == nullptr) {
246 | offset += LastActionSectionLength(game);
247 | } else {
248 | HanabiMove::Type last_move_type = last_move->move.MoveType();
249 |
250 | // player_id
251 | // Note: no assertion here. At a terminal state, the last player could have
252 | // been me (player id 0).
253 | (*encoding)[offset + last_move->player] = 1;
254 | offset += num_players;
255 |
256 | // move type
257 | switch (last_move_type) {
258 | case HanabiMove::Type::kPlay:
259 | (*encoding)[offset] = 1;
260 | break;
261 | case HanabiMove::Type::kDiscard:
262 | (*encoding)[offset + 1] = 1;
263 | break;
264 | case HanabiMove::Type::kRevealColor:
265 | (*encoding)[offset + 2] = 1;
266 | break;
267 | case HanabiMove::Type::kRevealRank:
268 | (*encoding)[offset + 3] = 1;
269 | break;
270 | default:
271 | std::abort();
272 | }
273 | offset += 4;
274 |
275 | // target player (if hint action)
276 | if (last_move_type == HanabiMove::Type::kRevealColor ||
277 | last_move_type == HanabiMove::Type::kRevealRank) {
278 | int8_t observer_relative_target =
279 | (last_move->player + last_move->move.TargetOffset()) % num_players;
280 | (*encoding)[offset + observer_relative_target] = 1;
281 | }
282 | offset += num_players;
283 |
284 | // color (if hint action)
285 | if (last_move_type == HanabiMove::Type::kRevealColor) {
286 | (*encoding)[offset + last_move->move.Color()] = 1;
287 | }
288 | offset += num_colors;
289 |
290 | // rank (if hint action)
291 | if (last_move_type == HanabiMove::Type::kRevealRank) {
292 | (*encoding)[offset + last_move->move.Rank()] = 1;
293 | }
294 | offset += num_ranks;
295 |
296 | // outcome (if hinted action)
297 | if (last_move_type == HanabiMove::Type::kRevealColor ||
298 | last_move_type == HanabiMove::Type::kRevealRank) {
299 | for (int i = 0, mask = 1; i < hand_size; ++i, mask <<= 1) {
300 | if ((last_move->reveal_bitmask & mask) > 0) {
301 | (*encoding)[offset + i] = 1;
302 | }
303 | }
304 | }
305 | offset += hand_size;
306 |
307 | // position (if play or discard action)
308 | if (last_move_type == HanabiMove::Type::kPlay ||
309 | last_move_type == HanabiMove::Type::kDiscard) {
310 | (*encoding)[offset + last_move->move.CardIndex()] = 1;
311 | }
312 | offset += hand_size;
313 |
314 | // card (if play or discard action)
315 | if (last_move_type == HanabiMove::Type::kPlay ||
316 | last_move_type == HanabiMove::Type::kDiscard) {
317 | assert(last_move->color >= 0);
318 | assert(last_move->rank >= 0);
319 | (*encoding)[offset +
320 | CardIndex(last_move->color, last_move->rank, num_ranks)] = 1;
321 | }
322 | offset += BitsPerCard(game);
323 |
324 | // was successful and/or added information token (if play action)
325 | if (last_move_type == HanabiMove::Type::kPlay) {
326 | if (last_move->scored) {
327 | (*encoding)[offset] = 1;
328 | }
329 | if (last_move->information_token) {
330 | (*encoding)[offset + 1] = 1;
331 | }
332 | }
333 | offset += 2;
334 | }
335 |
336 | assert(offset - start_offset == LastActionSectionLength(game));
337 | return offset - start_offset;
338 | }
339 |
340 | int CardKnowledgeSectionLength(const HanabiGame& game) {
341 | return game.NumPlayers() * game.HandSize() *
342 | (BitsPerCard(game) + game.NumColors() + game.NumRanks());
343 | }
344 |
345 | // Encode the common card knowledge.
346 | // For each card/position in each player's hand, including the observing player,
347 | // encode the possible cards that could be in that position and whether the
348 | // color and rank were directly revealed by a Reveal action. Possible card
349 | // values are in color-major order, using * bits per
350 | // card. For example, if you knew nothing about a card, and a player revealed
351 | // that is was green, the knowledge would be encoded as follows.
352 | // R Y G W B
353 | // 0000000000111110000000000 Only green cards are possible.
354 | // 0 0 1 0 0 Card was revealed to be green.
355 | // 00000 Card rank was not revealed.
356 | //
357 | // Similarly, if the player revealed that one of your other cards was green, you
358 | // would know that this card could not be green, resulting in:
359 | // R Y G W B
360 | // 1111111111000001111111111 Any card that is not green is possible.
361 | // 0 0 0 0 0 Card color was not revealed.
362 | // 00000 Card rank was not revealed.
363 | // Uses * *
364 | // ( * + + ) bits.
365 | // Returns the number of entries written to the encoding.
366 | int EncodeCardKnowledge(const HanabiGame& game, const HanabiObservation& obs,
367 | int start_offset, std::vector* encoding) {
368 | int bits_per_card = BitsPerCard(game);
369 | int num_colors = game.NumColors();
370 | int num_ranks = game.NumRanks();
371 | int num_players = game.NumPlayers();
372 | int hand_size = game.HandSize();
373 |
374 | int offset = start_offset;
375 | const std::vector& hands = obs.Hands();
376 | assert(hands.size() == num_players);
377 | for (int player = 0; player < num_players; ++player) {
378 | const std::vector& knowledge =
379 | hands[player].Knowledge();
380 | int num_cards = 0;
381 |
382 | for (const HanabiHand::CardKnowledge& card_knowledge : knowledge) {
383 | // Add bits for plausible card.
384 | for (int color = 0; color < num_colors; ++color) {
385 | if (card_knowledge.ColorPlausible(color)) {
386 | for (int rank = 0; rank < num_ranks; ++rank) {
387 | if (card_knowledge.RankPlausible(rank)) {
388 | (*encoding)[offset + CardIndex(color, rank, num_ranks)] = 1;
389 | }
390 | }
391 | }
392 | }
393 | offset += bits_per_card;
394 |
395 | // Add bits for explicitly revealed colors and ranks.
396 | if (card_knowledge.ColorHinted()) {
397 | (*encoding)[offset + card_knowledge.Color()] = 1;
398 | }
399 | offset += num_colors;
400 | if (card_knowledge.RankHinted()) {
401 | (*encoding)[offset + card_knowledge.Rank()] = 1;
402 | }
403 | offset += num_ranks;
404 |
405 | ++num_cards;
406 | }
407 |
408 | // A player's hand can have fewer cards than the initial hand size.
409 | // Leave the bits for the absent cards empty (adjust the offset to skip
410 | // bits for the missing cards).
411 | if (num_cards < hand_size) {
412 | offset +=
413 | (hand_size - num_cards) * (bits_per_card + num_colors + num_ranks);
414 | }
415 | }
416 |
417 | assert(offset - start_offset == CardKnowledgeSectionLength(game));
418 | return offset - start_offset;
419 | }
420 |
421 | } // namespace
422 |
423 | std::vector CanonicalObservationEncoder::Shape() const {
424 | return {HandsSectionLength(*parent_game_) +
425 | BoardSectionLength(*parent_game_) +
426 | DiscardSectionLength(*parent_game_) +
427 | LastActionSectionLength(*parent_game_) +
428 | (parent_game_->ObservationType() == HanabiGame::kMinimal
429 | ? 0
430 | : CardKnowledgeSectionLength(*parent_game_))};
431 | }
432 |
433 | std::vector CanonicalObservationEncoder::Encode(
434 | const HanabiObservation& obs) const {
435 | // Make an empty bit string of the proper size.
436 | std::vector encoding(FlatLength(Shape()), 0);
437 |
438 | // This offset is an index to the start of each section of the bit vector.
439 | // It is incremented at the end of each section.
440 | int offset = 0;
441 | offset += EncodeHands(*parent_game_, obs, offset, &encoding);
442 | offset += EncodeBoard(*parent_game_, obs, offset, &encoding);
443 | offset += EncodeDiscards(*parent_game_, obs, offset, &encoding);
444 | offset += EncodeLastAction(*parent_game_, obs, offset, &encoding);
445 | if (parent_game_->ObservationType() != HanabiGame::kMinimal) {
446 | offset += EncodeCardKnowledge(*parent_game_, obs, offset, &encoding);
447 | }
448 |
449 | assert(offset == encoding.size());
450 | return encoding;
451 | }
452 |
453 | } // namespace hanabi_learning_env
454 |
--------------------------------------------------------------------------------
/agents/rainbow/run_experiment.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Dopamine Authors and Google LLC.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | #
17 | #
18 | # This file is a fork of the original Dopamine code incorporating changes for
19 | # the multiplayer setting and the Hanabi Learning Environment.
20 | #
21 | """Run methods for training a DQN agent on Atari.
22 |
23 | Methods in this module are usually referenced by |train.py|.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | import time
31 |
32 | from third_party.dopamine import checkpointer
33 | from third_party.dopamine import iteration_statistics
34 | import dqn_agent
35 | import gin.tf
36 | import rl_env
37 | import numpy as np
38 | import rainbow_agent
39 | import tensorflow as tf
40 |
41 | LENIENT_SCORE = False
42 |
43 |
44 | class ObservationStacker(object):
45 | """Class for stacking agent observations."""
46 |
47 | def __init__(self, history_size, observation_size, num_players):
48 | """Initializer for observation stacker.
49 |
50 | Args:
51 | history_size: int, number of time steps to stack.
52 | observation_size: int, size of observation vector on one time step.
53 | num_players: int, number of players.
54 | """
55 | self._history_size = history_size
56 | self._observation_size = observation_size
57 | self._num_players = num_players
58 | self._obs_stacks = list()
59 | for _ in range(0, self._num_players):
60 | self._obs_stacks.append(np.zeros(self._observation_size *
61 | self._history_size))
62 |
63 | def add_observation(self, observation, current_player):
64 | """Adds observation for the current player.
65 |
66 | Args:
67 | observation: observation vector for current player.
68 | current_player: int, current player id.
69 | """
70 | self._obs_stacks[current_player] = np.roll(self._obs_stacks[current_player],
71 | -self._observation_size)
72 | self._obs_stacks[current_player][(self._history_size - 1) *
73 | self._observation_size:] = observation
74 |
75 | def get_observation_stack(self, current_player):
76 | """Returns the stacked observation for current player.
77 |
78 | Args:
79 | current_player: int, current player id.
80 | """
81 |
82 | return self._obs_stacks[current_player]
83 |
84 | def reset_stack(self):
85 | """Resets the observation stacks to all zero."""
86 |
87 | for i in range(0, self._num_players):
88 | self._obs_stacks[i].fill(0.0)
89 |
90 | @property
91 | def history_size(self):
92 | """Returns number of steps to stack."""
93 | return self._history_size
94 |
95 | def observation_size(self):
96 | """Returns the size of the observation vector after history stacking."""
97 | return self._observation_size * self._history_size
98 |
99 |
100 | def load_gin_configs(gin_files, gin_bindings):
101 | """Loads gin configuration files.
102 |
103 | Args:
104 | gin_files: A list of paths to the gin configuration files for this
105 | experiment.
106 | gin_bindings: List of gin parameter bindings to override the values in the
107 | config files.
108 | """
109 | gin.parse_config_files_and_bindings(gin_files,
110 | bindings=gin_bindings,
111 | skip_unknown=False)
112 |
113 |
114 | @gin.configurable
115 | def create_environment(game_type='Hanabi-Full', num_players=2):
116 | """Creates the Hanabi environment.
117 |
118 | Args:
119 | game_type: Type of game to play. Currently the following are supported:
120 | Hanabi-Full: Regular game.
121 | Hanabi-Small: The small version of Hanabi, with 2 cards and 2 colours.
122 | num_players: Int, number of players to play this game.
123 |
124 | Returns:
125 | A Hanabi environment.
126 | """
127 | return rl_env.make(
128 | environment_name=game_type, num_players=num_players, pyhanabi_path=None)
129 |
130 |
131 | @gin.configurable
132 | def create_obs_stacker(environment, history_size=4):
133 | """Creates an observation stacker.
134 |
135 | Args:
136 | environment: environment object.
137 | history_size: int, number of steps to stack.
138 |
139 | Returns:
140 | An observation stacker object.
141 | """
142 |
143 | return ObservationStacker(history_size,
144 | environment.vectorized_observation_shape()[0],
145 | environment.players)
146 |
147 |
148 | @gin.configurable
149 | def create_agent(environment, obs_stacker, agent_type='DQN'):
150 | """Creates the Hanabi agent.
151 |
152 | Args:
153 | environment: The environment.
154 | obs_stacker: Observation stacker object.
155 | agent_type: str, type of agent to construct.
156 |
157 | Returns:
158 | An agent for playing Hanabi.
159 |
160 | Raises:
161 | ValueError: if an unknown agent type is requested.
162 | """
163 | if agent_type == 'DQN':
164 | return dqn_agent.DQNAgent(observation_size=obs_stacker.observation_size(),
165 | num_actions=environment.num_moves(),
166 | num_players=environment.players)
167 | elif agent_type == 'Rainbow':
168 | return rainbow_agent.RainbowAgent(
169 | observation_size=obs_stacker.observation_size(),
170 | num_actions=environment.num_moves(),
171 | num_players=environment.players)
172 | else:
173 | raise ValueError('Expected valid agent_type, got {}'.format(agent_type))
174 |
175 |
176 | def initialize_checkpointing(agent, experiment_logger, checkpoint_dir,
177 | checkpoint_file_prefix='ckpt'):
178 | """Reloads the latest checkpoint if it exists.
179 |
180 | The following steps will be taken:
181 | - This method will first create a Checkpointer object, which will be used in
182 | the method and then returned to the caller for later use.
183 | - It will then call checkpointer.get_latest_checkpoint_number to determine
184 | whether there is a valid checkpoint in checkpoint_dir, and what is the
185 | largest file number.
186 | - If a valid checkpoint file is found, it will load the bundled data from
187 | this file and will pass it to the agent for it to reload its data.
188 | - If the agent is able to successfully unbundle, this method will verify that
189 | the unbundled data contains the keys, 'logs' and 'current_iteration'. It
190 | will then load the Logger's data from the bundle, and will return the
191 | iteration number keyed by 'current_iteration' as one of the return values
192 | (along with the Checkpointer object).
193 |
194 | Args:
195 | agent: The agent that will unbundle the checkpoint from checkpoint_dir.
196 | experiment_logger: The Logger object that will be loaded from the
197 | checkpoint.
198 | checkpoint_dir: str, the directory containing the checkpoints.
199 | checkpoint_file_prefix: str, the checkpoint file prefix.
200 |
201 | Returns:
202 | start_iteration: int, The iteration number to start the experiment from.
203 | experiment_checkpointer: The experiment checkpointer.
204 | """
205 | experiment_checkpointer = checkpointer.Checkpointer(
206 | checkpoint_dir, checkpoint_file_prefix)
207 |
208 | start_iteration = 0
209 |
210 | # Check if checkpoint exists. Note that the existence of checkpoint 0 means
211 | # that we have finished iteration 0 (so we will start from iteration 1).
212 | latest_checkpoint_version = checkpointer.get_latest_checkpoint_number(
213 | checkpoint_dir)
214 | if latest_checkpoint_version >= 0:
215 | dqn_dictionary = experiment_checkpointer.load_checkpoint(
216 | latest_checkpoint_version)
217 | if agent.unbundle(
218 | checkpoint_dir, latest_checkpoint_version, dqn_dictionary):
219 | assert 'logs' in dqn_dictionary
220 | assert 'current_iteration' in dqn_dictionary
221 | experiment_logger.data = dqn_dictionary['logs']
222 | start_iteration = dqn_dictionary['current_iteration'] + 1
223 | tf.logging.info('Reloaded checkpoint and will start from iteration %d',
224 | start_iteration)
225 |
226 | return start_iteration, experiment_checkpointer
227 |
228 |
229 | def format_legal_moves(legal_moves, action_dim):
230 | """Returns formatted legal moves.
231 |
232 | This function takes a list of actions and converts it into a fixed size vector
233 | of size action_dim. If an action is legal, its position is set to 0 and -Inf
234 | otherwise.
235 | Ex: legal_moves = [0, 1, 3], action_dim = 5
236 | returns [0, 0, -Inf, 0, -Inf]
237 |
238 | Args:
239 | legal_moves: list of legal actions.
240 | action_dim: int, number of actions.
241 |
242 | Returns:
243 | a vector of size action_dim.
244 | """
245 | new_legal_moves = np.full(action_dim, -float('inf'))
246 | if legal_moves:
247 | new_legal_moves[legal_moves] = 0
248 | return new_legal_moves
249 |
250 |
251 | def parse_observations(observations, num_actions, obs_stacker):
252 | """Deconstructs the rich observation data into relevant components.
253 |
254 | Args:
255 | observations: dict, containing full observations.
256 | num_actions: int, The number of available actions.
257 | obs_stacker: Observation stacker object.
258 |
259 | Returns:
260 | current_player: int, Whose turn it is.
261 | legal_moves: `np.array` of floats, of length num_actions, whose elements
262 | are -inf for indices corresponding to illegal moves and 0, for those
263 | corresponding to legal moves.
264 | observation_vector: Vectorized observation for the current player.
265 | """
266 | current_player = observations['current_player']
267 | current_player_observation = (
268 | observations['player_observations'][current_player])
269 |
270 | legal_moves = current_player_observation['legal_moves_as_int']
271 | legal_moves = format_legal_moves(legal_moves, num_actions)
272 |
273 | observation_vector = current_player_observation['vectorized']
274 | obs_stacker.add_observation(observation_vector, current_player)
275 | observation_vector = obs_stacker.get_observation_stack(current_player)
276 |
277 | return current_player, legal_moves, observation_vector
278 |
279 |
280 | def run_one_episode(agent, environment, obs_stacker):
281 | """Runs the agent on a single game of Hanabi in self-play mode.
282 |
283 | Args:
284 | agent: Agent playing Hanabi.
285 | environment: The Hanabi environment.
286 | obs_stacker: Observation stacker object.
287 |
288 | Returns:
289 | step_number: int, number of actions in this episode.
290 | total_reward: float, undiscounted return for this episode.
291 | """
292 | obs_stacker.reset_stack()
293 | observations = environment.reset()
294 | current_player, legal_moves, observation_vector = (
295 | parse_observations(observations, environment.num_moves(), obs_stacker))
296 | action = agent.begin_episode(current_player, legal_moves, observation_vector)
297 |
298 | is_done = False
299 | total_reward = 0
300 | step_number = 0
301 |
302 | has_played = {current_player}
303 |
304 | # Keep track of per-player reward.
305 | reward_since_last_action = np.zeros(environment.players)
306 |
307 | while not is_done:
308 | observations, reward, is_done, _ = environment.step(action.item())
309 |
310 | modified_reward = max(reward, 0) if LENIENT_SCORE else reward
311 | total_reward += modified_reward
312 |
313 | reward_since_last_action += modified_reward
314 |
315 | step_number += 1
316 | if is_done:
317 | break
318 | current_player, legal_moves, observation_vector = (
319 | parse_observations(observations, environment.num_moves(), obs_stacker))
320 | if current_player in has_played:
321 | action = agent.step(reward_since_last_action[current_player],
322 | current_player, legal_moves, observation_vector)
323 | else:
324 | # Each player begins the episode on their first turn (which may not be
325 | # the first move of the game).
326 | action = agent.begin_episode(current_player, legal_moves,
327 | observation_vector)
328 | has_played.add(current_player)
329 |
330 | # Reset this player's reward accumulator.
331 | reward_since_last_action[current_player] = 0
332 |
333 | agent.end_episode(reward_since_last_action)
334 |
335 | tf.logging.info('EPISODE: %d %g', step_number, total_reward)
336 | return step_number, total_reward
337 |
338 |
339 | def run_one_phase(agent, environment, obs_stacker, min_steps, statistics,
340 | run_mode_str):
341 | """Runs the agent/environment loop until a desired number of steps.
342 |
343 | Args:
344 | agent: Agent playing hanabi.
345 | environment: environment object.
346 | obs_stacker: Observation stacker object.
347 | min_steps: int, minimum number of steps to generate in this phase.
348 | statistics: `IterationStatistics` object which records the experimental
349 | results.
350 | run_mode_str: str, describes the run mode for this agent.
351 |
352 | Returns:
353 | The number of steps taken in this phase, the sum of returns, and the
354 | number of episodes performed.
355 | """
356 | step_count = 0
357 | num_episodes = 0
358 | sum_returns = 0.
359 |
360 | while step_count < min_steps:
361 | episode_length, episode_return = run_one_episode(agent, environment,
362 | obs_stacker)
363 | statistics.append({
364 | '{}_episode_lengths'.format(run_mode_str): episode_length,
365 | '{}_episode_returns'.format(run_mode_str): episode_return
366 | })
367 |
368 | step_count += episode_length
369 | sum_returns += episode_return
370 | num_episodes += 1
371 |
372 | return step_count, sum_returns, num_episodes
373 |
374 |
375 | @gin.configurable
376 | def run_one_iteration(agent, environment, obs_stacker,
377 | iteration, training_steps,
378 | evaluate_every_n=100,
379 | num_evaluation_games=100):
380 | """Runs one iteration of agent/environment interaction.
381 |
382 | An iteration involves running several episodes until a certain number of
383 | steps are obtained.
384 |
385 | Args:
386 | agent: Agent playing hanabi.
387 | environment: The Hanabi environment.
388 | obs_stacker: Observation stacker object.
389 | iteration: int, current iteration number, used as a global_step.
390 | training_steps: int, the number of training steps to perform.
391 | evaluate_every_n: int, frequency of evaluation.
392 | num_evaluation_games: int, number of games per evaluation.
393 |
394 | Returns:
395 | A dict containing summary statistics for this iteration.
396 | """
397 | start_time = time.time()
398 |
399 | statistics = iteration_statistics.IterationStatistics()
400 |
401 | # First perform the training phase, during which the agent learns.
402 | agent.eval_mode = False
403 | number_steps, sum_returns, num_episodes = (
404 | run_one_phase(agent, environment, obs_stacker, training_steps, statistics,
405 | 'train'))
406 | time_delta = time.time() - start_time
407 | tf.logging.info('Average training steps per second: %.2f',
408 | number_steps / time_delta)
409 |
410 | average_return = sum_returns / num_episodes
411 | tf.logging.info('Average per episode return: %.2f', average_return)
412 | statistics.append({'average_return': average_return})
413 |
414 | # Also run an evaluation phase if desired.
415 | if evaluate_every_n is not None and iteration % evaluate_every_n == 0:
416 | episode_data = []
417 | agent.eval_mode = True
418 | # Collect episode data for all games.
419 | for _ in range(num_evaluation_games):
420 | episode_data.append(run_one_episode(agent, environment, obs_stacker))
421 |
422 | eval_episode_length, eval_episode_return = map(np.mean, zip(*episode_data))
423 |
424 | statistics.append({
425 | 'eval_episode_lengths': eval_episode_length,
426 | 'eval_episode_returns': eval_episode_return
427 | })
428 | tf.logging.info('Average eval. episode length: %.2f Return: %.2f',
429 | eval_episode_length, eval_episode_return)
430 | else:
431 | statistics.append({
432 | 'eval_episode_lengths': -1,
433 | 'eval_episode_returns': -1
434 | })
435 |
436 | return statistics.data_lists
437 |
438 |
439 | def log_experiment(experiment_logger, iteration, statistics,
440 | logging_file_prefix='log', log_every_n=1):
441 | """Records the results of the current iteration.
442 |
443 | Args:
444 | experiment_logger: A `Logger` object.
445 | iteration: int, iteration number.
446 | statistics: Object containing statistics to log.
447 | logging_file_prefix: str, prefix to use for the log files.
448 | log_every_n: int, specifies logging frequency.
449 | """
450 | if iteration % log_every_n == 0:
451 | experiment_logger['iter{:d}'.format(iteration)] = statistics
452 | experiment_logger.log_to_file(logging_file_prefix, iteration)
453 |
454 |
455 | def checkpoint_experiment(experiment_checkpointer, agent, experiment_logger,
456 | iteration, checkpoint_dir, checkpoint_every_n):
457 | """Checkpoint experiment data.
458 |
459 | Args:
460 | experiment_checkpointer: A `Checkpointer` object.
461 | agent: An RL agent.
462 | experiment_logger: a Logger object, to include its data in the checkpoint.
463 | iteration: int, iteration number for checkpointing.
464 | checkpoint_dir: str, the directory where to save checkpoints.
465 | checkpoint_every_n: int, the frequency for writing checkpoints.
466 | """
467 | if iteration % checkpoint_every_n == 0:
468 | agent_dictionary = agent.bundle_and_checkpoint(checkpoint_dir, iteration)
469 | if agent_dictionary:
470 | agent_dictionary['current_iteration'] = iteration
471 | agent_dictionary['logs'] = experiment_logger.data
472 | experiment_checkpointer.save_checkpoint(iteration, agent_dictionary)
473 |
474 |
475 | @gin.configurable
476 | def run_experiment(agent,
477 | environment,
478 | start_iteration,
479 | obs_stacker,
480 | experiment_logger,
481 | experiment_checkpointer,
482 | checkpoint_dir,
483 | num_iterations=200,
484 | training_steps=5000,
485 | logging_file_prefix='log',
486 | log_every_n=1,
487 | checkpoint_every_n=1):
488 | """Runs a full experiment, spread over multiple iterations."""
489 | tf.logging.info('Beginning training...')
490 | if num_iterations <= start_iteration:
491 | tf.logging.warning('num_iterations (%d) < start_iteration(%d)',
492 | num_iterations, start_iteration)
493 | return
494 |
495 | for iteration in range(start_iteration, num_iterations):
496 | start_time = time.time()
497 | statistics = run_one_iteration(agent, environment, obs_stacker, iteration,
498 | training_steps)
499 | tf.logging.info('Iteration %d took %d seconds', iteration,
500 | time.time() - start_time)
501 | start_time = time.time()
502 | log_experiment(experiment_logger, iteration, statistics,
503 | logging_file_prefix, log_every_n)
504 | tf.logging.info('Logging iteration %d took %d seconds', iteration,
505 | time.time() - start_time)
506 | start_time = time.time()
507 | checkpoint_experiment(experiment_checkpointer, agent, experiment_logger,
508 | iteration, checkpoint_dir, checkpoint_every_n)
509 | tf.logging.info('Checkpointing iteration %d took %d seconds', iteration,
510 | time.time() - start_time)
511 |
--------------------------------------------------------------------------------