├── pyhanabi ├── generate_cp.sh ├── common_utils │ ├── __init__.py │ ├── logger.py │ ├── assert_utils.py │ ├── stopwatch.py │ ├── saver.py │ ├── multi_counter.py │ └── helper.py ├── set_path.py ├── tools │ ├── continual_evaluation.sh │ ├── pretrain.sh │ ├── testing.sh │ ├── continual_learning_scripts │ │ ├── MTL_easy_interactive.sh │ │ ├── MTL_hard_interactive.sh │ │ ├── ER_easy_interactive.sh │ │ ├── AGEM_easy_interactive.sh │ │ ├── AGEM_hard_interactive.sh │ │ ├── ER_hard_interactive.sh │ │ ├── Naive_easy_interactive.sh │ │ ├── Naive_hard_interactive.sh │ │ ├── EWC_online_easy_interactive.sh │ │ ├── EWC_offline_easy_interactive.sh │ │ ├── EWC_offline_hard_interactive.sh │ │ └── EWC_online_hard_interactive.sh │ └── eval_model.py ├── testing.py ├── ewc.py ├── eval.py ├── create.py ├── selfplay.py ├── continual_evaluation.py └── r2d2_gru.py ├── results ├── Pre-trained agents pool for Continual Hanabi.xlsx ├── scores_data_motivation_rand.csv └── sem_data_motivation_rand.csv ├── .gitignore ├── .gitmodules ├── .clang-format ├── requirements.txt ├── CMakeLists.txt ├── rela ├── thread_loop.h ├── utils.h ├── CMakeLists.txt ├── transition.h ├── context.h ├── env.h ├── batch_runner.h ├── pybind.cc ├── batcher.h ├── r2d2_actor.h ├── transition.cc ├── tensor_dict.h ├── transition_buffer.h └── prioritized_replay.h ├── cpp ├── pybind.cc ├── thread_loop.h ├── hanabi_env.h └── hanabi_env.cc └── README.md /pyhanabi/generate_cp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python tools/eval_model.py \ 3 | --weight_1_dir \ 4 | --num_player 2 \ 5 | -------------------------------------------------------------------------------- /results/Pre-trained agents pool for Continual Hanabi.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chandar-lab/Lifelong-Hanabi/HEAD/results/Pre-trained agents pool for Continual Hanabi.xlsx -------------------------------------------------------------------------------- /pyhanabi/common_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .assert_utils import * 2 | from .helper import * 3 | from .logger import * 4 | from .saver import * 5 | from .multi_counter import MultiCounter 6 | from .stopwatch import Stopwatch 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | .idea/* 3 | cmake-build-debug/ 4 | *.pyc 5 | */__pycache__/* 6 | .DS_Store 7 | pyhanabi/exps 8 | sweep 9 | sweep_* 10 | plot.png 11 | sweep*.py 12 | models/* 13 | logs/* 14 | unused/* 15 | !models/download.sh 16 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "hanabi-learning-environment"] 2 | path = hanabi-learning-environment 3 | url = git@github.com:hengyuan-hu/hanabi-learning-environment.git 4 | [submodule "third_party/pybind11"] 5 | path = third_party/pybind11 6 | url = git@github.com:pybind/pybind11.git 7 | branch = v2.3 8 | -------------------------------------------------------------------------------- /pyhanabi/set_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | def append_sys_path(): 6 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | tube = os.path.join(root, "build", "rela") 8 | if tube not in sys.path: 9 | sys.path.append(tube) 10 | 11 | hanalearn = os.path.join(root, "build") 12 | if hanalearn not in sys.path: 13 | sys.path.append(hanalearn) 14 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | AccessModifierOffset: -1 2 | AllowShortFunctionsOnASingleLine: false 3 | AllowShortIfStatementsOnASingleLine: false 4 | AllowShortLoopsOnASingleLine: false 5 | BinPackParameters: false 6 | BinPackArguments: false 7 | BreakConstructorInitializersBeforeComma: true 8 | PenaltyReturnTypeOnItsOwnLine: 200 9 | PenaltyBreakBeforeFirstCallParameter: 0 10 | PointerBindsToType: true 11 | SpacesBeforeTrailingComments: 2 12 | UseTab: Never 13 | IndentWidth: 2 14 | AlwaysBreakTemplateDeclarations: true 15 | ColumnLimit: 90 16 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # the last agent in --weight_2 is the pre-trained agent itself that was used for continual training 3 | LOAD_MODEL_DIR= 4 | python continual_evaluation.py \ 5 | --weight_1_dir \ 6 | --weight_2 ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 7 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 8 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 9 | --num_player 2 \ 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.12.5 2 | chardet==4.0.0 3 | click==7.1.2 4 | configparser==5.0.1 5 | docker-pycreds==0.4.0 6 | filelock==3.0.12 7 | future==0.18.2 8 | gdown==3.12.2 9 | gitdb==4.0.5 10 | GitPython==3.1.12 11 | idna==2.10 12 | numpy==1.19.5 13 | pandas==1.2.0 14 | pathtools==0.1.2 15 | promise==2.3 16 | protobuf==3.14.0 17 | psutil==5.8.0 18 | PySocks==1.7.1 19 | python-dateutil==2.8.1 20 | pytz==2020.5 21 | PyYAML==5.4 22 | requests==2.25.1 23 | sentry-sdk==0.19.5 24 | shortuuid==1.0.1 25 | six==1.15.0 26 | smmap==3.0.4 27 | subprocess32==3.5.4 28 | torch==1.5.1+cu101 29 | torchvision==0.6.1+cu101 30 | tqdm==4.58.0 31 | wandb==0.10.20 32 | watchdog==0.10.4 33 | -------------------------------------------------------------------------------- /pyhanabi/common_utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class Logger: 6 | def __init__(self, path, mode="w"): 7 | assert mode in {"w", "a"}, "unknown mode for logger %s" % mode 8 | self.terminal = sys.stdout 9 | if not os.path.exists(os.path.dirname(path)): 10 | os.makedirs(os.path.dirname(path)) 11 | if mode == "w" or not os.path.exists(path): 12 | self.log = open(path, "w") 13 | else: 14 | self.log = open(path, "a") 15 | 16 | def write(self, message): 17 | self.terminal.write(message) 18 | self.log.write(message) 19 | self.log.flush() 20 | 21 | def flush(self): 22 | # for python 3 compatibility. 23 | pass 24 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 3.3) 2 | project(hanalearn) 3 | add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | set(CMAKE_CXX_FLAGS 7 | "${CMAKE_CXX_FLAGS} -O3 -Wall -Wextra -Wno-register -fPIC -Wfatal-errors") 8 | 9 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/rela) 10 | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hanabi-learning-environment) 11 | 12 | find_package(Torch REQUIRED) 13 | 14 | pybind11_add_module( 15 | hanalearn 16 | ${CMAKE_CURRENT_SOURCE_DIR}/cpp/hanabi_env.cc 17 | ${CMAKE_CURRENT_SOURCE_DIR}/cpp/pybind.cc 18 | ) 19 | target_link_libraries(hanalearn PUBLIC hanabi) 20 | target_link_libraries(hanalearn PUBLIC rela) 21 | target_include_directories(hanalearn PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) 22 | -------------------------------------------------------------------------------- /pyhanabi/tools/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python selfplay.py \ 3 | --save_dir \ 4 | --num_thread 80 \ 5 | --num_game_per_thread 80 \ 6 | --method iql \ 7 | --pred_weight 0 \ 8 | --sad 0 \ 9 | --act_base_eps 0.1 \ 10 | --act_eps_alpha 7 \ 11 | --lr 6.25e-05 \ 12 | --eps 1.5e-05 \ 13 | --grad_clip 5 \ 14 | --gamma 0.999 \ 15 | --seed 25000 \ 16 | --batchsize 128 \ 17 | --burn_in_frames 10000 \ 18 | --replay_buffer_size 65536 \ 19 | --epoch_len 1000 \ 20 | --priority_exponent 0.9 \ 21 | --priority_weight 0.6 \ 22 | --train_bomb 0 \ 23 | --eval_bomb 0 \ 24 | --num_player 2 \ 25 | --num_fflayer 1 \ 26 | --rnn_type lstm \ 27 | --rnn_hid_dim 512 \ 28 | --num_rnn_layer 2 \ 29 | --shuffle_color 0 \ 30 | --multi_step 3 \ 31 | --act_device cuda:1,cuda:2 \ 32 | -------------------------------------------------------------------------------- /results/scores_data_motivation_rand.csv: -------------------------------------------------------------------------------- 1 | ,iql_2p_5,iql_2p_616,iql_2p_210,iql_2p_700,iql_2p_614,iql_2p_11,iql_2p_6,iql_2p_113,iql_2p_612,iql_2p_618 2 | iql_2p_5,21.883,18.1586,16.732,18.0486,18.1176,10.1126,15.7796,15.0336,18.308,18.999 3 | iql_2p_616,18.159,23.8616,12.046,23.396,20.296,13.7568,13.108,9.4598,21.1536,22.5106 4 | iql_2p_210,16.7912,11.682,23.8912,12.379,10.5546,6.523,6.4594,5.5086,10.8806,12.0396 5 | iql_2p_700,17.841,23.465,12.6986,23.995,20.3232,14.584,13.0032,10.2902,20.7572,22.218 6 | iql_2p_614,18.0426,20.2396,10.0654,20.1002,23.6172,10.771,15.5602,10.9876,22.6794,20.7506 7 | iql_2p_11,9.9138,13.6374,7.1394,14.588,10.971,23.9404,7.6746,7.289,12.4624,13.1434 8 | iql_2p_6,15.8602,12.8124,6.156,13.1458,15.4516,7.0722,23.3894,7.9024,15.7628,16.0992 9 | iql_2p_113,15.8614,9.6156,5.9836,10.2168,10.8778,7.061,7.918,23.9586,11.011,11.774 10 | iql_2p_612,18.5346,20.8924,11.0882,20.8322,22.7274,12.1282,16.166,10.8834,23.5472,21.0722 11 | iql_2p_618,19.097,22.5306,12.1048,22.2638,20.6478,12.8732,16.1214,11.3818,21.0354,23.7726 12 | -------------------------------------------------------------------------------- /pyhanabi/common_utils/assert_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for assertions""" 2 | 3 | 4 | def assert_eq(real, expected): 5 | assert real == expected, "%s (true) vs %s (expected)" % (real, expected) 6 | 7 | 8 | def assert_neq(real, expected): 9 | assert real != expected, "%s (true) vs %s (expected)" % (real, expected) 10 | 11 | 12 | def assert_lt(real, expected): 13 | assert real < expected, "%s (true) vs %s (expected)" % (real, expected) 14 | 15 | 16 | def assert_lteq(real, expected): 17 | assert real <= expected, "%s (true) vs %s (expected)" % (real, expected) 18 | 19 | 20 | def assert_tensor_eq(t1, t2, eps=1e-6): 21 | if t1.size() != t2.size(): 22 | print("Warning: size mismatch", t1.size(), "vs", t2.size()) 23 | return False 24 | 25 | t1 = t1.cpu().numpy() 26 | t2 = t2.cpu().numpy() 27 | diff = abs(t1 - t2) 28 | eq = (diff < eps).all() 29 | if not eq: 30 | import pdb 31 | 32 | pdb.set_trace() 33 | assert eq, (diff < eps).max() 34 | 35 | 36 | def assert_zero_grad(params): 37 | for p in params: 38 | if p.grad is not None: 39 | assert p.grad.sum().item() == 0 40 | -------------------------------------------------------------------------------- /pyhanabi/tools/testing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOAD_MODEL_DIR= 3 | python testing.py \ 4 | --weight_1_dir \ 5 | --weight_2 ${LOAD_MODEL_DIR}/iql_2p_510.pthw ${LOAD_MODEL_DIR}/iql_2p_7.pthw \ 6 | ${LOAD_MODEL_DIR}/iql_op_2p_612.pthw ${LOAD_MODEL_DIR}/iql_op_2p_6140.pthw \ 7 | ${LOAD_MODEL_DIR}/vdn_aux_2p_941.pthw ${LOAD_MODEL_DIR}/vdn_aux_2p_970.pthw \ 8 | ${LOAD_MODEL_DIR}/sad_op_2p_1.pthw ${LOAD_MODEL_DIR}/sad_op_2p_2501.pthw \ 9 | ${LOAD_MODEL_DIR}/sad_aux_op_2p_1.pthw ${LOAD_MODEL_DIR}/sad_aux_op_2p_25001.pthw \ 10 | ${LOAD_MODEL_DIR}/sad_aux_2p_1.pthw ${LOAD_MODEL_DIR}/sad_aux_2p_20001.pthw \ 11 | ${LOAD_MODEL_DIR}/sad_2p_1.pthw ${LOAD_MODEL_DIR}/sad_2p_2006.pthw \ 12 | ${LOAD_MODEL_DIR}/iql_aux_2p_800.pthw ${LOAD_MODEL_DIR}/iql_aux_2p_811.pthw \ 13 | ${LOAD_MODEL_DIR}/vdn_2p_726.pthw ${LOAD_MODEL_DIR}/vdn_2p_740.pthw \ 14 | ${LOAD_MODEL_DIR}/vdn_op_2p_727.pthw ${LOAD_MODEL_DIR}/vdn_op_2p_77111.pthw \ 15 | --num_player 2 \ 16 | -------------------------------------------------------------------------------- /rela/thread_loop.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | namespace rela { 8 | 9 | class ThreadLoop { 10 | public: 11 | ThreadLoop() = default; 12 | 13 | ThreadLoop(const ThreadLoop&) = delete; 14 | ThreadLoop& operator=(const ThreadLoop&) = delete; 15 | 16 | virtual ~ThreadLoop() { 17 | } 18 | 19 | virtual void terminate() { 20 | terminated_ = true; 21 | resume(); 22 | } 23 | 24 | virtual void pause() { 25 | std::lock_guard lk(mPaused_); 26 | paused_ = true; 27 | } 28 | 29 | virtual void resume() { 30 | { 31 | std::lock_guard lk(mPaused_); 32 | paused_ = false; 33 | } 34 | cvPaused_.notify_one(); 35 | } 36 | 37 | virtual void waitUntilResume() { 38 | std::unique_lock lk(mPaused_); 39 | cvPaused_.wait(lk, [this] { return !paused_; }); 40 | } 41 | 42 | virtual bool terminated() { 43 | return terminated_; 44 | } 45 | 46 | virtual bool paused() { 47 | return paused_; 48 | } 49 | 50 | virtual void mainLoop() = 0; 51 | 52 | private: 53 | std::atomic_bool terminated_{false}; 54 | 55 | std::mutex mPaused_; 56 | bool paused_ = false; 57 | std::condition_variable cvPaused_; 58 | }; 59 | } // namespace rela 60 | -------------------------------------------------------------------------------- /rela/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace rela { 10 | 11 | namespace utils { 12 | 13 | inline int getProduct(const std::vector& nums) { 14 | int prod = 1; 15 | for (auto v : nums) { 16 | prod *= v; 17 | } 18 | return prod; 19 | } 20 | 21 | template 22 | inline std::vector pushLeft(T left, const std::vector& vals) { 23 | std::vector vec; 24 | vec.reserve(1 + vals.size()); 25 | vec.push_back(left); 26 | for (auto v : vals) { 27 | vec.push_back(v); 28 | } 29 | return vec; 30 | } 31 | 32 | template 33 | inline void printVector(const std::vector& vec) { 34 | for (const auto& v : vec) { 35 | std::cout << v << ", "; 36 | } 37 | std::cout << std::endl; 38 | } 39 | 40 | template 41 | inline void printMapKey(const T& map) { 42 | for (const auto& name2sth : map) { 43 | std::cout << name2sth.first << ", "; 44 | } 45 | std::cout << std::endl; 46 | } 47 | 48 | template 49 | inline void printMap(const T& map) { 50 | for (const auto& name2sth : map) { 51 | std::cout << name2sth.first << ": " << name2sth.second << std::endl; 52 | } 53 | // std::cout << std::endl; 54 | } 55 | } // namespace utils 56 | } // namespace rela 57 | -------------------------------------------------------------------------------- /rela/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | 3 | 4 | set(CMAKE_CXX_STANDARD 14) 5 | set(CMAKE_CXX_FLAGS 6 | "${CMAKE_CXX_FLAGS} -O3 -Wall -Wextra -Wno-register -fPIC -march=native -Wfatal-errors") 7 | 8 | 9 | # get and append paths for finding dep 10 | execute_process( 11 | COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__), end='')" 12 | OUTPUT_VARIABLE TorchPath 13 | ) 14 | list(APPEND CMAKE_PREFIX_PATH ${TorchPath}) 15 | 16 | 17 | # find packages & third_party 18 | find_package(PythonInterp 3.7 REQUIRED) 19 | find_package(PythonLibs 3.7 REQUIRED) 20 | find_package(Torch REQUIRED) 21 | 22 | # Temp fix for PyTorch 1.5. 23 | set(TORCH_PYTHON_LIBRARIES "${TorchPath}/lib/libtorch_python.so") 24 | 25 | # message(${CMAKE_CURRENT_SOURCE_DIR}/../) 26 | add_subdirectory( 27 | ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/pybind11 third_party/pybind11 28 | ) 29 | 30 | 31 | # # lib for other c++ programs 32 | # add_library(_rela 33 | # transition.cc 34 | # ) 35 | # target_include_directories(_rela PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) 36 | # target_include_directories(_rela PUBLIC ${TORCH_INCLUDE_DIRS}) 37 | # target_include_directories(_rela PUBLIC ${PYTHON_INCLUDE_DIRS}) 38 | # target_link_libraries(_rela PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARIES} ) 39 | 40 | 41 | # python lib 42 | pybind11_add_module(rela SHARED transition.cc pybind.cc) 43 | # target_link_libraries(rela PUBLIC _rela) 44 | target_include_directories(rela PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) 45 | target_include_directories(rela PUBLIC ${TORCH_INCLUDE_DIRS}) 46 | target_include_directories(rela PUBLIC ${PYTHON_INCLUDE_DIRS}) 47 | target_link_libraries(rela PUBLIC ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARIES} ) 48 | -------------------------------------------------------------------------------- /pyhanabi/common_utils/stopwatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from collections import defaultdict 4 | from datetime import datetime 5 | import numpy as np 6 | 7 | 8 | def millis_interval(start, end): 9 | """start and end are datetime instances""" 10 | diff = end - start 11 | millis = diff.days * 24 * 60 * 60 * 1000 12 | millis += diff.seconds * 1000 13 | millis += diff.microseconds / 1000 14 | return millis 15 | 16 | 17 | class Stopwatch: 18 | def __init__(self): 19 | self.last_time = datetime.now() 20 | self.times = defaultdict(list) 21 | self.keys = [] 22 | 23 | def reset(self): 24 | self.last_time = datetime.now() 25 | self.times = defaultdict(list) 26 | self.keys = [] 27 | 28 | def time(self, key): 29 | if key not in self.times: 30 | self.keys.append(key) 31 | self.times[key].append(millis_interval(self.last_time, datetime.now())) 32 | self.last_time = datetime.now() 33 | 34 | def summary(self): 35 | num_elems = -1 36 | total = 0 37 | max_key_len = 0 38 | for k, v in self.times.items(): 39 | if num_elems == -1: 40 | num_elems = len(v) 41 | 42 | assert len(v) == num_elems 43 | total += np.sum(v) 44 | max_key_len = max(max_key_len, len(k)) 45 | 46 | print("@@@Time") 47 | for k in self.keys: 48 | v = self.times[k] 49 | print( 50 | "\t%s: %d MS, %.2f%%" 51 | % (k.ljust(max_key_len), np.mean(v), 100.0 * np.sum(v) / total) 52 | ) 53 | print("@@@total time per iter: %.2f ms" % (float(total) / num_elems)) 54 | self.reset() 55 | -------------------------------------------------------------------------------- /rela/transition.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #pragma once 4 | 5 | #include 6 | 7 | #include "tensor_dict.h" 8 | 9 | namespace rela { 10 | 11 | class FFTransition { 12 | public: 13 | FFTransition() = default; 14 | 15 | FFTransition( 16 | TensorDict& obs, 17 | TensorDict& action, 18 | torch::Tensor& reward, 19 | torch::Tensor& terminal, 20 | torch::Tensor& bootstrap, 21 | TensorDict& nextObs) 22 | : obs(obs) 23 | , action(action) 24 | , reward(reward) 25 | , terminal(terminal) 26 | , bootstrap(bootstrap) 27 | , nextObs(nextObs) { 28 | } 29 | 30 | FFTransition index(int i) const; 31 | 32 | FFTransition padLike() const; 33 | 34 | std::vector toVectorIValue(const torch::Device& device) const; 35 | 36 | TensorDict toDict(); 37 | 38 | static FFTransition makeBatch( 39 | const std::vector& transitions, const std::string& device); 40 | 41 | TensorDict obs; 42 | TensorDict action; 43 | torch::Tensor reward; 44 | torch::Tensor terminal; 45 | torch::Tensor bootstrap; 46 | TensorDict nextObs; 47 | }; 48 | 49 | class RNNTransition { 50 | public: 51 | RNNTransition() = default; 52 | 53 | RNNTransition( 54 | const std::vector& transitions, TensorDict h0, torch::Tensor seqLen); 55 | 56 | RNNTransition index(int i) const; 57 | 58 | static RNNTransition makeBatch( 59 | const std::vector& transitions, const std::string& device); 60 | 61 | TensorDict obs; 62 | TensorDict h0; 63 | TensorDict action; 64 | torch::Tensor reward; 65 | torch::Tensor terminal; 66 | torch::Tensor bootstrap; 67 | torch::Tensor seqLen; 68 | }; 69 | 70 | } // namespace rela 71 | -------------------------------------------------------------------------------- /rela/context.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #pragma once 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "rela/thread_loop.h" 15 | 16 | namespace rela { 17 | 18 | class Context { 19 | public: 20 | Context() 21 | : started_(false) 22 | , numTerminatedThread_(0) { 23 | } 24 | 25 | Context(const Context&) = delete; 26 | Context& operator=(const Context&) = delete; 27 | 28 | ~Context() { 29 | for (auto& v : loops_) { 30 | v->terminate(); 31 | } 32 | for (auto& v : threads_) { 33 | v.join(); 34 | } 35 | } 36 | 37 | int pushThreadLoop(std::shared_ptr env) { 38 | assert(!started_); 39 | loops_.push_back(std::move(env)); 40 | return (int)loops_.size(); 41 | } 42 | 43 | void start() { 44 | for (int i = 0; i < (int)loops_.size(); ++i) { 45 | threads_.emplace_back([this, i]() { 46 | loops_[i]->mainLoop(); 47 | ++numTerminatedThread_; 48 | }); 49 | } 50 | } 51 | 52 | void pause() { 53 | for (auto& v : loops_) { 54 | v->pause(); 55 | } 56 | } 57 | 58 | void resume() { 59 | for (auto& v : loops_) { 60 | v->resume(); 61 | } 62 | } 63 | 64 | void terminate() { 65 | for (auto& v : loops_) { 66 | v->terminate(); 67 | } 68 | } 69 | 70 | bool terminated() { 71 | // std::cout << ">>> " << numTerminatedThread_ << std::endl; 72 | return numTerminatedThread_ == (int)loops_.size(); 73 | } 74 | 75 | private: 76 | bool started_; 77 | std::atomic numTerminatedThread_; 78 | std::vector> loops_; 79 | std::vector threads_; 80 | }; 81 | } // namespace rela 82 | -------------------------------------------------------------------------------- /cpp/pybind.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #include 8 | 9 | #include "cpp/hanabi_env.h" 10 | #include "cpp/thread_loop.h" 11 | 12 | namespace py = pybind11; 13 | 14 | PYBIND11_MODULE(hanalearn, m) { 15 | py::class_>(m, "HanabiEnv") 16 | .def(py::init< 17 | const std::unordered_map&, 18 | const std::vector&, 19 | int, // maxLen 20 | bool, // sad 21 | bool, // shuffleObs 22 | bool, // shuffleColor 23 | bool>()) 24 | .def("feature_size", &HanabiEnv::featureSize) 25 | .def("num_action", &HanabiEnv::numAction) 26 | .def("reset", &HanabiEnv::reset) 27 | .def("step", &HanabiEnv::step) 28 | .def("terminated", &HanabiEnv::terminated) 29 | .def("get_current_player", &HanabiEnv::getCurrentPlayer) 30 | .def("move_is_legal", &HanabiEnv::moveIsLegal) 31 | .def("last_score", &HanabiEnv::lastScore) 32 | .def("hand_feature_size", &HanabiEnv::handFeatureSize) 33 | .def("deck_history", &HanabiEnv::deckHistory) 34 | .def("get_score", &HanabiEnv::getScore) 35 | .def("get_life", &HanabiEnv::getLife) 36 | .def("get_info", &HanabiEnv::getInfo) 37 | .def("get_fireworks", &HanabiEnv::getFireworks) 38 | ; 39 | 40 | py::class_>(m, "HanabiVecEnv") 41 | .def(py::init<>()) 42 | .def("append", &HanabiVecEnv::append, py::keep_alive<1, 2>()) 43 | ; 44 | 45 | py::class_>( 46 | m, "HanabiThreadLoop") 47 | .def(py::init< 48 | std::shared_ptr, 49 | std::shared_ptr, 50 | bool>()) 51 | .def(py::init< 52 | std::vector>, 53 | std::shared_ptr, 54 | bool>()) 55 | ; 56 | } 57 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/MTL_easy_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=128 12 | python multitask.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 17 | --load_partner_model ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 18 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw \ 20 | --num_thread 10 \ 21 | --num_game_per_thread 80 \ 22 | --eval_num_thread 10 \ 23 | --eval_num_game_per_thread 80 \ 24 | --sad 0 \ 25 | --act_base_eps 0.1 \ 26 | --act_eps_alpha 7 \ 27 | --eps 1.5e-05 \ 28 | --grad_clip 5 \ 29 | --gamma 0.999 \ 30 | --seed ${SEED} \ 31 | --initial_lr ${INITIAL_LR} \ 32 | --final_lr 6.25e-05 \ 33 | --lr_gamma 0.2 \ 34 | --dropout_p 0 \ 35 | --sgd_momentum 0.8 \ 36 | --optim_name ${OPTIM_NAME} \ 37 | --batchsize ${BATCH_SIZE} \ 38 | --max_train_steps 1000000000 \ 39 | --max_eval_steps 2500000 \ 40 | --burn_in_frames 50000 \ 41 | --eval_burn_in_frames 1000 \ 42 | --replay_buffer_size 163840 \ 43 | --eval_replay_buffer_size 10000 \ 44 | --epoch_len 200 \ 45 | --priority_exponent 0.9 \ 46 | --priority_weight 0.6 \ 47 | --train_bomb 0 \ 48 | --eval_bomb 0 \ 49 | --eval_epoch_len 50 \ 50 | --eval_method ${EVAL_METHOD} \ 51 | --eval_freq 25 \ 52 | --num_player 2 \ 53 | --rnn_type lstm \ 54 | --num_fflayer 1 \ 55 | --num_rnn_layer 2 \ 56 | --rnn_hid_dim 512 \ 57 | --act_device cuda:1,cuda:2 \ 58 | --shuffle_color 0 \ 59 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/MTL_hard_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=128 12 | python multitask.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 17 | --load_partner_model ${LOAD_MODEL_DIR}/vdn_op_2p_7771.pthw ${LOAD_MODEL_DIR}/vdn_2p_726.pthw \ 18 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_600.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 20 | --num_thread 10 \ 21 | --num_game_per_thread 80 \ 22 | --eval_num_thread 10 \ 23 | --eval_num_game_per_thread 80 \ 24 | --sad 0 \ 25 | --act_base_eps 0.1 \ 26 | --act_eps_alpha 7 \ 27 | --eps 1.5e-05 \ 28 | --grad_clip 5 \ 29 | --gamma 0.999 \ 30 | --seed ${SEED} \ 31 | --initial_lr ${INITIAL_LR} \ 32 | --final_lr 6.25e-05 \ 33 | --lr_gamma 0.2 \ 34 | --dropout_p 0 \ 35 | --sgd_momentum 0.8 \ 36 | --optim_name ${OPTIM_NAME} \ 37 | --batchsize ${BATCH_SIZE} \ 38 | --max_train_steps 1000000000 \ 39 | --max_eval_steps 2500000 \ 40 | --burn_in_frames 50000 \ 41 | --eval_burn_in_frames 1000 \ 42 | --replay_buffer_size 163840 \ 43 | --eval_replay_buffer_size 10000 \ 44 | --epoch_len 200 \ 45 | --priority_exponent 0.9 \ 46 | --priority_weight 0.6 \ 47 | --train_bomb 0 \ 48 | --eval_bomb 0 \ 49 | --eval_epoch_len 50 \ 50 | --eval_method ${EVAL_METHOD} \ 51 | --eval_freq 25 \ 52 | --num_player 2 \ 53 | --rnn_type lstm \ 54 | --num_fflayer 1 \ 55 | --num_rnn_layer 2 \ 56 | --rnn_hid_dim 512 \ 57 | --act_device cuda:1,cuda:2 \ 58 | --shuffle_color 0 \ 59 | -------------------------------------------------------------------------------- /results/sem_data_motivation_rand.csv: -------------------------------------------------------------------------------- 1 | ,iql_2p_5,iql_2p_616,iql_2p_210,iql_2p_700,iql_2p_614,iql_2p_11,iql_2p_6,iql_2p_113,iql_2p_612,iql_2p_618 2 | iql_2p_5,0.06922212218648023,0.1111617254633986,0.1258988292241036,0.10968813795483995,0.09852164253604383,0.14271896947497903,0.10629978724343712,0.12380991159030848,0.09213070715022217,0.09581711642498952 3 | iql_2p_616,0.11107305613874141,0.028003019265786324,0.14306605746996734,0.04154319198135839,0.08620949367674073,0.1394475197054433,0.13057376152964267,0.14385074484339663,0.07417170220508627,0.047529964527653505 4 | iql_2p_210,0.12646327732587037,0.1437584613161952,0.023353640230165403,0.139039389383009,0.12860156985044932,0.12973964004883007,0.11767748437148035,0.11837763812477423,0.1261535125472137,0.13422818768053155 5 | iql_2p_700,0.11298576813032692,0.03915501245051519,0.13921936506104315,0.022776193711856248,0.0828400166103315,0.13358453802742293,0.1301829403262962,0.14558735107144438,0.08040627856081888,0.05163889231964605 6 | iql_2p_614,0.0997857557369788,0.08760295867149694,0.12812612757747732,0.0858367752889168,0.024072657352274176,0.13327157161225345,0.10731297774267565,0.13497247589045702,0.03766328620818954,0.06758165378266501 7 | iql_2p_11,0.14325415844574985,0.13997937079441383,0.1327029522203632,0.13257379529907107,0.1340678626666361,0.02325703265681157,0.12492006631442362,0.12604878341340706,0.12533777263060006,0.13164530864409865 8 | iql_2p_6,0.10447071930450176,0.13236404817018857,0.11645674218352496,0.13050374888101873,0.10823867833635072,0.12282262589604571,0.026368798379903472,0.12747240818310446,0.10581770755407621,0.1109267860888433 9 | iql_2p_113,0.11717592759607241,0.14552596788202443,0.1215579952450681,0.14555150137322528,0.13411999639129132,0.12448942043402723,0.1276730793863765,0.030941512697345613,0.13381874233454744,0.14535509898176946 10 | iql_2p_612,0.08931875820901229,0.08016086606318572,0.12528241756926628,0.07739669651864994,0.038148890521219614,0.12774487446469232,0.10229021849619836,0.13321907103714542,0.02848217744485137,0.06230728233521343 11 | iql_2p_618,0.09365211262966788,0.04661558460429301,0.13311965817263804,0.049076694183695785,0.06864423524229839,0.13204235817342858,0.11063766270127003,0.14569037631909665,0.06256220398930971,0.02612772182950515 12 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/ER_easy_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo ER \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --max_train_steps 200000000 \ 40 | --max_eval_steps 500000 \ 41 | --burn_in_frames 10000 \ 42 | --eval_burn_in_frames 1000 \ 43 | --replay_buffer_size 32768 \ 44 | --eval_replay_buffer_size 10000 \ 45 | --epoch_len 200 \ 46 | --priority_exponent 0.9 \ 47 | --priority_weight 0.6 \ 48 | --train_bomb 0 \ 49 | --eval_bomb 0 \ 50 | --eval_epoch_len 50 \ 51 | --eval_method ${EVAL_METHOD} \ 52 | --eval_freq 25 \ 53 | --num_player 2 \ 54 | --rnn_type lstm \ 55 | --num_fflayer 1 \ 56 | --num_rnn_layer 2 \ 57 | --rnn_hid_dim 512 \ 58 | --act_device cuda:1,cuda:2 \ 59 | --shuffle_color 0 \ 60 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/AGEM_easy_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo AGEM \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --max_train_steps 200000000 \ 40 | --max_eval_steps 500000 \ 41 | --burn_in_frames 10000 \ 42 | --eval_burn_in_frames 1000 \ 43 | --replay_buffer_size 32768 \ 44 | --eval_replay_buffer_size 10000 \ 45 | --epoch_len 200 \ 46 | --priority_exponent 0.9 \ 47 | --priority_weight 0.6 \ 48 | --train_bomb 0 \ 49 | --eval_bomb 0 \ 50 | --eval_epoch_len 50 \ 51 | --eval_method ${EVAL_METHOD} \ 52 | --eval_freq 25 \ 53 | --num_player 2 \ 54 | --rnn_type lstm \ 55 | --num_fflayer 1 \ 56 | --num_rnn_layer 2 \ 57 | --rnn_hid_dim 512 \ 58 | --act_device cuda:1,cuda:2 \ 59 | --shuffle_color 0 \ 60 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/AGEM_hard_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo AGEM \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/vdn_op_2p_7771.pthw ${LOAD_MODEL_DIR}/vdn_2p_726.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_600.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --max_train_steps 200000000 \ 40 | --max_eval_steps 500000 \ 41 | --burn_in_frames 10000 \ 42 | --eval_burn_in_frames 1000 \ 43 | --replay_buffer_size 32768 \ 44 | --eval_replay_buffer_size 10000 \ 45 | --epoch_len 200 \ 46 | --priority_exponent 0.9 \ 47 | --priority_weight 0.6 \ 48 | --train_bomb 0 \ 49 | --eval_bomb 0 \ 50 | --eval_epoch_len 50 \ 51 | --eval_method ${EVAL_METHOD} \ 52 | --eval_freq 25 \ 53 | --num_player 2 \ 54 | --rnn_type lstm \ 55 | --num_fflayer 1 \ 56 | --num_rnn_layer 2 \ 57 | --rnn_hid_dim 512 \ 58 | --act_device cuda:1,cuda:2 \ 59 | --shuffle_color 0 \ 60 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/ER_hard_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo ER \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/vdn_op_2p_7771.pthw ${LOAD_MODEL_DIR}/vdn_2p_726.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_600.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --max_train_steps 200000000 \ 40 | --max_eval_steps 500000 \ 41 | --burn_in_frames 10000 \ 42 | --eval_burn_in_frames 1000 \ 43 | --replay_buffer_size 32768 \ 44 | --eval_replay_buffer_size 10000 \ 45 | --epoch_len 200 \ 46 | --priority_exponent 0.9 \ 47 | --priority_weight 0.6 \ 48 | --train_bomb 0 \ 49 | --eval_bomb 0 \ 50 | --eval_epoch_len 50 \ 51 | --eval_method ${EVAL_METHOD} \ 52 | --eval_freq 25 \ 53 | --num_player 2 \ 54 | --rnn_type lstm \ 55 | --num_fflayer 1 \ 56 | --num_rnn_layer 2 \ 57 | --rnn_hid_dim 512 \ 58 | --act_device cuda:1,cuda:2 \ 59 | --shuffle_color 0 \ 60 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/Naive_easy_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo Naive \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --max_train_steps 200000000 \ 40 | --max_eval_steps 500000 \ 41 | --burn_in_frames 10000 \ 42 | --eval_burn_in_frames 1000 \ 43 | --replay_buffer_size 32768 \ 44 | --eval_replay_buffer_size 10000 \ 45 | --epoch_len 200 \ 46 | --priority_exponent 0.9 \ 47 | --priority_weight 0.6 \ 48 | --train_bomb 0 \ 49 | --eval_bomb 0 \ 50 | --eval_epoch_len 50 \ 51 | --eval_method ${EVAL_METHOD} \ 52 | --eval_freq 25 \ 53 | --num_player 2 \ 54 | --rnn_type lstm \ 55 | --num_fflayer 1 \ 56 | --num_rnn_layer 2 \ 57 | --rnn_hid_dim 512 \ 58 | --act_device cuda:1,cuda:2 \ 59 | --shuffle_color 0 \ 60 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/Naive_hard_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo Naive \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/vdn_op_2p_7771.pthw ${LOAD_MODEL_DIR}/vdn_2p_726.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_600.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --max_train_steps 200000000 \ 40 | --max_eval_steps 500000 \ 41 | --burn_in_frames 10000 \ 42 | --eval_burn_in_frames 1000 \ 43 | --replay_buffer_size 32768 \ 44 | --eval_replay_buffer_size 10000 \ 45 | --epoch_len 200 \ 46 | --priority_exponent 0.9 \ 47 | --priority_weight 0.6 \ 48 | --train_bomb 0 \ 49 | --eval_bomb 0 \ 50 | --eval_epoch_len 50 \ 51 | --eval_method ${EVAL_METHOD} \ 52 | --eval_freq 25 \ 53 | --num_player 2 \ 54 | --rnn_type lstm \ 55 | --num_fflayer 1 \ 56 | --num_rnn_layer 2 \ 57 | --rnn_hid_dim 512 \ 58 | --act_device cuda:1,cuda:2 \ 59 | --shuffle_color 0 \ 60 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/EWC_online_easy_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo EWC \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --online 1 \ 40 | --ewc_lambda 50000 \ 41 | --ewc_gamma 1 \ 42 | --max_train_steps 200000000 \ 43 | --max_eval_steps 500000 \ 44 | --burn_in_frames 10000 \ 45 | --eval_burn_in_frames 1000 \ 46 | --replay_buffer_size 32768 \ 47 | --eval_replay_buffer_size 10000 \ 48 | --epoch_len 200 \ 49 | --priority_exponent 0.9 \ 50 | --priority_weight 0.6 \ 51 | --train_bomb 0 \ 52 | --eval_bomb 0 \ 53 | --eval_epoch_len 50 \ 54 | --eval_method ${EVAL_METHOD} \ 55 | --eval_freq 25 \ 56 | --num_player 2 \ 57 | --rnn_type lstm \ 58 | --num_fflayer 1 \ 59 | --num_rnn_layer 2 \ 60 | --rnn_hid_dim 512 \ 61 | --act_device cuda:1,cuda:2 \ 62 | --shuffle_color 0 \ 63 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/EWC_offline_easy_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo EWC \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/iql_2p_310.pthw ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_710.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_op_2p_729.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --online 0 \ 40 | --ewc_lambda 50000 \ 41 | --ewc_gamma 1 \ 42 | --max_train_steps 200000000 \ 43 | --max_eval_steps 500000 \ 44 | --burn_in_frames 10000 \ 45 | --eval_burn_in_frames 1000 \ 46 | --replay_buffer_size 32768 \ 47 | --eval_replay_buffer_size 10000 \ 48 | --epoch_len 200 \ 49 | --priority_exponent 0.9 \ 50 | --priority_weight 0.6 \ 51 | --train_bomb 0 \ 52 | --eval_bomb 0 \ 53 | --eval_epoch_len 50 \ 54 | --eval_method ${EVAL_METHOD} \ 55 | --eval_freq 25 \ 56 | --num_player 2 \ 57 | --rnn_type lstm \ 58 | --num_fflayer 1 \ 59 | --num_rnn_layer 2 \ 60 | --rnn_hid_dim 512 \ 61 | --act_device cuda:1,cuda:2 \ 62 | --shuffle_color 0 \ 63 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/EWC_offline_hard_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo EWC \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/vdn_op_2p_7771.pthw ${LOAD_MODEL_DIR}/vdn_2p_726.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_600.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --online 0 \ 40 | --ewc_lambda 50000 \ 41 | --ewc_gamma 1 \ 42 | --max_train_steps 200000000 \ 43 | --max_eval_steps 500000 \ 44 | --burn_in_frames 10000 \ 45 | --eval_burn_in_frames 1000 \ 46 | --replay_buffer_size 32768 \ 47 | --eval_replay_buffer_size 10000 \ 48 | --epoch_len 200 \ 49 | --priority_exponent 0.9 \ 50 | --priority_weight 0.6 \ 51 | --train_bomb 0 \ 52 | --eval_bomb 0 \ 53 | --eval_epoch_len 50 \ 54 | --eval_method ${EVAL_METHOD} \ 55 | --eval_freq 25 \ 56 | --num_player 2 \ 57 | --rnn_type lstm \ 58 | --num_fflayer 1 \ 59 | --num_rnn_layer 2 \ 60 | --rnn_hid_dim 512 \ 61 | --act_device cuda:1,cuda:2 \ 62 | --shuffle_color 0 \ 63 | -------------------------------------------------------------------------------- /pyhanabi/tools/continual_learning_scripts/EWC_online_hard_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ## specify optim_name to be either Adam or SGD. 3 | ## specify --decay_lr for learning rate decay. 4 | ## dropout_p should be 0 for no dropout. dropout_p is drop probability. 5 | OPTIM_NAME="SGD" 6 | SEED=10 7 | EVAL_METHOD="few_shot" 8 | LOAD_MODEL_DIR= 9 | SAVE_DIR= 10 | INITIAL_LR=0.02 11 | BATCH_SIZE=32 12 | python continual_training.py \ 13 | --save_dir ${SAVE_DIR} \ 14 | --load_model_dir ${LOAD_MODEL_DIR} \ 15 | --method iql \ 16 | --ll_algo EWC \ 17 | --load_learnable_model ${LOAD_MODEL_DIR}/iql_2p_210.pthw \ 18 | --load_partner_model ${LOAD_MODEL_DIR}/vdn_op_2p_7771.pthw ${LOAD_MODEL_DIR}/vdn_2p_726.pthw \ 19 | ${LOAD_MODEL_DIR}/vdn_2p_7140.pthw ${LOAD_MODEL_DIR}/iql_op_2p_600.pthw \ 20 | ${LOAD_MODEL_DIR}/vdn_2p_720.pthw \ 21 | --num_thread 10 \ 22 | --num_game_per_thread 80 \ 23 | --eval_num_thread 10 \ 24 | --eval_num_game_per_thread 80 \ 25 | --sad 0 \ 26 | --act_base_eps 0.1 \ 27 | --act_eps_alpha 7 \ 28 | --eps 1.5e-05 \ 29 | --grad_clip 5 \ 30 | --gamma 0.999 \ 31 | --seed ${SEED} \ 32 | --initial_lr ${INITIAL_LR} \ 33 | --final_lr 6.25e-05 \ 34 | --lr_gamma 0.2 \ 35 | --dropout_p 0 \ 36 | --sgd_momentum 0.8 \ 37 | --optim_name ${OPTIM_NAME} \ 38 | --batchsize ${BATCH_SIZE} \ 39 | --online 1 \ 40 | --ewc_lambda 50000 \ 41 | --ewc_gamma 1 \ 42 | --max_train_steps 200000000 \ 43 | --max_eval_steps 500000 \ 44 | --burn_in_frames 10000 \ 45 | --eval_burn_in_frames 1000 \ 46 | --replay_buffer_size 32768 \ 47 | --eval_replay_buffer_size 10000 \ 48 | --epoch_len 200 \ 49 | --priority_exponent 0.9 \ 50 | --priority_weight 0.6 \ 51 | --train_bomb 0 \ 52 | --eval_bomb 0 \ 53 | --eval_epoch_len 50 \ 54 | --eval_method ${EVAL_METHOD} \ 55 | --eval_freq 25 \ 56 | --num_player 2 \ 57 | --rnn_type lstm \ 58 | --num_fflayer 1 \ 59 | --num_rnn_layer 2 \ 60 | --rnn_hid_dim 512 \ 61 | --act_device cuda:1,cuda:2 \ 62 | --shuffle_color 0 \ 63 | -------------------------------------------------------------------------------- /pyhanabi/tools/eval_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Requires only 1 GPU. 3 | Sample usage: 4 | python tools/eval_model.py --weight_1_dir ../models/iql_2p --num_player 2 5 | It dumps a .csv as output 6 | """ 7 | import argparse 8 | import os 9 | import sys 10 | import json 11 | import glob 12 | 13 | lib_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | sys.path.append(lib_path) 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import torch 19 | import r2d2_gru as r2d2_gru 20 | import r2d2_lstm as r2d2_lstm 21 | import utils 22 | from eval import evaluate_legacy_model 23 | 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--weight_1_dir", default=None, type=str, required=True) 28 | parser.add_argument("--num_player", default=None, type=int, required=True) 29 | parser.add_argument("--is_rand", action="store_true", default=True) 30 | 31 | args = parser.parse_args() 32 | 33 | assert os.path.exists(args.weight_1_dir) 34 | weight_1 = [] 35 | weight_1 = glob.glob(f"{args.weight_1_dir}/*.pthw") 36 | 37 | scores_arr = np.zeros([len(weight_1), len(weight_1)]) 38 | sem_arr = np.zeros([len(weight_1), len(weight_1)]) 39 | ag1_names = [] 40 | 41 | for ag1 in weight_1: 42 | ag1_names.append(ag1.split("/")[-1].split(".")[0]) 43 | 44 | for ag1_idx, ag1 in enumerate(weight_1): 45 | for ag2_idx, ag2 in enumerate(weight_1): 46 | ## we are doing cross player, the 2 players use different weights 47 | print("Current game is ", str(ag1_idx) + " vs " + str(ag2_idx)) 48 | weight_files = [ag1, ag2] 49 | agent_args = {} 50 | # # fast evaluation for 5k games 51 | mean, sem, _ = evaluate_legacy_model( 52 | weight_files, 53 | 1000, 54 | 1, 55 | 0, 56 | agent_args, 57 | args, 58 | num_run=5, 59 | gen_cross_play=True, 60 | ) 61 | scores_arr[ag1_idx, ag2_idx] = mean 62 | sem_arr[ag1_idx, ag2_idx] = sem 63 | np.save("scores_data", scores_arr) 64 | np.save("sem_data", sem_arr) 65 | 66 | scores_df = pd.DataFrame(data=scores_arr, index=ag1_names, columns=ag1_names) 67 | sem_df = pd.DataFrame(data=sem_arr, index=ag1_names, columns=ag1_names) 68 | 69 | scores_df.to_csv("scores_data.csv") 70 | sem_df.to_csv("sem_data.csv") 71 | -------------------------------------------------------------------------------- /pyhanabi/common_utils/saver.py: -------------------------------------------------------------------------------- 1 | # model saver that saves top-k performing model 2 | import os 3 | import torch 4 | 5 | 6 | class TopkSaver: 7 | def __init__(self, save_dir, topk): 8 | self.save_dir = save_dir 9 | self.topk = topk 10 | self.worse_perf = -float("inf") 11 | self.worse_perf_idx = 0 12 | self.perfs = [self.worse_perf] 13 | 14 | if not os.path.exists(save_dir): 15 | os.makedirs(save_dir) 16 | 17 | def save(self, model, state_dict, save_latest=False, force_save_name=None): 18 | # print("worst perf idx inside save is ", self.worse_perf_idx) 19 | if force_save_name is not None: 20 | model_name = "%s.pthm" % force_save_name 21 | weight_name = "%s.pthw" % force_save_name 22 | if model is not None: 23 | model.save(os.path.join(self.save_dir, model_name)) 24 | if state_dict is not None: 25 | torch.save(state_dict, os.path.join(self.save_dir, weight_name)) 26 | 27 | if save_latest: 28 | model_name = "latest.pthm" 29 | weight_name = "latest.pthw" 30 | if model is not None: 31 | model.save(os.path.join(self.save_dir, model_name)) 32 | if state_dict is not None: 33 | torch.save(state_dict, os.path.join(self.save_dir, weight_name)) 34 | 35 | # if perf <= self.worse_perf: 36 | # # print('i am sorry') 37 | # # [print(i) for i in self.perfs] 38 | # return False 39 | 40 | # model_name = "model%i.pthm" % self.worse_perf_idx 41 | # weight_name = "model%i.pthw" % self.worse_perf_idx 42 | # if model is not None: 43 | # model.save(os.path.join(self.save_dir, model_name)) 44 | # if state_dict is not None: 45 | # torch.save(state_dict, os.path.join(self.save_dir, weight_name)) 46 | 47 | # if len(self.perfs) < self.topk: 48 | # self.perfs.append(perf) 49 | # return True 50 | 51 | # # neesd to replace 52 | # self.perfs[self.worse_perf_idx] = perf 53 | # worse_perf = self.perfs[0] 54 | # worse_perf_idx = 0 55 | # for i, perf in enumerate(self.perfs): 56 | # if perf < worse_perf: 57 | # worse_perf = perf 58 | # worse_perf_idx = i 59 | 60 | # self.worse_perf = worse_perf 61 | # self.worse_perf_idx = worse_perf_idx 62 | return True 63 | -------------------------------------------------------------------------------- /cpp/thread_loop.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #pragma once 8 | 9 | #include "rela/r2d2_actor.h" 10 | #include "rela/thread_loop.h" 11 | 12 | using HanabiVecEnv = rela::VectorEnv; 13 | 14 | class HanabiThreadLoop : public rela::ThreadLoop { 15 | public: 16 | HanabiThreadLoop( 17 | std::shared_ptr actor, 18 | std::shared_ptr vecEnv, 19 | bool eval) 20 | : actors_({std::move(actor)}) 21 | , vecEnv_(std::move(vecEnv)) 22 | , eval_(eval) { 23 | assert(actors_.size() >= 1); 24 | if (eval_) { 25 | assert(vecEnv_->size() == 1); 26 | } 27 | } 28 | 29 | HanabiThreadLoop( 30 | std::vector> actors, 31 | std::shared_ptr vecEnv, 32 | bool eval) 33 | : actors_(std::move(actors)) 34 | , vecEnv_(std::move(vecEnv)) 35 | , eval_(eval) { 36 | assert(actors_.size() >= 1); 37 | if (eval_) { 38 | assert(vecEnv_->size() == 1); 39 | } 40 | } 41 | 42 | void mainLoop() final { 43 | rela::TensorDict obs = {}; 44 | torch::Tensor r; 45 | torch::Tensor t; 46 | while (!terminated()) { 47 | obs = vecEnv_->reset(obs); 48 | while (!vecEnv_->anyTerminated()) { 49 | if (terminated()) { 50 | break; 51 | } 52 | 53 | if (paused()) { 54 | waitUntilResume(); 55 | } 56 | 57 | rela::TensorDict reply; 58 | if (actors_.size() == 1) { 59 | reply = actors_[0]->act(obs); 60 | } else { 61 | std::vector replyVec; 62 | for (int i = 0; i < (int)actors_.size(); ++i) { 63 | auto input = rela::tensor_dict::narrow(obs, 1, i, 1, true); 64 | // if (!logFile_.empty()) { 65 | // logState(*file, input); 66 | // } 67 | auto rep = actors_[i]->act(input); 68 | replyVec.push_back(rep); 69 | } 70 | reply = rela::tensor_dict::stack(replyVec, 1); 71 | } 72 | std::tie(obs, r, t) = vecEnv_->step(reply); 73 | 74 | if (eval_) { 75 | continue; 76 | } 77 | 78 | for (int i = 0; i < (int)actors_.size(); ++i) { 79 | actors_[i]->postAct(r, t); 80 | } 81 | } 82 | 83 | // eval only runs for one game 84 | if (eval_) { 85 | break; 86 | } 87 | } 88 | } 89 | 90 | private: 91 | std::vector> actors_; 92 | std::shared_ptr vecEnv_; 93 | const bool eval_; 94 | }; 95 | -------------------------------------------------------------------------------- /rela/env.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | #include "rela/tensor_dict.h" 9 | #include "rela/utils.h" 10 | 11 | namespace rela { 12 | 13 | class Env { 14 | public: 15 | Env() = default; 16 | 17 | virtual ~Env() { 18 | } 19 | 20 | virtual TensorDict reset() = 0; 21 | 22 | // return 'obs', 'reward', 'terminal' 23 | virtual std::tuple step(const TensorDict& action) = 0; 24 | 25 | virtual bool terminated() const = 0; 26 | }; 27 | 28 | // a "container" as if it is a vector of envs 29 | template < 30 | typename EnvType, 31 | typename = std::enable_if_t::value>> 32 | class VectorEnv { 33 | public: 34 | VectorEnv() = default; 35 | 36 | virtual ~VectorEnv() { 37 | } 38 | 39 | void append(std::shared_ptr env) { 40 | envs_.push_back(std::move(env)); 41 | } 42 | 43 | int size() const { 44 | return envs_.size(); 45 | } 46 | 47 | // reset envs that have reached end of terminal 48 | virtual TensorDict reset(const TensorDict& input) { 49 | std::vector batch; 50 | for (size_t i = 0; i < envs_.size(); i++) { 51 | if (envs_[i]->terminated()) { 52 | TensorDict obs = envs_[i]->reset(); 53 | batch.push_back(obs); 54 | } else { 55 | assert(!input.empty()); 56 | batch.push_back(tensor_dict::index(input, i)); 57 | } 58 | } 59 | return tensor_dict::stack(batch, 0); 60 | } 61 | 62 | // return 'obs', 'reward', 'terminal' 63 | // obs: [num_envs, obs_dims] 64 | // reward: float32 tensor [num_envs] 65 | // terminal: bool tensor [num_envs] 66 | virtual std::tuple step( 67 | const TensorDict& action) { 68 | std::vector vObs; 69 | std::vector vReward(envs_.size()); 70 | std::vector vTerminal(envs_.size()); 71 | for (size_t i = 0; i < envs_.size(); i++) { 72 | TensorDict obs; 73 | float reward; 74 | bool terminal; 75 | 76 | auto a = tensor_dict::index(action, i); 77 | std::tie(obs, reward, terminal) = envs_[i]->step(a); 78 | 79 | vObs.push_back(obs); 80 | vReward[i] = reward; 81 | vTerminal[i] = (float)terminal; 82 | } 83 | auto batchObs = tensor_dict::stack(vObs, 0); 84 | auto batchReward = torch::tensor(vReward); 85 | auto batchTerminal = torch::tensor(vTerminal).to(torch::kBool); 86 | return std::make_tuple(batchObs, batchReward, batchTerminal); 87 | } 88 | 89 | virtual bool anyTerminated() const { 90 | for (size_t i = 0; i < envs_.size(); i++) { 91 | if (envs_[i]->terminated()) { 92 | return true; 93 | } 94 | } 95 | return false; 96 | } 97 | 98 | virtual bool allTerminated() const { 99 | for (size_t i = 0; i < envs_.size(); i++) { 100 | if (!envs_[i]->terminated()) 101 | return false; 102 | } 103 | return true; 104 | } 105 | 106 | private: 107 | std::vector> envs_; 108 | }; 109 | } // namespace rela 110 | -------------------------------------------------------------------------------- /pyhanabi/common_utils/multi_counter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict, Counter 3 | from datetime import datetime 4 | 5 | # from tensorboardX import SummaryWriter 6 | 7 | 8 | class ValueStats: 9 | def __init__(self, name=None): 10 | self.name = name 11 | self.reset() 12 | 13 | def feed(self, v): 14 | self.summation += v 15 | if v > self.max_value: 16 | self.max_value = v 17 | self.max_idx = self.counter 18 | if v < self.min_value: 19 | self.min_value = v 20 | self.min_idx = self.counter 21 | 22 | self.counter += 1 23 | 24 | def mean(self): 25 | if self.counter == 0: 26 | print("Counter %s is 0" % self.name) 27 | assert False 28 | return self.summation / self.counter 29 | 30 | def summary(self, info=None): 31 | info = "" if info is None else info 32 | name = "" if self.name is None else self.name 33 | if self.counter > 0: 34 | # try: 35 | return "%s%s[%4d]: avg: %8.4f, min: %8.4f[%4d], max: %8.4f[%4d]" % ( 36 | info, 37 | name, 38 | self.counter, 39 | self.summation / self.counter, 40 | self.min_value, 41 | self.min_idx, 42 | self.max_value, 43 | self.max_idx, 44 | ) 45 | # except BaseException: 46 | # return "%s%s[Err]:" % (info, name) 47 | else: 48 | return "%s%s[0]" % (info, name) 49 | 50 | def reset(self): 51 | self.counter = 0 52 | self.summation = 0.0 53 | self.max_value = -1e38 54 | self.min_value = 1e38 55 | self.max_idx = None 56 | self.min_idx = None 57 | 58 | 59 | class MultiCounter: 60 | def __init__(self, root, verbose=False): 61 | # TODO: rethink counters 62 | self.last_time = None 63 | self.verbose = verbose 64 | self.counts = Counter() 65 | self.stats = defaultdict(lambda: ValueStats()) 66 | self.total_count = 0 67 | self.max_key_len = 0 68 | # if root is not None: 69 | # self.tb_writer = SummaryWriter(os.path.join(root, "stat.tb")) 70 | # else: 71 | # self.tb_writer = None 72 | 73 | def __getitem__(self, key): 74 | if len(key) > self.max_key_len: 75 | self.max_key_len = len(key) 76 | 77 | if key in self.counts: 78 | return self.counts[key] 79 | 80 | return self.stats[key] 81 | 82 | def inc(self, key): 83 | if self.verbose: 84 | print("[MultiCounter]: %s" % key) 85 | self.counts[key] += 1 86 | self.total_count += 1 87 | if self.last_time is None: 88 | self.last_time = datetime.now() 89 | 90 | def reset(self): 91 | for k in self.stats.keys(): 92 | self.stats[k].reset() 93 | 94 | self.counts = Counter() 95 | self.total_count = 0 96 | self.last_time = datetime.now() 97 | 98 | def time_elapsed(self): 99 | return (datetime.now() - self.last_time).total_seconds() 100 | 101 | def summary(self, global_counter): 102 | assert self.last_time is not None 103 | time_elapsed = (datetime.now() - self.last_time).total_seconds() 104 | print("[%d] Time spent = %.2f s" % (global_counter, time_elapsed)) 105 | 106 | for key, count in self.counts.items(): 107 | print("%s: %d/%d" % (key, count, self.total_count)) 108 | 109 | for k in sorted(self.stats.keys()): 110 | v = self.stats[k] 111 | info = str(global_counter) + ":" + k 112 | print(v.summary(info=info.ljust(self.max_key_len + 4))) 113 | 114 | # if self.tb_writer is not None: 115 | # self.tb_writer.add_scalar(k, v.mean(), global_counter) 116 | -------------------------------------------------------------------------------- /rela/batch_runner.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "rela/tensor_dict.h" 13 | #include "rela/batcher.h" 14 | 15 | namespace rela { 16 | 17 | class BatchRunner { 18 | public: 19 | BatchRunner(py::object pyModel, 20 | const std::string& device, 21 | int maxBatchsize, 22 | const std::vector& methods) 23 | : pyModel_(pyModel) 24 | , jitModel_(pyModel_.attr("_c").cast()) 25 | , device_(torch::Device(device)) 26 | , batchsize_(maxBatchsize) 27 | , methods_(methods) { 28 | } 29 | 30 | BatchRunner(const BatchRunner&) = delete; 31 | BatchRunner& operator=(const BatchRunner&) = delete; 32 | 33 | ~BatchRunner() { 34 | stop(); 35 | } 36 | 37 | std::shared_ptr call(const std::string& method, const TensorDict& t, int* slot) { 38 | auto batcherIt = batchers_.find(method); 39 | if (batcherIt == batchers_.end()) { 40 | std::cerr << "Error: Cannot find method: " << method << std::endl; 41 | assert(false); 42 | } 43 | return batcherIt->second->send(t, slot); 44 | } 45 | 46 | void start() { 47 | if (batchers_.empty()) { 48 | for (auto& name : methods_) { 49 | batchers_.emplace(name, std::make_unique(batchsize_)); 50 | } 51 | } else { 52 | for (auto& kv : batchers_) { 53 | kv.second->reset(); 54 | } 55 | } 56 | 57 | for (auto& kv : batchers_) { 58 | threads_.emplace_back(&BatchRunner::runnerLoop, this, kv.first); 59 | } 60 | } 61 | 62 | void stop() { 63 | for (auto& kv : batchers_) { 64 | kv.second->exit(); 65 | } 66 | // batchers_.clear(); 67 | 68 | for (auto& v : threads_) { 69 | v.join(); 70 | } 71 | threads_.clear(); 72 | } 73 | 74 | void updateModel(py::object agent) { 75 | std::lock_guard lk(mtxUpdate_); 76 | pyModel_.attr("load_state_dict")(agent.attr("state_dict")()); 77 | } 78 | 79 | const torch::jit::script::Module& jitModel() { 80 | return *jitModel_; 81 | } 82 | 83 | private: 84 | void runnerLoop(const std::string& method) { 85 | auto batcherIt = batchers_.find(method); 86 | if (batcherIt == batchers_.end()) { 87 | std::cerr << "Error: RunnerLoop, Cannot find method: " << method << std::endl; 88 | assert(false); 89 | } 90 | auto& batcher = *(batcherIt->second); 91 | 92 | while(!batcher.terminated()) { 93 | auto batch = batcher.get(); 94 | if (batch.empty()) { 95 | assert(batcher.terminated()); 96 | break; 97 | } 98 | 99 | { 100 | std::lock_guard lk(mtxDevice_); 101 | 102 | torch::NoGradGuard ng; 103 | std::vector input; 104 | input.push_back(tensor_dict::toIValue(batch, device_)); 105 | torch::jit::IValue output; 106 | { 107 | std::lock_guard lk(mtxUpdate_); 108 | output = jitModel_->get_method(method)(input); 109 | } 110 | batcher.set(tensor_dict::fromIValue(output, torch::kCPU, true)); 111 | } 112 | } 113 | } 114 | 115 | py::object pyModel_; 116 | torch::jit::script::Module* const jitModel_; 117 | const torch::Device device_; 118 | // const int batchsize_; 119 | int batchsize_; 120 | const std::vector methods_; 121 | 122 | // ideally this mutex should be 1 per device, thus global 123 | std::mutex mtxDevice_; 124 | std::mutex mtxUpdate_; 125 | bool updateDue_; 126 | py::object stateDict_; 127 | 128 | // std::unordered_map batchers_; 129 | std::map> batchers_; 130 | std::vector threads_; 131 | }; 132 | } 133 | -------------------------------------------------------------------------------- /pyhanabi/testing.py: -------------------------------------------------------------------------------- 1 | """ Evaluating all the checkpoints saved periodically during train args.eval_freq 2 | Requires only 1 GPU. 3 | Sample usage: 4 | python testing.py 5 | --weight_1_dir 6 | --weight_2 i.e a.pthw b.pthw ... 7 | --num_player 2 8 | note the last arg of --weight_2 is the self-play agent that is the agent that was being trained in continual fashion... 9 | """ 10 | 11 | import argparse 12 | import os 13 | import sys 14 | import glob 15 | import wandb 16 | import json 17 | 18 | lib_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 19 | sys.path.append(lib_path) 20 | 21 | import numpy as np 22 | import torch 23 | import utils 24 | from eval import evaluate_legacy_model 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--weight_1_dir", default=None, type=str, required=True) 30 | parser.add_argument("--weight_2", default=None, type=str, nargs="+", required=True) 31 | parser.add_argument("--is_rand", action="store_true", default=True) 32 | parser.add_argument("--num_player", default=None, type=int, required=True) 33 | args = parser.parse_args() 34 | 35 | ## Note: This assumes that we have access to configuration file 36 | ## during continual training in the models directory specifying architecture details 37 | ## like type of RNN, num of RNN layers etc. 38 | ## else to evaluate models otherwise created, please specify these in 39 | cont_train_args_txt = glob.glob(f"{args.weight_1_dir}/*.txt") 40 | with open(cont_train_args_txt[0], "r") as f: 41 | agent_args = {**json.load(f)} 42 | save_dir = agent_args["save_dir"].split("/")[-1] 43 | exp_name = f"test_{save_dir}" 44 | wandb.init(project="Lifelong_Hanabi_project", name=exp_name) 45 | wandb.config.update(agent_args) 46 | 47 | assert os.path.exists(args.weight_1_dir) 48 | weight_1 = [] 49 | weight_1 = glob.glob(f"{args.weight_1_dir}/*.pthw") 50 | weight_1.sort(key=os.path.getmtime) 51 | 52 | ## check if everything in weights_2 exist 53 | for ag2 in args.weight_2: 54 | assert os.path.exists(ag2) 55 | 56 | for ag1_idx, ag1 in enumerate(weight_1): 57 | ag1_name = ag1.split("/")[-1].split("_")[-1] 58 | act_epoch_cnt = int(ag1.split("/")[-1].split("_")[1][5:]) 59 | 60 | ### this is for different zero-shot evaluations... 61 | if ag1_name == "shot.pthw": 62 | for fixed_agent_idx in range(len(args.weight_2)): 63 | weight_files = [ag1, args.weight_2[fixed_agent_idx]] 64 | mean_score, sem, perfect_rate = evaluate_legacy_model( 65 | weight_files, 66 | 1000, 67 | 1, 68 | 0, 69 | agent_args, 70 | args, 71 | num_run=5, 72 | ) 73 | wandb.log( 74 | { 75 | "epoch_zeroshot": act_epoch_cnt, 76 | "final_eval_score_zeroshot_" + str(fixed_agent_idx): mean_score, 77 | "perfect_zeroshot_" + str(fixed_agent_idx): perfect_rate, 78 | "sem_zeroshot_" + str(fixed_agent_idx): sem, 79 | } 80 | ) 81 | else: 82 | ## for different few shot evaluations ... 83 | for i in range(len(args.weight_2)): 84 | if ag1_name == f"{i}.pthw": 85 | weight_files = [ag1, args.weight_2[i]] 86 | 87 | mean_score, sem, perfect_rate = evaluate_legacy_model( 88 | weight_files, 89 | 1000, 90 | 1, 91 | 0, 92 | agent_args, 93 | args, 94 | num_run=5, 95 | ) 96 | wandb.log( 97 | { 98 | "epoch_fewshot": act_epoch_cnt, 99 | "final_eval_score_fewshot_" + ag1_name.split(".")[0]: mean_score, 100 | "perfect_fewshot_" + ag1_name.split(".")[0]: perfect_rate, 101 | "sem_fewshot_" + ag1_name.split(".")[0]: sem, 102 | } 103 | ) 104 | -------------------------------------------------------------------------------- /rela/pybind.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #include 4 | #include 5 | 6 | #include "rela/context.h" 7 | #include "rela/env.h" 8 | #include "rela/prioritized_replay.h" 9 | #include "rela/r2d2_actor.h" 10 | #include "rela/thread_loop.h" 11 | #include "rela/transition.h" 12 | 13 | namespace py = pybind11; 14 | using namespace rela; 15 | 16 | PYBIND11_MODULE(rela, m) { 17 | py::class_>(m, "FFTransition") 18 | .def_readwrite("obs", &FFTransition::obs) 19 | .def_readwrite("action", &FFTransition::action) 20 | .def_readwrite("reward", &FFTransition::reward) 21 | .def_readwrite("terminal", &FFTransition::terminal) 22 | .def_readwrite("bootstrap", &FFTransition::bootstrap) 23 | .def_readwrite("next_obs", &FFTransition::nextObs); 24 | 25 | py::class_>(m, "RNNTransition") 26 | .def_readwrite("obs", &RNNTransition::obs) 27 | .def_readwrite("h0", &RNNTransition::h0) 28 | .def_readwrite("action", &RNNTransition::action) 29 | .def_readwrite("reward", &RNNTransition::reward) 30 | .def_readwrite("terminal", &RNNTransition::terminal) 31 | .def_readwrite("bootstrap", &RNNTransition::bootstrap) 32 | .def_readwrite("seq_len", &RNNTransition::seqLen); 33 | 34 | // py::class_>( 35 | // m, "FFPrioritizedReplay") 36 | // .def(py::init()) 41 | // .def("size", &FFPrioritizedReplay::size) 42 | // .def("num_add", &FFPrioritizedReplay::numAdd) 43 | // .def("sample", &FFPrioritizedReplay::sample) 44 | // .def("update_priority", &FFPrioritizedReplay::updatePriority); 45 | 46 | py::class_>( 47 | m, "RNNPrioritizedReplay") 48 | .def(py::init< 49 | int, // capacity, 50 | int, // seed, 51 | float, // alpha, priority exponent 52 | float, // beta, importance sampling exponent 53 | int>()) 54 | .def("size", &RNNPrioritizedReplay::size) 55 | .def("num_add", &RNNPrioritizedReplay::numAdd) 56 | .def("sample", &RNNPrioritizedReplay::sample) 57 | .def("update_priority", &RNNPrioritizedReplay::updatePriority) 58 | .def("slice", &RNNPrioritizedReplay::slice) 59 | .def("get", &RNNPrioritizedReplay::get); 60 | 61 | py::class_>(m, "ThreadLoop"); 62 | 63 | py::class_(m, "Context") 64 | .def(py::init<>()) 65 | .def("push_env_thread", &Context::pushThreadLoop, py::keep_alive<1, 2>()) 66 | .def("start", &Context::start) 67 | .def("pause", &Context::pause) 68 | .def("resume", &Context::resume) 69 | .def("terminate", &Context::terminate) 70 | .def("terminated", &Context::terminated); 71 | 72 | py::class_>(m, "R2D2Actor") 73 | .def(py::init< 74 | std::shared_ptr, // runner 75 | int, // multiStep 76 | int, // batchsize 77 | float, // gamma 78 | float, // eta 79 | int, // seqLen 80 | int, // numPlayer 81 | std::shared_ptr>()) // replayBuffer 82 | .def(py::init, int>()) // evaluation mode 83 | .def("num_act", &R2D2Actor::numAct) 84 | ; 85 | 86 | py::class_>(m, "BatchRunner") 87 | .def(py::init&>()) 88 | .def("start", &BatchRunner::start) 89 | .def("stop", &BatchRunner::stop) 90 | .def("update_model", &BatchRunner::updateModel) 91 | ; 92 | 93 | m.def("aggregate_priority", &aggregatePriority); 94 | } 95 | -------------------------------------------------------------------------------- /pyhanabi/ewc.py: -------------------------------------------------------------------------------- 1 | # Implements EWC loss 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class EWC(nn.Module): 8 | """ 9 | Estimates the fisher matrix and calculates ewc loss based on that. 10 | Args: 11 | args.ewc_gamma (online EWC): decay-term for old tasks' contribution to quadratic term 12 | args.online: Bool "online" (=single quadratic term) or "offline" (=quadratic term per task) EWC 13 | args.batchsize: int 14 | args.pred_weight: 0.0 (Auxilary task loss coefficient) 15 | args.train_device: cuda:0 16 | Returns: 17 | float: ewc loss 18 | """ 19 | 20 | def __init__(self, args): 21 | super().__init__() 22 | self.ewc_gamma = args.ewc_gamma 23 | self.online = args.online 24 | self.batchsize = args.batchsize 25 | self.pred_weight = args.pred_weight 26 | self.train_device = args.train_device 27 | 28 | def forward(self, x): 29 | pass 30 | 31 | # ----------------- EWC-specifc functions -----------------# 32 | 33 | def estimate_fisher(self, learnable_agent, batch, weight, stat, task_idx): 34 | # Prepare to store estimated Fisher Information matrix 35 | est_fisher_info = {} 36 | 37 | tmp_loss, tmp_priority = learnable_agent.loss(batch, self.pred_weight, stat) 38 | tmp_loss = (tmp_loss * weight).mean() 39 | tmp_loss.backward() 40 | 41 | # Square gradients and keep running sum 42 | for n, p in learnable_agent.online_net.named_parameters(): 43 | if p.requires_grad: 44 | n = n.replace(".", "__") 45 | est_fisher_info[n] = p.detach().clone().zero_() 46 | if p.grad is not None: 47 | est_fisher_info[n] += p.grad.detach() ** 2 48 | 49 | # Normalize by sample size used for estimation 50 | est_fisher_info = {n: p / self.batchsize for n, p in est_fisher_info.items()} 51 | 52 | # Store new values in the network 53 | for n, p in learnable_agent.online_net.named_parameters(): 54 | if p.requires_grad: 55 | n = n.replace(".", "__") 56 | self.register_buffer( 57 | "{}_EWC_prev_task{}".format(n, "" if self.online else task_idx + 1), 58 | p.detach().clone(), 59 | ) 60 | if self.online and task_idx > 0: 61 | existing_values = getattr(self, "{}_EWC_estimated_fisher".format(n)) 62 | est_fisher_info[n] += self.ewc_gamma * existing_values 63 | self.register_buffer( 64 | "{}_EWC_estimated_fisher{}".format( 65 | n, "" if self.online else task_idx + 1 66 | ), 67 | est_fisher_info[n], 68 | ) 69 | return tmp_priority 70 | 71 | def compute_ewc_loss(self, learnable_agent, task_idx): 72 | if task_idx > 0: 73 | losses = [] 74 | # If "offline EWC", loop over all previous tasks (if "online EWC", [EWC_task_count]=1 so only 1 iteration) 75 | for task in range(1, task_idx + 1): 76 | for n, p in learnable_agent.online_net.named_parameters(): 77 | if p.requires_grad: 78 | # Retrieve stored mode (MAP estimate) and precision (Fisher Information matrix) 79 | n = n.replace(".", "__") 80 | mean = getattr( 81 | self, 82 | "{}_EWC_prev_task{}".format(n, "" if self.online else task), 83 | ) 84 | fisher = getattr( 85 | self, 86 | "{}_EWC_estimated_fisher{}".format( 87 | n, "" if self.online else task 88 | ), 89 | ) 90 | # If "online EWC", apply decay-term to the running sum of the Fisher Information matrices 91 | fisher = self.ewc_gamma * fisher if self.online else fisher 92 | 93 | # Calculate EWC-loss 94 | losses.append((fisher * (p - mean) ** 2).sum()) 95 | # Sum EWC-loss from all parameters (and from all tasks, if "offline EWC") 96 | return (1.0 / 2) * sum(losses) 97 | else: 98 | # EWC-loss is 0 if there are no stored mode and precision yet 99 | return torch.tensor(0.0, device=self.train_device) 100 | -------------------------------------------------------------------------------- /cpp/hanabi_env.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #pragma once 8 | 9 | #include "hanabi-learning-environment/hanabi_lib/canonical_encoders.h" 10 | #include "hanabi-learning-environment/hanabi_lib/hanabi_game.h" 11 | #include "hanabi-learning-environment/hanabi_lib/hanabi_state.h" 12 | 13 | #include "rela/env.h" 14 | 15 | namespace hle = hanabi_learning_env; 16 | 17 | class HanabiEnv : public rela::Env { 18 | public: 19 | HanabiEnv( 20 | const std::unordered_map& gameParams, 21 | const std::vector& epsList, 22 | int maxLen, 23 | bool sad, 24 | bool shuffleObs, 25 | bool shuffleColor, 26 | bool verbose) 27 | : game_(gameParams) 28 | , obsEncoder_(&game_) 29 | , state_(nullptr) 30 | , epsList_(epsList) 31 | , maxLen_(maxLen) 32 | , sad_(sad) 33 | , shuffleObs_(shuffleObs) 34 | , shuffleColor_(shuffleColor) 35 | , verbose_(verbose) 36 | , playerEps_(game_.NumPlayers()) 37 | , numStep_(0) 38 | , colorPermutes_(game_.NumPlayers()) 39 | , invColorPermutes_(game_.NumPlayers()) 40 | , lastScore_(-1) { 41 | auto params = game_.Parameters(); 42 | if (verbose_) { 43 | std::cout << "Hanabi game created, with parameters:\n"; 44 | for (const auto& item : params) { 45 | std::cout << " " << item.first << "=" << item.second << "\n"; 46 | } 47 | } 48 | } 49 | 50 | virtual ~HanabiEnv() { 51 | } 52 | 53 | int featureSize() const { 54 | int size = obsEncoder_.Shape()[0]; 55 | if (sad_) { 56 | size += hle::LastActionSectionLength(game_); 57 | } 58 | 59 | return size; 60 | } 61 | 62 | int numAction() const { 63 | return game_.MaxMoves() + 1; 64 | } 65 | 66 | int noOpUid() const { 67 | return numAction() - 1; 68 | } 69 | 70 | int handFeatureSize() const { 71 | return game_.HandSize() * game_.NumColors() * game_.NumRanks(); 72 | } 73 | 74 | virtual rela::TensorDict reset() override; 75 | 76 | // return {'obs', 'reward', 'terminal'} 77 | // action_p0 is a tensor of size 1, representing uid of move 78 | virtual std::tuple step( 79 | const rela::TensorDict& action) override; 80 | 81 | bool terminated() const final { 82 | if (state_ == nullptr) { 83 | return true; 84 | } 85 | 86 | bool term = false; 87 | if (maxLen_ <= 0) { 88 | term = state_->IsTerminal(); 89 | } else { 90 | term = state_->IsTerminal() || numStep_ >= maxLen_; 91 | } 92 | if (term) { 93 | lastScore_ = state_->Score(); 94 | } 95 | return term; 96 | } 97 | 98 | int getCurrentPlayer() const { 99 | assert(state_ != nullptr); 100 | return state_->CurPlayer(); 101 | } 102 | 103 | bool moveIsLegal(int actionUid) const { 104 | hle::HanabiMove move = game_.GetMove(actionUid); 105 | return state_->MoveIsLegal(move); 106 | } 107 | 108 | int lastScore() const { 109 | return lastScore_; 110 | } 111 | 112 | std::vector deckHistory() const { 113 | return state_->DeckHistory(); 114 | } 115 | 116 | const hle::HanabiState& getHanabiState() const { 117 | assert(state_ != nullptr); 118 | return *state_; 119 | } 120 | 121 | int getScore() const { 122 | return state_->Score(); 123 | } 124 | 125 | int getLife() const { 126 | return state_->LifeTokens(); 127 | } 128 | 129 | int getInfo() const { 130 | return state_->InformationTokens(); 131 | } 132 | 133 | std::vector getFireworks() const { 134 | return state_->Fireworks(); 135 | } 136 | 137 | protected: 138 | bool maybeInversePermuteColor_(hle::HanabiMove& move, int curPlayer) { 139 | if (shuffleColor_ && move.MoveType() == hle::HanabiMove::Type::kRevealColor) { 140 | int realColor = invColorPermutes_[curPlayer][move.Color()]; 141 | move.SetColor(realColor); 142 | return true; 143 | } else { 144 | return false; 145 | } 146 | } 147 | 148 | rela::TensorDict computeFeatureAndLegalMove( 149 | const std::unique_ptr& cloneState); 150 | 151 | const hle::HanabiGame game_; 152 | const hle::CanonicalObservationEncoder obsEncoder_; 153 | std::unique_ptr state_; 154 | const std::vector epsList_; 155 | const int maxLen_; 156 | const bool sad_; 157 | const bool shuffleObs_; 158 | const bool shuffleColor_; 159 | const bool verbose_; 160 | 161 | std::vector playerEps_; 162 | 163 | int numStep_; 164 | std::vector> colorPermutes_; 165 | std::vector> invColorPermutes_; 166 | 167 | mutable int lastScore_; 168 | }; 169 | -------------------------------------------------------------------------------- /pyhanabi/common_utils/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from typing import Dict 8 | 9 | 10 | def get_all_files(root, file_extension, contain=None): 11 | files = [] 12 | for folder, _, fs in os.walk(root): 13 | for f in fs: 14 | if file_extension is not None: 15 | if f.endswith(file_extension): 16 | if contain is None or contain in os.path.join(folder, f): 17 | files.append(os.path.join(folder, f)) 18 | else: 19 | if contain in f: 20 | files.append(os.path.join(folder, f)) 21 | return files 22 | 23 | 24 | def moving_average(data, period): 25 | # padding 26 | left_pad = [data[0] for _ in range(period // 2)] 27 | right_pad = data[-period // 2 + 1 :] 28 | data = left_pad + data + right_pad 29 | weights = np.ones(period) / period 30 | return np.convolve(data, weights, mode="valid") 31 | 32 | 33 | def mem2str(num_bytes): 34 | assert num_bytes >= 0 35 | if num_bytes >= 2 ** 30: # GB 36 | val = float(num_bytes) / (2 ** 30) 37 | result = "%.3f GB" % val 38 | elif num_bytes >= 2 ** 20: # MB 39 | val = float(num_bytes) / (2 ** 20) 40 | result = "%.3f MB" % val 41 | elif num_bytes >= 2 ** 10: # KB 42 | val = float(num_bytes) / (2 ** 10) 43 | result = "%.3f KB" % val 44 | else: 45 | result = "%d bytes" % num_bytes 46 | return result 47 | 48 | 49 | def sec2str(seconds): 50 | seconds = int(seconds) 51 | hour = seconds // 3600 52 | seconds = seconds % (24 * 3600) 53 | seconds %= 3600 54 | minutes = seconds // 60 55 | seconds %= 60 56 | return "%dH %02dM %02dS" % (hour, minutes, seconds) 57 | 58 | 59 | def num2str(n): 60 | if n < 1e3: 61 | s = str(n) 62 | unit = "" 63 | elif n < 1e6: 64 | n /= 1e3 65 | s = "%.3f" % n 66 | unit = "K" 67 | else: 68 | n /= 1e6 69 | s = "%.3f" % n 70 | unit = "M" 71 | 72 | s = s.rstrip("0").rstrip(".") 73 | return s + unit 74 | 75 | 76 | def get_mem_usage(): 77 | import psutil 78 | 79 | mem = psutil.virtual_memory() 80 | result = "" 81 | result += "available: %s, " % (mem2str(mem.available)) 82 | result += "used: %s, " % (mem2str(mem.used)) 83 | result += "free: %s" % (mem2str(mem.free)) 84 | return result 85 | 86 | 87 | def flatten_first2dim(batch): 88 | if isinstance(batch, torch.Tensor): 89 | size = batch.size()[2:] 90 | batch = batch.view(-1, *size) 91 | return batch 92 | elif isinstance(batch, dict): 93 | return {key: flatten_first2dim(batch[key]) for key in batch} 94 | else: 95 | assert False, "unsupported type: %s" % type(batch) 96 | 97 | 98 | def _tensor_slice(t, dim, b, e): 99 | if dim == 0: 100 | return t[b:e] 101 | elif dim == 1: 102 | return t[:, b:e] 103 | elif dim == 2: 104 | return t[:, :, b:e] 105 | else: 106 | raise ValueError("unsupported %d in tensor_slice" % dim) 107 | 108 | 109 | def tensor_slice(t, dim, b, e): 110 | if isinstance(t, dict): 111 | return {key: tensor_slice(t[key], dim, b, e) for key in t} 112 | elif isinstance(t, torch.Tensor): 113 | return _tensor_slice(t, dim, b, e).contiguous() 114 | else: 115 | assert False, "Error: unsupported type: %s" % (type(t)) 116 | 117 | 118 | def tensor_index(t, dim, i): 119 | if isinstance(t, dict): 120 | return {key: tensor_index(t[key], dim, i) for key in t} 121 | elif isinstance(t, torch.Tensor): 122 | return _tensor_slice(t, dim, i, i + 1).squeeze(dim).contiguous() 123 | else: 124 | assert False, "Error: unsupported type: %s" % (type(t)) 125 | 126 | 127 | def one_hot(x, n): 128 | assert x.dim() == 2 and x.size(1) == 1 129 | one_hot_x = torch.zeros(x.size(0), n, device=x.device) 130 | one_hot_x.scatter_(1, x, 1) 131 | return one_hot_x 132 | 133 | 134 | def set_all_seeds(rand_seed): 135 | random.seed(rand_seed) 136 | np.random.seed(rand_seed + 1) 137 | torch.manual_seed(rand_seed + 2) 138 | torch.cuda.manual_seed(rand_seed + 3) 139 | 140 | 141 | def weights_init(m): 142 | """custom weights initialization""" 143 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 144 | # nn.init.kaiming_normal(m.weight.data) 145 | nn.init.orthogonal_(m.weight.data) 146 | else: 147 | print("%s is not custom-initialized." % m.__class__) 148 | 149 | 150 | def init_net(net, net_file): 151 | if net_file: 152 | net.load_state_dict(torch.load(net_file)) 153 | else: 154 | net.apply(weights_init) 155 | 156 | 157 | def count_output_size(input_shape, model): 158 | fake_input = torch.FloatTensor(*input_shape) 159 | output_size = model.forward(fake_input).view(-1).size()[0] 160 | return output_size 161 | -------------------------------------------------------------------------------- /rela/batcher.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #pragma once 8 | 9 | #include "rela/utils.h" 10 | 11 | namespace rela { 12 | 13 | TensorDict allocateBatchStorage(const TensorDict& data, int size) { 14 | TensorDict storage; 15 | for (const auto& kv : data) { 16 | auto t = kv.second.sizes(); 17 | std::vector sizes; 18 | // for (int i = 0; i < batchdim_; ++i) { 19 | // sizes.push_back(t[i]); 20 | // } 21 | sizes.push_back(size); 22 | for (size_t i = 0; i < t.size(); ++i) { 23 | sizes.push_back(t[i]); 24 | } 25 | 26 | storage[kv.first] = torch::zeros(sizes, kv.second.dtype()); 27 | } 28 | return storage; 29 | } 30 | 31 | class FutureReply { 32 | public: 33 | FutureReply() 34 | : ready_(false) { 35 | } 36 | 37 | TensorDict get(int slot) { 38 | // std::cout << "getting slot: " << slot << std::endl; 39 | std::unique_lock lk(mReady_); 40 | cvReady_.wait(lk, [this] { return ready_; }); 41 | lk.unlock(); 42 | 43 | TensorDict e; 44 | for (const auto& kv : data_) { 45 | assert(slot >= 0 && slot < kv.second.size(0)); 46 | e[kv.first] = kv.second[slot]; 47 | // std::cout << kv.first << "\n" << e[kv.first] << std::endl; 48 | } 49 | return e; 50 | // return data_[slot]; 51 | } 52 | 53 | void set(TensorDict&& t) { 54 | // assert(t.device().is_cpu()); 55 | { 56 | std::lock_guard lk(mReady_); 57 | ready_ = true; 58 | data_ = std::move(t); 59 | } 60 | cvReady_.notify_all(); 61 | } 62 | 63 | private: 64 | // no need for protection, only set() can set it 65 | TensorDict data_; 66 | 67 | std::mutex mReady_; 68 | bool ready_; 69 | std::condition_variable cvReady_; 70 | }; 71 | 72 | class Batcher { 73 | public: 74 | Batcher(int batchsize) 75 | : batchsize_(batchsize) 76 | , nextSlot_(0) 77 | , numActiveWrite_(0) 78 | , fillingReply_(std::make_shared()) 79 | , filledReply_(nullptr) { 80 | } 81 | 82 | Batcher(const Batcher&) = delete; 83 | Batcher& operator=(const Batcher&) = delete; 84 | 85 | ~Batcher() { 86 | if (!exit_) { 87 | exit(); 88 | } 89 | } 90 | 91 | void exit() { 92 | { 93 | std::unique_lock lk(mNextSlot_); 94 | exit_ = true; 95 | } 96 | cvGetBatch_.notify_all(); 97 | } 98 | 99 | void reset() { 100 | assert(exit_ == true); 101 | exit_ = false; 102 | } 103 | 104 | bool terminated() { 105 | return exit_; 106 | } 107 | 108 | // send data into batcher 109 | std::shared_ptr send(const TensorDict& t, int* slot) { 110 | std::unique_lock lk(mNextSlot_); 111 | 112 | // init buffer 113 | if (fillingBuffer_.empty()) { 114 | assert(filledBuffer_.empty()); 115 | fillingBuffer_ = allocateBatchStorage(t, batchsize_); 116 | filledBuffer_ = allocateBatchStorage(t, batchsize_); 117 | } 118 | 119 | assert(nextSlot_ <= batchsize_); 120 | // wait if current batch is full and not extracted 121 | cvNextSlot_.wait(lk, [this] { return nextSlot_ < batchsize_; }); 122 | 123 | *slot = nextSlot_; 124 | ++nextSlot_; 125 | ++numActiveWrite_; 126 | lk.unlock(); 127 | 128 | // this will copy 129 | for (const auto& kv : t) { 130 | fillingBuffer_[kv.first][*slot] = kv.second; 131 | } 132 | 133 | // batch has not been extracted yet 134 | assert(numActiveWrite_ > 0); 135 | assert(fillingReply_ != nullptr); 136 | auto reply = fillingReply_; 137 | lk.lock(); 138 | --numActiveWrite_; 139 | lk.unlock(); 140 | if (numActiveWrite_ == 0) { 141 | cvGetBatch_.notify_one(); 142 | } 143 | return reply; 144 | } 145 | 146 | // get batch input from batcher 147 | TensorDict get() { 148 | std::unique_lock lk(mNextSlot_); 149 | cvGetBatch_.wait(lk, [this] { 150 | return (nextSlot_ > 0 && numActiveWrite_ == 0) || exit_; 151 | }); 152 | 153 | if (exit_) { 154 | return TensorDict(); 155 | } 156 | 157 | // TensorDict batch; 158 | // for (const auto& kv : buffer_) { 159 | // batch[kv.first] = kv.second.narrow_copy(batchdim_, 0, nextSlot_).contiguous(); 160 | // } 161 | int bsize = nextSlot_; 162 | nextSlot_ = 0; 163 | // assert previous reply has been handled 164 | assert(filledReply_ == nullptr); 165 | std::swap(fillingBuffer_, filledBuffer_); 166 | std::swap(fillingReply_, filledReply_); 167 | fillingReply_ = std::make_shared(); 168 | 169 | // assert currentReply has been handled 170 | // assert(currentReply_ == nullptr); 171 | // currentreply_ = std::move(nextReply_); 172 | // nextReply_ = std::make_shared(batchdim_); 173 | 174 | lk.unlock(); 175 | cvNextSlot_.notify_all(); 176 | 177 | TensorDict batch; 178 | for (const auto& kv : filledBuffer_) { 179 | batch[kv.first] = kv.second.narrow(0, 0, bsize).contiguous(); 180 | // batch[kv.first] = kv.second.narrow_copy(0, 0, batchsize_).contiguous(); 181 | } 182 | 183 | sumBatchsize_ += bsize; 184 | batchCount_ += 1; 185 | if (batchCount_ % 5000 == 0) { 186 | /* 187 | if (sumBatchsize_ / batchCount_ > 100) { 188 | std::cout << ">>>>>>>>>>>>>>>.batchcount: " << (int64_t)this << std::endl; 189 | std::cout << sumBatchsize_ / (float)batchCount_ << std::endl; 190 | std::cout << ">>>>>>>>>>>>>>>>>>>>>>>>>>>"<< std::endl; 191 | } 192 | */ 193 | sumBatchsize_ = 0; 194 | batchCount_ = 0; 195 | } 196 | 197 | return batch; 198 | } 199 | 200 | // set batch reply for batcher 201 | void set(TensorDict&& t) { 202 | // auto start = high_resolution_clock::now(); 203 | 204 | for (const auto& kv : t) { 205 | assert(kv.second.device().is_cpu()); 206 | } 207 | // assert(currentReply_ != nullptr); 208 | filledReply_->set(std::move(t)); 209 | filledReply_ = nullptr; 210 | } 211 | 212 | private: 213 | // const int batchsize_; 214 | int batchsize_; 215 | 216 | int sumBatchsize_ = 0; 217 | int batchCount_ = 0; 218 | 219 | int nextSlot_; 220 | int numActiveWrite_; 221 | std::condition_variable cvNextSlot_; 222 | 223 | TensorDict fillingBuffer_; 224 | std::shared_ptr fillingReply_; 225 | 226 | TensorDict filledBuffer_; 227 | std::shared_ptr filledReply_; 228 | 229 | bool exit_ = false; 230 | std::condition_variable cvGetBatch_; 231 | std::mutex mNextSlot_; 232 | // std::map timer_; 233 | // int counter_ = 0; 234 | }; 235 | 236 | } // namespace rela 237 | -------------------------------------------------------------------------------- /rela/r2d2_actor.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #pragma once 3 | 4 | #include "rela/batch_runner.h" 5 | #include "rela/transition_buffer.h" 6 | #include "rela/prioritized_replay.h" 7 | 8 | namespace rela { 9 | 10 | torch::Tensor aggregatePriority(torch::Tensor priority, torch::Tensor seqLen, float eta) { 11 | assert(priority.device().is_cpu() && seqLen.device().is_cpu()); 12 | auto mask = torch::arange(0, priority.size(0)); 13 | mask = (mask.unsqueeze(1) < seqLen.unsqueeze(0)).to(torch::kFloat32); 14 | assert(priority.sizes() == mask.sizes()); 15 | priority = priority * mask; 16 | 17 | auto pMean = priority.sum(0) / seqLen; 18 | auto pMax = std::get<0>(priority.max(0)); 19 | auto aggPriority = eta * pMax + (1.0 - eta) * pMean; 20 | return aggPriority.detach(); 21 | } 22 | 23 | class R2D2Actor { 24 | public: 25 | R2D2Actor( 26 | std::shared_ptr runner, 27 | int multiStep, 28 | int numEnvs, 29 | float gamma, 30 | float eta, 31 | int seqLen, 32 | int numPlayer, 33 | std::shared_ptr replayBuffer) 34 | : runner_(std::move(runner)) 35 | , numEnvs_(numEnvs) 36 | , numPlayer_(numPlayer) 37 | , r2d2Buffer_(std::make_unique(numEnvs, numPlayer, multiStep, seqLen)) 38 | , multiStepBuffer_(std::make_unique(multiStep, numEnvs, gamma)) 39 | , replayBuffer_(std::move(replayBuffer)) 40 | , eta_(eta) 41 | , hidden_(getH0(numEnvs, numPlayer)) 42 | , numAct_(0) { 43 | } 44 | 45 | R2D2Actor(std::shared_ptr runner, int numPlayer) 46 | : runner_(std::move(runner)) 47 | , numEnvs_(1) 48 | , numPlayer_(numPlayer) 49 | , r2d2Buffer_(nullptr) 50 | , multiStepBuffer_(nullptr) 51 | , replayBuffer_(nullptr) 52 | , eta_(0) 53 | , hidden_(getH0(1, numPlayer)) 54 | , numAct_(0) { 55 | } 56 | 57 | int numAct() const { 58 | return numAct_; 59 | } 60 | 61 | virtual TensorDict act(const TensorDict& obs) { 62 | // std::cout << ":: start c++ act ::" << std::endl; 63 | torch::NoGradGuard ng; 64 | assert(!hidden_.empty()); 65 | 66 | if (replayBuffer_ != nullptr) { 67 | historyHidden_.push_back(hidden_); 68 | } 69 | 70 | // to avoid adding hid into obs; 71 | auto input = obs; 72 | for (auto& kv : hidden_) { 73 | // convert to batch_first 74 | auto ret = input.emplace(kv.first, kv.second.transpose(0, 1)); 75 | assert(ret.second); 76 | } 77 | 78 | int slot = -1; 79 | auto futureReply = runner_->call("act", input, &slot); 80 | auto reply = futureReply->get(slot); 81 | 82 | for (auto& kv : hidden_) { 83 | auto newHidIt = reply.find(kv.first); 84 | assert(newHidIt != reply.end()); 85 | assert(newHidIt->second.sizes() == kv.second.transpose(0, 1).sizes()); 86 | hidden_[kv.first] = newHidIt->second.transpose(0, 1); 87 | reply.erase(newHidIt); 88 | } 89 | 90 | // for (auto& kv : reply) { 91 | // std::cout << "reply: " << kv.first << ", " << kv.second << std::endl; 92 | // } 93 | 94 | if (replayBuffer_ != nullptr) { 95 | multiStepBuffer_->pushObsAndAction(obs, reply); 96 | } 97 | 98 | numAct_ += numEnvs_; 99 | return reply; 100 | } 101 | 102 | // r is float32 tensor, t is byte tensor 103 | virtual void postAct(const torch::Tensor& r, const torch::Tensor& t) { 104 | if (replayBuffer_ == nullptr) { 105 | return; 106 | } 107 | 108 | // assert(replayBuffer_ != nullptr); 109 | multiStepBuffer_->pushRewardAndTerminal(r, t); 110 | 111 | // if ith state is terminal, reset hidden states 112 | // h0: [num_layers * num_directions, batch, hidden_size] 113 | TensorDict h0 = getH0(1, numPlayer_); 114 | auto terminal = t.accessor(); 115 | // std::cout << "terminal size: " << t.sizes() << std::endl; 116 | // std::cout << "hid size: " << hidden_["h0"].sizes() << std::endl; 117 | for (int i = 0; i < terminal.size(0); i++) { 118 | if (!terminal[i]) { 119 | continue; 120 | } 121 | for (auto& kv : hidden_) { 122 | // [numLayer, numEnvs, hidDim] 123 | // [numLayer, numEnvs, numPlayer (>1), hidDim] 124 | kv.second.narrow(1, i * numPlayer_, numPlayer_) = h0.at(kv.first); 125 | } 126 | } 127 | 128 | if (replayBuffer_ == nullptr) { 129 | return; 130 | } 131 | assert(multiStepBuffer_->size() == historyHidden_.size()); 132 | 133 | if (!multiStepBuffer_->canPop()) { 134 | assert(!r2d2Buffer_->canPop()); 135 | return; 136 | } 137 | 138 | { 139 | FFTransition transition = multiStepBuffer_->popTransition(); 140 | TensorDict hid = historyHidden_.front(); 141 | TensorDict nextHid = historyHidden_.back(); 142 | historyHidden_.pop_front(); 143 | 144 | auto input = transition.toDict(); 145 | for (auto& kv : hid) { 146 | auto ret = input.emplace(kv.first, kv.second.transpose(0, 1)); 147 | assert(ret.second); 148 | } 149 | for (auto& kv : nextHid) { 150 | auto ret = input.emplace("next_" + kv.first, kv.second.transpose(0, 1)); 151 | assert(ret.second); 152 | } 153 | 154 | int slot = -1; 155 | auto futureReply = runner_->call("compute_priority", input, &slot); 156 | auto priority = futureReply->get(slot)["priority"]; 157 | 158 | r2d2Buffer_->push(transition, priority, hid); 159 | } 160 | 161 | if (!r2d2Buffer_->canPop()) { 162 | return; 163 | } 164 | 165 | std::vector batch; 166 | torch::Tensor seqBatchPriority; 167 | torch::Tensor batchLen; 168 | 169 | std::tie(batch, seqBatchPriority, batchLen) = r2d2Buffer_->popTransition(); 170 | auto priority = aggregatePriority(seqBatchPriority, batchLen, eta_); 171 | replayBuffer_->add(batch, priority); 172 | } 173 | 174 | private: 175 | TensorDict getH0(int numEnvs, int numPlayer) { 176 | std::vector input{numEnvs * numPlayer}; 177 | auto model = runner_->jitModel(); 178 | auto output = model.get_method("get_h0")(input); 179 | auto h0 = tensor_dict::fromIValue(output, torch::kCPU, true); 180 | // for (auto& kv : h0) { 181 | // h0[kv.first] = kv.second.transpose(0, 1); 182 | // } 183 | return h0; 184 | } 185 | 186 | std::shared_ptr runner_; 187 | const int numEnvs_; 188 | const int numPlayer_; 189 | 190 | std::deque historyHidden_; 191 | std::unique_ptr r2d2Buffer_; 192 | std::unique_ptr multiStepBuffer_; 193 | std::shared_ptr replayBuffer_; 194 | 195 | const float eta_; 196 | 197 | TensorDict hidden_; 198 | std::atomic numAct_; 199 | 200 | }; 201 | } 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lifelong Hanabi 2 | 3 | ## Introduction 4 | 5 | This repo contains code and models for [Continuous Coordination As a Realistic Scenario for Lifelong Learning](https://arxiv.org/pdf/2103.03216.pdf), a multi-agent lifelong learning testbed that supports both zero-shot and few-shot settings. Our setup is based on [hanabi](https://github.com/deepmind/hanabi-learning-environment) — a partially-observable, fully cooperative multi-agent game. 6 | 7 | 8 |
9 | 10 | ![LifelongHanabi_setup](https://user-images.githubusercontent.com/43013139/107289273-c4f17680-6a32-11eb-93c2-0a70a9e342f3.png) 11 | 12 |
13 | 14 | 15 | 16 | Lifelong Hanabi consists of 3 phases: 1- [Pre-training](https://github.com/chandar-lab/Lifelong-Hanabi/blob/master/README.md#1--pre-trained-agents), 2- [Continual training](https://github.com/chandar-lab/Lifelong-Hanabi#2--continual-training), 3- [Testing](https://github.com/chandar-lab/Lifelong-Hanabi#3--testing). 17 | 18 | The code is built on top of the [Other-Play & Simplified Action Decoder in Hanabi](https://github.com/facebookresearch/hanabi_SAD) repo. 19 | 20 | 21 | 22 | ## Requirements and Installation 23 | The build process is tested with Python 3.7, PyTorch 1.5.1, CUDA 10.1, cudnn 7.6, and nccl 2.4 24 | 25 | ```bash 26 | # clone the repo 27 | git clone --recursive git@github.com:chandar-lab/Lifelong-Hanabi.git 28 | cd Lifelong-Hanabi 29 | 30 | # create new conda env 31 | conda create -n lifelong_hanabi python=3.7 32 | conda activate lifelong_hanabi 33 | pip install -r requirements.txt 34 | 35 | # build 36 | mkdir build 37 | cd build 38 | cmake .. 39 | make 40 | mv hanalearn.cpython-37m-x86_64-linux-gnu.so .. 41 | mv rela/rela.cpython-37m-x86_64-linux-gnu.so .. 42 | mv hanabi-learning-environment/libpyhanabi.so ../hanabi-learning-environment/ 43 | 44 | ``` 45 | Once the building is done and the `.so` files are moved to their required places as mentioned above, every subsequent time you just need to run: 46 | ```bash 47 | conda activate lifelong_hanabi 48 | export PYTHONPATH=/path/to/lifelong_hanabi:$PYTHONPATH 49 | export OMP_NUM_THREADS=1 50 | ``` 51 | ## Run 52 | 53 | ### 1- Pre-Trained Agents 54 | 55 | Run the following command to download the pre-trained agents used in the paper. 56 | ```bash 57 | pip install gdown 58 | gdown --id 1rpmTPIT-g026pdQfAwHoE4i8tP7Qj2vI 59 | ``` 60 | You can find a detailed description of each agent's configs and architectures here: 61 | `results/Pre-trained agents pool for Continual Hanabi.xlsx` 62 | 63 | `all_pretrained_pool.zip` contains the pre-trained agents we used in our experiments (this can be extended by further training more expert Hanabi players). 64 | 65 | To run any `.sh` file, update `` and ``, accordingly. 66 | Important flags are: 67 | |Flags | Description| 68 | |:-------------|:-------------| 69 | | `--sad` |enables Simplified Action Decoder| 70 | | `--pred_weight` |weight for auxiliary task (typically 0.25)| 71 | | `--shuffle_color` |enable other-play| 72 | | `--seed` |seed| 73 | 74 | For details of other hyperparameters refer code and/or paper. 75 | 76 | #### * Pre-train a new agent through self-play: 77 | A sample script is provided in `pyhanabi/tools/pretrain.sh` that can be run: 78 | ```bash 79 | cd pyhanabi 80 | sh tools/pretrain.sh 81 | ``` 82 | 83 | #### * Reproduce the cross-play matrix: 84 | To evaluate all the agents with each other, run: 85 | ```bash 86 | cd pyhanabi 87 | sh generate_cp.sh 88 | ``` 89 | Cross-play matrix from our runs can be found in `results/scores_data_100_nrun5.csv` (`results/sem_data_100_nrun5.csv` contains s.e.m) 90 | 91 | ### 2- Continual Training 92 | To train the learner with a set of 5 partners using for eg. [ER](https://arxiv.org/abs/1902.10486) method, run: 93 | ```bash 94 | cd pyhanabi 95 | sh tools/continual_learning_scripts/ER_easy_interactive.sh 96 | ``` 97 | Zero-shot and few-shot checkpoints will be stored in ``. 98 | Similar scripts are available for all the other algorithms described in paper. 99 | 100 | In order to log the continual training results (from the above checkpoints stored in ``), run: 101 | 102 | ```bash 103 | cd pyhanabi 104 | sh tools/continual_evaluation.sh 105 | ``` 106 | 107 | #### * Add your lifelong algorithm: 108 | In order to implement a new lifelong learning algorithm, depending on the type of the algorithm you can modify one of the following: 109 | 110 | **Memory based methods:** [episodic_memory](https://github.com/chandar-lab/Lifelong-Hanabi/blob/1c79a5349e70419f45b34e13b90fb003109e85ec/pyhanabi/continual_training.py#L378) is a list of the replay buffers from previous tasks. You can change the way the batch is collected like [here](https://github.com/chandar-lab/Lifelong-Hanabi/blob/1c79a5349e70419f45b34e13b90fb003109e85ec/pyhanabi/utils.py#L264) or the way this replayed batch constrains the current gradients [code](https://github.com/chandar-lab/Lifelong-Hanabi/blob/1c79a5349e70419f45b34e13b90fb003109e85ec/pyhanabi/continual_training.py#L567). 111 | 112 | **Regularization based methods:** [Here](https://github.com/chandar-lab/Lifelong-Hanabi/blob/1c79a5349e70419f45b34e13b90fb003109e85ec/pyhanabi/continual_training.py#L387) is where the fisher information matrix at the end of each task is estimated. You can modify the way corresponding regularization loss is calculated and added to the original loss [here](https://github.com/chandar-lab/Lifelong-Hanabi/blob/1c79a5349e70419f45b34e13b90fb003109e85ec/pyhanabi/continual_training.py#L561). 113 | 114 | **Training regimes:** These are a list of hyper-parameters which has been shown [here](https://arxiv.org/abs/2006.06958) that have high impact on the performance of the lifelong learning algorithms. 115 | |Flags | Description| 116 | |:-------------|:-------------| 117 | | `--optim_name` |optimizer| 118 | | `--batchsize` |batch size| 119 | | `--decay_lr` |learning rate decay| 120 | | `--initial_lr` |initial learning rate| 121 | 122 | ### 3- Testing 123 | To evaluate the learner against a set of unseen agents, run: 124 | ```bash 125 | cd pyhanabi 126 | sh tools/testing.sh 127 | ``` 128 | Logging continual training results and testing requires a [wandb](https://wandb.ai/home) account to plot the results. 129 | 130 | ## Plot results 131 | All the plots and experiment details are available at [wandb report](https://wandb.ai/akileshbadrinaaraayanan/ContPlay_Hanabi_complete/reports/Lifelong-Hanabi-Experiments--VmlldzozOTk2NjY). 132 | 133 | * Other code used to reproduce figures in the paper can be found in `results` 134 | 135 | ## Citation: 136 | 137 | If you found this work useful, please consider citing our [paper](https://arxiv.org/abs/2103.03216). 138 | ``` 139 | @misc{nekoei2021continuous, 140 | title={Continuous Coordination As a Realistic Scenario for Lifelong Learning}, 141 | author={Hadi Nekoei and Akilesh Badrinaaraayanan and Aaron Courville and Sarath Chandar}, 142 | year={2021}, 143 | eprint={2103.03216}, 144 | archivePrefix={arXiv}, 145 | primaryClass={cs.LG} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /pyhanabi/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import numpy as np 5 | import torch 6 | import random 7 | from create import * 8 | import rela 9 | import utils 10 | 11 | 12 | def evaluate(agents, num_game, seed, bomb, eps, sad, *, hand_size=5, runners=None): 13 | """ 14 | evaluate agents as long as they have a "act" function 15 | """ 16 | assert agents is None or runners is None 17 | if agents is not None: 18 | runners = [rela.BatchRunner(agent, "cuda:0", 1000, ["act"]) for agent in agents] 19 | num_player = len(runners) 20 | 21 | context = rela.Context() 22 | games = create_envs( 23 | num_game, 24 | seed, 25 | num_player, 26 | hand_size, 27 | bomb, 28 | [eps], 29 | -1, 30 | sad, 31 | False, 32 | False, 33 | ) 34 | 35 | for g in games: 36 | env = hanalearn.HanabiVecEnv() 37 | env.append(g) 38 | actors = [] 39 | for i in range(num_player): 40 | actors.append(rela.R2D2Actor(runners[i], 1)) 41 | thread = hanalearn.HanabiThreadLoop(actors, env, True) 42 | context.push_env_thread(thread) 43 | 44 | for runner in runners: 45 | runner.start() 46 | 47 | context.start() 48 | while not context.terminated(): 49 | time.sleep(0.5) 50 | context.terminate() 51 | while not context.terminated(): 52 | time.sleep(0.5) 53 | 54 | for runner in runners: 55 | runner.stop() 56 | 57 | scores = [g.last_score() for g in games] 58 | num_perfect = np.sum([1 for s in scores if s == 25]) 59 | return np.mean(scores), num_perfect / len(scores), scores, num_perfect 60 | 61 | 62 | def evaluate_legacy_model( 63 | weight_files, 64 | num_game, 65 | seed, 66 | bomb, 67 | agent_args, 68 | args, 69 | num_run=1, 70 | gen_cross_play=False, 71 | verbose=True, 72 | ): 73 | agents = [] 74 | num_player = len(weight_files) 75 | assert num_player > 1, "1 weight file per player" 76 | 77 | env_sad = False 78 | for i, weight_file in enumerate(weight_files): 79 | if verbose: 80 | print( 81 | "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game) 82 | ) 83 | if "sad" in weight_file: 84 | sad = True 85 | env_sad = True 86 | else: 87 | sad = False 88 | 89 | device = "cuda:0" 90 | 91 | state_dict = torch.load(weight_file) 92 | input_dim = state_dict["net.0.weight"].size()[1] 93 | output_dim = state_dict["fc_a.weight"].size()[0] 94 | 95 | if gen_cross_play: 96 | agent_name = weight_file.split("/")[-1].split(".")[0] 97 | 98 | with open(f"{args.weight_1_dir}/{agent_name}.txt", "r") as f: 99 | agent_args = {**json.load(f)} 100 | else: 101 | learnable_pretrain = True 102 | 103 | if i == 0: 104 | learnable_agent_name = agent_args["load_learnable_model"] 105 | if learnable_agent_name != "": 106 | agent_args_file = f"{learnable_agent_name[:-4]}txt" 107 | else: 108 | learnable_pretrain = False 109 | else: 110 | agent_args_file = f"{weight_file[:-4]}txt" 111 | 112 | if learnable_pretrain == True: 113 | with open(agent_args_file, "r") as f: 114 | agent_args = {**json.load(f)} 115 | 116 | rnn_type = agent_args["rnn_type"] 117 | rnn_hid_dim = agent_args["rnn_hid_dim"] 118 | num_fflayer = agent_args["num_fflayer"] 119 | num_rnn_layer = agent_args["num_rnn_layer"] 120 | 121 | if rnn_type == "lstm": 122 | import r2d2_lstm as r2d2 123 | elif rnn_type == "gru": 124 | import r2d2_gru as r2d2 125 | 126 | agent = r2d2.R2D2Agent( 127 | False, 128 | 3, 129 | 0.999, 130 | 0.9, 131 | device, 132 | input_dim, 133 | rnn_hid_dim, 134 | output_dim, 135 | num_fflayer, 136 | num_rnn_layer, 137 | 5, 138 | False, 139 | sad=sad, 140 | ).to(device) 141 | 142 | utils.load_weight(agent.online_net, weight_file, device) 143 | agents.append(agent) 144 | 145 | scores = [] 146 | perfect = 0 147 | for i in range(num_run): 148 | if args.is_rand: 149 | random.shuffle(agents) 150 | 151 | _, _, score, p = evaluate( 152 | agents, 153 | num_game, 154 | num_game * i + seed, 155 | bomb, 156 | 0, 157 | env_sad, 158 | ) 159 | scores.extend(score) 160 | perfect += p 161 | 162 | mean = np.mean(scores) 163 | sem = np.std(scores) / np.sqrt(len(scores)) 164 | perfect_rate = perfect / (num_game * num_run) 165 | if verbose: 166 | print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate) 167 | return mean, sem, perfect_rate 168 | 169 | 170 | def evaluate_saved_model( 171 | weight_files, 172 | num_game, 173 | seed, 174 | bomb, 175 | *, 176 | overwrite=None, 177 | num_run=1, 178 | verbose=True, 179 | ): 180 | agents = [] 181 | sad = [] 182 | hide_action = [] 183 | if overwrite is None: 184 | overwrite = {} 185 | overwrite["vdn"] = False 186 | overwrite["device"] = "cuda:0" 187 | overwrite["boltzmann_act"] = False 188 | 189 | for weight_file in weight_files: 190 | agent, cfg = utils.load_agent( 191 | weight_file, 192 | overwrite, 193 | ) 194 | agents.append(agent) 195 | sad.append(cfg["sad"] if "sad" in cfg else cfg["greedy_extra"]) 196 | hide_action.append(bool(cfg["hide_action"])) 197 | 198 | hand_size = cfg.get("hand_size", 5) 199 | 200 | assert all(s == sad[0] for s in sad) 201 | sad = sad[0] 202 | if all(h == hide_action[0] for h in hide_action): 203 | hide_action = hide_action[0] 204 | process_game = None 205 | else: 206 | hide_actions = hide_action 207 | process_game = lambda g: g.set_hide_actions(hide_actions) 208 | hide_action = False 209 | 210 | scores = [] 211 | perfect = 0 212 | for i in range(num_run): 213 | _, _, score, p, _ = evaluate( 214 | agents, 215 | num_game, 216 | num_game * i + seed, 217 | bomb, 218 | 0, # eps 219 | sad, 220 | hide_action, 221 | process_game=process_game, 222 | hand_size=hand_size, 223 | ) 224 | scores.extend(score) 225 | perfect += p 226 | 227 | mean = np.mean(scores) 228 | sem = np.std(scores) / np.sqrt(len(scores)) 229 | perfect_rate = perfect / (num_game * num_run) 230 | if verbose: 231 | print( 232 | "score: %f +/- %f" % (mean, sem), "; perfect: %.2f%%" % (100 * perfect_rate) 233 | ) 234 | return mean, sem, perfect_rate, scores 235 | -------------------------------------------------------------------------------- /rela/transition.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #include "rela/transition.h" 4 | #include "rela/utils.h" 5 | 6 | using namespace rela; 7 | 8 | FFTransition FFTransition::index(int i) const { 9 | FFTransition element; 10 | 11 | for (auto& name2tensor : obs) { 12 | element.obs.insert({name2tensor.first, name2tensor.second[i]}); 13 | } 14 | for (auto& name2tensor : action) { 15 | element.action.insert({name2tensor.first, name2tensor.second[i]}); 16 | } 17 | 18 | element.reward = reward[i]; 19 | element.terminal = terminal[i]; 20 | element.bootstrap = bootstrap[i]; 21 | 22 | for (auto& name2tensor : nextObs) { 23 | element.nextObs.insert({name2tensor.first, name2tensor.second[i]}); 24 | } 25 | 26 | return element; 27 | } 28 | 29 | FFTransition FFTransition::padLike() const { 30 | FFTransition pad; 31 | 32 | pad.obs = tensor_dict::zerosLike(obs); 33 | pad.action = tensor_dict::zerosLike(action); 34 | pad.reward = torch::zeros_like(reward); 35 | pad.terminal = torch::ones_like(terminal); 36 | pad.bootstrap = torch::zeros_like(bootstrap); 37 | pad.nextObs = tensor_dict::zerosLike(nextObs); 38 | 39 | return pad; 40 | } 41 | 42 | std::vector FFTransition::toVectorIValue( 43 | const torch::Device& device) const { 44 | std::vector vec; 45 | vec.push_back(tensor_dict::toIValue(obs, device)); 46 | vec.push_back(tensor_dict::toIValue(action, device)); 47 | vec.push_back(reward.to(device)); 48 | vec.push_back(terminal.to(device)); 49 | vec.push_back(bootstrap.to(device)); 50 | vec.push_back(tensor_dict::toIValue(nextObs, device)); 51 | return vec; 52 | } 53 | 54 | TensorDict FFTransition::toDict() { 55 | auto dict = obs; 56 | for (auto& kv : nextObs) { 57 | dict["next_" + kv.first] = kv.second; 58 | } 59 | 60 | for (auto& kv : action) { 61 | auto ret = dict.emplace(kv.first, kv.second); 62 | assert(ret.second); 63 | } 64 | 65 | auto ret = dict.emplace("reward", reward); 66 | assert(ret.second); 67 | ret = dict.emplace("terminal", terminal); 68 | assert(ret.second); 69 | ret = dict.emplace("bootstrap", bootstrap); 70 | assert(ret.second); 71 | return dict; 72 | } 73 | 74 | RNNTransition::RNNTransition( 75 | const std::vector& transitions, TensorDict h0, torch::Tensor seqLen) 76 | : h0(h0) 77 | , seqLen(seqLen) { 78 | std::vector obsVec; 79 | std::vector actionVec; 80 | std::vector rewardVec; 81 | std::vector terminalVec; 82 | std::vector bootstrapVec; 83 | 84 | for (size_t i = 0; i < transitions.size(); i++) { 85 | obsVec.push_back(transitions[i].obs); 86 | actionVec.push_back(transitions[i].action); 87 | rewardVec.push_back(transitions[i].reward); 88 | terminalVec.push_back(transitions[i].terminal); 89 | bootstrapVec.push_back(transitions[i].bootstrap); 90 | } 91 | 92 | obs = tensor_dict::stack(obsVec, 0); 93 | action = tensor_dict::stack(actionVec, 0); 94 | reward = torch::stack(rewardVec, 0); 95 | terminal = torch::stack(terminalVec, 0); 96 | bootstrap = torch::stack(bootstrapVec, 0); 97 | } 98 | 99 | RNNTransition RNNTransition::index(int i) const { 100 | RNNTransition element; 101 | 102 | for (auto& name2tensor : obs) { 103 | element.obs.insert({name2tensor.first, name2tensor.second[i]}); 104 | } 105 | for (auto& name2tensor : h0) { 106 | auto t = name2tensor.second.narrow(1, i, 1).squeeze(1); 107 | element.h0.insert({name2tensor.first, t}); 108 | } 109 | for (auto& name2tensor : action) { 110 | element.action.insert({name2tensor.first, name2tensor.second[i]}); 111 | } 112 | 113 | element.reward = reward[i]; 114 | element.terminal = terminal[i]; 115 | element.bootstrap = bootstrap[i]; 116 | element.seqLen = seqLen[i]; 117 | return element; 118 | } 119 | 120 | FFTransition FFTransition::makeBatch( 121 | const std::vector& transitions, const std::string& device) { 122 | std::vector obsVec; 123 | std::vector actionVec; 124 | std::vector rewardVec; 125 | std::vector terminalVec; 126 | std::vector bootstrapVec; 127 | std::vector nextObsVec; 128 | 129 | for (size_t i = 0; i < transitions.size(); i++) { 130 | obsVec.push_back(transitions[i].obs); 131 | actionVec.push_back(transitions[i].action); 132 | rewardVec.push_back(transitions[i].reward); 133 | terminalVec.push_back(transitions[i].terminal); 134 | bootstrapVec.push_back(transitions[i].bootstrap); 135 | nextObsVec.push_back(transitions[i].nextObs); 136 | } 137 | 138 | FFTransition batch; 139 | batch.obs = tensor_dict::stack(obsVec, 0); 140 | batch.action = tensor_dict::stack(actionVec, 0); 141 | batch.reward = torch::stack(rewardVec, 0); 142 | batch.terminal = torch::stack(terminalVec, 0); 143 | batch.bootstrap = torch::stack(bootstrapVec, 0); 144 | batch.nextObs = tensor_dict::stack(nextObsVec, 0); 145 | 146 | if (device != "cpu") { 147 | auto d = torch::Device(device); 148 | auto toDevice = [&](const torch::Tensor& t) { return t.to(d); }; 149 | batch.obs = tensor_dict::apply(batch.obs, toDevice); 150 | batch.action = tensor_dict::apply(batch.action, toDevice); 151 | batch.reward = batch.reward.to(d); 152 | batch.terminal = batch.terminal.to(d); 153 | batch.bootstrap = batch.bootstrap.to(d); 154 | batch.nextObs = tensor_dict::apply(batch.nextObs, toDevice); 155 | } 156 | 157 | return batch; 158 | } 159 | 160 | RNNTransition RNNTransition::makeBatch( 161 | const std::vector& transitions, const std::string& device) { 162 | std::vector obsVec; 163 | // TensorVecDict h0Vec; 164 | std::vector actionVec; 165 | std::vector rewardVec; 166 | std::vector terminalVec; 167 | std::vector bootstrapVec; 168 | std::vector seqLenVec; 169 | 170 | for (size_t i = 0; i < transitions.size(); i++) { 171 | obsVec.push_back(transitions[i].obs); 172 | // utils::tensorVecDictAppend(h0Vec, transitions[i].h0); 173 | actionVec.push_back(transitions[i].action); 174 | rewardVec.push_back(transitions[i].reward); 175 | terminalVec.push_back(transitions[i].terminal); 176 | bootstrapVec.push_back(transitions[i].bootstrap); 177 | seqLenVec.push_back(transitions[i].seqLen); 178 | } 179 | 180 | RNNTransition batch; 181 | batch.obs = tensor_dict::stack(obsVec, 1); 182 | // batch.h0 = tensor_dict::stack(h0Vec, 1); // 1 is batch for rnn hid 183 | batch.action = tensor_dict::stack(actionVec, 1); 184 | batch.reward = torch::stack(rewardVec, 1); 185 | batch.terminal = torch::stack(terminalVec, 1); 186 | batch.bootstrap = torch::stack(bootstrapVec, 1); 187 | batch.seqLen = torch::stack(seqLenVec, 0); //.squeeze(1); 188 | 189 | if (device != "cpu") { 190 | auto d = torch::Device(device); 191 | auto toDevice = [&](const torch::Tensor& t) { return t.to(d); }; 192 | batch.obs = tensor_dict::apply(batch.obs, toDevice); 193 | // batch.h0 = tensor_dict::apply(batch.h0, toDevice); 194 | batch.action = tensor_dict::apply(batch.action, toDevice); 195 | batch.reward = batch.reward.to(d); 196 | batch.terminal = batch.terminal.to(d); 197 | batch.bootstrap = batch.bootstrap.to(d); 198 | batch.seqLen = batch.seqLen.to(d); 199 | } 200 | 201 | return batch; 202 | } 203 | -------------------------------------------------------------------------------- /rela/tensor_dict.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #pragma once 3 | 4 | #include 5 | #include 6 | 7 | namespace rela { 8 | 9 | using TensorDict = std::unordered_map; 10 | 11 | namespace tensor_dict { 12 | 13 | inline void compareShape(const TensorDict& src, const TensorDict& dest) { 14 | if (src.size() != dest.size()) { 15 | std::cout << "src.size()[" << src.size() << "] != dest.size()[" << dest.size() << "]" 16 | << std::endl; 17 | std::cout << "src keys: "; 18 | for (const auto& p : src) 19 | std::cout << p.first << " "; 20 | std::cout << "dest keys: "; 21 | for (const auto& p : dest) 22 | std::cout << p.first << " "; 23 | std::cout << std::endl; 24 | assert(false); 25 | } 26 | 27 | for (const auto& name2tensor : src) { 28 | const auto& name = name2tensor.first; 29 | const auto& srcTensor = name2tensor.second; 30 | // std::cout << "in copy: trying to get: " << name << std::endl; 31 | // std::cout << "dest map keys" << std::endl; 32 | // printMapKey(dest); 33 | const auto& destTensor = dest.at(name); 34 | // if (destTensor.sizes() != srcTensor.sizes()) { 35 | // std::cout << "copy size-mismatch: " 36 | // << destTensor.sizes() << ", " << srcTensor.sizes() << 37 | // std::endl; 38 | // } 39 | if (destTensor.sizes() != srcTensor.sizes()) { 40 | std::cout << name << ", dstSize: " << destTensor.sizes() 41 | << ", srcSize: " << srcTensor.sizes() << std::endl; 42 | assert(false); 43 | } 44 | 45 | // if (destTensor.dtype() != srcTensor.dtype()) { 46 | // std::cout << name << ", dstType: " << destTensor.dtype() 47 | // << ", srcType: " << srcTensor.dtype() << std::endl; 48 | // assert(false); 49 | // } 50 | } 51 | } 52 | 53 | inline void copy(const TensorDict& src, TensorDict& dest) { 54 | compareShape(src, dest); 55 | for (const auto& name2tensor : src) { 56 | const auto& name = name2tensor.first; 57 | const auto& srcTensor = name2tensor.second; 58 | // std::cout << "in copy: trying to get: " << name << std::endl; 59 | // std::cout << "dest map keys" << std::endl; 60 | // printMapKey(dest); 61 | auto& destTensor = dest.at(name); 62 | // if (destTensor.sizes() != srcTensor.sizes()) { 63 | // std::cout << "copy size-mismatch: " 64 | // << destTensor.sizes() << ", " << srcTensor.sizes() << 65 | // std::endl; 66 | // } 67 | destTensor.copy_(srcTensor); 68 | } 69 | } 70 | 71 | // // TODO: maybe merge these two functions? 72 | // inline void copyTensors( 73 | // const std::unordered_map& src, 74 | // std::unordered_map& dest, 75 | // std::vector& index) { 76 | // assert(src.size() == dest.size()); 77 | // assert(!index.empty()); 78 | // torch::Tensor indexTensor = 79 | // torch::from_blob(index.data(), {(int64_t)index.size()}, torch::kInt64); 80 | 81 | // for (const auto& name2tensor : src) { 82 | // const auto& name = name2tensor.first; 83 | // const auto& srcTensor = name2tensor.second; 84 | // auto& destTensor = dest.at(name); 85 | // // assert(destTensor.sizes() == srcTensor.sizes()); 86 | // assert(destTensor.dtype() == srcTensor.dtype()); 87 | // assert(indexTensor.size(0) == srcTensor.size(0)); 88 | // destTensor.index_copy_(0, indexTensor, srcTensor); 89 | // } 90 | // } 91 | 92 | inline void copy(const TensorDict& src, TensorDict& dest, const torch::Tensor& index) { 93 | assert(src.size() == dest.size()); 94 | assert(index.size(0) > 0); 95 | for (const auto& name2tensor : src) { 96 | const auto& name = name2tensor.first; 97 | const auto& srcTensor = name2tensor.second; 98 | auto& destTensor = dest.at(name); 99 | assert(destTensor.dtype() == srcTensor.dtype()); 100 | assert(index.size(0) == srcTensor.size(0)); 101 | destTensor.index_copy_(0, index, srcTensor); 102 | } 103 | } 104 | 105 | inline bool eq(const TensorDict& d0, const TensorDict& d1) { 106 | if (d0.size() != d1.size()) { 107 | return false; 108 | } 109 | 110 | for (const auto& name2tensor : d0) { 111 | auto key = name2tensor.first; 112 | if ((d1.at(key) != name2tensor.second).all().item()) { 113 | return false; 114 | } 115 | } 116 | return true; 117 | } 118 | 119 | /* 120 | * indexes into a TensorDict 121 | */ 122 | inline TensorDict index(const TensorDict& batch, size_t i) { 123 | TensorDict result; 124 | for (const auto& name2tensor : batch) { 125 | result.insert({name2tensor.first, name2tensor.second[i]}); 126 | } 127 | return result; 128 | } 129 | 130 | inline TensorDict narrow( 131 | const TensorDict& batch, size_t dim, size_t i, size_t len, bool squeeze) { 132 | TensorDict result; 133 | for (auto& name2tensor : batch) { 134 | auto t = name2tensor.second.narrow(dim, i, len); 135 | if (squeeze) { 136 | assert(len == 1); 137 | t = t.squeeze(dim); 138 | } 139 | result.insert({name2tensor.first, std::move(t)}); 140 | } 141 | return result; 142 | } 143 | 144 | inline TensorDict clone(const TensorDict& input) { 145 | TensorDict output; 146 | for (auto& name2tensor : input) { 147 | output.insert({name2tensor.first, name2tensor.second.clone()}); 148 | } 149 | return output; 150 | } 151 | 152 | inline TensorDict zerosLike(const TensorDict& input) { 153 | TensorDict output; 154 | for (auto& name2tensor : input) { 155 | output.insert({name2tensor.first, torch::zeros_like(name2tensor.second)}); 156 | } 157 | return output; 158 | } 159 | 160 | // TODO: rewrite the above functions with this template 161 | template 162 | inline TensorDict apply(TensorDict& dict, Func f) { 163 | TensorDict output; 164 | for (const auto& name2tensor : dict) { 165 | auto tensor = f(name2tensor.second); 166 | output.insert({name2tensor.first, tensor}); 167 | } 168 | return output; 169 | } 170 | 171 | inline TensorDict stack(const std::vector& vec, int stackdim) { 172 | assert(vec.size() >= 1); 173 | TensorDict ret; 174 | for (auto& name2tensor : vec[0]) { 175 | std::vector buffer(vec.size()); 176 | for (size_t i = 0; i < vec.size(); ++i) { 177 | buffer[i] = vec[i].at(name2tensor.first); 178 | } 179 | ret[name2tensor.first] = torch::stack(buffer, stackdim); 180 | } 181 | return ret; 182 | } 183 | 184 | inline TensorDict fromIValue( 185 | const torch::jit::IValue& value, torch::DeviceType device, bool detach) { 186 | std::unordered_map map; 187 | auto dict = value.toGenericDict(); 188 | // auto ivalMap = dict->elements(); 189 | for (auto& name2tensor : dict) { 190 | auto name = name2tensor.key().toString(); 191 | torch::Tensor tensor = name2tensor.value().toTensor(); 192 | tensor = tensor.to(device); 193 | if (detach) { 194 | tensor = tensor.detach(); 195 | } 196 | map.insert({name->string(), tensor}); 197 | } 198 | return map; 199 | } 200 | 201 | // TODO: this may be simplified with constructor in the future version 202 | inline torch::jit::IValue toIValue( 203 | const TensorDict& tensorDict, const torch::Device& device) { 204 | torch::Dict dict; 205 | for (const auto& name2tensor : tensorDict) { 206 | dict.insert(name2tensor.first, name2tensor.second.to(device)); 207 | } 208 | return torch::jit::IValue(dict); 209 | } 210 | } // namespace tdict 211 | } // namespace rela 212 | -------------------------------------------------------------------------------- /cpp/hanabi_env.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | // 7 | #include "hanabi_env.h" 8 | 9 | rela::TensorDict HanabiEnv::reset() { 10 | assert(terminated()); 11 | state_ = std::make_unique(&game_); 12 | // chance player 13 | while (state_->CurPlayer() == hle::kChancePlayerId) { 14 | state_->ApplyRandomChance(); 15 | } 16 | numStep_ = 0; 17 | 18 | for (int pid = 0; pid < game_.NumPlayers(); ++pid) { 19 | playerEps_[pid] = epsList_[game_.rng()->operator()() % epsList_.size()]; 20 | } 21 | 22 | if (shuffleColor_) { 23 | // assert(game_.NumPlayers() == 2); 24 | int fixColorPlayer = game_.rng()->operator()() % game_.NumPlayers(); 25 | for (int pid = 0; pid < game_.NumPlayers(); ++pid) { 26 | auto& colorPermute = colorPermutes_[pid]; 27 | auto& invColorPermute = invColorPermutes_[pid]; 28 | colorPermute.clear(); 29 | invColorPermute.clear(); 30 | for (int i = 0; i < game_.NumColors(); ++i) { 31 | colorPermute.push_back(i); 32 | invColorPermute.push_back(i); 33 | } 34 | if (pid != fixColorPlayer) { 35 | std::shuffle(colorPermute.begin(), colorPermute.end(), *game_.rng()); 36 | std::sort(invColorPermute.begin(), invColorPermute.end(), [&](int i, int j) { 37 | return colorPermute[i] < colorPermute[j]; 38 | }); 39 | } 40 | for (int i = 0; i < (int)colorPermute.size(); ++i) { 41 | assert(invColorPermute[colorPermute[i]] == i); 42 | } 43 | } 44 | } 45 | 46 | return computeFeatureAndLegalMove(state_); 47 | } 48 | 49 | std::tuple HanabiEnv::step( 50 | const rela::TensorDict& action) { 51 | assert(!terminated()); 52 | 53 | numStep_ += 1; 54 | 55 | float prevScore = state_->Score(); 56 | 57 | // perform action for only current player 58 | int curPlayer = state_->CurPlayer(); 59 | int actionUid = action.at("a")[curPlayer].item(); 60 | hle::HanabiMove move = game_.GetMove(actionUid); 61 | maybeInversePermuteColor_(move, curPlayer); 62 | 63 | if (!state_->MoveIsLegal(move)) { 64 | std::cout << "Error: move is not legal" << std::endl; 65 | std::cout << "UID: " << actionUid << std::endl; 66 | std::cout << "legal move:" << std::endl; 67 | std::cout << "numStep: " << numStep_ - 1 << std::endl; 68 | 69 | auto legalMoves = state_->LegalMoves(curPlayer); 70 | for (auto move : legalMoves) { 71 | if (shuffleColor_ && 72 | move.MoveType() == hle::HanabiMove::Type::kRevealColor) { 73 | int permColor = colorPermutes_[curPlayer][move.Color()]; 74 | move.SetColor(permColor); 75 | } 76 | auto uid = game_.GetMoveUid(move); 77 | std::cout << "legal_move: " << uid << std::endl; 78 | } 79 | assert(false); 80 | } 81 | 82 | std::unique_ptr cloneState = nullptr; 83 | if (sad_) { 84 | cloneState = std::make_unique(*state_); 85 | int greedyActionUid = action.at("greedy_a")[curPlayer].item(); 86 | hle::HanabiMove greedyMove = game_.GetMove(greedyActionUid); 87 | maybeInversePermuteColor_(greedyMove, curPlayer); 88 | 89 | assert(state_->MoveIsLegal(greedyMove)); 90 | cloneState->ApplyMove(greedyMove); 91 | } 92 | state_->ApplyMove(move); 93 | 94 | bool terminal = state_->IsTerminal(); 95 | float reward = state_->Score() - prevScore; 96 | 97 | // forced termination, lose all points 98 | if (maxLen_ > 0 && numStep_ == maxLen_) { 99 | terminal = true; 100 | reward = 0 - prevScore; 101 | } 102 | 103 | if (!terminal) { 104 | // chance player 105 | while (state_->CurPlayer() == hle::kChancePlayerId) { 106 | state_->ApplyRandomChance(); 107 | } 108 | } 109 | 110 | // std::cout << "score: " << state_->Score() << std::endl; 111 | auto obs = computeFeatureAndLegalMove(cloneState); 112 | return std::make_tuple(obs, reward, terminal); 113 | } 114 | 115 | rela::TensorDict HanabiEnv::computeFeatureAndLegalMove( 116 | const std::unique_ptr& cloneState) { 117 | std::vector privS; 118 | std::vector privS_gen; 119 | // std::vector publS; 120 | // std::vector superS; 121 | std::vector legalMove; 122 | std::vector legalMatrix; 123 | // auto epsAccessor = eps_.accessor(); 124 | // std::vector eps; 125 | std::vector ownHand; 126 | // std::vector ownHandARIn; 127 | // std::vector allHand; 128 | // std::vector allHandARIn; 129 | 130 | // std::vector privCardCount; 131 | 132 | for (int i = 0; i < game_.NumPlayers(); ++i) { 133 | auto obs = hle::HanabiObservation(*state_, i, false); 134 | std::vector shuffleOrder; 135 | if (shuffleObs_) { 136 | // hacked for 2 players 137 | assert(game_.NumPlayers() == 2); 138 | // [1] for partner's hand 139 | int partnerHandSize = obs.Hands()[1].Cards().size(); 140 | for (int i = 0; i < partnerHandSize; ++i) { 141 | shuffleOrder.push_back(i); 142 | } 143 | std::shuffle(shuffleOrder.begin(), shuffleOrder.end(), *game_.rng()); 144 | } 145 | 146 | std::vector vS = obsEncoder_.Encode( 147 | obs, 148 | false, 149 | shuffleOrder, 150 | shuffleColor_, 151 | colorPermutes_[i], 152 | invColorPermutes_[i], 153 | false); 154 | 155 | privS_gen.push_back(torch::tensor(vS)); 156 | 157 | if (sad_) { 158 | assert(cloneState != nullptr); 159 | auto extraObs = hle::HanabiObservation(*cloneState, i, false); 160 | std::vector vGreedyAction = obsEncoder_.EncodeLastAction( 161 | extraObs, shuffleOrder, shuffleColor_, colorPermutes_[i]); 162 | vS.insert(vS.end(), vGreedyAction.begin(), vGreedyAction.end()); 163 | } 164 | 165 | privS.push_back(torch::tensor(vS)); 166 | 167 | { 168 | auto cheatObs = hle::HanabiObservation(*state_, i, true); 169 | auto vOwnHand = obsEncoder_.EncodeOwnHandTrinary(cheatObs); 170 | ownHand.push_back(torch::tensor(vOwnHand)); 171 | } 172 | 173 | // legal moves 174 | auto legalMoves = state_->LegalMoves(i); 175 | std::vector moveUids(numAction(), 0); 176 | // auto moveUids = torch::zeros({numAction()}); 177 | // auto moveAccessor = moveUids.accessor(); 178 | for (auto move : legalMoves) { 179 | if (shuffleColor_ && 180 | // fixColorPlayer_ == i && 181 | move.MoveType() == hle::HanabiMove::Type::kRevealColor) { 182 | int permColor = colorPermutes_[i][move.Color()]; 183 | move.SetColor(permColor); 184 | } 185 | auto uid = game_.GetMoveUid(move); 186 | if (uid >= noOpUid()) { 187 | std::cout << "Error: legal move id should be < " << numAction() - 1 << std::endl; 188 | assert(false); 189 | } 190 | moveUids[uid] = 1; 191 | } 192 | if (legalMoves.size() == 0) { 193 | moveUids[noOpUid()] = 1; 194 | } 195 | 196 | legalMove.push_back(torch::tensor(moveUids)); 197 | // epsAccessor[i] = playerEps_[i]; 198 | } 199 | 200 | rela::TensorDict dict = { 201 | {"priv_s", torch::stack(privS, 0)}, 202 | {"priv_s_gen", torch::stack(privS_gen, 0)}, 203 | {"legal_move", torch::stack(legalMove, 0)}, 204 | {"eps", torch::tensor(playerEps_)}, 205 | {"own_hand", torch::stack(ownHand, 0)}, 206 | }; 207 | 208 | return dict; 209 | } 210 | -------------------------------------------------------------------------------- /rela/transition_buffer.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #pragma once 3 | 4 | #include "transition.h" 5 | 6 | namespace rela { 7 | 8 | class MultiStepBuffer { 9 | public: 10 | MultiStepBuffer(int multiStep, int batchsize, float gamma) 11 | : multiStep_(multiStep) 12 | , batchsize_(batchsize) 13 | , gamma_(gamma) { 14 | } 15 | 16 | void pushObsAndAction(const TensorDict& obs, const TensorDict& action) { 17 | assert((int)obsHistory_.size() <= multiStep_); 18 | assert((int)actionHistory_.size() <= multiStep_); 19 | 20 | obsHistory_.push_back(obs); 21 | actionHistory_.push_back(action); 22 | } 23 | 24 | void pushRewardAndTerminal(const torch::Tensor& reward, const torch::Tensor& terminal) { 25 | assert(rewardHistory_.size() == terminalHistory_.size()); 26 | assert(rewardHistory_.size() == obsHistory_.size() - 1); 27 | assert(reward.dim() == 1 && terminal.dim() == 1); 28 | assert(reward.size(0) == terminal.size(0) && reward.size(0) == batchsize_); 29 | 30 | rewardHistory_.push_back(reward); 31 | terminalHistory_.push_back(terminal); 32 | } 33 | 34 | size_t size() { 35 | return obsHistory_.size(); 36 | } 37 | 38 | bool canPop() { 39 | return (int)obsHistory_.size() == multiStep_ + 1; 40 | } 41 | 42 | /* assumes that: 43 | * obsHistory contains s_t, s_t+1, ..., s_t+n 44 | * actionHistory contains a_t, a_t+1, ..., a_t+n 45 | * rewardHistory contains r_t, r_t+1, ..., r_t+n 46 | * terminalHistory contains T_t, T_t+1, ..., T_t+n 47 | * 48 | * returns: 49 | * obs, action, cumulative reward, terminal, "next"_obs 50 | */ 51 | FFTransition popTransition() { 52 | assert((int)obsHistory_.size() == multiStep_ + 1); 53 | assert((int)actionHistory_.size() == multiStep_ + 1); 54 | assert((int)rewardHistory_.size() == multiStep_ + 1); 55 | assert((int)terminalHistory_.size() == multiStep_ + 1); 56 | 57 | TensorDict obs = obsHistory_.front(); 58 | TensorDict action = actionHistory_.front(); 59 | torch::Tensor terminal = terminalHistory_.front(); 60 | torch::Tensor bootstrap = torch::ones(batchsize_, torch::kFloat32); 61 | auto bootstrapAccessor = bootstrap.accessor(); 62 | 63 | std::vector nextObsIndices(batchsize_); 64 | // calculate bootstrap and nextState indices 65 | for (int i = 0; i < batchsize_; i++) { 66 | for (int step = 0; step < multiStep_; step++) { 67 | // next state is step (shouldn't be used anyways) 68 | if (terminalHistory_[step][i].item()) { 69 | bootstrapAccessor[i] = 0.0; 70 | nextObsIndices[i] = step; 71 | break; 72 | } 73 | } 74 | // next state is step+n 75 | if (bootstrapAccessor[i] > 1e-6) { 76 | nextObsIndices[i] = multiStep_; 77 | } 78 | } 79 | 80 | // calculate discounted rewards 81 | auto reward = torch::zeros_like(rewardHistory_.front()); 82 | auto accessor = reward.accessor(); 83 | for (int i = 0; i < batchsize_; i++) { 84 | // if bootstrap, we only use the first nsAccessor[i]-1 (i.e. multiStep_-1) 85 | int initial = bootstrapAccessor[i] ? multiStep_ - 1 : nextObsIndices[i]; 86 | for (int step = initial; step >= 0; step--) { 87 | float stepReward = rewardHistory_[step][i].item(); 88 | accessor[i] = stepReward + gamma_ * accessor[i]; 89 | } 90 | } 91 | 92 | TensorDict nextObs = obsHistory_.back(); 93 | 94 | obsHistory_.pop_front(); 95 | actionHistory_.pop_front(); 96 | rewardHistory_.pop_front(); 97 | terminalHistory_.pop_front(); 98 | return FFTransition(obs, action, reward, terminal, bootstrap, nextObs); 99 | } 100 | 101 | void clear() { 102 | obsHistory_.clear(); 103 | actionHistory_.clear(); 104 | rewardHistory_.clear(); 105 | terminalHistory_.clear(); 106 | } 107 | 108 | private: 109 | const int multiStep_; 110 | // const int batchsize_; 111 | int batchsize_; 112 | const float gamma_; 113 | 114 | std::deque obsHistory_; 115 | std::deque actionHistory_; 116 | std::deque rewardHistory_; 117 | std::deque terminalHistory_; 118 | }; 119 | 120 | class R2D2Buffer { 121 | public: 122 | R2D2Buffer(int batchsize, int numPlayer, int multiStep, int seqLen) 123 | : batchsize(batchsize) 124 | , numPlayer(numPlayer) 125 | , multiStep(multiStep) 126 | , seqLen(seqLen) 127 | , batchNextIdx_(batchsize, 0) 128 | , batchH0_(batchsize) 129 | , batchSeqTransition_(batchsize, std::vector(seqLen)) 130 | , batchSeqPriority_(batchsize, std::vector(seqLen)) 131 | , batchLen_(batchsize, 0) 132 | , canPop_(false) { 133 | } 134 | 135 | void push( 136 | const FFTransition& transition, 137 | const torch::Tensor& priority, 138 | const TensorDict& /*hid*/) { 139 | assert(priority.size(0) == batchsize); 140 | 141 | auto priorityAccessor = priority.accessor(); 142 | for (int i = 0; i < batchsize; ++i) { 143 | int nextIdx = batchNextIdx_[i]; 144 | assert(nextIdx < seqLen && nextIdx >= 0); 145 | if (nextIdx == 0) { 146 | // TODO: !!! simplification for unconditional h0 147 | // batchH0_[i] = 148 | // utils::tensorDictNarrow(hid, 1, i * numPlayer, numPlayer, false); 149 | } 150 | 151 | auto t = transition.index(i); 152 | // some sanity check for termination 153 | if (nextIdx != 0) { 154 | // should not append after terminal 155 | // terminal should be processed when it is pushed 156 | assert(!batchSeqTransition_[i][nextIdx - 1].terminal.item()); 157 | assert(batchLen_[i] == 0); 158 | } 159 | 160 | batchSeqTransition_[i][nextIdx] = t; 161 | batchSeqPriority_[i][nextIdx] = priorityAccessor[i]; 162 | 163 | ++batchNextIdx_[i]; 164 | if (!t.terminal.item()) { 165 | continue; 166 | } 167 | 168 | // pad the rest of the seq in case of terminal 169 | batchLen_[i] = batchNextIdx_[i]; 170 | while (batchNextIdx_[i] < seqLen) { 171 | batchSeqTransition_[i][batchNextIdx_[i]] = t.padLike(); 172 | batchSeqPriority_[i][batchNextIdx_[i]] = 0; 173 | ++batchNextIdx_[i]; 174 | } 175 | canPop_ = true; 176 | } 177 | } 178 | 179 | bool canPop() { 180 | return canPop_; 181 | } 182 | 183 | std::tuple, torch::Tensor, torch::Tensor> popTransition() { 184 | assert(canPop_); 185 | 186 | std::vector batchTransition; 187 | std::vector batchSeqPriority; 188 | std::vector batchLen; 189 | 190 | // std::cout << "batch size inside popTransiton inside transition buffer ... " << batchsize << std::endl; 191 | for (int i = 0; i < batchsize; ++i) { 192 | if (batchLen_[i] == 0) { 193 | continue; 194 | } 195 | assert(batchNextIdx_[i] == seqLen); 196 | 197 | batchSeqPriority.push_back(torch::tensor(batchSeqPriority_[i])); 198 | batchLen.push_back((float)batchLen_[i]); 199 | auto t = RNNTransition( 200 | batchSeqTransition_[i], batchH0_[i], torch::tensor(float(batchLen_[i]))); 201 | batchTransition.push_back(t); 202 | 203 | batchLen_[i] = 0; 204 | batchNextIdx_[i] = 0; 205 | } 206 | canPop_ = false; 207 | assert(batchTransition.size() > 0); 208 | 209 | return std::make_tuple( 210 | batchTransition, 211 | torch::stack(batchSeqPriority, 1), // batchdim = 1 212 | torch::tensor(batchLen)); 213 | } 214 | 215 | // const int batchsize; 216 | int batchsize; 217 | const int numPlayer; 218 | const int multiStep; 219 | const int seqLen; 220 | 221 | private: 222 | std::vector batchNextIdx_; 223 | std::vector batchH0_; 224 | 225 | std::vector> batchSeqTransition_; 226 | std::vector> batchSeqPriority_; 227 | std::vector batchLen_; 228 | 229 | bool canPop_; 230 | }; 231 | } 232 | -------------------------------------------------------------------------------- /pyhanabi/create.py: -------------------------------------------------------------------------------- 1 | import set_path 2 | 3 | set_path.append_sys_path() 4 | 5 | import os 6 | import pprint 7 | import time 8 | import copy 9 | 10 | import numpy as np 11 | import torch 12 | import rela 13 | import hanalearn 14 | 15 | assert rela.__file__.endswith(".so") 16 | assert hanalearn.__file__.endswith(".so") 17 | 18 | 19 | def create_envs( 20 | num_env, 21 | seed, 22 | num_player, 23 | hand_size, 24 | bomb, 25 | explore_eps, 26 | max_len, 27 | sad, 28 | shuffle_obs, 29 | shuffle_color, 30 | ): 31 | games = [] 32 | for game_idx in range(num_env): 33 | params = { 34 | "players": str(num_player), 35 | "hand_size": str(hand_size), 36 | "seed": str(seed + game_idx), 37 | "bomb": str(bomb), 38 | } 39 | game = hanalearn.HanabiEnv( 40 | params, 41 | explore_eps, 42 | max_len, 43 | sad, 44 | shuffle_obs, 45 | shuffle_color, 46 | False, 47 | ) 48 | games.append(game) 49 | return games 50 | 51 | 52 | def create_threads( 53 | num_thread, 54 | num_game_per_thread, 55 | actors, 56 | games, 57 | ): 58 | context = rela.Context() 59 | threads = [] 60 | for thread_idx in range(num_thread): 61 | env = hanalearn.HanabiVecEnv() 62 | for game_idx in range(num_game_per_thread): 63 | env.append(games[thread_idx * num_game_per_thread + game_idx]) 64 | thread = hanalearn.HanabiThreadLoop(actors[thread_idx], env, False) 65 | threads.append(thread) 66 | context.push_env_thread(thread) 67 | print( 68 | "Finished creating %d threads with %d games and %d actors" 69 | % (len(threads), len(games), len(actors)) 70 | ) 71 | return context, threads 72 | 73 | 74 | class ActGroup: 75 | """ 76 | Creates actors given the agents. Starts to stores transitions in the replay_buffer by calling ActGroup.start() 77 | Args: 78 | method(str): iql or vdn 79 | devices(str): cuda1 80 | agent(object): an R2d2 agent object 81 | num_thread(int): default=10 82 | num_game_per_thread(int): default=80 83 | multi_step(int): default=3 84 | gamma(float): discount factor 85 | eta(float): eta for aggregate priority 86 | max_len(int): max seq len 87 | num_player(int): default=2 88 | replay_buffer(object): a replay buffer onject 89 | Returns: 90 | None 91 | """ 92 | 93 | def __init__( 94 | self, 95 | method, 96 | devices, 97 | agent, 98 | num_thread, 99 | num_game_per_thread, 100 | multi_step, 101 | gamma, 102 | eta, 103 | max_len, 104 | num_player, 105 | replay_buffer, 106 | ): 107 | self.devices = devices.split(",") 108 | 109 | self.model_runners = [] 110 | for dev in self.devices: 111 | runner = rela.BatchRunner( 112 | agent.clone(dev), dev, 100, ["act", "compute_priority"] 113 | ) 114 | self.model_runners.append(runner) 115 | 116 | self.num_runners = len(self.model_runners) 117 | 118 | self.actors = [] 119 | self.eval_actors = [] 120 | if method == "vdn": 121 | for i in range(num_thread): 122 | actor = rela.R2D2Actor( 123 | self.model_runners[i % self.num_runners], 124 | multi_step, 125 | num_game_per_thread, 126 | gamma, 127 | eta, 128 | max_len, 129 | num_player, 130 | replay_buffer, 131 | ) 132 | self.actors.append(actor) 133 | elif method == "iql": 134 | for i in range(num_thread): 135 | thread_actors = [] 136 | for _ in range(num_player): 137 | actor = rela.R2D2Actor( 138 | self.model_runners[i % self.num_runners], 139 | multi_step, 140 | num_game_per_thread, 141 | gamma, 142 | eta, 143 | max_len, 144 | 1, 145 | replay_buffer, 146 | ) 147 | thread_actors.append(actor) 148 | self.actors.append(thread_actors) 149 | print("ActGroup created") 150 | self.state_dicts = [] 151 | 152 | def start(self): 153 | for runner in self.model_runners: 154 | runner.start() 155 | 156 | def update_model(self, agent): 157 | for runner in self.model_runners: 158 | runner.update_model(agent) 159 | 160 | 161 | class ContActGroup: 162 | """ 163 | Creates actors given the agents. Starts to stores transitions in the replay_buffer by calling ContActGroup.start() 164 | Args: 165 | method(str): iql or vdn 166 | devices(str): cuda1 167 | agent_list(list): list of a learner and its partners 168 | num_thread(int): default=10 169 | num_game_per_thread(int): default=80 170 | multi_step(int): default=3 171 | gamma(float): discount factor 172 | eta(float): eta for aggregate priority 173 | max_len(int): max seq len 174 | num_player(int): default=2 175 | is_rand(bool): To randomize ordering of the learner and its partners or not 176 | replay_buffer(object): a replay buffer onject 177 | Returns: 178 | None 179 | """ 180 | 181 | def __init__( 182 | self, 183 | method, 184 | devices, 185 | agent_list, 186 | num_thread, 187 | num_game_per_thread, 188 | multi_step, 189 | gamma, 190 | eta, 191 | max_len, 192 | num_player, 193 | is_rand, 194 | replay_buffer, 195 | ): 196 | self.devices = devices.split(",") 197 | self.flags = [] 198 | self.model_runners = [] 199 | self.is_rand = is_rand 200 | 201 | for dev in self.devices: 202 | learnable_runner = rela.BatchRunner( 203 | agent_list[0].clone(dev), dev, 100, ["act", "compute_priority"] 204 | ) 205 | fixed_runner = rela.BatchRunner( 206 | agent_list[1].clone(dev), dev, 100, ["act", "compute_priority"] 207 | ) 208 | if self.is_rand: 209 | flag = np.random.randint(0, num_player) 210 | if flag == 0: 211 | self.model_runners.append([learnable_runner, fixed_runner]) 212 | elif flag == 1: 213 | self.model_runners.append([fixed_runner, learnable_runner]) 214 | 215 | self.flags.append(flag) 216 | else: 217 | self.model_runners.append([learnable_runner, fixed_runner]) 218 | 219 | self.num_runners = len(self.model_runners) 220 | 221 | self.actors = [] 222 | if method == "vdn": 223 | for i in range(num_thread): 224 | actor = rela.R2D2Actor( 225 | self.model_runners[i % self.num_runners], 226 | multi_step, 227 | num_game_per_thread, 228 | gamma, 229 | eta, 230 | max_len, 231 | num_player, 232 | replay_buffer, 233 | ) 234 | self.actors.append(actor) 235 | elif method == "iql": 236 | for i in range(num_thread): 237 | thread_actors = [] 238 | for n in range(num_player): 239 | actor = rela.R2D2Actor( 240 | self.model_runners[i % self.num_runners][n], 241 | multi_step, 242 | num_game_per_thread, 243 | gamma, 244 | eta, 245 | max_len, 246 | 1, 247 | replay_buffer, 248 | ) 249 | thread_actors.append(actor) 250 | self.actors.append(thread_actors) 251 | self.state_dicts = [] 252 | 253 | def start(self): 254 | for runner in self.model_runners: 255 | runner[0].start() 256 | runner[1].start() 257 | 258 | def update_model(self, agent): 259 | for idx, runner in enumerate(self.model_runners): 260 | if self.is_rand: 261 | if self.flags[idx] == 0: 262 | runner[0].update_model(agent) 263 | elif self.flags[idx] == 1: 264 | runner[1].update_model(agent) 265 | else: 266 | runner[0].update_model(agent) 267 | -------------------------------------------------------------------------------- /pyhanabi/selfplay.py: -------------------------------------------------------------------------------- 1 | # 2 | import time 3 | import os 4 | import sys 5 | import argparse 6 | import pprint 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | from create import create_envs, create_threads, ActGroup 14 | from eval import evaluate 15 | import common_utils 16 | import rela 17 | import r2d2_lstm as r2d2_lstm 18 | import r2d2_gru as r2d2_gru 19 | import utils 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="train dqn on hanabi") 24 | parser.add_argument("--save_dir", type=str, default="exps/exp1") 25 | parser.add_argument("--method", type=str, default="vdn") 26 | parser.add_argument("--shuffle_obs", type=int, default=0) 27 | parser.add_argument("--shuffle_color", type=int, default=0) 28 | parser.add_argument("--pred_weight", type=float, default=0) 29 | parser.add_argument("--num_eps", type=int, default=80) 30 | 31 | parser.add_argument("--load_model", type=str, default="") 32 | 33 | parser.add_argument("--seed", type=int, default=10001) 34 | parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") 35 | parser.add_argument( 36 | "--eta", type=float, default=0.9, help="eta for aggregate priority" 37 | ) 38 | parser.add_argument("--train_bomb", type=int, default=0) 39 | parser.add_argument("--eval_bomb", type=int, default=0) 40 | parser.add_argument("--sad", type=int, default=0) 41 | parser.add_argument("--num_player", type=int, default=2) 42 | parser.add_argument("--hand_size", type=int, default=5) 43 | 44 | # optimization/training settings 45 | parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") 46 | parser.add_argument("--eps", type=float, default=1.5e-4, help="Adam epsilon") 47 | parser.add_argument("--grad_clip", type=float, default=50, help="max grad norm") 48 | 49 | parser.add_argument("--rnn_type", type=str, default="lstm") 50 | parser.add_argument("--num_fflayer", type=int, default=1) 51 | parser.add_argument("--num_rnn_layer", type=int, default=2) 52 | parser.add_argument("--rnn_hid_dim", type=int, default=512) 53 | 54 | parser.add_argument("--train_device", type=str, default="cuda:0") 55 | parser.add_argument("--batchsize", type=int, default=128) 56 | parser.add_argument("--num_epoch", type=int, default=5000) 57 | parser.add_argument("--epoch_len", type=int, default=1000) 58 | parser.add_argument("--num_update_between_sync", type=int, default=2500) 59 | 60 | # DQN settings 61 | parser.add_argument("--multi_step", type=int, default=3) 62 | 63 | # replay buffer settings 64 | parser.add_argument("--burn_in_frames", type=int, default=80000) 65 | parser.add_argument("--replay_buffer_size", type=int, default=2 ** 20) 66 | parser.add_argument( 67 | "--priority_exponent", 68 | type=float, 69 | default=0.6, 70 | help="prioritized replay alpha", 71 | ) 72 | parser.add_argument( 73 | "--priority_weight", 74 | type=float, 75 | default=0.4, 76 | help="prioritized replay beta", 77 | ) 78 | parser.add_argument("--max_len", type=int, default=80, help="max seq len") 79 | parser.add_argument("--prefetch", type=int, default=3, help="#prefetch batch") 80 | 81 | # thread setting 82 | parser.add_argument("--num_thread", type=int, default=40, help="#thread_loop") 83 | parser.add_argument("--num_game_per_thread", type=int, default=20) 84 | 85 | # actor setting 86 | parser.add_argument("--act_base_eps", type=float, default=0.4) 87 | parser.add_argument("--act_eps_alpha", type=float, default=7) 88 | parser.add_argument("--act_device", type=str, default="cuda:1") 89 | parser.add_argument("--actor_sync_freq", type=int, default=10) 90 | 91 | ## args dump settings 92 | parser.add_argument("--args_dump_name", type=str, default="iql_2p.txt") 93 | 94 | args = parser.parse_args() 95 | assert args.method in ["vdn", "iql"] 96 | return args 97 | 98 | 99 | if __name__ == "__main__": 100 | torch.backends.cudnn.benchmark = True 101 | args = parse_args() 102 | 103 | if not os.path.exists(args.save_dir): 104 | os.makedirs(args.save_dir) 105 | 106 | args.args_dump_name = f"{args.method}_2p_{args.seed}.txt" 107 | 108 | with open(f"{args.save_dir}/{args.args_dump_name}", "w") as f: 109 | json.dump(args.__dict__, f, indent=2) 110 | 111 | logger_path = os.path.join(args.save_dir, "train.log") 112 | sys.stdout = common_utils.Logger(logger_path) 113 | saver = common_utils.TopkSaver(args.save_dir, 5) 114 | 115 | common_utils.set_all_seeds(args.seed) 116 | pprint.pprint(vars(args)) 117 | 118 | if args.method == "vdn": 119 | args.batchsize = int(np.round(args.batchsize / args.num_player)) 120 | args.replay_buffer_size //= args.num_player 121 | args.burn_in_frames //= args.num_player 122 | 123 | explore_eps = utils.generate_explore_eps( 124 | args.act_base_eps, args.act_eps_alpha, args.num_eps 125 | ) 126 | expected_eps = np.mean(explore_eps) 127 | print("explore eps:", explore_eps) 128 | print("avg explore eps:", np.mean(explore_eps)) 129 | 130 | games = create_envs( 131 | args.num_thread * args.num_game_per_thread, 132 | args.seed, 133 | args.num_player, 134 | args.hand_size, 135 | args.train_bomb, 136 | explore_eps, 137 | args.max_len, 138 | args.sad, 139 | args.shuffle_obs, 140 | args.shuffle_color, 141 | ) 142 | 143 | if args.rnn_type == "lstm": 144 | agent = r2d2_lstm.R2D2Agent( 145 | (args.method == "vdn"), 146 | args.multi_step, 147 | args.gamma, 148 | args.eta, 149 | args.train_device, 150 | games[0].feature_size(), 151 | args.rnn_hid_dim, 152 | games[0].num_action(), 153 | args.num_fflayer, 154 | args.num_rnn_layer, 155 | args.hand_size, 156 | False, # uniform priority 157 | ) 158 | elif args.rnn_type == "gru": 159 | agent = r2d2_gru.R2D2Agent( 160 | (args.method == "vdn"), 161 | args.multi_step, 162 | args.gamma, 163 | args.eta, 164 | args.train_device, 165 | games[0].feature_size(), 166 | args.rnn_hid_dim, 167 | games[0].num_action(), 168 | args.num_fflayer, 169 | args.num_rnn_layer, 170 | args.hand_size, 171 | False, # uniform priority 172 | ) 173 | 174 | agent.sync_target_with_online() 175 | 176 | if args.load_model: 177 | print("*****loading pretrained model*****") 178 | utils.load_weight(agent.online_net, args.load_model, args.train_device) 179 | print("*****done*****") 180 | 181 | agent = agent.to(args.train_device) 182 | optim = torch.optim.Adam(agent.online_net.parameters(), lr=args.lr, eps=args.eps) 183 | print(agent) 184 | eval_agent = agent.clone(args.train_device, {"vdn": False}) 185 | 186 | replay_buffer = rela.RNNPrioritizedReplay( 187 | args.replay_buffer_size, 188 | args.seed, 189 | args.priority_exponent, 190 | args.priority_weight, 191 | args.prefetch, 192 | ) 193 | 194 | act_group = ActGroup( 195 | args.method, 196 | args.act_device, 197 | agent, 198 | args.num_thread, 199 | args.num_game_per_thread, 200 | args.multi_step, 201 | args.gamma, 202 | args.eta, 203 | args.max_len, 204 | args.num_player, 205 | replay_buffer, 206 | ) 207 | 208 | assert args.shuffle_obs == False, "not working with 2nd order aux" 209 | context, threads = create_threads( 210 | args.num_thread, 211 | args.num_game_per_thread, 212 | act_group.actors, 213 | games, 214 | ) 215 | act_group.start() 216 | context.start() 217 | while replay_buffer.size() < args.burn_in_frames: 218 | print("warming up replay buffer:", replay_buffer.size()) 219 | time.sleep(1) 220 | 221 | print("Success, Done") 222 | print("=======================") 223 | 224 | frame_stat = dict() 225 | frame_stat["num_acts"] = 0 226 | frame_stat["num_buffer"] = 0 227 | 228 | stat = common_utils.MultiCounter(args.save_dir) 229 | tachometer = utils.Tachometer() 230 | stopwatch = common_utils.Stopwatch() 231 | 232 | for epoch in range(args.num_epoch): 233 | print("beginning of epoch: ", epoch) 234 | print(common_utils.get_mem_usage()) 235 | tachometer.start() 236 | stat.reset() 237 | stopwatch.reset() 238 | 239 | for batch_idx in range(args.epoch_len): 240 | num_update = batch_idx + epoch * args.epoch_len 241 | if num_update % args.num_update_between_sync == 0: 242 | agent.sync_target_with_online() 243 | if num_update % args.actor_sync_freq == 0: 244 | act_group.update_model(agent) 245 | 246 | torch.cuda.synchronize() 247 | stopwatch.time("sync and updating") 248 | 249 | batch, weight = replay_buffer.sample(args.batchsize, args.train_device) 250 | stopwatch.time("sample data") 251 | 252 | loss, priority = agent.loss(batch, args.pred_weight, stat) 253 | priority = rela.aggregate_priority( 254 | priority.cpu(), batch.seq_len.cpu(), args.eta 255 | ) 256 | loss = (loss * weight).mean() 257 | loss.backward() 258 | 259 | torch.cuda.synchronize() 260 | stopwatch.time("forward & backward") 261 | 262 | g_norm = torch.nn.utils.clip_grad_norm_( 263 | agent.online_net.parameters(), args.grad_clip 264 | ) 265 | optim.step() 266 | optim.zero_grad() 267 | 268 | torch.cuda.synchronize() 269 | stopwatch.time("update model") 270 | 271 | replay_buffer.update_priority(priority) 272 | stopwatch.time("updating priority") 273 | 274 | stat["loss"].feed(loss.detach().item()) 275 | stat["grad_norm"].feed(g_norm) 276 | 277 | count_factor = args.num_player if args.method == "vdn" else 1 278 | print("EPOCH: %d" % epoch) 279 | tachometer.lap( 280 | act_group.actors, 281 | replay_buffer, 282 | args.epoch_len * args.batchsize, 283 | count_factor, 284 | ) 285 | stopwatch.summary() 286 | stat.summary(epoch) 287 | 288 | context.pause() 289 | eval_seed = (9917 + epoch * 999999) % 7777777 290 | eval_agent.load_state_dict(agent.state_dict()) 291 | score, perfect, *_ = evaluate( 292 | [eval_agent for _ in range(args.num_player)], 293 | 1000, 294 | eval_seed, 295 | args.eval_bomb, 296 | 0, # explore eps 297 | args.sad, 298 | ) 299 | if epoch > 0 and epoch % 200 == 0: 300 | force_save_name = "model_epoch%d" % epoch 301 | else: 302 | force_save_name = None 303 | model_saved = saver.save( 304 | None, agent.online_net.state_dict(), score, force_save_name=force_save_name 305 | ) 306 | print( 307 | "epoch %d, eval score: %.4f, perfect: %.2f, model saved: %s" 308 | % (epoch, score, perfect * 100, model_saved) 309 | ) 310 | context.resume() 311 | print("==========") 312 | -------------------------------------------------------------------------------- /pyhanabi/continual_evaluation.py: -------------------------------------------------------------------------------- 1 | """ Evaluating all the checkpoints saved periodically during train args.eval_freq 2 | Requires only 1 GPU. 3 | Sample usage: 4 | python continual_evaluation.py 5 | --weight_1_dir 6 | --weight_2 i.e a.pthw b.pthw ... 7 | --num_player 2 8 | note the last arg of --weight_2 is the self-play agent that is the agent that was being trained in continual fashion... 9 | """ 10 | 11 | import argparse 12 | import os 13 | import sys 14 | import glob 15 | import wandb 16 | import json 17 | 18 | lib_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 19 | sys.path.append(lib_path) 20 | 21 | import numpy as np 22 | import torch 23 | import utils 24 | from eval import evaluate_legacy_model 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--weight_1_dir", default=None, type=str, required=True) 30 | parser.add_argument("--weight_2", default=None, type=str, nargs="+", required=True) 31 | parser.add_argument("--is_rand", action="store_true", default=True) 32 | parser.add_argument("--num_player", default=None, type=int, required=True) 33 | args = parser.parse_args() 34 | 35 | test_dir = f"{args.weight_1_dir}_test_models" 36 | 37 | if not os.path.exists(test_dir): 38 | os.makedirs(test_dir) 39 | 40 | cont_train_args_txt = glob.glob(f"{args.weight_1_dir}/*.txt") 41 | 42 | # move cont_args.txt to test_dir 43 | move_cont_args = f"cp {cont_train_args_txt[0]} {test_dir}/" 44 | os.system(move_cont_args) 45 | 46 | with open(cont_train_args_txt[0], "r") as f: 47 | agent_args = {**json.load(f)} 48 | 49 | ## move learnable model to test_dir 50 | if agent_args["load_learnable_model"] != "": 51 | initial_learnable_model = agent_args["load_learnable_model"] 52 | move_model_0 = ( 53 | f"cp {initial_learnable_model} {test_dir}/model_epoch0_zero_shot.pthw" 54 | ) 55 | os.system(move_model_0) 56 | 57 | exp_name = agent_args["save_dir"].split("/")[-1] 58 | wandb.init(project="Lifelong_Hanabi_project", name=exp_name) 59 | wandb.config.update(agent_args) 60 | 61 | assert os.path.exists(args.weight_1_dir) 62 | weight_1 = [] 63 | weight_1 = glob.glob(args.weight_1_dir + "/*.pthw") 64 | weight_1.sort(key=os.path.getmtime) 65 | 66 | ## check if everything in weights_2 exist 67 | for ag2 in args.weight_2: 68 | assert os.path.exists(ag2) 69 | 70 | slice_epoch = int(agent_args["num_epoch"]) * (len(args.weight_2) - 1) 71 | act_steps = utils.get_act_steps(args.weight_1_dir, slice_epoch) 72 | 73 | cur_task = 0 74 | prev_max = [0] * len(args.weight_2) 75 | prev_task_max = [0] * len(args.weight_2) 76 | prev_max_fs = [0] * len(args.weight_2) 77 | prev_task_max_fs = [0] * len(args.weight_2) 78 | avg_fs_score = 0 79 | avg_fs_future_score = 0 80 | avg_fs_forgetting = 0 81 | all_done = 0 82 | total_prev_act_steps = 0 83 | 84 | for ag1_idx, ag1 in enumerate(weight_1): 85 | ag1_name = ag1.split("/")[-1].split("_")[-1] 86 | act_epoch_cnt = int(ag1.split("/")[-1].split("_")[1][5:]) 87 | ### move zs ckpts after every task to test dir 88 | if act_epoch_cnt % int(agent_args["num_epoch"]) == 0: 89 | if ag1_name == "shot.pthw": 90 | move_zs_ckpt = f"cp {ag1} {test_dir}/" 91 | os.system(move_zs_ckpt) 92 | 93 | ### this is for different zero-shot evaluations... 94 | total_tasks = len(args.weight_2) 95 | if ag1_name == "shot.pthw": 96 | all_done += 1 97 | avg_score = 0 98 | avg_future_score = 0 99 | avg_forgetting = 0 100 | 101 | for fixed_agent_idx in range(len(args.weight_2)): 102 | weight_files = [ag1, args.weight_2[fixed_agent_idx]] 103 | mean_score, sem, perfect_rate = evaluate_legacy_model( 104 | weight_files, 1000, 1, 0, agent_args, args, num_run=5 105 | ) 106 | 107 | if mean_score > prev_max[fixed_agent_idx]: 108 | prev_max[fixed_agent_idx] = mean_score 109 | wandb.log( 110 | { 111 | "epoch_zeroshot": act_epoch_cnt, 112 | "eval_score_zeroshot_" + str(fixed_agent_idx): mean_score, 113 | "perfect_zeroshot_" + str(fixed_agent_idx): perfect_rate, 114 | "sem_zeroshot_" + str(fixed_agent_idx): sem, 115 | "total_act_steps": ( 116 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 117 | ), 118 | } 119 | ) 120 | if fixed_agent_idx == cur_task: 121 | wandb.log( 122 | { 123 | "epoch_zs_curtask": act_epoch_cnt, 124 | "eval_score_zs_curtask": mean_score, 125 | "perfect_zs_curtask": perfect_rate, 126 | "sem_zs_curtask": sem, 127 | "total_act_steps": ( 128 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 129 | ), 130 | } 131 | ) 132 | 133 | if fixed_agent_idx <= cur_task: 134 | avg_score += mean_score 135 | if fixed_agent_idx > cur_task: 136 | avg_future_score += mean_score 137 | if cur_task > 0: 138 | forgetting = prev_task_max[fixed_agent_idx] - mean_score 139 | if fixed_agent_idx < cur_task: 140 | avg_forgetting += forgetting 141 | wandb.log( 142 | { 143 | "epoch_zs_forgetting": act_epoch_cnt, 144 | "forgetting_zs_" + str(fixed_agent_idx): forgetting, 145 | "total_act_steps": ( 146 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 147 | ), 148 | } 149 | ) 150 | 151 | avg_score = avg_score / (cur_task + 1) 152 | avg_future_score = avg_future_score / (total_tasks - (cur_task + 1)) 153 | wandb.log( 154 | { 155 | "epoch_zs_avg_score": act_epoch_cnt, 156 | "avg_zs_score": avg_score, 157 | "avg_future_zs_score": avg_future_score, 158 | "total_act_steps": ( 159 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 160 | ), 161 | } 162 | ) 163 | 164 | if cur_task > 0: 165 | avg_forgetting = avg_forgetting / (cur_task) 166 | wandb.log( 167 | { 168 | "epoch_zs_avg_forgetting": act_epoch_cnt, 169 | "avg_zs_forgetting": avg_forgetting, 170 | "total_act_steps": ( 171 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 172 | ), 173 | } 174 | ) 175 | 176 | else: 177 | ## for different few shot evaluations ... 178 | for i in range(len(args.weight_2)): 179 | if ag1_name == str(i) + ".pthw": 180 | all_done += 1 181 | weight_files = [ag1, args.weight_2[i]] 182 | 183 | cur_ag_id = ag1_name.split(".")[0] 184 | 185 | mean_score, sem, perfect_rate = evaluate_legacy_model( 186 | weight_files, 1000, 1, 0, agent_args, args, num_run=5 187 | ) 188 | if mean_score > prev_max_fs[int(cur_ag_id)]: 189 | prev_max_fs[int(cur_ag_id)] = mean_score 190 | 191 | wandb.log( 192 | { 193 | "epoch_fewshot": act_epoch_cnt, 194 | "eval_score_fewshot_" + cur_ag_id: mean_score, 195 | "perfect_fewshot_" + cur_ag_id: perfect_rate, 196 | "sem_fewshot_" + cur_ag_id: sem, 197 | "total_act_steps": ( 198 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 199 | ), 200 | } 201 | ) 202 | 203 | if int(cur_ag_id) <= cur_task: 204 | avg_fs_score += mean_score 205 | if int(cur_ag_id) > cur_task: 206 | avg_fs_future_score += mean_score 207 | if int(cur_ag_id) == cur_task: 208 | wandb.log( 209 | { 210 | "epoch_fs_curtask": act_epoch_cnt, 211 | "eval_score_fs_curtask": mean_score, 212 | "perfect_fs_curtask": perfect_rate, 213 | "sem_fs_curtask": sem, 214 | "total_act_steps": ( 215 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 216 | ), 217 | } 218 | ) 219 | 220 | if cur_task > 0: 221 | forgetting_fs = prev_task_max_fs[int(cur_ag_id)] - mean_score 222 | if int(cur_ag_id) < cur_task: 223 | avg_fs_forgetting += forgetting_fs 224 | wandb.log( 225 | { 226 | "epoch_fs_forgetting": act_epoch_cnt, 227 | "forgetting_fs_" + cur_ag_id: forgetting_fs, 228 | "total_act_steps": ( 229 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 230 | ), 231 | } 232 | ) 233 | 234 | if all_done % (total_tasks + 1) == 0: 235 | avg_fs_score = avg_fs_score / (cur_task + 1) 236 | avg_fs_future_score = avg_fs_future_score / (total_tasks - (cur_task + 1)) 237 | wandb.log( 238 | { 239 | "epoch_fs_avg_score": act_epoch_cnt, 240 | "avg_fs_score": avg_fs_score, 241 | "avg_fs_future_score": avg_fs_future_score, 242 | "total_act_steps": ( 243 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 244 | ), 245 | } 246 | ) 247 | 248 | avg_fs_score = 0 249 | avg_fs_future_score = 0 250 | 251 | if cur_task > 0: 252 | avg_fs_forgetting = avg_fs_forgetting / cur_task 253 | wandb.log( 254 | { 255 | "epoch_fs_avg_forgetting": act_epoch_cnt, 256 | "avg_fs_forgetting": avg_fs_forgetting, 257 | "total_act_steps": ( 258 | total_prev_act_steps + act_steps[act_epoch_cnt - 1] 259 | ), 260 | } 261 | ) 262 | 263 | avg_fs_forgetting = 0 264 | 265 | if ( 266 | act_epoch_cnt == agent_args["num_epoch"] * (cur_task + 1) 267 | and all_done % (total_tasks + 1) == 0 268 | ): 269 | cur_task += 1 270 | for fixed_agent_idx in range(len(args.weight_2)): 271 | prev_task_max[fixed_agent_idx] = prev_max[fixed_agent_idx] 272 | prev_task_max_fs[fixed_agent_idx] = prev_max_fs[fixed_agent_idx] 273 | all_done = 0 274 | total_prev_act_steps += act_steps[act_epoch_cnt - 1] 275 | -------------------------------------------------------------------------------- /rela/prioritized_replay.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "rela/tensor_dict.h" 11 | #include "rela/transition.h" 12 | 13 | namespace rela { 14 | 15 | template 16 | class ConcurrentQueue { 17 | public: 18 | ConcurrentQueue(int capacity) 19 | : capacity(capacity) 20 | , head_(0) 21 | , tail_(0) 22 | , size_(0) 23 | , safeTail_(0) 24 | , safeSize_(0) 25 | , sum_(0) 26 | , evicted_(capacity, false) 27 | , elements_(capacity) 28 | , weights_(capacity, 0) { 29 | } 30 | 31 | int safeSize(float* sum) const { 32 | std::unique_lock lk(m_); 33 | if (sum != nullptr) { 34 | *sum = sum_; 35 | } 36 | return safeSize_; 37 | } 38 | 39 | int size() const { 40 | std::unique_lock lk(m_); 41 | return size_; 42 | } 43 | 44 | void blockAppend(const std::vector& block, const torch::Tensor& weights) { 45 | int blockSize = block.size(); 46 | 47 | std::unique_lock lk(m_); 48 | cvSize_.wait(lk, [=] { return size_ + blockSize <= capacity; }); 49 | 50 | int start = tail_; 51 | int end = (tail_ + blockSize) % capacity; 52 | 53 | tail_ = end; 54 | size_ += blockSize; 55 | checkSize(head_, tail_, size_); 56 | 57 | lk.unlock(); 58 | 59 | float sum = 0; 60 | auto weightAcc = weights.accessor(); 61 | assert(weightAcc.size(0) == blockSize); 62 | for (int i = 0; i < blockSize; ++i) { 63 | int j = (start + i) % capacity; 64 | elements_[j] = block[i]; 65 | weights_[j] = weightAcc[i]; 66 | sum += weightAcc[i]; 67 | } 68 | 69 | lk.lock(); 70 | 71 | cvTail_.wait(lk, [=] { return safeTail_ == start; }); 72 | safeTail_ = end; 73 | safeSize_ += blockSize; 74 | sum_ += sum; 75 | checkSize(head_, safeTail_, safeSize_); 76 | 77 | lk.unlock(); 78 | cvTail_.notify_all(); 79 | } 80 | 81 | // ------------------------------------------------------------- // 82 | // blockPop, update are thread-safe against blockAppend 83 | // but they are NOT thread-safe against each other 84 | 85 | void blockPop(int blockSize) { 86 | double diff = 0; 87 | int head = head_; 88 | for (int i = 0; i < blockSize; ++i) { 89 | diff -= weights_[head]; 90 | evicted_[head] = true; 91 | head = (head + 1) % capacity; 92 | } 93 | 94 | { 95 | std::lock_guard lk(m_); 96 | sum_ += diff; 97 | head_ = head; 98 | safeSize_ -= blockSize; 99 | size_ -= blockSize; 100 | assert(safeSize_ >= 0); 101 | checkSize(head_, safeTail_, safeSize_); 102 | } 103 | cvSize_.notify_all(); 104 | } 105 | 106 | void update(const std::vector& ids, const torch::Tensor& weights) { 107 | double diff = 0; 108 | auto weightAcc = weights.accessor(); 109 | for (int i = 0; i < (int)ids.size(); ++i) { 110 | auto id = ids[i]; 111 | if (evicted_[id]) { 112 | continue; 113 | } 114 | diff += (weightAcc[i] - weights_[id]); 115 | weights_[id] = weightAcc[i]; 116 | } 117 | 118 | std::lock_guard lk_(m_); 119 | sum_ += diff; 120 | } 121 | 122 | // ------------------------------------------------------------- // 123 | // accessing elements is never locked, operate safely! 124 | 125 | DataType get(int idx) { 126 | int id = (head_ + idx) % capacity; 127 | return elements_[id]; 128 | } 129 | 130 | DataType getElementAndMark(int idx) { 131 | int id = (head_ + idx) % capacity; 132 | evicted_[id] = false; 133 | return elements_[id]; 134 | } 135 | 136 | float getWeight(int idx, int* id) { 137 | assert(id != nullptr); 138 | *id = (head_ + idx) % capacity; 139 | return weights_[*id]; 140 | } 141 | 142 | const int capacity; 143 | 144 | private: 145 | void checkSize(int head, int tail, int size) { 146 | if (size == 0) { 147 | assert(tail == head); 148 | } else if (tail > head) { 149 | if (tail - head != size) { 150 | std::cout << "tail-head: " << tail - head << " vs size: " << size << std::endl; 151 | } 152 | assert(tail - head == size); 153 | } else { 154 | if (tail + capacity - head != size) { 155 | std::cout << "tail-head: " << tail + capacity - head << " vs size: " << size 156 | << std::endl; 157 | } 158 | assert(tail + capacity - head == size); 159 | } 160 | } 161 | 162 | mutable std::mutex m_; 163 | std::condition_variable cvSize_; 164 | std::condition_variable cvTail_; 165 | 166 | int head_; 167 | int tail_; 168 | int size_; 169 | 170 | int safeTail_; 171 | int safeSize_; 172 | double sum_; 173 | std::vector evicted_; 174 | 175 | std::vector elements_; 176 | std::vector weights_; 177 | }; 178 | 179 | template 180 | class PrioritizedReplay { 181 | public: 182 | PrioritizedReplay(int capacity, int seed, float alpha, float beta, int prefetch) 183 | : alpha_(alpha) // priority exponent 184 | , beta_(beta) // importance sampling exponent 185 | , prefetch_(prefetch) 186 | , capacity_(capacity) 187 | , storage_(int(1.25 * capacity)) 188 | , numAdd_(0) { 189 | rng_.seed(seed); 190 | } 191 | 192 | void add(const std::vector& sample, const torch::Tensor& priority) { 193 | assert(priority.dim() == 1); 194 | auto weights = torch::pow(priority, alpha_); 195 | storage_.blockAppend(sample, weights); 196 | numAdd_ += priority.size(0); 197 | } 198 | 199 | void add(const DataType& sample, const torch::Tensor& priority) { 200 | std::vector vec; 201 | int n = priority.size(0); 202 | for (int i = 0; i < n; ++i) { 203 | vec.push_back(sample.index(i)); 204 | } 205 | add(vec, priority); 206 | } 207 | 208 | std::tuple sample(int batchsize, const std::string& device) { 209 | if (!sampledIds_.empty()) { 210 | std::cout << "Error: previous samples' priority has not been updated." << std::endl; 211 | assert(false); 212 | } 213 | // std::cout << "Batch size inside sample is : " << batchsize << std::endl; 214 | 215 | DataType batch; 216 | torch::Tensor priority; 217 | if (prefetch_ == 0) { 218 | std::tie(batch, priority, sampledIds_) = sample_(batchsize, device); 219 | return std::make_tuple(batch, priority); 220 | } 221 | 222 | if (futures_.empty()) { 223 | std::tie(batch, priority, sampledIds_) = sample_(batchsize, device); 224 | } else { 225 | // assert(futures_.size() == 1); 226 | std::tie(batch, priority, sampledIds_) = futures_.front().get(); 227 | futures_.pop(); 228 | } 229 | 230 | while ((int)futures_.size() < prefetch_) { 231 | auto f = std::async( 232 | std::launch::async, 233 | &PrioritizedReplay::sample_, 234 | this, 235 | batchsize, 236 | device); 237 | futures_.push(std::move(f)); 238 | } 239 | 240 | return std::make_tuple(batch, priority); 241 | } 242 | 243 | void updatePriority(const torch::Tensor& priority) { 244 | if (priority.size(0) == 0) { 245 | sampledIds_.clear(); 246 | return; 247 | } 248 | // std::cout << "inside updatePriority sample Ids size is " << (int)sampledIds_.size() << std::endl; 249 | // std::cout << "inside updatePriority priority size is " << priority.size(0) << std::endl; 250 | 251 | assert(priority.dim() == 1); 252 | assert((int)sampledIds_.size() == priority.size(0)); 253 | 254 | auto weights = torch::pow(priority, alpha_); 255 | { 256 | std::lock_guard lk(mSampler_); 257 | storage_.update(sampledIds_, weights); 258 | } 259 | sampledIds_.clear(); 260 | } 261 | 262 | void slice(int drop_size){ 263 | storage_.blockPop(drop_size); 264 | } 265 | 266 | DataType get(int idx) { 267 | return storage_.get(idx); 268 | } 269 | 270 | int size() const { 271 | return storage_.safeSize(nullptr); 272 | } 273 | 274 | int numAdd() const { 275 | return numAdd_; 276 | } 277 | 278 | private: 279 | using SampleWeightIds = std::tuple>; 280 | 281 | SampleWeightIds sample_(int batchsize, const std::string& device) { 282 | std::unique_lock lk(mSampler_); 283 | 284 | // std::cout << "batch size inside sample_ function is " << batchsize << std::endl; 285 | float sum; 286 | int size = storage_.safeSize(&sum); 287 | assert(size >= batchsize); 288 | // std::cout << "size: "<< size << ", sum: " << sum << std::endl; 289 | // storage_ [0, size) remains static in the subsequent section 290 | 291 | float segment = sum / batchsize; 292 | std::uniform_real_distribution dist(0.0, segment); 293 | 294 | std::vector samples; 295 | auto weights = torch::zeros({batchsize}, torch::kFloat32); 296 | auto weightAcc = weights.accessor(); 297 | std::vector ids(batchsize); 298 | 299 | double accSum = 0; 300 | int nextIdx = 0; 301 | float w = 0; 302 | int id = 0; 303 | for (int i = 0; i < batchsize; i++) { 304 | float rand = dist(rng_) + i * segment; 305 | rand = std::min(sum - (float)0.1, rand); 306 | // std::cout << "looking for " << i << "th/" << batchsize << " sample" << 307 | // std::endl; 308 | // std::cout << "\ttarget: " << rand << std::endl; 309 | 310 | while (nextIdx <= size) { 311 | if (accSum > 0 && accSum >= rand) { 312 | assert(nextIdx >= 1); 313 | // std::cout << "\tfound: " << nextIdx - 1 << ", " << id << ", " << 314 | // accSum << std::endl; 315 | DataType element = storage_.getElementAndMark(nextIdx - 1); 316 | samples.push_back(element); 317 | weightAcc[i] = w; 318 | ids[i] = id; 319 | break; 320 | } 321 | 322 | if (nextIdx == size) { 323 | std::cout << "nextIdx: " << nextIdx << "/" << size << std::endl; 324 | std::cout << std::setprecision(10) << "accSum: " << accSum << ", sum: " << sum 325 | << ", rand: " << rand << std::endl; 326 | assert(false); 327 | } 328 | 329 | w = storage_.getWeight(nextIdx, &id); 330 | accSum += w; 331 | ++nextIdx; 332 | } 333 | } 334 | // std::cout << "samples size after populating inside samples_ is " << (int)samples.size() << std::endl; 335 | // std::cout << "batch size before asserting inside samples_ is " << batchsize << std::endl; 336 | assert((int)samples.size() == batchsize); 337 | 338 | // pop storage if full 339 | size = storage_.size(); 340 | if (size > capacity_) { 341 | storage_.blockPop(size - capacity_); 342 | } 343 | 344 | // safe to unlock, because contains copys 345 | lk.unlock(); 346 | 347 | weights = weights / sum; 348 | weights = torch::pow(size * weights, -beta_); 349 | weights /= weights.max(); 350 | if (device != "cpu") { 351 | weights = weights.to(torch::Device(device)); 352 | } 353 | auto batch = DataType::makeBatch(samples, device); 354 | return std::make_tuple(batch, weights, ids); 355 | } 356 | 357 | const float alpha_; 358 | const float beta_; 359 | const int prefetch_; 360 | const int capacity_; 361 | 362 | ConcurrentQueue storage_; 363 | std::atomic numAdd_; 364 | 365 | // make sure that sample & update does not overlap 366 | std::mutex mSampler_; 367 | std::vector sampledIds_; 368 | std::queue> futures_; 369 | 370 | std::mt19937 rng_; 371 | }; 372 | 373 | // template class PrioritizedReplay; 374 | using FFPrioritizedReplay = PrioritizedReplay; 375 | 376 | // template class PrioritizedReplay; 377 | using RNNPrioritizedReplay = PrioritizedReplay; 378 | } 379 | -------------------------------------------------------------------------------- /pyhanabi/r2d2_gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple, Dict 4 | import common_utils 5 | 6 | 7 | class R2D2Net(torch.jit.ScriptModule): 8 | __constants__ = ["hid_dim", "out_dim", "num_fflayer", "num_rnn_layer", "hand_size"] 9 | 10 | def __init__( 11 | self, device, in_dim, hid_dim, out_dim, num_fflayer, num_rnn_layer, hand_size 12 | ): 13 | super().__init__() 14 | self.in_dim = in_dim 15 | self.hid_dim = hid_dim 16 | self.out_dim = out_dim 17 | self.num_fflayer = num_fflayer 18 | self.num_rnn_layer = num_rnn_layer 19 | self.hand_size = hand_size 20 | 21 | layers = [nn.Linear(self.in_dim, self.hid_dim), nn.ReLU()] 22 | for i in range(1, self.num_fflayer): 23 | layers += [nn.Linear(self.hid_dim, self.hid_dim), nn.ReLU()] 24 | 25 | self.net = nn.Sequential(*layers) 26 | 27 | self.gru = nn.GRU( 28 | self.hid_dim, 29 | self.hid_dim, 30 | num_layers=self.num_rnn_layer, # , batch_first=True 31 | ).to(device) 32 | 33 | self.gru.flatten_parameters() 34 | 35 | self.fc_v = nn.Linear(self.hid_dim, 1) 36 | self.fc_a = nn.Linear(self.hid_dim, self.out_dim) 37 | 38 | # for aux task 39 | self.pred = nn.Linear(self.hid_dim, self.hand_size * 3) 40 | 41 | @torch.jit.script_method 42 | def get_h0(self, batchsize: int) -> Dict[str, torch.Tensor]: 43 | shape = (self.num_rnn_layer, batchsize, self.hid_dim) 44 | hid = {"h0": torch.zeros(*shape)} 45 | return hid 46 | 47 | @torch.jit.script_method 48 | def act( 49 | self, priv_s: torch.Tensor, hid: Dict[str, torch.Tensor] 50 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 51 | assert priv_s.dim() == 2, "dim should be 2, [batch, dim], get %d" % s.dim() 52 | 53 | priv_s = priv_s.unsqueeze(0) 54 | x = self.net(priv_s) 55 | o, (h) = self.gru(x, (hid["h0"])) 56 | a = self.fc_a(o) 57 | a = a.squeeze(0) 58 | return a, {"h0": h} # , t_pred 59 | 60 | @torch.jit.script_method 61 | def forward( 62 | self, 63 | priv_s: torch.Tensor, 64 | legal_move: torch.Tensor, 65 | action: torch.Tensor, 66 | hid: Dict[str, torch.Tensor], 67 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 68 | assert ( 69 | priv_s.dim() == 3 or priv_s.dim() == 2 70 | ), "dim = 3/2, [seq_len(optional), batch, dim]" 71 | 72 | one_step = False 73 | if priv_s.dim() == 2: 74 | priv_s = priv_s.unsqueeze(0) 75 | legal_move = legal_move.unsqueeze(0) 76 | action = action.unsqueeze(0) 77 | one_step = True 78 | 79 | x = self.net(priv_s) 80 | if len(hid) == 0: 81 | o, (h) = self.gru(x) 82 | else: 83 | o, (h) = self.gru(x, hid["h0"]) 84 | a = self.fc_a(o) 85 | v = self.fc_v(o) 86 | q = self._duel(v, a, legal_move) 87 | 88 | # q: [seq_len, batch, num_action] 89 | # action: [seq_len, batch] 90 | qa = q.gather(2, action.unsqueeze(2)).squeeze(2) 91 | 92 | assert q.size() == legal_move.size() 93 | legal_q = (1 + q - q.min()) * legal_move 94 | # greedy_action: [seq_len, batch] 95 | greedy_action = legal_q.argmax(2).detach() 96 | 97 | if one_step: 98 | qa = qa.squeeze(0) 99 | greedy_action = greedy_action.squeeze(0) 100 | o = o.squeeze(0) 101 | q = q.squeeze(0) 102 | return qa, greedy_action, q, o 103 | 104 | @torch.jit.script_method 105 | def _duel( 106 | self, v: torch.Tensor, a: torch.Tensor, legal_move: torch.Tensor 107 | ) -> torch.Tensor: 108 | assert a.size() == legal_move.size() 109 | legal_a = a * legal_move 110 | q = v + legal_a - legal_a.mean(2, keepdim=True) 111 | return q 112 | 113 | def cross_entropy(self, net, lstm_o, target_p, hand_slot_mask, seq_len): 114 | # target_p: [seq_len, batch, num_player, 5, 3] 115 | # hand_slot_mask: [seq_len, batch, num_player, 5] 116 | logit = net(lstm_o).view(target_p.size()) 117 | q = nn.functional.softmax(logit, -1) 118 | logq = nn.functional.log_softmax(logit, -1) 119 | plogq = (target_p * logq).sum(-1) 120 | xent = -(plogq * hand_slot_mask).sum(-1) / hand_slot_mask.sum(-1).clamp( 121 | min=1e-6 122 | ) 123 | 124 | if xent.dim() == 3: 125 | # [seq, batch, num_player] 126 | xent = xent.mean(2) 127 | 128 | # save before sum out 129 | seq_xent = xent 130 | xent = xent.sum(0) 131 | assert xent.size() == seq_len.size() 132 | avg_xent = (xent / seq_len).mean().item() 133 | return xent, avg_xent, q, seq_xent.detach() 134 | 135 | def pred_loss_1st(self, lstm_o, target, hand_slot_mask, seq_len): 136 | return self.cross_entropy(self.pred, lstm_o, target, hand_slot_mask, seq_len) 137 | 138 | 139 | class R2D2Agent(torch.jit.ScriptModule): 140 | __constants__ = ["vdn", "multi_step", "gamma", "eta", "uniform_priority"] 141 | 142 | def __init__( 143 | self, 144 | vdn, 145 | multi_step, 146 | gamma, 147 | eta, 148 | device, 149 | in_dim, 150 | hid_dim, 151 | out_dim, 152 | num_fflayer, 153 | num_rnn_layer, 154 | hand_size, 155 | uniform_priority, 156 | sad=False, 157 | ): 158 | super().__init__() 159 | self.online_net = R2D2Net( 160 | device, in_dim, hid_dim, out_dim, num_fflayer, num_rnn_layer, hand_size 161 | ).to(device) 162 | self.target_net = R2D2Net( 163 | device, in_dim, hid_dim, out_dim, num_fflayer, num_rnn_layer, hand_size 164 | ).to(device) 165 | self.vdn = vdn 166 | self.multi_step = multi_step 167 | self.gamma = gamma 168 | self.eta = eta 169 | self.uniform_priority = uniform_priority 170 | self.sad = sad 171 | 172 | @torch.jit.script_method 173 | def get_h0(self, batchsize: int) -> Dict[str, torch.Tensor]: 174 | return self.online_net.get_h0(batchsize) 175 | 176 | def clone(self, device, overwrite=None): 177 | if overwrite is None: 178 | overwrite = {} 179 | cloned = type(self)( 180 | overwrite.get("vdn", self.vdn), 181 | self.multi_step, 182 | self.gamma, 183 | self.eta, 184 | device, 185 | self.online_net.in_dim, 186 | self.online_net.hid_dim, 187 | self.online_net.out_dim, 188 | self.online_net.num_fflayer, 189 | self.online_net.num_rnn_layer, 190 | self.online_net.hand_size, 191 | self.uniform_priority, 192 | self.sad, 193 | ) 194 | cloned.load_state_dict(self.state_dict()) 195 | return cloned.to(device) 196 | 197 | def sync_target_with_online(self): 198 | self.target_net.load_state_dict(self.online_net.state_dict()) 199 | 200 | @torch.jit.script_method 201 | def greedy_act( 202 | self, 203 | priv_s: torch.Tensor, 204 | legal_move: torch.Tensor, 205 | hid: Dict[str, torch.Tensor], 206 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 207 | adv, new_hid = self.online_net.act(priv_s, hid) 208 | legal_adv = (1 + adv - adv.min()) * legal_move 209 | greedy_action = legal_adv.argmax(1).detach() 210 | return greedy_action, new_hid 211 | 212 | @torch.jit.script_method 213 | def act(self, obs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 214 | """ 215 | Acts on the given obs, with eps-greedy policy. 216 | output: {'a' : actions}, a long Tensor of shape 217 | [batchsize] or [batchsize, num_player] 218 | """ 219 | obsize, ibsize, num_player = 0, 0, 0 220 | if self.sad: 221 | _priv_s = obs["priv_s"] 222 | else: 223 | _priv_s = obs["priv_s_gen"] 224 | 225 | if self.vdn: 226 | obsize, ibsize, num_player = _priv_s.size()[:3] 227 | priv_s = _priv_s.flatten(0, 2) 228 | legal_move = obs["legal_move"].flatten(0, 2) 229 | eps = obs["eps"].flatten(0, 2) 230 | else: 231 | obsize, ibsize = _priv_s.size()[:2] 232 | num_player = 1 233 | priv_s = _priv_s.flatten(0, 1) 234 | legal_move = obs["legal_move"].flatten(0, 1) 235 | eps = obs["eps"].flatten(0, 1) 236 | 237 | hid = { 238 | "h0": obs["h0"].flatten(0, 1).transpose(0, 1).contiguous(), 239 | } 240 | 241 | greedy_action, new_hid = self.greedy_act(priv_s, legal_move, hid) 242 | 243 | random_action = legal_move.multinomial(1).squeeze(1) 244 | rand = torch.rand(greedy_action.size(), device=greedy_action.device) 245 | assert rand.size() == eps.size() 246 | rand = (rand < eps).long() 247 | action = (greedy_action * (1 - rand) + random_action * rand).detach().long() 248 | 249 | if self.vdn: 250 | action = action.view(obsize, ibsize, num_player) 251 | greedy_action = greedy_action.view(obsize, ibsize, num_player) 252 | rand = rand.view(obsize, ibsize, num_player) 253 | else: 254 | action = action.view(obsize, ibsize) 255 | greedy_action = greedy_action.view(obsize, ibsize) 256 | rand = rand.view(obsize, ibsize) 257 | 258 | hid_shape = ( 259 | obsize, 260 | ibsize * num_player, 261 | self.online_net.num_rnn_layer, 262 | self.online_net.hid_dim, 263 | ) 264 | h0 = new_hid["h0"].transpose(0, 1).view(*hid_shape) 265 | 266 | reply = { 267 | "a": action.detach().cpu(), 268 | "greedy_a": greedy_action.detach().cpu(), 269 | "h0": h0.contiguous().detach().cpu(), 270 | } 271 | return reply 272 | 273 | @torch.jit.script_method 274 | def compute_priority( 275 | self, input_: Dict[str, torch.Tensor] 276 | ) -> Dict[str, torch.Tensor]: 277 | """ 278 | compute priority for one batch 279 | """ 280 | if self.uniform_priority: 281 | return {"priority": torch.ones_like(input_["reward"]).detach().cpu()} 282 | 283 | obsize, ibsize, num_player = 0, 0, 0 284 | flatten_end = 0 285 | if self.sad: 286 | _priv_s = input_["priv_s"] 287 | _next_priv_s = input_["next_priv_s"] 288 | else: 289 | _priv_s = input_["priv_s_gen"] 290 | _next_priv_s = input_["next_priv_s_gen"] 291 | 292 | if self.vdn: 293 | obsize, ibsize, num_player = _priv_s.size()[:3] 294 | flatten_end = 2 295 | else: 296 | obsize, ibsize = _priv_s.size()[:2] 297 | num_player = 1 298 | flatten_end = 1 299 | 300 | priv_s = _priv_s.flatten(0, flatten_end) 301 | legal_move = input_["legal_move"].flatten(0, flatten_end) 302 | online_a = input_["a"].flatten(0, flatten_end) 303 | 304 | next_priv_s = _next_priv_s.flatten(0, flatten_end) 305 | next_legal_move = input_["next_legal_move"].flatten(0, flatten_end) 306 | temperature = input_["temperature"].flatten(0, flatten_end) 307 | 308 | hid = { 309 | "h0": input_["h0"].flatten(0, 1).transpose(0, 1).contiguous(), 310 | } 311 | next_hid = { 312 | "h0": input_["next_h0"].flatten(0, 1).transpose(0, 1).contiguous(), 313 | } 314 | reward = input_["reward"].flatten(0, 1) 315 | bootstrap = input_["bootstrap"].flatten(0, 1) 316 | 317 | online_qa = self.online_net(priv_s, legal_move, online_a, hid)[0] 318 | next_a, _ = self.greedy_act(next_priv_s, next_legal_move, next_hid) 319 | target_qa, _, _, _ = self.target_net( 320 | next_priv_s, 321 | next_legal_move, 322 | next_a, 323 | next_hid, 324 | ) 325 | 326 | bsize = obsize * ibsize 327 | if self.vdn: 328 | # sum over action & player 329 | online_qa = online_qa.view(bsize, num_player).sum(1) 330 | target_qa = target_qa.view(bsize, num_player).sum(1) 331 | 332 | assert reward.size() == bootstrap.size() 333 | assert reward.size() == target_qa.size() 334 | target = reward + bootstrap * (self.gamma ** self.multi_step) * target_qa 335 | priority = (target - online_qa).abs() 336 | priority = priority.view(obsize, ibsize).detach().cpu() 337 | return {"priority": priority} 338 | 339 | ############# python only functions ############# 340 | def flat_4d(self, data): 341 | """ 342 | rnn_hid: [num_layer, batch, num_player, dim] -> [num_player, batch, dim] 343 | seq_obs: [seq_len, batch, num_player, dim] -> [seq_len, batch, dim] 344 | """ 345 | bsize = 0 346 | num_player = 0 347 | for k, v in data.items(): 348 | if num_player == 0: 349 | bsize, num_player = v.size()[1:3] 350 | 351 | if v.dim() == 4: 352 | d0, d1, d2, d3 = v.size() 353 | data[k] = v.view(d0, d1 * d2, d3) 354 | elif v.dim() == 3: 355 | d0, d1, d2 = v.size() 356 | data[k] = v.view(d0, d1 * d2) 357 | return bsize, num_player 358 | 359 | def td_error(self, obs, hid, action, reward, terminal, bootstrap, seq_len, stat): 360 | max_seq_len = obs["priv_s"].size(0) 361 | 362 | bsize, num_player = 0, 1 363 | if self.vdn: 364 | bsize, num_player = self.flat_4d(obs) 365 | self.flat_4d(action) 366 | 367 | if self.sad: 368 | priv_s = obs["priv_s"] 369 | else: 370 | priv_s = obs["priv_s_gen"] 371 | 372 | legal_move = obs["legal_move"] 373 | action = action["a"] 374 | 375 | hid = {} 376 | 377 | # this only works because the trajectories are padded, 378 | # i.e. no terminal in the middle 379 | online_qa, greedy_a, _, lstm_o = self.online_net( 380 | priv_s, legal_move, action, hid 381 | ) 382 | 383 | with torch.no_grad(): 384 | target_qa, _, _, _ = self.target_net(priv_s, legal_move, greedy_a, hid) 385 | # assert target_q.size() == pa.size() 386 | # target_qe = (pa * target_q).sum(-1).detach() 387 | assert online_qa.size() == target_qa.size() 388 | 389 | if self.vdn: 390 | online_qa = online_qa.view(max_seq_len, bsize, num_player).sum(-1) 391 | target_qa = target_qa.view(max_seq_len, bsize, num_player).sum(-1) 392 | lstm_o = lstm_o.view(max_seq_len, bsize, num_player, -1) 393 | 394 | terminal = terminal.float() 395 | bootstrap = bootstrap.float() 396 | 397 | errs = [] 398 | target_qa = torch.cat( 399 | [target_qa[self.multi_step :], target_qa[: self.multi_step]], 0 400 | ) 401 | target_qa[-self.multi_step :] = 0 402 | 403 | assert target_qa.size() == reward.size() 404 | target = reward + bootstrap * (self.gamma ** self.multi_step) * target_qa 405 | mask = torch.arange(0, max_seq_len, device=seq_len.device) 406 | mask = (mask.unsqueeze(1) < seq_len.unsqueeze(0)).float() 407 | err = (target.detach() - online_qa) * mask 408 | return err, lstm_o 409 | 410 | def aux_task_iql(self, lstm_o, hand, seq_len, rl_loss_size, stat): 411 | seq_size, bsize, _ = hand.size() 412 | own_hand = hand.view(seq_size, bsize, self.online_net.hand_size, 3) 413 | own_hand_slot_mask = own_hand.sum(3) 414 | pred_loss1, avg_xent1, _, _ = self.online_net.pred_loss_1st( 415 | lstm_o, own_hand, own_hand_slot_mask, seq_len 416 | ) 417 | assert pred_loss1.size() == rl_loss_size 418 | 419 | stat["aux1"].feed(avg_xent1) 420 | return pred_loss1 421 | 422 | def aux_task_vdn(self, lstm_o, hand, seq_len, rl_loss_size, stat): 423 | """1st and 2nd order aux task used in VDN""" 424 | seq_size, bsize, num_player, _ = hand.size() 425 | own_hand = hand.view(seq_size, bsize, num_player, self.online_net.hand_size, 3) 426 | own_hand_slot_mask = own_hand.sum(4) 427 | pred_loss1, avg_xent1, belief1, _ = self.online_net.pred_loss_1st( 428 | lstm_o, own_hand, own_hand_slot_mask, seq_len 429 | ) 430 | assert pred_loss1.size() == rl_loss_size 431 | 432 | rotate = [num_player - 1] 433 | rotate.extend(list(range(num_player - 1))) 434 | partner_hand = own_hand[:, :, rotate, :, :] 435 | partner_hand_slot_mask = partner_hand.sum(4) 436 | partner_belief1 = belief1[:, :, rotate, :, :].detach() 437 | 438 | stat["aux1"].feed(avg_xent1) 439 | return pred_loss1 440 | 441 | def loss(self, batch, pred_weight, stat): 442 | err, lstm_o = self.td_error( 443 | batch.obs, 444 | batch.h0, 445 | batch.action, 446 | batch.reward, 447 | batch.terminal, 448 | batch.bootstrap, 449 | batch.seq_len, 450 | stat, 451 | ) 452 | rl_loss = nn.functional.smooth_l1_loss( 453 | err, torch.zeros_like(err), reduction="none" 454 | ) 455 | rl_loss = rl_loss.sum(0) 456 | stat["rl_loss"].feed((rl_loss / batch.seq_len).mean().item()) 457 | 458 | priority = err.abs() 459 | # priority = self.aggregate_priority(p, batch.seq_len) 460 | 461 | if pred_weight > 0: 462 | if self.vdn: 463 | pred_loss1 = self.aux_task_vdn( 464 | lstm_o, 465 | batch.obs["own_hand"], 466 | batch.seq_len, 467 | rl_loss.size(), 468 | stat, 469 | ) 470 | loss = rl_loss + pred_weight * pred_loss1 471 | else: 472 | pred_loss = self.aux_task_iql( 473 | lstm_o, 474 | batch.obs["own_hand"], 475 | batch.seq_len, 476 | rl_loss.size(), 477 | stat, 478 | ) 479 | loss = rl_loss + pred_weight * pred_loss 480 | else: 481 | loss = rl_loss 482 | return loss, priority 483 | --------------------------------------------------------------------------------