├── core ├── __init__.py ├── ctree │ ├── make.sh │ ├── setup.py │ ├── cminimax.h │ ├── cminimax.cpp │ ├── ctree.pxd │ ├── cnode.h │ ├── cytree.pyx │ └── cnode.cpp ├── dataset.py ├── mcts.py ├── storage.py ├── model.py ├── replay_buffer.py ├── test.py ├── game.py ├── log.py ├── utils.py ├── config.py ├── selfplay_worker.py ├── train.py └── reanalyze_worker.py ├── static └── imgs │ ├── archi.png │ └── total_results.png ├── requirements.txt ├── .gitignore ├── test.sh ├── train.sh ├── config └── atari │ ├── env_wrapper.py │ ├── __init__.py │ └── model.py ├── README.md └── main.py /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/ctree/make.sh: -------------------------------------------------------------------------------- 1 | python setup.py build_ext --inplace -------------------------------------------------------------------------------- /static/imgs/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeWR/EfficientZero/HEAD/static/imgs/archi.png -------------------------------------------------------------------------------- /static/imgs/total_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeWR/EfficientZero/HEAD/static/imgs/total_results.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.5 2 | ray==1.0.0 3 | gym[atari,roms,accept-rom-license]==0.15.7 4 | cython==0.29.23 5 | tensorboard 6 | opencv-python==4.5.1.48 7 | kornia==0.6.6 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.pyc 4 | *$py.class 5 | 6 | 7 | # Pycharm 8 | .idea/* 9 | 10 | # results dir 11 | /results 12 | /files 13 | -------------------------------------------------------------------------------- /core/ctree/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from Cython.Build import cythonize 3 | import numpy as np 4 | 5 | setup(ext_modules=cythonize('cytree.pyx'), extra_compile_args=['-O3'], include_dirs=[np.get_include()]) 6 | 7 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | export CUDA_DEVICE_ORDER='PCI_BUS_ID' 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | python main.py --env BreakoutNoFrameskip-v4 --case atari --opr test --seed 0 --num_gpus 1 --num_cpus 20 --force \ 6 | --test_episodes 32 \ 7 | --load_model \ 8 | --amp_type 'torch_amp' \ 9 | --model_path 'model.p' \ 10 | --info 'Test' -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | export CUDA_DEVICE_ORDER='PCI_BUS_ID' 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --force \ 6 | --num_gpus 4 --num_cpus 96 --cpu_actor 14 --gpu_actor 20 \ 7 | --seed 0 \ 8 | --p_mcts_num 4 \ 9 | --use_priority \ 10 | --use_max_priority \ 11 | --amp_type 'torch_amp' \ 12 | --info 'EfficientZero-V1' 13 | -------------------------------------------------------------------------------- /core/ctree/cminimax.h: -------------------------------------------------------------------------------- 1 | #ifndef CMINIMAX_H 2 | #define CMINIMAX_H 3 | 4 | #include 5 | #include 6 | 7 | const float FLOAT_MAX = 1000000.0; 8 | const float FLOAT_MIN = -FLOAT_MAX; 9 | 10 | namespace tools { 11 | 12 | class CMinMaxStats { 13 | public: 14 | float maximum, minimum, value_delta_max; 15 | 16 | CMinMaxStats(); 17 | ~CMinMaxStats(); 18 | 19 | void set_delta(float value_delta_max); 20 | void update(float value); 21 | void clear(); 22 | float normalize(float value); 23 | }; 24 | 25 | class CMinMaxStatsList { 26 | public: 27 | int num; 28 | std::vector stats_lst; 29 | 30 | CMinMaxStatsList(); 31 | CMinMaxStatsList(int num); 32 | ~CMinMaxStatsList(); 33 | 34 | void set_delta(float value_delta_max); 35 | }; 36 | } 37 | 38 | #endif -------------------------------------------------------------------------------- /config/atari/env_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core.game import Game 3 | from core.utils import arr_to_str 4 | 5 | 6 | class AtariWrapper(Game): 7 | def __init__(self, env, discount: float, cvt_string=True): 8 | """Atari Wrapper 9 | Parameters 10 | ---------- 11 | env: Any 12 | another env wrapper 13 | discount: float 14 | discount of env 15 | cvt_string: bool 16 | True -> convert the observation into string in the replay buffer 17 | """ 18 | super().__init__(env, env.action_space.n, discount) 19 | self.cvt_string = cvt_string 20 | 21 | def legal_actions(self): 22 | return [_ for _ in range(self.env.action_space.n)] 23 | 24 | def get_max_episode_steps(self): 25 | return self.env.get_max_episode_steps() 26 | 27 | def step(self, action): 28 | observation, reward, done, info = self.env.step(action) 29 | observation = observation.astype(np.uint8) 30 | 31 | if self.cvt_string: 32 | observation = arr_to_str(observation) 33 | 34 | return observation, reward, done, info 35 | 36 | def reset(self, **kwargs): 37 | observation = self.env.reset(**kwargs) 38 | observation = observation.astype(np.uint8) 39 | 40 | if self.cvt_string: 41 | observation = arr_to_str(observation) 42 | 43 | return observation 44 | 45 | def close(self): 46 | self.env.close() 47 | -------------------------------------------------------------------------------- /core/ctree/cminimax.cpp: -------------------------------------------------------------------------------- 1 | #include "cminimax.h" 2 | 3 | namespace tools{ 4 | 5 | CMinMaxStats::CMinMaxStats(){ 6 | this->maximum = FLOAT_MIN; 7 | this->minimum = FLOAT_MAX; 8 | this->value_delta_max = 0.; 9 | } 10 | 11 | CMinMaxStats::~CMinMaxStats(){} 12 | 13 | void CMinMaxStats::set_delta(float value_delta_max){ 14 | this->value_delta_max = value_delta_max; 15 | } 16 | 17 | void CMinMaxStats::update(float value){ 18 | if(value > this->maximum){ 19 | this->maximum = value; 20 | } 21 | if(value < this->minimum){ 22 | this->minimum = value; 23 | } 24 | } 25 | 26 | void CMinMaxStats::clear(){ 27 | this->maximum = FLOAT_MIN; 28 | this->minimum = FLOAT_MAX; 29 | } 30 | 31 | float CMinMaxStats::normalize(float value){ 32 | float norm_value = value; 33 | float delta = this->maximum - this->minimum; 34 | if(delta > 0){ 35 | if(delta < this->value_delta_max){ 36 | norm_value = (norm_value - this->minimum) / this->value_delta_max; 37 | } 38 | else{ 39 | norm_value = (norm_value - this->minimum) / delta; 40 | } 41 | } 42 | return norm_value; 43 | } 44 | 45 | //********************************************************* 46 | 47 | CMinMaxStatsList::CMinMaxStatsList(){ 48 | this->num = 0; 49 | } 50 | 51 | CMinMaxStatsList::CMinMaxStatsList(int num){ 52 | this->num = num; 53 | for(int i = 0; i < num; ++i){ 54 | this->stats_lst.push_back(CMinMaxStats()); 55 | } 56 | } 57 | 58 | CMinMaxStatsList::~CMinMaxStatsList(){} 59 | 60 | void CMinMaxStatsList::set_delta(float value_delta_max){ 61 | for(int i = 0; i < this->num; ++i){ 62 | this->stats_lst[i].set_delta(value_delta_max); 63 | } 64 | } 65 | 66 | } -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from kornia.augmentation import RandomAffine, RandomCrop, CenterCrop, RandomResizedCrop 5 | from kornia.filters import GaussianBlur2d 6 | 7 | 8 | class Transforms(object): 9 | """ Reference : Data-Efficient Reinforcement Learning with Self-Predictive Representations 10 | Thanks to Repo: https://github.com/mila-iqia/spr.git 11 | """ 12 | def __init__(self, augmentation, shift_delta=4, image_shape=(96, 96)): 13 | self.augmentation = augmentation 14 | 15 | self.transforms = [] 16 | for aug in self.augmentation: 17 | if aug == "affine": 18 | transformation = RandomAffine(5, (.14, .14), (.9, 1.1), (-5, 5)) 19 | elif aug == "crop": 20 | transformation = RandomCrop(image_shape) 21 | elif aug == "rrc": 22 | transformation = RandomResizedCrop((100, 100), (0.8, 1)) 23 | elif aug == "blur": 24 | transformation = GaussianBlur2d((5, 5), (1.5, 1.5)) 25 | elif aug == "shift": 26 | transformation = nn.Sequential(nn.ReplicationPad2d(shift_delta), RandomCrop(image_shape)) 27 | elif aug == "intensity": 28 | transformation = Intensity(scale=0.05) 29 | elif aug == "none": 30 | transformation = nn.Identity() 31 | else: 32 | raise NotImplementedError() 33 | self.transforms.append(transformation) 34 | 35 | def apply_transforms(self, transforms, image): 36 | for transform in transforms: 37 | image = transform(image) 38 | return image 39 | 40 | @torch.no_grad() 41 | def transform(self, images): 42 | # images = images.float() / 255. if images.dtype == torch.uint8 else images 43 | flat_images = images.reshape(-1, *images.shape[-3:]) 44 | processed_images = self.apply_transforms(self.transforms, flat_images) 45 | 46 | processed_images = processed_images.view(*images.shape[:-3], 47 | *processed_images.shape[1:]) 48 | return processed_images 49 | 50 | 51 | class Intensity(nn.Module): 52 | def __init__(self, scale): 53 | super().__init__() 54 | self.scale = scale 55 | 56 | def forward(self, x): 57 | r = torch.randn((x.size(0), 1, 1, 1), device=x.device) 58 | noise = 1.0 + (self.scale * r.clamp(-2.0, 2.0)) 59 | return x * noise 60 | -------------------------------------------------------------------------------- /core/ctree/ctree.pxd: -------------------------------------------------------------------------------- 1 | # distutils: language=c++ 2 | from libcpp.vector cimport vector 3 | 4 | 5 | cdef extern from "cminimax.cpp": 6 | pass 7 | 8 | 9 | cdef extern from "cminimax.h" namespace "tools": 10 | cdef cppclass CMinMaxStats: 11 | CMinMaxStats() except + 12 | float maximum, minimum, value_delta_max 13 | 14 | void set_delta(float value_delta_max) 15 | void update(float value) 16 | void clear() 17 | float normalize(float value) 18 | 19 | cdef cppclass CMinMaxStatsList: 20 | CMinMaxStatsList() except + 21 | CMinMaxStatsList(int num) except + 22 | int num 23 | vector[CMinMaxStats] stats_lst 24 | 25 | void set_delta(float value_delta_max) 26 | 27 | cdef extern from "cnode.cpp": 28 | pass 29 | 30 | 31 | cdef extern from "cnode.h" namespace "tree": 32 | cdef cppclass CNode: 33 | CNode() except + 34 | CNode(float prior, int action_num, vector[CNode]* ptr_node_pool) except + 35 | int visit_count, to_play, action_num, hidden_state_index_x, hidden_state_index_y, best_action 36 | float value_prefixs, prior, value_sum 37 | vector[int] children_index; 38 | vector[CNode]* ptr_node_pool; 39 | 40 | void expand(int to_play, int hidden_state_index_x, int hidden_state_index_y, float value_prefixs, vector[float] policy_logits) 41 | void add_exploration_noise(float exploration_fraction, vector[float] noises) 42 | float get_mean_q(int isRoot, float parent_q, float discount) 43 | 44 | int expanded() 45 | float value() 46 | vector[int] get_trajectory() 47 | vector[int] get_children_distribution() 48 | CNode* get_child(int action) 49 | 50 | cdef cppclass CRoots: 51 | CRoots() except + 52 | CRoots(int root_num, int action_num, int pool_size) except + 53 | int root_num, action_num, pool_size 54 | vector[CNode] roots 55 | vector[vector[CNode]] node_pools 56 | 57 | void prepare(float root_exploration_fraction, const vector[vector[float]] &noises, const vector[float] &value_prefixs, const vector[vector[float]] &policies) 58 | void prepare_no_noise(const vector[float] &value_prefixs, const vector[vector[float]] &policies) 59 | void clear() 60 | vector[vector[int]] get_trajectories() 61 | vector[vector[int]] get_distributions() 62 | vector[float] get_values() 63 | 64 | cdef cppclass CSearchResults: 65 | CSearchResults() except + 66 | CSearchResults(int num) except + 67 | int num 68 | vector[int] hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens 69 | vector[CNode*] nodes 70 | 71 | cdef void cback_propagate(vector[CNode*] &search_path, CMinMaxStats &min_max_stats, int to_play, float value, float discount) 72 | void cbatch_back_propagate(int hidden_state_index_x, float discount, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies, 73 | CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] is_reset_lst) 74 | void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results) 75 | -------------------------------------------------------------------------------- /core/ctree/cnode.h: -------------------------------------------------------------------------------- 1 | #ifndef CNODE_H 2 | #define CNODE_H 3 | 4 | #include "cminimax.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | const int DEBUG_MODE = 0; 15 | 16 | namespace tree { 17 | 18 | class CNode { 19 | public: 20 | int visit_count, to_play, action_num, hidden_state_index_x, hidden_state_index_y, best_action, is_reset; 21 | float value_prefix, prior, value_sum; 22 | std::vector children_index; 23 | std::vector* ptr_node_pool; 24 | 25 | CNode(); 26 | CNode(float prior, int action_num, std::vector *ptr_node_pool); 27 | ~CNode(); 28 | 29 | void expand(int to_play, int hidden_state_index_x, int hidden_state_index_y, float value_prefix, const std::vector &policy_logits); 30 | void add_exploration_noise(float exploration_fraction, const std::vector &noises); 31 | float get_mean_q(int isRoot, float parent_q, float discount); 32 | void print_out(); 33 | 34 | int expanded(); 35 | 36 | float value(); 37 | 38 | std::vector get_trajectory(); 39 | std::vector get_children_distribution(); 40 | CNode* get_child(int action); 41 | }; 42 | 43 | class CRoots{ 44 | public: 45 | int root_num, action_num, pool_size; 46 | std::vector roots; 47 | std::vector> node_pools; 48 | 49 | CRoots(); 50 | CRoots(int root_num, int action_num, int pool_size); 51 | ~CRoots(); 52 | 53 | void prepare(float root_exploration_fraction, const std::vector> &noises, const std::vector &value_prefixs, const std::vector> &policies); 54 | void prepare_no_noise(const std::vector &value_prefixs, const std::vector> &policies); 55 | void clear(); 56 | std::vector> get_trajectories(); 57 | std::vector> get_distributions(); 58 | std::vector get_values(); 59 | 60 | }; 61 | 62 | class CSearchResults{ 63 | public: 64 | int num; 65 | std::vector hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, search_lens; 66 | std::vector nodes; 67 | std::vector> search_paths; 68 | 69 | CSearchResults(); 70 | CSearchResults(int num); 71 | ~CSearchResults(); 72 | 73 | }; 74 | 75 | 76 | //********************************************************* 77 | void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount); 78 | void cback_propagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount); 79 | void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst); 80 | int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q); 81 | float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount); 82 | void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results); 83 | } 84 | 85 | #endif -------------------------------------------------------------------------------- /core/ctree/cytree.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language=c++ 2 | import ctypes 3 | cimport cython 4 | from ctree cimport CMinMaxStatsList, CNode, CRoots, CSearchResults, cbatch_back_propagate, cbatch_traverse 5 | from libcpp.vector cimport vector 6 | from libc.stdlib cimport malloc, free 7 | from libcpp.list cimport list as cpplist 8 | 9 | import numpy as np 10 | cimport numpy as np 11 | 12 | ctypedef np.npy_float FLOAT 13 | ctypedef np.npy_intp INTP 14 | 15 | 16 | cdef class MinMaxStatsList: 17 | cdef CMinMaxStatsList *cmin_max_stats_lst 18 | 19 | def __cinit__(self, int num): 20 | self.cmin_max_stats_lst = new CMinMaxStatsList(num) 21 | 22 | def set_delta(self, float value_delta_max): 23 | self.cmin_max_stats_lst[0].set_delta(value_delta_max) 24 | 25 | def __dealloc__(self): 26 | del self.cmin_max_stats_lst 27 | 28 | 29 | cdef class ResultsWrapper: 30 | cdef CSearchResults cresults 31 | 32 | def __cinit__(self, int num): 33 | self.cresults = CSearchResults(num) 34 | 35 | def get_search_len(self): 36 | return self.cresults.search_lens 37 | 38 | 39 | cdef class Roots: 40 | cdef int root_num 41 | cdef int pool_size 42 | cdef CRoots *roots 43 | 44 | def __cinit__(self, int root_num, int action_num, int tree_nodes): 45 | self.root_num = root_num 46 | self.pool_size = action_num * (tree_nodes + 2) 47 | self.roots = new CRoots(root_num, action_num, self.pool_size) 48 | 49 | def prepare(self, float root_exploration_fraction, list noises, list value_prefix_pool, list policy_logits_pool): 50 | self.roots[0].prepare(root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool) 51 | 52 | def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool): 53 | self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool) 54 | 55 | def get_trajectories(self): 56 | return self.roots[0].get_trajectories() 57 | 58 | def get_distributions(self): 59 | return self.roots[0].get_distributions() 60 | 61 | def get_values(self): 62 | return self.roots[0].get_values() 63 | 64 | def clear(self): 65 | self.roots[0].clear() 66 | 67 | def __dealloc__(self): 68 | del self.roots 69 | 70 | @property 71 | def num(self): 72 | return self.root_num 73 | 74 | 75 | cdef class Node: 76 | cdef CNode cnode 77 | 78 | def __cinit__(self): 79 | pass 80 | 81 | def __cinit__(self, float prior, int action_num): 82 | # self.cnode = CNode(prior, action_num) 83 | pass 84 | 85 | def expand(self, int to_play, int hidden_state_index_x, int hidden_state_index_y, float value_prefix, list policy_logits): 86 | cdef vector[float] cpolicy = policy_logits 87 | self.cnode.expand(to_play, hidden_state_index_x, hidden_state_index_y, value_prefix, cpolicy) 88 | 89 | def batch_back_propagate(int hidden_state_index_x, float discount, list value_prefixs, list values, list policies, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_lst): 90 | cdef int i 91 | cdef vector[float] cvalue_prefixs = value_prefixs 92 | cdef vector[float] cvalues = values 93 | cdef vector[vector[float]] cpolicies = policies 94 | 95 | cbatch_back_propagate(hidden_state_index_x, discount, cvalue_prefixs, cvalues, cpolicies, 96 | min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_lst) 97 | 98 | 99 | def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount, MinMaxStatsList min_max_stats_lst, ResultsWrapper results): 100 | 101 | cbatch_traverse(roots.roots, pb_c_base, pb_c_init, discount, min_max_stats_lst.cmin_max_stats_lst, results.cresults) 102 | 103 | return results.cresults.hidden_state_index_x_lst, results.cresults.hidden_state_index_y_lst, results.cresults.last_actions 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientZero (NeurIPS 2021) 2 | Open-source codebase for EfficientZero, from ["Mastering Atari Games with Limited Data"](https://arxiv.org/abs/2111.00210) at NeurIPS 2021. 3 | 4 | ## Environments 5 | EfficientZero requires python3 (>=3.6) and pytorch (>=1.8.0) with the development headers. 6 | 7 | We recommend to use torch amp (`--amp_type torch_amp`) to accelerate training. 8 | 9 | ### Prerequisites 10 | Before starting training, you need to build the c++/cython style external packages. (GCC version 7.5+ is required.) 11 | ``` 12 | cd core/ctree 13 | bash make.sh 14 | ``` 15 | The distributed framework of this codebase is built on [ray](https://docs.ray.io/en/releases-1.0.0/auto_examples/overview.html). 16 | 17 | ### Installation 18 | As for other packages required for this codebase, please run `pip install -r requirements.txt`. 19 | 20 | ## Usage 21 | ### Quick start 22 | * Train: `python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --amp_type torch_amp --num_gpus 1 --num_cpus 10 --cpu_actor 1 --gpu_actor 1 --force` 23 | * Test: `python main.py --env BreakoutNoFrameskip-v4 --case atari --opr test --amp_type torch_amp --num_gpus 1 --load_model --model_path model.p \` 24 | ### Bash file 25 | We provide `train.sh` and `test.sh` for training and evaluation. 26 | * Train: 27 | * With 4 GPUs (3090): `bash train.sh` 28 | * Test: `bash test.sh` 29 | 30 | |Required Arguments | Description| 31 | |:-------------|:-------------| 32 | | `--env` |Name of the environment| 33 | | `--case {atari}` |It's used for switching between different domains(default: atari)| 34 | | `--opr {train,test}` |select the operation to be performed| 35 | | `--amp_type {torch_amp,none}` |use torch amp for acceleration| 36 | 37 | |Other Arguments | Description| 38 | |:-------------|:-------------| 39 | | `--force` |will rewrite the result directory 40 | | `--num_gpus 4` |how many GPUs are available 41 | | `--num_cpus 96` |how many CPUs are available 42 | | `--cpu_actor 14` |how many cpu workers 43 | | `--gpu_actor 20` |how many gpu workers 44 | | `--seed 0` |the seed 45 | | `--use_priority` |use priority in replay buffer sampling 46 | | `--use_max_priority` |use the max priority for the newly collectted data 47 | | `--amp_type 'torch_amp'` |use torch amp for acceleration 48 | | `--info 'EZ-V0'` |some tags for you experiments 49 | | `--p_mcts_num 8` |set the parallel number of envs in self-play 50 | | `--revisit_policy_search_rate 0.99` |set the rate of reanalyzing policies 51 | | `--use_root_value` |use root values in value targets (require more GPU actors) 52 | | `--render` |render in evaluation 53 | | `--save_video` |save videos for evaluation 54 | 55 | ## Architecture Designs 56 | The architecture of the training pipeline is shown as follows: 57 | ![](static/imgs/archi.png) 58 | 59 | ### Some suggestions 60 | * To use a smaller model, you can choose smaller dim of the projection layers (Eg: 256/64) and the LSTM hidden layer (Eg: 64) in the config. 61 | * For GPUs with 10G memory instead of 20G memory, you can allocate 0.25 gpu for each GPU maker (`@ray.remote(num_gpus=0.25)`) in `core/reanalyze_worker.py`. 62 | 63 | ### New environment registration 64 | If you wan to apply EfficientZero to a new environment like `mujoco`. Here are the steps for registration: 65 | 1. Follow the directory `config/atari` and create dir for the env at `config/mujoco`. 66 | 2. Implement your `MujocoConfig(BaseConfig)` class and implement the models as well as your environment wrapper. 67 | 3. Register the case at `main.py`. 68 | 69 | ## Results 70 | Evaluation with 32 seeds for 3 different runs (different seeds). 71 | ![](static/imgs/total_results.png) 72 | 73 | ## Citation 74 | If you find this repo useful, please cite our paper: 75 | ``` 76 | @inproceedings{ye2021mastering, 77 | title={Mastering Atari Games with Limited Data}, 78 | author={Weirui Ye, and Shaohuai Liu, and Thanard Kurutach, and Pieter Abbeel, and Yang Gao}, 79 | booktitle={NeurIPS}, 80 | year={2021} 81 | } 82 | ``` 83 | 84 | ## Contact 85 | If you have any question or want to use the code, please contact ywr20@mails.tsinghua.edu.cn . 86 | 87 | ## Acknowledgement 88 | We appreciate the following github repos a lot for their valuable code base implementations: 89 | 90 | https://github.com/koulanurag/muzero-pytorch 91 | 92 | https://github.com/werner-duvaud/muzero-general 93 | 94 | https://github.com/pytorch/ELF 95 | -------------------------------------------------------------------------------- /core/mcts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | import core.ctree.cytree as tree 5 | 6 | from torch.cuda.amp import autocast as autocast 7 | 8 | 9 | class MCTS(object): 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | def search(self, roots, model, hidden_state_roots, reward_hidden_roots): 14 | """Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference 15 | Parameters 16 | ---------- 17 | roots: Any 18 | a batch of expanded root nodes 19 | hidden_state_roots: list 20 | the hidden states of the roots 21 | reward_hidden_roots: list 22 | the value prefix hidden states in LSTM of the roots 23 | """ 24 | with torch.no_grad(): 25 | model.eval() 26 | 27 | # preparation 28 | num = roots.num 29 | device = self.config.device 30 | pb_c_base, pb_c_init, discount = self.config.pb_c_base, self.config.pb_c_init, self.config.discount 31 | # the data storage of hidden states: storing the states of all the tree nodes 32 | hidden_state_pool = [hidden_state_roots] 33 | # 1 x batch x 64 34 | # the data storage of value prefix hidden states in LSTM 35 | reward_hidden_c_pool = [reward_hidden_roots[0]] 36 | reward_hidden_h_pool = [reward_hidden_roots[1]] 37 | # the index of each layer in the tree 38 | hidden_state_index_x = 0 39 | # minimax value storage 40 | min_max_stats_lst = tree.MinMaxStatsList(num) 41 | min_max_stats_lst.set_delta(self.config.value_delta_max) 42 | horizons = self.config.lstm_horizon_len 43 | 44 | for index_simulation in range(self.config.num_simulations): 45 | hidden_states = [] 46 | hidden_states_c_reward = [] 47 | hidden_states_h_reward = [] 48 | 49 | # prepare a result wrapper to transport results between python and c++ parts 50 | results = tree.ResultsWrapper(num) 51 | # traverse to select actions for each root 52 | # hidden_state_index_x_lst: the first index of leaf node states in hidden_state_pool 53 | # hidden_state_index_y_lst: the second index of leaf node states in hidden_state_pool 54 | # the hidden state of the leaf node is hidden_state_pool[x, y]; value prefix states are the same 55 | hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = tree.batch_traverse(roots, pb_c_base, pb_c_init, discount, min_max_stats_lst, results) 56 | # obtain the search horizon for leaf nodes 57 | search_lens = results.get_search_len() 58 | 59 | # obtain the states for leaf nodes 60 | for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst): 61 | hidden_states.append(hidden_state_pool[ix][iy]) 62 | hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy]) 63 | hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy]) 64 | 65 | hidden_states = torch.from_numpy(np.asarray(hidden_states)).to(device).float() 66 | hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(device).unsqueeze(0) 67 | hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(device).unsqueeze(0) 68 | 69 | last_actions = torch.from_numpy(np.asarray(last_actions)).to(device).unsqueeze(1).long() 70 | 71 | # evaluation for leaf nodes 72 | if self.config.amp_type == 'torch_amp': 73 | with autocast(): 74 | network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions) 75 | else: 76 | network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions) 77 | 78 | hidden_state_nodes = network_output.hidden_state 79 | value_prefix_pool = network_output.value_prefix.reshape(-1).tolist() 80 | value_pool = network_output.value.reshape(-1).tolist() 81 | policy_logits_pool = network_output.policy_logits.tolist() 82 | reward_hidden_nodes = network_output.reward_hidden 83 | 84 | hidden_state_pool.append(hidden_state_nodes) 85 | # reset 0 86 | # reset the hidden states in LSTM every horizon steps in search 87 | # only need to predict the value prefix in a range (eg: s0 -> s5) 88 | assert horizons > 0 89 | reset_idx = (np.array(search_lens) % horizons == 0) 90 | assert len(reset_idx) == num 91 | reward_hidden_nodes[0][:, reset_idx, :] = 0 92 | reward_hidden_nodes[1][:, reset_idx, :] = 0 93 | is_reset_lst = reset_idx.astype(np.int32).tolist() 94 | 95 | reward_hidden_c_pool.append(reward_hidden_nodes[0]) 96 | reward_hidden_h_pool.append(reward_hidden_nodes[1]) 97 | hidden_state_index_x += 1 98 | 99 | # backpropagation along the search path to update the attributes 100 | tree.batch_back_propagate(hidden_state_index_x, discount, 101 | value_prefix_pool, value_pool, policy_logits_pool, 102 | min_max_stats_lst, results, is_reset_lst) 103 | -------------------------------------------------------------------------------- /core/storage.py: -------------------------------------------------------------------------------- 1 | import ray 2 | 3 | from ray.util.queue import Queue 4 | 5 | 6 | class QueueStorage(object): 7 | def __init__(self, threshold=15, size=20): 8 | """Queue storage 9 | Parameters 10 | ---------- 11 | threshold: int 12 | if the current size if larger than threshold, the data won't be collected 13 | size: int 14 | the size of the queue 15 | """ 16 | self.threshold = threshold 17 | self.queue = Queue(maxsize=size) 18 | 19 | def push(self, batch): 20 | if self.queue.qsize() <= self.threshold: 21 | self.queue.put(batch) 22 | 23 | def pop(self): 24 | if self.queue.qsize() > 0: 25 | return self.queue.get() 26 | else: 27 | return None 28 | 29 | def get_len(self): 30 | return self.queue.qsize() 31 | 32 | 33 | @ray.remote 34 | class SharedStorage(object): 35 | def __init__(self, model, target_model): 36 | """Shared storage for models and others 37 | Parameters 38 | ---------- 39 | model: any 40 | models for self-play (update every checkpoint_interval) 41 | target_model: any 42 | models for reanalyzing (update every target_model_interval) 43 | """ 44 | self.step_counter = 0 45 | self.test_counter = 0 46 | self.model = model 47 | self.target_model = target_model 48 | self.ori_reward_log = [] 49 | self.reward_log = [] 50 | self.reward_max_log = [] 51 | self.test_dict_log = {} 52 | self.eps_lengths = [] 53 | self.eps_lengths_max = [] 54 | self.temperature_log = [] 55 | self.visit_entropies_log = [] 56 | self.priority_self_play_log = [] 57 | self.distributions_log = {} 58 | self.start = False 59 | 60 | def set_start_signal(self): 61 | self.start = True 62 | 63 | def get_start_signal(self): 64 | return self.start 65 | 66 | def get_weights(self): 67 | return self.model.get_weights() 68 | 69 | def set_weights(self, weights): 70 | return self.model.set_weights(weights) 71 | 72 | def get_target_weights(self): 73 | return self.target_model.get_weights() 74 | 75 | def set_target_weights(self, weights): 76 | return self.target_model.set_weights(weights) 77 | 78 | def incr_counter(self): 79 | self.step_counter += 1 80 | 81 | def get_counter(self): 82 | return self.step_counter 83 | 84 | def set_data_worker_logs(self, eps_len, eps_len_max, eps_ori_reward, eps_reward, eps_reward_max, temperature, visit_entropy, priority_self_play, distributions): 85 | self.eps_lengths.append(eps_len) 86 | self.eps_lengths_max.append(eps_len_max) 87 | self.ori_reward_log.append(eps_ori_reward) 88 | self.reward_log.append(eps_reward) 89 | self.reward_max_log.append(eps_reward_max) 90 | self.temperature_log.append(temperature) 91 | self.visit_entropies_log.append(visit_entropy) 92 | self.priority_self_play_log.append(priority_self_play) 93 | 94 | for key, val in distributions.items(): 95 | if key not in self.distributions_log.keys(): 96 | self.distributions_log[key] = [] 97 | self.distributions_log[key] += val 98 | 99 | def add_test_log(self, test_counter, test_dict): 100 | self.test_counter = test_counter 101 | for key, val in test_dict.items(): 102 | if key not in self.test_dict_log.keys(): 103 | self.test_dict_log[key] = [] 104 | self.test_dict_log[key].append(val) 105 | 106 | def get_worker_logs(self): 107 | if len(self.reward_log) > 0: 108 | ori_reward = sum(self.ori_reward_log) / len(self.ori_reward_log) 109 | reward = sum(self.reward_log) / len(self.reward_log) 110 | reward_max = sum(self.reward_max_log) / len(self.reward_max_log) 111 | eps_lengths = sum(self.eps_lengths) / len(self.eps_lengths) 112 | eps_lengths_max = sum(self.eps_lengths_max) / len(self.eps_lengths_max) 113 | temperature = sum(self.temperature_log) / len(self.temperature_log) 114 | visit_entropy = sum(self.visit_entropies_log) / len(self.visit_entropies_log) 115 | priority_self_play = sum(self.priority_self_play_log) / len(self.priority_self_play_log) 116 | distributions = self.distributions_log 117 | 118 | self.ori_reward_log = [] 119 | self.reward_log = [] 120 | self.reward_max_log = [] 121 | self.eps_lengths = [] 122 | self.eps_lengths_max = [] 123 | self.temperature_log = [] 124 | self.visit_entropies_log = [] 125 | self.priority_self_play_log = [] 126 | self.distributions_log = {} 127 | 128 | else: 129 | ori_reward = None 130 | reward = None 131 | reward_max = None 132 | eps_lengths = None 133 | eps_lengths_max = None 134 | temperature = None 135 | visit_entropy = None 136 | priority_self_play = None 137 | distributions = None 138 | 139 | if len(self.test_dict_log) > 0: 140 | test_dict = self.test_dict_log 141 | 142 | self.test_dict_log = {} 143 | test_counter = self.test_counter 144 | else: 145 | test_dict = None 146 | test_counter = None 147 | 148 | return ori_reward, reward, reward_max, eps_lengths, eps_lengths_max, test_counter, test_dict, temperature, visit_entropy, priority_self_play, distributions 149 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | from typing import List 8 | 9 | 10 | class NetworkOutput(typing.NamedTuple): 11 | # output format of the model 12 | value: float 13 | value_prefix: float 14 | policy_logits: List[float] 15 | hidden_state: List[float] 16 | reward_hidden: object 17 | 18 | 19 | def concat_output_value(output_lst): 20 | # concat the values of the model output list 21 | value_lst = [] 22 | for output in output_lst: 23 | value_lst.append(output.value) 24 | 25 | value_lst = np.concatenate(value_lst) 26 | 27 | return value_lst 28 | 29 | 30 | def concat_output(output_lst): 31 | # concat the model output 32 | value_lst, reward_lst, policy_logits_lst, hidden_state_lst = [], [], [], [] 33 | reward_hidden_c_lst, reward_hidden_h_lst =[], [] 34 | for output in output_lst: 35 | value_lst.append(output.value) 36 | reward_lst.append(output.value_prefix) 37 | policy_logits_lst.append(output.policy_logits) 38 | hidden_state_lst.append(output.hidden_state) 39 | reward_hidden_c_lst.append(output.reward_hidden[0].squeeze(0)) 40 | reward_hidden_h_lst.append(output.reward_hidden[1].squeeze(0)) 41 | 42 | value_lst = np.concatenate(value_lst) 43 | reward_lst = np.concatenate(reward_lst) 44 | policy_logits_lst = np.concatenate(policy_logits_lst) 45 | # hidden_state_lst = torch.cat(hidden_state_lst, 0) 46 | hidden_state_lst = np.concatenate(hidden_state_lst) 47 | reward_hidden_c_lst = np.expand_dims(np.concatenate(reward_hidden_c_lst), axis=0) 48 | reward_hidden_h_lst = np.expand_dims(np.concatenate(reward_hidden_h_lst), axis=0) 49 | 50 | return value_lst, reward_lst, policy_logits_lst, hidden_state_lst, (reward_hidden_c_lst, reward_hidden_h_lst) 51 | 52 | 53 | class BaseNet(nn.Module): 54 | def __init__(self, inverse_value_transform, inverse_reward_transform, lstm_hidden_size): 55 | """Base Network 56 | schedule_timesteps. After this many timesteps pass final_p is 57 | returned. 58 | Parameters 59 | ---------- 60 | inverse_value_transform: Any 61 | A function that maps value supports into value scalars 62 | inverse_reward_transform: Any 63 | A function that maps reward supports into value scalars 64 | lstm_hidden_size: int 65 | dim of lstm hidden 66 | """ 67 | super(BaseNet, self).__init__() 68 | self.inverse_value_transform = inverse_value_transform 69 | self.inverse_reward_transform = inverse_reward_transform 70 | self.lstm_hidden_size = lstm_hidden_size 71 | 72 | def prediction(self, state): 73 | raise NotImplementedError 74 | 75 | def representation(self, obs_history): 76 | raise NotImplementedError 77 | 78 | def dynamics(self, state, reward_hidden, action): 79 | raise NotImplementedError 80 | 81 | def initial_inference(self, obs) -> NetworkOutput: 82 | num = obs.size(0) 83 | 84 | state = self.representation(obs) 85 | actor_logit, value = self.prediction(state) 86 | 87 | if not self.training: 88 | # if not in training, obtain the scalars of the value/reward 89 | value = self.inverse_value_transform(value).detach().cpu().numpy() 90 | state = state.detach().cpu().numpy() 91 | actor_logit = actor_logit.detach().cpu().numpy() 92 | # zero initialization for reward (value prefix) hidden states 93 | reward_hidden = (torch.zeros(1, num, self.lstm_hidden_size).detach().cpu().numpy(), 94 | torch.zeros(1, num, self.lstm_hidden_size).detach().cpu().numpy()) 95 | else: 96 | # zero initialization for reward (value prefix) hidden states 97 | reward_hidden = (torch.zeros(1, num, self.lstm_hidden_size).to('cuda'), torch.zeros(1, num, self.lstm_hidden_size).to('cuda')) 98 | 99 | return NetworkOutput(value, [0. for _ in range(num)], actor_logit, state, reward_hidden) 100 | 101 | def recurrent_inference(self, hidden_state, reward_hidden, action) -> NetworkOutput: 102 | state, reward_hidden, value_prefix = self.dynamics(hidden_state, reward_hidden, action) 103 | actor_logit, value = self.prediction(state) 104 | 105 | if not self.training: 106 | # if not in training, obtain the scalars of the value/reward 107 | value = self.inverse_value_transform(value).detach().cpu().numpy() 108 | value_prefix = self.inverse_reward_transform(value_prefix).detach().cpu().numpy() 109 | state = state.detach().cpu().numpy() 110 | reward_hidden = (reward_hidden[0].detach().cpu().numpy(), reward_hidden[1].detach().cpu().numpy()) 111 | actor_logit = actor_logit.detach().cpu().numpy() 112 | 113 | return NetworkOutput(value, value_prefix, actor_logit, state, reward_hidden) 114 | 115 | def get_weights(self): 116 | return {k: v.cpu() for k, v in self.state_dict().items()} 117 | 118 | def set_weights(self, weights): 119 | self.load_state_dict(weights) 120 | 121 | def get_gradients(self): 122 | grads = [] 123 | for p in self.parameters(): 124 | grad = None if p.grad is None else p.grad.data.cpu().numpy() 125 | grads.append(grad) 126 | return grads 127 | 128 | def set_gradients(self, gradients): 129 | for g, p in zip(gradients, self.parameters()): 130 | if g is not None: 131 | p.grad = torch.from_numpy(g) 132 | 133 | 134 | def renormalize(tensor, first_dim=1): 135 | # normalize the tensor (states) 136 | if first_dim < 0: 137 | first_dim = len(tensor.shape) + first_dim 138 | flat_tensor = tensor.view(*tensor.shape[:first_dim], -1) 139 | max = torch.max(flat_tensor, first_dim, keepdim=True).values 140 | min = torch.min(flat_tensor, first_dim, keepdim=True).values 141 | flat_tensor = (flat_tensor - min) / (max - min) 142 | 143 | return flat_tensor.view(*tensor.shape) 144 | -------------------------------------------------------------------------------- /core/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import time 3 | 4 | import numpy as np 5 | 6 | 7 | @ray.remote 8 | class ReplayBuffer(object): 9 | """Reference : DISTRIBUTED PRIORITIZED EXPERIENCE REPLAY 10 | Algo. 1 and Algo. 2 in Page-3 of (https://arxiv.org/pdf/1803.00933.pdf 11 | """ 12 | def __init__(self, config=None): 13 | self.config = config 14 | self.batch_size = config.batch_size 15 | self.keep_ratio = 1 16 | 17 | self.model_index = 0 18 | self.model_update_interval = 10 19 | 20 | self.buffer = [] 21 | self.priorities = [] 22 | self.game_look_up = [] 23 | 24 | self._eps_collected = 0 25 | self.base_idx = 0 26 | self._alpha = config.priority_prob_alpha 27 | self.transition_top = int(config.transition_num * 10 ** 6) 28 | self.clear_time = 0 29 | 30 | def save_pools(self, pools, gap_step): 31 | # save a list of game histories 32 | for (game, priorities) in pools: 33 | # Only append end game 34 | # if end_tag: 35 | if len(game) > 0: 36 | self.save_game(game, True, gap_step, priorities) 37 | 38 | def save_game(self, game, end_tag, gap_steps, priorities=None): 39 | """Save a game history block 40 | Parameters 41 | ---------- 42 | game: Any 43 | a game history block 44 | end_tag: bool 45 | True -> the game is finished. (always True) 46 | gap_steps: int 47 | if the game is not finished, we only save the transitions that can be computed 48 | priorities: list 49 | the priorities corresponding to the transitions in the game history 50 | """ 51 | if self.get_total_len() >= self.config.total_transitions: 52 | return 53 | 54 | if end_tag: 55 | self._eps_collected += 1 56 | valid_len = len(game) 57 | else: 58 | valid_len = len(game) - gap_steps 59 | 60 | if priorities is None: 61 | max_prio = self.priorities.max() if self.buffer else 1 62 | self.priorities = np.concatenate((self.priorities, [max_prio for _ in range(valid_len)] + [0. for _ in range(valid_len, len(game))])) 63 | else: 64 | assert len(game) == len(priorities), " priorities should be of same length as the game steps" 65 | priorities = priorities.copy().reshape(-1) 66 | # priorities[valid_len:len(game)] = 0. 67 | self.priorities = np.concatenate((self.priorities, priorities)) 68 | 69 | self.buffer.append(game) 70 | self.game_look_up += [(self.base_idx + len(self.buffer) - 1, step_pos) for step_pos in range(len(game))] 71 | 72 | def get_game(self, idx): 73 | # return a game 74 | game_id, game_pos = self.game_look_up[idx] 75 | game_id -= self.base_idx 76 | game = self.buffer[game_id] 77 | return game 78 | 79 | def prepare_batch_context(self, batch_size, beta): 80 | """Prepare a batch context that contains: 81 | game_lst: a list of game histories 82 | game_pos_lst: transition index in game (relative index) 83 | indices_lst: transition index in replay buffer 84 | weights_lst: the weight concering the priority 85 | make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) 86 | Parameters 87 | ---------- 88 | batch_size: int 89 | batch size 90 | beta: float 91 | the parameter in PER for calculating the priority 92 | """ 93 | assert beta > 0 94 | 95 | total = self.get_total_len() 96 | 97 | probs = self.priorities ** self._alpha 98 | 99 | probs /= probs.sum() 100 | # sample data 101 | indices_lst = np.random.choice(total, batch_size, p=probs, replace=False) 102 | 103 | weights_lst = (total * probs[indices_lst]) ** (-beta) 104 | weights_lst /= weights_lst.max() 105 | 106 | game_lst = [] 107 | game_pos_lst = [] 108 | 109 | for idx in indices_lst: 110 | game_id, game_pos = self.game_look_up[idx] 111 | game_id -= self.base_idx 112 | game = self.buffer[game_id] 113 | 114 | game_lst.append(game) 115 | game_pos_lst.append(game_pos) 116 | 117 | make_time = [time.time() for _ in range(len(indices_lst))] 118 | 119 | context = (game_lst, game_pos_lst, indices_lst, weights_lst, make_time) 120 | return context 121 | 122 | def update_priorities(self, batch_indices, batch_priorities, make_time): 123 | # update the priorities for data still in replay buffer 124 | for i in range(len(batch_indices)): 125 | if make_time[i] > self.clear_time: 126 | idx, prio = batch_indices[i], batch_priorities[i] 127 | self.priorities[idx] = prio 128 | 129 | def remove_to_fit(self): 130 | # remove some old data if the replay buffer is full. 131 | current_size = self.size() 132 | total_transition = self.get_total_len() 133 | if total_transition > self.transition_top: 134 | index = 0 135 | for i in range(current_size): 136 | total_transition -= len(self.buffer[i]) 137 | if total_transition <= self.transition_top * self.keep_ratio: 138 | index = i 139 | break 140 | 141 | if total_transition >= self.config.batch_size: 142 | self._remove(index + 1) 143 | 144 | def _remove(self, num_excess_games): 145 | # delete game histories 146 | excess_games_steps = sum([len(game) for game in self.buffer[:num_excess_games]]) 147 | del self.buffer[:num_excess_games] 148 | self.priorities = self.priorities[excess_games_steps:] 149 | del self.game_look_up[:excess_games_steps] 150 | self.base_idx += num_excess_games 151 | 152 | self.clear_time = time.time() 153 | 154 | def clear_buffer(self): 155 | del self.buffer[:] 156 | 157 | def size(self): 158 | # number of games 159 | return len(self.buffer) 160 | 161 | def episodes_collected(self): 162 | # number of collected histories 163 | return self._eps_collected 164 | 165 | def get_batch_size(self): 166 | return self.batch_size 167 | 168 | def get_priorities(self): 169 | return self.priorities 170 | 171 | def get_total_len(self): 172 | # number of transitions 173 | return len(self.priorities) 174 | -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ray 3 | import time 4 | import torch 5 | 6 | import numpy as np 7 | import core.ctree.cytree as cytree 8 | 9 | from tqdm.auto import tqdm 10 | from torch.cuda.amp import autocast as autocast 11 | from core.mcts import MCTS 12 | from core.game import GameHistory 13 | from core.utils import select_action, prepare_observation_lst 14 | 15 | 16 | @ray.remote(num_gpus=0.25) 17 | def _test(config, shared_storage): 18 | test_model = config.get_uniform_network() 19 | best_test_score = float('-inf') 20 | episodes = 0 21 | while True: 22 | counter = ray.get(shared_storage.get_counter.remote()) 23 | if counter >= config.training_steps + config.last_steps: 24 | time.sleep(30) 25 | break 26 | if counter >= config.test_interval * episodes: 27 | episodes += 1 28 | test_model.set_weights(ray.get(shared_storage.get_weights.remote())) 29 | test_model.eval() 30 | 31 | test_score, eval_steps, _ = test(config, test_model, counter, config.test_episodes, config.device, False, save_video=False) 32 | mean_score = test_score.mean() 33 | std_score = test_score.std() 34 | print('Start evaluation at step {}.'.format(counter)) 35 | if mean_score >= best_test_score: 36 | best_test_score = mean_score 37 | torch.save(test_model.state_dict(), config.model_path) 38 | 39 | test_log = { 40 | 'mean_score': mean_score, 41 | 'std_score': std_score, 42 | 'max_score': test_score.max(), 43 | 'min_score': test_score.min(), 44 | } 45 | 46 | shared_storage.add_test_log.remote(counter, test_log) 47 | print('Training step {}, test scores: \n{} of {} eval steps.'.format(counter, test_score, eval_steps)) 48 | 49 | time.sleep(30) 50 | 51 | 52 | def test(config, model, counter, test_episodes, device, render, save_video=False, final_test=False, use_pb=False): 53 | """evaluation test 54 | Parameters 55 | ---------- 56 | model: any 57 | models for evaluation 58 | counter: int 59 | current training step counter 60 | test_episodes: int 61 | number of test episodes 62 | device: str 63 | 'cuda' or 'cpu' 64 | render: bool 65 | True -> render the image during evaluation 66 | save_video: bool 67 | True -> save the videos during evaluation 68 | final_test: bool 69 | True -> this test is the final test, and the max moves would be 108k/skip 70 | use_pb: bool 71 | True -> use tqdm bars 72 | """ 73 | model.to(device) 74 | model.eval() 75 | save_path = os.path.join(config.exp_path, 'recordings', 'step_{}'.format(counter)) 76 | 77 | with torch.no_grad(): 78 | # new games 79 | envs = [config.new_game(seed=i, save_video=save_video, save_path=save_path, test=True, final_test=final_test, 80 | video_callable=lambda episode_id: True, uid=i) for i in range(test_episodes)] 81 | max_episode_steps = envs[0].get_max_episode_steps() 82 | if use_pb: 83 | pb = tqdm(np.arange(max_episode_steps), leave=True) 84 | # initializations 85 | init_obses = [env.reset() for env in envs] 86 | dones = np.array([False for _ in range(test_episodes)]) 87 | game_histories = [GameHistory(envs[_].env.action_space, max_length=max_episode_steps, config=config) for _ in range(test_episodes)] 88 | for i in range(test_episodes): 89 | game_histories[i].init([init_obses[i] for _ in range(config.stacked_observations)]) 90 | 91 | step = 0 92 | ep_ori_rewards = np.zeros(test_episodes) 93 | ep_clip_rewards = np.zeros(test_episodes) 94 | # loop 95 | while not dones.all(): 96 | if render: 97 | for i in range(test_episodes): 98 | envs[i].render() 99 | 100 | if config.image_based: 101 | stack_obs = [] 102 | for game_history in game_histories: 103 | stack_obs.append(game_history.step_obs()) 104 | stack_obs = prepare_observation_lst(stack_obs) 105 | stack_obs = torch.from_numpy(stack_obs).to(device).float() / 255.0 106 | else: 107 | stack_obs = [game_history.step_obs() for game_history in game_histories] 108 | stack_obs = torch.from_numpy(np.array(stack_obs)).to(device) 109 | 110 | with autocast(): 111 | network_output = model.initial_inference(stack_obs.float()) 112 | hidden_state_roots = network_output.hidden_state 113 | reward_hidden_roots = network_output.reward_hidden 114 | value_prefix_pool = network_output.value_prefix 115 | policy_logits_pool = network_output.policy_logits.tolist() 116 | 117 | roots = cytree.Roots(test_episodes, config.action_space_size, config.num_simulations) 118 | roots.prepare_no_noise(value_prefix_pool, policy_logits_pool) 119 | # do MCTS for a policy (argmax in testing) 120 | MCTS(config).search(roots, model, hidden_state_roots, reward_hidden_roots) 121 | 122 | roots_distributions = roots.get_distributions() 123 | roots_values = roots.get_values() 124 | for i in range(test_episodes): 125 | if dones[i]: 126 | continue 127 | 128 | distributions, value, env = roots_distributions[i], roots_values[i], envs[i] 129 | # select the argmax, not sampling 130 | action, _ = select_action(distributions, temperature=1, deterministic=True) 131 | 132 | obs, ori_reward, done, info = env.step(action) 133 | if config.clip_reward: 134 | clip_reward = np.sign(ori_reward) 135 | else: 136 | clip_reward = ori_reward 137 | 138 | game_histories[i].store_search_stats(distributions, value) 139 | game_histories[i].append(action, obs, clip_reward) 140 | 141 | dones[i] = done 142 | ep_ori_rewards[i] += ori_reward 143 | ep_clip_rewards[i] += clip_reward 144 | 145 | step += 1 146 | if use_pb: 147 | pb.set_description('{} In step {}, scores: {}(max: {}, min: {}) currently.' 148 | ''.format(config.env_name, counter, 149 | ep_ori_rewards.mean(), ep_ori_rewards.max(), ep_ori_rewards.min())) 150 | pb.update(1) 151 | 152 | for env in envs: 153 | env.close() 154 | 155 | return ep_ori_rewards, step, save_path 156 | -------------------------------------------------------------------------------- /config/atari/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from core.config import BaseConfig 4 | from core.utils import make_atari, WarpFrame, EpisodicLifeEnv 5 | from core.dataset import Transforms 6 | from .env_wrapper import AtariWrapper 7 | from .model import EfficientZeroNet 8 | 9 | 10 | class AtariConfig(BaseConfig): 11 | def __init__(self): 12 | super(AtariConfig, self).__init__( 13 | training_steps=100000, 14 | last_steps=20000, 15 | test_interval=10000, 16 | log_interval=1000, 17 | vis_interval=1000, 18 | test_episodes=32, 19 | checkpoint_interval=100, 20 | target_model_interval=200, 21 | save_ckpt_interval=10000, 22 | max_moves=12000, 23 | test_max_moves=12000, 24 | history_length=400, 25 | discount=0.997, 26 | dirichlet_alpha=0.3, 27 | value_delta_max=0.01, 28 | num_simulations=50, 29 | batch_size=256, 30 | td_steps=5, 31 | num_actors=1, 32 | # network initialization/ & normalization 33 | episode_life=True, 34 | init_zero=True, 35 | clip_reward=True, 36 | # storage efficient 37 | cvt_string=True, 38 | image_based=True, 39 | # lr scheduler 40 | lr_warm_up=0.01, 41 | lr_init=0.2, 42 | lr_decay_rate=0.1, 43 | lr_decay_steps=100000, 44 | auto_td_steps_ratio=0.3, 45 | # replay window 46 | start_transitions=8, 47 | total_transitions=100 * 1000, 48 | transition_num=1, 49 | # frame skip & stack observation 50 | frame_skip=4, 51 | stacked_observations=4, 52 | # coefficient 53 | reward_loss_coeff=1, 54 | value_loss_coeff=0.25, 55 | policy_loss_coeff=1, 56 | consistency_coeff=2, 57 | # reward sum 58 | lstm_hidden_size=512, 59 | lstm_horizon_len=5, 60 | # siamese 61 | proj_hid=1024, 62 | proj_out=1024, 63 | pred_hid=512, 64 | pred_out=1024,) 65 | self.discount **= self.frame_skip 66 | self.max_moves //= self.frame_skip 67 | self.test_max_moves //= self.frame_skip 68 | 69 | self.start_transitions = self.start_transitions * 1000 // self.frame_skip 70 | self.start_transitions = max(1, self.start_transitions) 71 | 72 | self.bn_mt = 0.1 73 | self.blocks = 1 # Number of blocks in the ResNet 74 | self.channels = 64 # Number of channels in the ResNet 75 | if self.gray_scale: 76 | self.channels = 32 77 | self.reduced_channels_reward = 16 # x36 Number of channels in reward head 78 | self.reduced_channels_value = 16 # x36 Number of channels in value head 79 | self.reduced_channels_policy = 16 # x36 Number of channels in policy head 80 | self.resnet_fc_reward_layers = [32] # Define the hidden layers in the reward head of the dynamic network 81 | self.resnet_fc_value_layers = [32] # Define the hidden layers in the value head of the prediction network 82 | self.resnet_fc_policy_layers = [32] # Define the hidden layers in the policy head of the prediction network 83 | self.downsample = True # Downsample observations before representation network (See paper appendix Network Architecture) 84 | 85 | def visit_softmax_temperature_fn(self, num_moves, trained_steps): 86 | if self.change_temperature: 87 | if trained_steps < 0.5 * (self.training_steps): 88 | return 1.0 89 | elif trained_steps < 0.75 * (self.training_steps): 90 | return 0.5 91 | else: 92 | return 0.25 93 | else: 94 | return 1.0 95 | 96 | def set_game(self, env_name, save_video=False, save_path=None, video_callable=None): 97 | self.env_name = env_name 98 | # gray scale 99 | if self.gray_scale: 100 | self.image_channel = 1 101 | obs_shape = (self.image_channel, 96, 96) 102 | self.obs_shape = (obs_shape[0] * self.stacked_observations, obs_shape[1], obs_shape[2]) 103 | 104 | game = self.new_game() 105 | self.action_space_size = game.action_space_size 106 | 107 | def get_uniform_network(self): 108 | return EfficientZeroNet( 109 | self.obs_shape, 110 | self.action_space_size, 111 | self.blocks, 112 | self.channels, 113 | self.reduced_channels_reward, 114 | self.reduced_channels_value, 115 | self.reduced_channels_policy, 116 | self.resnet_fc_reward_layers, 117 | self.resnet_fc_value_layers, 118 | self.resnet_fc_policy_layers, 119 | self.reward_support.size, 120 | self.value_support.size, 121 | self.downsample, 122 | self.inverse_value_transform, 123 | self.inverse_reward_transform, 124 | self.lstm_hidden_size, 125 | bn_mt=self.bn_mt, 126 | proj_hid=self.proj_hid, 127 | proj_out=self.proj_out, 128 | pred_hid=self.pred_hid, 129 | pred_out=self.pred_out, 130 | init_zero=self.init_zero, 131 | state_norm=self.state_norm) 132 | 133 | def new_game(self, seed=None, save_video=False, save_path=None, video_callable=None, uid=None, test=False, final_test=False): 134 | if test: 135 | if final_test: 136 | max_moves = 108000 // self.frame_skip 137 | else: 138 | max_moves = self.test_max_moves 139 | env = make_atari(self.env_name, skip=self.frame_skip, max_episode_steps=max_moves) 140 | else: 141 | env = make_atari(self.env_name, skip=self.frame_skip, max_episode_steps=self.max_moves) 142 | 143 | if self.episode_life and not test: 144 | env = EpisodicLifeEnv(env) 145 | env = WarpFrame(env, width=self.obs_shape[1], height=self.obs_shape[2], grayscale=self.gray_scale) 146 | 147 | if seed is not None: 148 | env.seed(seed) 149 | 150 | if save_video: 151 | from gym.wrappers import Monitor 152 | env = Monitor(env, directory=save_path, force=True, video_callable=video_callable, uid=uid) 153 | return AtariWrapper(env, discount=self.discount, cvt_string=self.cvt_string) 154 | 155 | def scalar_reward_loss(self, prediction, target): 156 | return -(torch.log_softmax(prediction, dim=1) * target).sum(1) 157 | 158 | def scalar_value_loss(self, prediction, target): 159 | return -(torch.log_softmax(prediction, dim=1) * target).sum(1) 160 | 161 | def set_transforms(self): 162 | if self.use_augmentation: 163 | self.transforms = Transforms(self.augmentation, image_shape=(self.obs_shape[1], self.obs_shape[2])) 164 | 165 | def transform(self, images): 166 | return self.transforms.transform(images) 167 | 168 | 169 | game_config = AtariConfig() 170 | -------------------------------------------------------------------------------- /core/game.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import copy 3 | 4 | import numpy as np 5 | 6 | from core.utils import str_to_arr 7 | 8 | 9 | class Game: 10 | def __init__(self, env, action_space_size: int, discount: float, config=None): 11 | self.env = env 12 | self.action_space_size = action_space_size 13 | self.discount = discount 14 | self.config = config 15 | 16 | def legal_actions(self): 17 | raise NotImplementedError 18 | 19 | def step(self, action): 20 | raise NotImplementedError 21 | 22 | def reset(self): 23 | raise NotImplementedError() 24 | 25 | def close(self, *args, **kwargs): 26 | self.env.close(*args, **kwargs) 27 | 28 | def render(self, *args, **kwargs): 29 | self.env.render(*args, **kwargs) 30 | 31 | 32 | class GameHistory: 33 | """ 34 | A block of game history from a full trajectories. 35 | The horizons of Atari games are quite large. Split the whole trajectory into several history blocks. 36 | """ 37 | def __init__(self, action_space, max_length=200, config=None): 38 | """ 39 | Parameters 40 | ---------- 41 | action_space: int 42 | action space 43 | max_length: int 44 | max transition number of the history block 45 | """ 46 | self.action_space = action_space 47 | self.max_length = max_length 48 | self.config = config 49 | 50 | self.stacked_observations = config.stacked_observations 51 | self.discount = config.discount 52 | self.action_space_size = config.action_space_size 53 | self.zero_obs_shape = (config.obs_shape[-2], config.obs_shape[-1], config.image_channel) 54 | 55 | self.child_visits = [] 56 | self.root_values = [] 57 | 58 | self.actions = [] 59 | self.obs_history = [] 60 | self.rewards = [] 61 | 62 | def init(self, init_observations): 63 | """Initialize a history block, stack the previous stacked_observations frames. 64 | Parameters 65 | ---------- 66 | init_observations: list 67 | list of the stack observations in the previous time steps 68 | """ 69 | self.child_visits = [] 70 | self.root_values = [] 71 | 72 | self.actions = [] 73 | self.obs_history = [] 74 | self.rewards = [] 75 | self.target_values = [] 76 | self.target_rewards = [] 77 | self.target_policies = [] 78 | 79 | assert len(init_observations) == self.stacked_observations 80 | 81 | for observation in init_observations: 82 | self.obs_history.append(copy.deepcopy(observation)) 83 | 84 | def pad_over(self, next_block_observations, next_block_rewards, next_block_root_values, next_block_child_visits): 85 | """To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next history block 86 | , which is necessary for the bootstrapped values at the end states of this history block. 87 | Eg: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105, 88 | but r_101, r_102, ... are from the next history block. 89 | Parameters 90 | ---------- 91 | next_block_observations: list 92 | o_t from the next history block 93 | next_block_rewards: list 94 | r_t from the next history block 95 | next_block_root_values: list 96 | root values of MCTS from the next history block 97 | next_block_child_visits: list 98 | root visit count distributions of MCTS from the next history block 99 | """ 100 | assert len(next_block_observations) <= self.config.num_unroll_steps 101 | assert len(next_block_child_visits) <= self.config.num_unroll_steps 102 | assert len(next_block_root_values) <= self.config.num_unroll_steps + self.config.td_steps 103 | assert len(next_block_rewards) <= self.config.num_unroll_steps + self.config.td_steps - 1 104 | 105 | # notice: next block observation should start from (stacked_observation - 1) in next trajectory 106 | for observation in next_block_observations: 107 | self.obs_history.append(copy.deepcopy(observation)) 108 | 109 | for reward in next_block_rewards: 110 | self.rewards.append(reward) 111 | 112 | for value in next_block_root_values: 113 | self.root_values.append(value) 114 | 115 | for child_visits in next_block_child_visits: 116 | self.child_visits.append(child_visits) 117 | 118 | def is_full(self): 119 | # history block is full 120 | return self.__len__() >= self.max_length 121 | 122 | def legal_actions(self): 123 | return [_ for _ in range(self.action_space.n)] 124 | 125 | def append(self, action, obs, reward): 126 | # append a transition tuple 127 | self.actions.append(action) 128 | self.obs_history.append(obs) 129 | self.rewards.append(reward) 130 | 131 | def obs(self, i, extra_len=0, padding=False): 132 | """To obtain an observation of correct format: o[t, t + stack frames + extra len] 133 | Parameters 134 | ---------- 135 | i: int 136 | time step i 137 | extra_len: int 138 | extra len of the obs frames 139 | padding: bool 140 | True -> padding frames if (t + stack frames) are out of trajectory 141 | """ 142 | frames = ray.get(self.obs_history)[i:i + self.stacked_observations + extra_len] 143 | if padding: 144 | pad_len = self.stacked_observations + extra_len - len(frames) 145 | if pad_len > 0: 146 | pad_frames = [frames[-1] for _ in range(pad_len)] 147 | frames = np.concatenate((frames, pad_frames)) 148 | if self.config.cvt_string: 149 | frames = [str_to_arr(obs, self.config.gray_scale) for obs in frames] 150 | return frames 151 | 152 | def zero_obs(self): 153 | # return a zero frame 154 | return [np.zeros(self.zero_obs_shape, dtype=np.uint8) for _ in range(self.stacked_observations)] 155 | 156 | def step_obs(self): 157 | # return an observation of correct format for model inference 158 | index = len(self.rewards) 159 | frames = self.obs_history[index:index + self.stacked_observations] 160 | if self.config.cvt_string: 161 | frames = [str_to_arr(obs, self.config.gray_scale) for obs in frames] 162 | return frames 163 | 164 | def get_targets(self, i): 165 | # return the value/rewrad/policy targets at step i 166 | return self.target_values[i], self.target_rewards[i], self.target_policies[i] 167 | 168 | def game_over(self): 169 | # post processing the data when a history block is full 170 | # obs_history should be sent into the ray memory. Otherwise, it will cost large amounts of time in copying obs. 171 | self.rewards = np.array(self.rewards) 172 | self.obs_history = ray.put(np.array(self.obs_history)) 173 | self.actions = np.array(self.actions) 174 | self.child_visits = np.array(self.child_visits) 175 | self.root_values = np.array(self.root_values) 176 | 177 | def store_search_stats(self, visit_counts, root_value, idx: int = None): 178 | # store the visit count distributions and value of the root node after MCTS 179 | sum_visits = sum(visit_counts) 180 | if idx is None: 181 | self.child_visits.append([visit_count / sum_visits for visit_count in visit_counts]) 182 | self.root_values.append(root_value) 183 | else: 184 | self.child_visits[idx] = [visit_count / sum_visits for visit_count in visit_counts] 185 | self.root_values[idx] = root_value 186 | 187 | def __len__(self): 188 | return len(self.actions) 189 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging.config 3 | import os 4 | 5 | import numpy as np 6 | import ray 7 | import torch 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | from core.test import test 11 | from core.train import train 12 | from core.utils import init_logger, make_results_dir, set_seed 13 | if __name__ == '__main__': 14 | # Lets gather arguments 15 | parser = argparse.ArgumentParser(description='EfficientZero') 16 | parser.add_argument('--env', required=True, help='Name of the environment') 17 | parser.add_argument('--result_dir', default=os.path.join(os.getcwd(), 'results'), 18 | help="Directory Path to store results (default: %(default)s)") 19 | parser.add_argument('--case', required=True, choices=['atari'], 20 | help="It's used for switching between different domains(default: %(default)s)") 21 | parser.add_argument('--opr', required=True, choices=['train', 'test']) 22 | parser.add_argument('--amp_type', required=True, choices=['torch_amp', 'none'], 23 | help='choose automated mixed precision type') 24 | parser.add_argument('--no_cuda', action='store_true', default=False, help='no cuda usage (default: %(default)s)') 25 | parser.add_argument('--debug', action='store_true', default=False, 26 | help='If enabled, logs additional values ' 27 | '(gradients, target value, reward distribution, etc.) (default: %(default)s)') 28 | parser.add_argument('--render', action='store_true', default=False, 29 | help='Renders the environment (default: %(default)s)') 30 | parser.add_argument('--save_video', action='store_true', default=False, help='save video in test.') 31 | parser.add_argument('--force', action='store_true', default=False, 32 | help='Overrides past results (default: %(default)s)') 33 | parser.add_argument('--cpu_actor', type=int, default=14, help='batch cpu actor') 34 | parser.add_argument('--gpu_actor', type=int, default=20, help='batch bpu actor') 35 | parser.add_argument('--p_mcts_num', type=int, default=4, help='number of parallel mcts') 36 | parser.add_argument('--seed', type=int, default=0, help='seed (default: %(default)s)') 37 | parser.add_argument('--num_gpus', type=int, default=4, help='gpus available') 38 | parser.add_argument('--num_cpus', type=int, default=80, help='cpus available') 39 | parser.add_argument('--revisit_policy_search_rate', type=float, default=0.99, 40 | help='Rate at which target policy is re-estimated (default: %(default)s)') 41 | parser.add_argument('--use_root_value', action='store_true', default=False, 42 | help='choose to use root value in reanalyzing') 43 | parser.add_argument('--use_priority', action='store_true', default=False, 44 | help='Uses priority for data sampling in replay buffer. ' 45 | 'Also, priority for new data is calculated based on loss (default: False)') 46 | parser.add_argument('--use_max_priority', action='store_true', default=False, help='max priority') 47 | parser.add_argument('--test_episodes', type=int, default=10, help='Evaluation episode count (default: %(default)s)') 48 | parser.add_argument('--use_augmentation', action='store_true', default=True, help='use augmentation') 49 | parser.add_argument('--augmentation', type=str, default=['shift', 'intensity'], nargs='+', 50 | choices=['none', 'rrc', 'affine', 'crop', 'blur', 'shift', 'intensity'], 51 | help='Style of augmentation') 52 | parser.add_argument('--info', type=str, default='none', help='debug string') 53 | parser.add_argument('--load_model', action='store_true', default=False, help='choose to load model') 54 | parser.add_argument('--model_path', type=str, default='./results/test_model.p', help='load model path') 55 | parser.add_argument('--object_store_memory', type=int, default=150 * 1024 * 1024 * 1024, help='object store memory') 56 | 57 | # Process arguments 58 | args = parser.parse_args() 59 | args.device = 'cuda' if (not args.no_cuda) and torch.cuda.is_available() else 'cpu' 60 | assert args.revisit_policy_search_rate is None or 0 <= args.revisit_policy_search_rate <= 1, \ 61 | ' Revisit policy search rate should be in [0,1]' 62 | 63 | if args.opr == 'train': 64 | ray.init(num_gpus=args.num_gpus, num_cpus=args.num_cpus, 65 | object_store_memory=args.object_store_memory) 66 | else: 67 | ray.init() 68 | 69 | # seeding random iterators 70 | set_seed(args.seed) 71 | 72 | # import corresponding configuration , neural networks and envs 73 | if args.case == 'atari': 74 | from config.atari import game_config 75 | else: 76 | raise Exception('Invalid --case option') 77 | 78 | # set config as per arguments 79 | exp_path = game_config.set_config(args) 80 | exp_path, log_base_path = make_results_dir(exp_path, args) 81 | 82 | # set-up logger 83 | init_logger(log_base_path) 84 | logging.getLogger('train').info('Path: {}'.format(exp_path)) 85 | logging.getLogger('train').info('Param: {}'.format(game_config.get_hparams())) 86 | 87 | device = game_config.device 88 | try: 89 | if args.opr == 'train': 90 | summary_writer = SummaryWriter(exp_path, flush_secs=10) 91 | if args.load_model and os.path.exists(args.model_path): 92 | model_path = args.model_path 93 | else: 94 | model_path = None 95 | model, weights = train(game_config, summary_writer, model_path) 96 | model.set_weights(weights) 97 | total_steps = game_config.training_steps + game_config.last_steps 98 | test_score, _, test_path = test(game_config, model.to(device), total_steps, game_config.test_episodes, device, render=False, save_video=args.save_video, final_test=True, use_pb=True) 99 | mean_score = test_score.mean() 100 | std_score = test_score.std() 101 | 102 | test_log = { 103 | 'mean_score': mean_score, 104 | 'std_score': std_score, 105 | } 106 | for key, val in test_log.items(): 107 | summary_writer.add_scalar('train/{}'.format(key), np.mean(val), total_steps) 108 | 109 | test_msg = '#{:<10} Test Mean Score of {}: {:<10} (max: {:<10}, min:{:<10}, std: {:<10})' \ 110 | ''.format(total_steps, game_config.env_name, mean_score, test_score.max(), test_score.min(), std_score) 111 | logging.getLogger('train_test').info(test_msg) 112 | if args.save_video: 113 | logging.getLogger('train_test').info('Saving video in path: {}'.format(test_path)) 114 | elif args.opr == 'test': 115 | assert args.load_model 116 | if args.model_path is None: 117 | model_path = game_config.model_path 118 | else: 119 | model_path = args.model_path 120 | assert os.path.exists(model_path), 'model not found at {}'.format(model_path) 121 | 122 | model = game_config.get_uniform_network().to(device) 123 | model.load_state_dict(torch.load(model_path, map_location=torch.device(device))) 124 | test_score, _, test_path = test(game_config, model, 0, args.test_episodes, device=device, render=args.render, save_video=args.save_video, final_test=True, use_pb=True) 125 | mean_score = test_score.mean() 126 | std_score = test_score.std() 127 | logging.getLogger('test').info('Test Mean Score: {} (max: {}, min: {})'.format(mean_score, test_score.max(), test_score.min())) 128 | logging.getLogger('test').info('Test Std Score: {}'.format(std_score)) 129 | if args.save_video: 130 | logging.getLogger('test').info('Saving video in path: {}'.format(test_path)) 131 | else: 132 | raise Exception('Please select a valid operation(--opr) to be performed') 133 | ray.shutdown() 134 | except Exception as e: 135 | logging.getLogger('root').error(e, exc_info=True) 136 | -------------------------------------------------------------------------------- /core/log.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import logging 3 | 4 | import numpy as np 5 | 6 | 7 | train_logger = logging.getLogger('train') 8 | test_logger = logging.getLogger('train_test') 9 | 10 | 11 | def _log(config, step_count, log_data, model, replay_buffer, lr, shared_storage, summary_writer, vis_result): 12 | loss_data, td_data, priority_data = log_data 13 | total_loss, weighted_loss, loss, reg_loss, policy_loss, value_prefix_loss, value_loss, consistency_loss = loss_data 14 | if vis_result: 15 | new_priority, target_value_prefix, target_value, trans_target_value_prefix, trans_target_value, target_value_prefix_phi, target_value_phi, \ 16 | pred_value_prefix, pred_value, target_policies, predicted_policies, state_lst, other_loss, other_log, other_dist = td_data 17 | batch_weights, batch_indices = priority_data 18 | 19 | replay_episodes_collected, replay_buffer_size, priorities, total_num, worker_logs = ray.get([ 20 | replay_buffer.episodes_collected.remote(), replay_buffer.size.remote(), 21 | replay_buffer.get_priorities.remote(), replay_buffer.get_total_len.remote(), 22 | shared_storage.get_worker_logs.remote()]) 23 | 24 | worker_ori_reward, worker_reward, worker_reward_max, worker_eps_len, worker_eps_len_max, test_counter, test_dict, temperature, visit_entropy, priority_self_play, distributions = worker_logs 25 | 26 | _msg = '#{:<10} Total Loss: {:<8.3f} [weighted Loss:{:<8.3f} Policy Loss: {:<8.3f} Value Loss: {:<8.3f} ' \ 27 | 'Reward Sum Loss: {:<8.3f} Consistency Loss: {:<8.3f} ] ' \ 28 | 'Replay Episodes Collected: {:<10d} Buffer Size: {:<10d} Transition Number: {:<8.3f}k ' \ 29 | 'Batch Size: {:<10d} Lr: {:<8.3f}' 30 | _msg = _msg.format(step_count, total_loss, weighted_loss, policy_loss, value_loss, value_prefix_loss, consistency_loss, 31 | replay_episodes_collected, replay_buffer_size, total_num / 1000, config.batch_size, lr) 32 | train_logger.info(_msg) 33 | 34 | if test_dict is not None: 35 | mean_score = np.mean(test_dict['mean_score']) 36 | max_score = np.mean(test_dict['max_score']) 37 | min_score = np.mean(test_dict['min_score']) 38 | std_score = np.mean(test_dict['std_score']) 39 | test_msg = '#{:<10} Test Mean Score of {}: {:<10} (max: {:<10}, min:{:<10}, std: {:<10})' \ 40 | ''.format(test_counter, config.env_name, mean_score, max_score, min_score, std_score) 41 | test_logger.info(test_msg) 42 | 43 | if summary_writer is not None: 44 | if config.debug: 45 | for name, W in model.named_parameters(): 46 | summary_writer.add_histogram('after_grad_clip' + '/' + name + '_grad', W.grad.data.cpu().numpy(), 47 | step_count) 48 | summary_writer.add_histogram('network_weights' + '/' + name, W.data.cpu().numpy(), step_count) 49 | pass 50 | tag = 'Train' 51 | if vis_result: 52 | summary_writer.add_histogram('{}_replay_data/replay_buffer_priorities'.format(tag), 53 | priorities, 54 | step_count) 55 | summary_writer.add_histogram('{}_replay_data/batch_weight'.format(tag), batch_weights, step_count) 56 | summary_writer.add_histogram('{}_replay_data/batch_indices'.format(tag), batch_indices, step_count) 57 | target_value_prefix = target_value_prefix.flatten() 58 | pred_value_prefix = pred_value_prefix.flatten() 59 | target_value = target_value.flatten() 60 | pred_value = pred_value.flatten() 61 | new_priority = new_priority.flatten() 62 | 63 | summary_writer.add_scalar('{}_statistics/new_priority_mean'.format(tag), new_priority.mean(), step_count) 64 | summary_writer.add_scalar('{}_statistics/new_priority_std'.format(tag), new_priority.std(), step_count) 65 | 66 | summary_writer.add_scalar('{}_statistics/target_value_prefix_mean'.format(tag), target_value_prefix.mean(), step_count) 67 | summary_writer.add_scalar('{}_statistics/target_value_prefix_std'.format(tag), target_value_prefix.std(), step_count) 68 | summary_writer.add_scalar('{}_statistics/pre_value_prefix_mean'.format(tag), pred_value_prefix.mean(), step_count) 69 | summary_writer.add_scalar('{}_statistics/pre_value_prefix_std'.format(tag), pred_value_prefix.std(), step_count) 70 | 71 | summary_writer.add_scalar('{}_statistics/target_value_mean'.format(tag), target_value.mean(), step_count) 72 | summary_writer.add_scalar('{}_statistics/target_value_std'.format(tag), target_value.std(), step_count) 73 | summary_writer.add_scalar('{}_statistics/pre_value_mean'.format(tag), pred_value.mean(), step_count) 74 | summary_writer.add_scalar('{}_statistics/pre_value_std'.format(tag), pred_value.std(), step_count) 75 | 76 | summary_writer.add_histogram('{}_data_dist/new_priority'.format(tag), new_priority, step_count) 77 | summary_writer.add_histogram('{}_data_dist/target_value_prefix'.format(tag), target_value_prefix - 1e-5, step_count) 78 | summary_writer.add_histogram('{}_data_dist/target_value'.format(tag), target_value - 1e-5, step_count) 79 | summary_writer.add_histogram('{}_data_dist/transformed_target_value_prefix'.format(tag), trans_target_value_prefix, 80 | step_count) 81 | summary_writer.add_histogram('{}_data_dist/transformed_target_value'.format(tag), trans_target_value, 82 | step_count) 83 | summary_writer.add_histogram('{}_data_dist/pred_value_prefix'.format(tag), pred_value_prefix - 1e-5, step_count) 84 | summary_writer.add_histogram('{}_data_dist/pred_value'.format(tag), pred_value - 1e-5, step_count) 85 | summary_writer.add_histogram('{}_data_dist/pred_policies'.format(tag), predicted_policies.flatten(), 86 | step_count) 87 | summary_writer.add_histogram('{}_data_dist/target_policies'.format(tag), target_policies.flatten(), 88 | step_count) 89 | 90 | summary_writer.add_histogram('{}_data_dist/hidden_state'.format(tag), state_lst.flatten(), step_count) 91 | 92 | for key, val in other_loss.items(): 93 | if val >= 0: 94 | summary_writer.add_scalar('{}_metric/'.format(tag) + key, val, step_count) 95 | 96 | for key, val in other_log.items(): 97 | summary_writer.add_scalar('{}_weight/'.format(tag) + key, val, step_count) 98 | 99 | for key, val in other_dist.items(): 100 | summary_writer.add_histogram('{}_dist/'.format(tag) + key, val, step_count) 101 | 102 | summary_writer.add_scalar('{}/total_loss'.format(tag), total_loss, step_count) 103 | summary_writer.add_scalar('{}/loss'.format(tag), loss, step_count) 104 | summary_writer.add_scalar('{}/weighted_loss'.format(tag), weighted_loss, step_count) 105 | summary_writer.add_scalar('{}/reg_loss'.format(tag), reg_loss, step_count) 106 | summary_writer.add_scalar('{}/policy_loss'.format(tag), policy_loss, step_count) 107 | summary_writer.add_scalar('{}/value_loss'.format(tag), value_loss, step_count) 108 | summary_writer.add_scalar('{}/value_prefix_loss'.format(tag), value_prefix_loss, step_count) 109 | summary_writer.add_scalar('{}/consistency_loss'.format(tag), consistency_loss, step_count) 110 | summary_writer.add_scalar('{}/episodes_collected'.format(tag), replay_episodes_collected, 111 | step_count) 112 | summary_writer.add_scalar('{}/replay_buffer_len'.format(tag), replay_buffer_size, step_count) 113 | summary_writer.add_scalar('{}/total_node_num'.format(tag), total_num, step_count) 114 | summary_writer.add_scalar('{}/lr'.format(tag), lr, step_count) 115 | 116 | if worker_reward is not None: 117 | summary_writer.add_scalar('workers/ori_reward', worker_ori_reward, step_count) 118 | summary_writer.add_scalar('workers/clip_reward', worker_reward, step_count) 119 | summary_writer.add_scalar('workers/clip_reward_max', worker_reward_max, step_count) 120 | summary_writer.add_scalar('workers/eps_len', worker_eps_len, step_count) 121 | summary_writer.add_scalar('workers/eps_len_max', worker_eps_len_max, step_count) 122 | summary_writer.add_scalar('workers/temperature', temperature, step_count) 123 | summary_writer.add_scalar('workers/visit_entropy', visit_entropy, step_count) 124 | summary_writer.add_scalar('workers/priority_self_play', priority_self_play, step_count) 125 | for key, val in distributions.items(): 126 | if len(val) == 0: 127 | continue 128 | 129 | val = np.array(val).flatten() 130 | summary_writer.add_histogram('workers/{}'.format(key), val, step_count) 131 | 132 | if test_dict is not None: 133 | for key, val in test_dict.items(): 134 | summary_writer.add_scalar('train/{}'.format(key), np.mean(val), test_counter) -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import gym 4 | import torch 5 | import random 6 | import shutil 7 | import logging 8 | 9 | import numpy as np 10 | 11 | from scipy.stats import entropy 12 | 13 | 14 | class LinearSchedule(object): 15 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0): 16 | """Linear interpolation between initial_p and final_p over 17 | schedule_timesteps. After this many timesteps pass final_p is 18 | returned. 19 | Parameters 20 | ---------- 21 | schedule_timesteps: int 22 | Number of timesteps for which to linearly anneal initial_p 23 | to final_p 24 | initial_p: float 25 | initial output value 26 | final_p: float 27 | final output value 28 | """ 29 | self.schedule_timesteps = schedule_timesteps 30 | self.final_p = final_p 31 | self.initial_p = initial_p 32 | 33 | def value(self, t): 34 | """See Schedule.value""" 35 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 36 | return self.initial_p + fraction * (self.final_p - self.initial_p) 37 | 38 | 39 | class TimeLimit(gym.Wrapper): 40 | def __init__(self, env, max_episode_steps=None): 41 | super(TimeLimit, self).__init__(env) 42 | self._max_episode_steps = max_episode_steps 43 | self._elapsed_steps = 0 44 | 45 | def step(self, ac): 46 | observation, reward, done, info = self.env.step(ac) 47 | self._elapsed_steps += 1 48 | if self._elapsed_steps >= self._max_episode_steps: 49 | done = True 50 | info['TimeLimit.truncated'] = True 51 | return observation, reward, done, info 52 | 53 | def get_max_episode_steps(self): 54 | return self._max_episode_steps 55 | 56 | def reset(self, **kwargs): 57 | self._elapsed_steps = 0 58 | return self.env.reset(**kwargs) 59 | 60 | 61 | class NoopResetEnv(gym.Wrapper): 62 | def __init__(self, env, noop_max=30): 63 | """Sample initial states by taking random number of no-ops on reset. 64 | No-op is assumed to be action 0. 65 | """ 66 | gym.Wrapper.__init__(self, env) 67 | self.noop_max = noop_max 68 | self.override_num_noops = None 69 | self.noop_action = 0 70 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 71 | 72 | def reset(self, **kwargs): 73 | """ Do no-op action for a number of steps in [1, noop_max].""" 74 | self.env.reset(**kwargs) 75 | if self.override_num_noops is not None: 76 | noops = self.override_num_noops 77 | else: 78 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 79 | assert noops > 0 80 | obs = None 81 | for _ in range(noops): 82 | obs, _, done, _ = self.env.step(self.noop_action) 83 | if done: 84 | obs = self.env.reset(**kwargs) 85 | return obs 86 | 87 | def step(self, ac): 88 | return self.env.step(ac) 89 | 90 | 91 | class EpisodicLifeEnv(gym.Wrapper): 92 | def __init__(self, env): 93 | """Make end-of-life == end-of-episode, but only reset on true game over. 94 | Done by DeepMind for the DQN and co. since it helps value estimation. 95 | """ 96 | gym.Wrapper.__init__(self, env) 97 | self.lives = 0 98 | self.was_real_done = True 99 | 100 | def step(self, action): 101 | obs, reward, done, info = self.env.step(action) 102 | self.was_real_done = done 103 | # check current lives, make loss of life terminal, 104 | # then update lives to handle bonus lives 105 | lives = self.env.unwrapped.ale.lives() 106 | if lives < self.lives and lives > 0: 107 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 108 | # so it's important to keep lives > 0, so that we only reset once 109 | # the environment advertises done. 110 | done = True 111 | self.lives = lives 112 | return obs, reward, done, info 113 | 114 | def reset(self, **kwargs): 115 | """Reset only when lives are exhausted. 116 | This way all states are still reachable even though lives are episodic, 117 | and the learner need not know about any of this behind-the-scenes. 118 | """ 119 | if self.was_real_done: 120 | obs = self.env.reset(**kwargs) 121 | else: 122 | # no-op step to advance from terminal/lost life state 123 | obs, _, _, _ = self.env.step(0) 124 | self.lives = self.env.unwrapped.ale.lives() 125 | return obs 126 | 127 | 128 | class MaxAndSkipEnv(gym.Wrapper): 129 | def __init__(self, env, skip=4): 130 | """Return only every `skip`-th frame""" 131 | gym.Wrapper.__init__(self, env) 132 | # most recent raw observations (for max pooling across time steps) 133 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 134 | self._skip = skip 135 | self.max_frame = np.zeros(env.observation_space.shape, dtype=np.uint8) 136 | 137 | def step(self, action): 138 | """Repeat action, sum reward, and max over last observations.""" 139 | total_reward = 0.0 140 | done = None 141 | for i in range(self._skip): 142 | obs, reward, done, info = self.env.step(action) 143 | if i == self._skip - 2: self._obs_buffer[0] = obs 144 | if i == self._skip - 1: self._obs_buffer[1] = obs 145 | total_reward += reward 146 | if done: 147 | break 148 | # Note that the observation on the done=True frame 149 | # doesn't matter 150 | self.max_frame = self._obs_buffer.max(axis=0) 151 | 152 | return self.max_frame, total_reward, done, info 153 | 154 | def reset(self, **kwargs): 155 | return self.env.reset(**kwargs) 156 | 157 | def render(self, mode='human', **kwargs): 158 | img = self.max_frame 159 | img = cv2.resize(img, (400, 400), interpolation=cv2.INTER_AREA).astype(np.uint8) 160 | if mode == 'rgb_array': 161 | return img 162 | elif mode == 'human': 163 | from gym.envs.classic_control import rendering 164 | if self.viewer is None: 165 | self.viewer = rendering.SimpleImageViewer() 166 | self.viewer.imshow(img) 167 | return self.viewer.isopen 168 | 169 | 170 | class WarpFrame(gym.ObservationWrapper): 171 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): 172 | """ 173 | Warp frames to 84x84 as done in the Nature paper and later work. 174 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which 175 | observation should be warped. 176 | """ 177 | super().__init__(env) 178 | self._width = width 179 | self._height = height 180 | self._grayscale = grayscale 181 | self._key = dict_space_key 182 | if self._grayscale: 183 | num_colors = 1 184 | else: 185 | num_colors = 3 186 | 187 | new_space = gym.spaces.Box( 188 | low=0, 189 | high=255, 190 | shape=(self._height, self._width, num_colors), 191 | dtype=np.uint8, 192 | ) 193 | if self._key is None: 194 | original_space = self.observation_space 195 | self.observation_space = new_space 196 | else: 197 | original_space = self.observation_space.spaces[self._key] 198 | self.observation_space.spaces[self._key] = new_space 199 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 200 | 201 | def observation(self, obs): 202 | if self._key is None: 203 | frame = obs 204 | else: 205 | frame = obs[self._key] 206 | 207 | if self._grayscale: 208 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 209 | frame = cv2.resize( 210 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA 211 | ) 212 | if self._grayscale: 213 | frame = np.expand_dims(frame, -1) 214 | 215 | if self._key is None: 216 | obs = frame 217 | else: 218 | obs = obs.copy() 219 | obs[self._key] = frame 220 | return obs 221 | 222 | 223 | def make_atari(env_id, skip=4, max_episode_steps=None): 224 | """Make Atari games 225 | Parameters 226 | ---------- 227 | env_id: str 228 | name of environment 229 | skip: int 230 | frame skip 231 | max_episode_steps: int 232 | max moves for an episode 233 | """ 234 | env = gym.make(env_id) 235 | assert 'NoFrameskip' in env.spec.id 236 | env = NoopResetEnv(env, noop_max=30) 237 | env = MaxAndSkipEnv(env, skip=skip) 238 | if max_episode_steps is not None: 239 | env = TimeLimit(env, max_episode_steps=max_episode_steps) 240 | return env 241 | 242 | 243 | def set_seed(seed): 244 | # set seed 245 | random.seed(seed) 246 | np.random.seed(seed) 247 | torch.manual_seed(seed) 248 | torch.cuda.manual_seed(seed) 249 | torch.backends.cudnn.deterministic = True 250 | 251 | 252 | def make_results_dir(exp_path, args): 253 | # make the result directory 254 | os.makedirs(exp_path, exist_ok=True) 255 | if args.opr == 'train' and os.path.exists(exp_path) and os.listdir(exp_path): 256 | if not args.force: 257 | raise FileExistsError('{} is not empty. Please use --force to overwrite it'.format(exp_path)) 258 | else: 259 | print('Warning, path exists! Rewriting...') 260 | shutil.rmtree(exp_path) 261 | os.makedirs(exp_path) 262 | log_path = os.path.join(exp_path, 'logs') 263 | os.makedirs(log_path, exist_ok=True) 264 | os.makedirs(os.path.join(exp_path, 'model'), exist_ok=True) 265 | return exp_path, log_path 266 | 267 | 268 | def init_logger(base_path): 269 | # initialize the logger 270 | formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s][%(filename)s>%(funcName)s] ==> %(message)s') 271 | for mode in ['train', 'test', 'train_test', 'root']: 272 | file_path = os.path.join(base_path, mode + '.log') 273 | logger = logging.getLogger(mode) 274 | handler = logging.StreamHandler() 275 | handler.setFormatter(formatter) 276 | logger.addHandler(handler) 277 | handler = logging.FileHandler(file_path, mode='a') 278 | handler.setFormatter(formatter) 279 | logger.addHandler(handler) 280 | logger.setLevel(logging.DEBUG) 281 | 282 | 283 | def select_action(visit_counts, temperature=1, deterministic=True): 284 | """select action from the root visit counts. 285 | Parameters 286 | ---------- 287 | temperature: float 288 | the temperature for the distribution 289 | deterministic: bool 290 | True -> select the argmax 291 | False -> sample from the distribution 292 | """ 293 | action_probs = [visit_count_i ** (1 / temperature) for visit_count_i in visit_counts] 294 | total_count = sum(action_probs) 295 | action_probs = [x / total_count for x in action_probs] 296 | if deterministic: 297 | # best_actions = np.argwhere(visit_counts == np.amax(visit_counts)).flatten() 298 | # action_pos = np.random.choice(best_actions) 299 | action_pos = np.argmax([v for v in visit_counts]) 300 | else: 301 | action_pos = np.random.choice(len(visit_counts), p=action_probs) 302 | 303 | count_entropy = entropy(action_probs, base=2) 304 | return action_pos, count_entropy 305 | 306 | 307 | def prepare_observation_lst(observation_lst): 308 | """Prepare the observations to satisfy the input fomat of torch 309 | [B, S, W, H, C] -> [B, S x C, W, H] 310 | batch, stack num, width, height, channel 311 | """ 312 | # B, S, W, H, C 313 | observation_lst = np.array(observation_lst, dtype=np.uint8) 314 | observation_lst = np.moveaxis(observation_lst, -1, 2) 315 | 316 | shape = observation_lst.shape 317 | observation_lst = observation_lst.reshape((shape[0], -1, shape[-2], shape[-1])) 318 | 319 | return observation_lst 320 | 321 | 322 | def arr_to_str(arr): 323 | """To reduce memory usage, we choose to store the jpeg strings of image instead of the numpy array in the buffer. 324 | This function encodes the observation numpy arr to the jpeg strings 325 | """ 326 | img_str = cv2.imencode('.jpg', arr)[1].tobytes() 327 | 328 | return img_str 329 | 330 | 331 | def str_to_arr(s, gray_scale=False): 332 | """To reduce memory usage, we choose to store the jpeg strings of image instead of the numpy array in the buffer. 333 | This function decodes the observation numpy arr from the jpeg strings 334 | Parameters 335 | ---------- 336 | s: string 337 | the inputs 338 | gray_scale: bool 339 | True -> the inputs observation is gray not RGB. 340 | """ 341 | nparr = np.frombuffer(s, np.uint8) 342 | if gray_scale: 343 | arr = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) 344 | arr = np.expand_dims(arr, -1) 345 | else: 346 | arr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 347 | 348 | return arr 349 | -------------------------------------------------------------------------------- /core/ctree/cnode.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cnode.h" 3 | 4 | namespace tree{ 5 | 6 | CSearchResults::CSearchResults(){ 7 | this->num = 0; 8 | } 9 | 10 | CSearchResults::CSearchResults(int num){ 11 | this->num = num; 12 | for(int i = 0; i < num; ++i){ 13 | this->search_paths.push_back(std::vector()); 14 | } 15 | } 16 | 17 | CSearchResults::~CSearchResults(){} 18 | 19 | //********************************************************* 20 | 21 | CNode::CNode(){ 22 | this->prior = 0; 23 | this->action_num = 0; 24 | this->best_action = -1; 25 | 26 | this->is_reset = 0; 27 | this->visit_count = 0; 28 | this->value_sum = 0; 29 | this->to_play = 0; 30 | this->value_prefix = 0.0; 31 | this->ptr_node_pool = nullptr; 32 | } 33 | 34 | CNode::CNode(float prior, int action_num, std::vector* ptr_node_pool){ 35 | this->prior = prior; 36 | this->action_num = action_num; 37 | 38 | this->is_reset = 0; 39 | this->visit_count = 0; 40 | this->value_sum = 0; 41 | this->best_action = -1; 42 | this->to_play = 0; 43 | this->value_prefix = 0.0; 44 | this->ptr_node_pool = ptr_node_pool; 45 | this->hidden_state_index_x = -1; 46 | this->hidden_state_index_y = -1; 47 | } 48 | 49 | CNode::~CNode(){} 50 | 51 | void CNode::expand(int to_play, int hidden_state_index_x, int hidden_state_index_y, float value_prefix, const std::vector &policy_logits){ 52 | this->to_play = to_play; 53 | this->hidden_state_index_x = hidden_state_index_x; 54 | this->hidden_state_index_y = hidden_state_index_y; 55 | this->value_prefix = value_prefix; 56 | 57 | int action_num = this->action_num; 58 | float temp_policy; 59 | float policy_sum = 0.0; 60 | float policy[action_num]; 61 | float policy_max = FLOAT_MIN; 62 | for(int a = 0; a < action_num; ++a){ 63 | if(policy_max < policy_logits[a]){ 64 | policy_max = policy_logits[a]; 65 | } 66 | } 67 | 68 | for(int a = 0; a < action_num; ++a){ 69 | temp_policy = exp(policy_logits[a] - policy_max); 70 | policy_sum += temp_policy; 71 | policy[a] = temp_policy; 72 | } 73 | 74 | float prior; 75 | std::vector* ptr_node_pool = this->ptr_node_pool; 76 | for(int a = 0; a < action_num; ++a){ 77 | prior = policy[a] / policy_sum; 78 | int index = ptr_node_pool->size(); 79 | this->children_index.push_back(index); 80 | 81 | ptr_node_pool->push_back(CNode(prior, action_num, ptr_node_pool)); 82 | } 83 | } 84 | 85 | void CNode::add_exploration_noise(float exploration_fraction, const std::vector &noises){ 86 | float noise, prior; 87 | for(int a = 0; a < this->action_num; ++a){ 88 | noise = noises[a]; 89 | CNode* child = this->get_child(a); 90 | 91 | prior = child->prior; 92 | child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction; 93 | } 94 | } 95 | 96 | float CNode::get_mean_q(int isRoot, float parent_q, float discount){ 97 | float total_unsigned_q = 0.0; 98 | int total_visits = 0; 99 | float parent_value_prefix = this->value_prefix; 100 | for(int a = 0; a < this->action_num; ++a){ 101 | CNode* child = this->get_child(a); 102 | if(child->visit_count > 0){ 103 | float true_reward = child->value_prefix - parent_value_prefix; 104 | if(this->is_reset == 1){ 105 | true_reward = child->value_prefix; 106 | } 107 | float qsa = true_reward + discount * child->value(); 108 | total_unsigned_q += qsa; 109 | total_visits += 1; 110 | } 111 | } 112 | 113 | float mean_q = 0.0; 114 | if(isRoot && total_visits > 0){ 115 | mean_q = (total_unsigned_q) / (total_visits); 116 | } 117 | else{ 118 | mean_q = (parent_q + total_unsigned_q) / (total_visits + 1); 119 | } 120 | return mean_q; 121 | } 122 | 123 | void CNode::print_out(){ 124 | return; 125 | } 126 | 127 | int CNode::expanded(){ 128 | int child_num = this->children_index.size(); 129 | if(child_num > 0) { 130 | return 1; 131 | } 132 | else { 133 | return 0; 134 | } 135 | } 136 | 137 | float CNode::value(){ 138 | float true_value = 0.0; 139 | if(this->visit_count == 0){ 140 | return true_value; 141 | } 142 | else{ 143 | true_value = this->value_sum / this->visit_count; 144 | return true_value; 145 | } 146 | } 147 | 148 | std::vector CNode::get_trajectory(){ 149 | std::vector traj; 150 | 151 | CNode* node = this; 152 | int best_action = node->best_action; 153 | while(best_action >= 0){ 154 | traj.push_back(best_action); 155 | 156 | node = node->get_child(best_action); 157 | best_action = node->best_action; 158 | } 159 | return traj; 160 | } 161 | 162 | std::vector CNode::get_children_distribution(){ 163 | std::vector distribution; 164 | if(this->expanded()){ 165 | for(int a = 0; a < this->action_num; ++a){ 166 | CNode* child = this->get_child(a); 167 | distribution.push_back(child->visit_count); 168 | } 169 | } 170 | return distribution; 171 | } 172 | 173 | CNode* CNode::get_child(int action){ 174 | int index = this->children_index[action]; 175 | return &((*(this->ptr_node_pool))[index]); 176 | } 177 | 178 | //********************************************************* 179 | 180 | CRoots::CRoots(){ 181 | this->root_num = 0; 182 | this->action_num = 0; 183 | this->pool_size = 0; 184 | } 185 | 186 | CRoots::CRoots(int root_num, int action_num, int pool_size){ 187 | this->root_num = root_num; 188 | this->action_num = action_num; 189 | this->pool_size = pool_size; 190 | 191 | this->node_pools.reserve(root_num); 192 | this->roots.reserve(root_num); 193 | 194 | for(int i = 0; i < root_num; ++i){ 195 | this->node_pools.push_back(std::vector()); 196 | this->node_pools[i].reserve(pool_size); 197 | 198 | this->roots.push_back(CNode(0, action_num, &this->node_pools[i])); 199 | } 200 | } 201 | 202 | CRoots::~CRoots(){} 203 | 204 | void CRoots::prepare(float root_exploration_fraction, const std::vector> &noises, const std::vector &value_prefixs, const std::vector> &policies){ 205 | for(int i = 0; i < this->root_num; ++i){ 206 | this->roots[i].expand(0, 0, i, value_prefixs[i], policies[i]); 207 | this->roots[i].add_exploration_noise(root_exploration_fraction, noises[i]); 208 | 209 | this->roots[i].visit_count += 1; 210 | } 211 | } 212 | 213 | void CRoots::prepare_no_noise(const std::vector &value_prefixs, const std::vector> &policies){ 214 | for(int i = 0; i < this->root_num; ++i){ 215 | this->roots[i].expand(0, 0, i, value_prefixs[i], policies[i]); 216 | 217 | this->roots[i].visit_count += 1; 218 | } 219 | } 220 | 221 | void CRoots::clear(){ 222 | this->node_pools.clear(); 223 | this->roots.clear(); 224 | } 225 | 226 | std::vector> CRoots::get_trajectories(){ 227 | std::vector> trajs; 228 | trajs.reserve(this->root_num); 229 | 230 | for(int i = 0; i < this->root_num; ++i){ 231 | trajs.push_back(this->roots[i].get_trajectory()); 232 | } 233 | return trajs; 234 | } 235 | 236 | std::vector> CRoots::get_distributions(){ 237 | std::vector> distributions; 238 | distributions.reserve(this->root_num); 239 | 240 | for(int i = 0; i < this->root_num; ++i){ 241 | distributions.push_back(this->roots[i].get_children_distribution()); 242 | } 243 | return distributions; 244 | } 245 | 246 | std::vector CRoots::get_values(){ 247 | std::vector values; 248 | for(int i = 0; i < this->root_num; ++i){ 249 | values.push_back(this->roots[i].value()); 250 | } 251 | return values; 252 | } 253 | 254 | //********************************************************* 255 | 256 | void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount){ 257 | std::stack node_stack; 258 | node_stack.push(root); 259 | float parent_value_prefix = 0.0; 260 | int is_reset = 0; 261 | while(node_stack.size() > 0){ 262 | CNode* node = node_stack.top(); 263 | node_stack.pop(); 264 | 265 | if(node != root){ 266 | float true_reward = node->value_prefix - parent_value_prefix; 267 | if(is_reset == 1){ 268 | true_reward = node->value_prefix; 269 | } 270 | float qsa = true_reward + discount * node->value(); 271 | min_max_stats.update(qsa); 272 | } 273 | 274 | for(int a = 0; a < node->action_num; ++a){ 275 | CNode* child = node->get_child(a); 276 | if(child->expanded()){ 277 | node_stack.push(child); 278 | } 279 | } 280 | 281 | parent_value_prefix = node->value_prefix; 282 | is_reset = node->is_reset; 283 | } 284 | } 285 | 286 | void cback_propagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount){ 287 | float bootstrap_value = value; 288 | int path_len = search_path.size(); 289 | for(int i = path_len - 1; i >= 0; --i){ 290 | CNode* node = search_path[i]; 291 | node->value_sum += bootstrap_value; 292 | node->visit_count += 1; 293 | 294 | float parent_value_prefix = 0.0; 295 | int is_reset = 0; 296 | if(i >= 1){ 297 | CNode* parent = search_path[i - 1]; 298 | parent_value_prefix = parent->value_prefix; 299 | is_reset = parent->is_reset; 300 | // float qsa = (node->value_prefix - parent_value_prefix) + discount * node->value(); 301 | // min_max_stats.update(qsa); 302 | } 303 | 304 | float true_reward = node->value_prefix - parent_value_prefix; 305 | if(is_reset == 1){ 306 | // parent is reset 307 | true_reward = node->value_prefix; 308 | } 309 | 310 | bootstrap_value = true_reward + discount * bootstrap_value; 311 | } 312 | min_max_stats.clear(); 313 | CNode* root = search_path[0]; 314 | update_tree_q(root, min_max_stats, discount); 315 | } 316 | 317 | void cbatch_back_propagate(int hidden_state_index_x, float discount, const std::vector &value_prefixs, const std::vector &values, const std::vector> &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_lst){ 318 | for(int i = 0; i < results.num; ++i){ 319 | results.nodes[i]->expand(0, hidden_state_index_x, i, value_prefixs[i], policies[i]); 320 | // reset 321 | results.nodes[i]->is_reset = is_reset_lst[i]; 322 | 323 | cback_propagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], 0, values[i], discount); 324 | } 325 | } 326 | 327 | int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q){ 328 | float max_score = FLOAT_MIN; 329 | const float epsilon = 0.000001; 330 | std::vector max_index_lst; 331 | for(int a = 0; a < root->action_num; ++a){ 332 | CNode* child = root->get_child(a); 333 | float temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount); 334 | 335 | if(max_score < temp_score){ 336 | max_score = temp_score; 337 | 338 | max_index_lst.clear(); 339 | max_index_lst.push_back(a); 340 | } 341 | else if(temp_score >= max_score - epsilon){ 342 | max_index_lst.push_back(a); 343 | } 344 | } 345 | 346 | int action = 0; 347 | if(max_index_lst.size() > 0){ 348 | int rand_index = rand() % max_index_lst.size(); 349 | action = max_index_lst[rand_index]; 350 | } 351 | return action; 352 | } 353 | 354 | float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount){ 355 | float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; 356 | pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; 357 | pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); 358 | 359 | prior_score = pb_c * child->prior; 360 | if (child->visit_count == 0){ 361 | value_score = parent_mean_q; 362 | } 363 | else { 364 | float true_reward = child->value_prefix - parent_value_prefix; 365 | if(is_reset == 1){ 366 | true_reward = child->value_prefix; 367 | } 368 | value_score = true_reward + discount * child->value(); 369 | } 370 | 371 | value_score = min_max_stats.normalize(value_score); 372 | 373 | if (value_score < 0) value_score = 0; 374 | if (value_score > 1) value_score = 1; 375 | 376 | float ucb_value = prior_score + value_score; 377 | return ucb_value; 378 | } 379 | 380 | void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results){ 381 | // set seed 382 | timeval t1; 383 | gettimeofday(&t1, NULL); 384 | srand(t1.tv_usec); 385 | 386 | int last_action = -1; 387 | float parent_q = 0.0; 388 | results.search_lens = std::vector(); 389 | for(int i = 0; i < results.num; ++i){ 390 | CNode *node = &(roots->roots[i]); 391 | int is_root = 1; 392 | int search_len = 0; 393 | results.search_paths[i].push_back(node); 394 | 395 | while(node->expanded()){ 396 | float mean_q = node->get_mean_q(is_root, parent_q, discount); 397 | is_root = 0; 398 | parent_q = mean_q; 399 | 400 | int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount, mean_q); 401 | node->best_action = action; 402 | // next 403 | node = node->get_child(action); 404 | last_action = action; 405 | results.search_paths[i].push_back(node); 406 | search_len += 1; 407 | } 408 | 409 | CNode* parent = results.search_paths[i][results.search_paths[i].size() - 2]; 410 | 411 | results.hidden_state_index_x_lst.push_back(parent->hidden_state_index_x); 412 | results.hidden_state_index_y_lst.push_back(parent->hidden_state_index_y); 413 | 414 | results.last_actions.push_back(last_action); 415 | results.search_lens.push_back(search_len); 416 | results.nodes.push_back(node); 417 | } 418 | } 419 | 420 | } -------------------------------------------------------------------------------- /core/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | 5 | import numpy as np 6 | 7 | from core.game import Game 8 | 9 | 10 | class DiscreteSupport(object): 11 | def __init__(self, min: int, max: int, delta=1.): 12 | assert min < max 13 | self.min = min 14 | self.max = max 15 | self.range = np.arange(min, max + 1, delta) 16 | self.size = len(self.range) 17 | self.delta = delta 18 | 19 | 20 | class BaseConfig(object): 21 | 22 | def __init__(self, 23 | training_steps: int, 24 | last_steps: int, 25 | test_interval: int, 26 | test_episodes: int, 27 | checkpoint_interval: int, 28 | target_model_interval: int, 29 | save_ckpt_interval: int, 30 | log_interval: int, 31 | vis_interval: int, 32 | max_moves: int, 33 | test_max_moves: int, 34 | history_length: int, 35 | discount: float, 36 | dirichlet_alpha: float, 37 | value_delta_max: float, 38 | num_simulations: int, 39 | batch_size: int, 40 | td_steps: int, 41 | num_actors: int, 42 | lr_warm_up: float, 43 | lr_init: float, 44 | lr_decay_rate: float, 45 | lr_decay_steps: float, 46 | start_transitions: int, 47 | auto_td_steps_ratio: float = 0.3, 48 | total_transitions: int = 100 * 1000, 49 | transition_num: float = 25, 50 | do_consistency: bool = True, 51 | use_value_prefix: bool = True, 52 | off_correction: bool = True, 53 | gray_scale: bool = False, 54 | episode_life: bool = False, 55 | change_temperature: bool = True, 56 | init_zero: bool = False, 57 | state_norm: bool = False, 58 | clip_reward: bool = False, 59 | random_start: bool = True, 60 | cvt_string: bool = False, 61 | image_based: bool = False, 62 | frame_skip: int = 1, 63 | stacked_observations: int = 16, 64 | lstm_hidden_size: int = 64, 65 | lstm_horizon_len: int = 1, 66 | reward_loss_coeff: float = 1, 67 | value_loss_coeff: float = 1, 68 | policy_loss_coeff: float = 1, 69 | consistency_coeff: float = 1, 70 | proj_hid: int = 256, 71 | proj_out: int = 256, 72 | pred_hid: int = 64, 73 | pred_out: int = 256, 74 | value_support: DiscreteSupport = DiscreteSupport(-300, 300, delta=1), 75 | reward_support: DiscreteSupport = DiscreteSupport(-300, 300, delta=1)): 76 | """Base Config for EfficietnZero 77 | Parameters 78 | ---------- 79 | training_steps: int 80 | training steps while collecting data 81 | last_steps: int 82 | training steps without collecting data after @training_steps; 83 | So total training steps = training_steps + last_steps 84 | test_interval: int 85 | interval of testing 86 | test_episodes: int 87 | episodes of testing 88 | checkpoint_interval: int 89 | interval of updating the models for self-play 90 | target_model_interval: int 91 | interval of updating the target models for reanalyzing 92 | save_ckpt_interval: int 93 | interval of saving models 94 | log_interval: int 95 | interval of logging 96 | vis_interval: int 97 | interval of visualizations for some distributions and loggings 98 | max_moves: int 99 | max number of moves for an episode 100 | test_max_moves: int 101 | max number of moves for an episode during testing (in training stage), 102 | set this small to make sure the game will end faster. 103 | history_length: int 104 | horizons of each stored history trajectory. 105 | The horizons of Atari games are quite large. Split the whole trajectory into several history blocks. 106 | discount: float 107 | discount of env 108 | dirichlet_alpha: float 109 | dirichlet alpha of exploration noise in MCTS. 110 | Smaller -> more exploration 111 | value_delta_max: float 112 | the threshold in the minmax normalization of Q-values in MCTS. 113 | See the soft minimum-maximum updates in Appendix. 114 | num_simulations: int 115 | number of simulations in MCTS 116 | batch_size: int, 117 | batch size 118 | td_steps: int 119 | td steps for bootstrapped value targets 120 | num_actors: int 121 | number of self-play actors 122 | lr_warm_up: float 123 | rate of learning rate warm up 124 | lr_init: float 125 | initial lr 126 | lr_decay_rate: float 127 | how much lr drops every time 128 | lr -> lr * lr_decay_rate 129 | lr_decay_steps: float 130 | lr drops every lr_decay_steps 131 | start_transitions: int 132 | least transition numbers to start the training steps ( larger than batch size) 133 | auto_td_steps_ratio: float 134 | ratio of short td steps, samller td steps for older trajectories. 135 | auto_td_steps = auto_td_steps_ratio * training_steps 136 | See the details of off-policy correction in Appendix. 137 | total_transitions: int 138 | total number of collected transitions. (100k setting) 139 | transition_num: float 140 | capacity of transitions in replay buffer 141 | do_consistency: bool 142 | True -> use temporal consistency 143 | use_value_prefix: bool = True, 144 | True -> predict value prefix 145 | off_correction: bool 146 | True -> use off-policy correction 147 | gray_scale: bool 148 | True -> use gray image observation 149 | episode_life: bool 150 | True -> one life in atari games 151 | change_temperature: bool 152 | True -> change temperature of visit count distributions 153 | init_zero: bool 154 | True -> zero initialization for the last layer of mlps 155 | state_norm: bool 156 | True -> normalization for hidden states 157 | clip_reward: bool 158 | True -> clip the reward, reward -> sign(reward) 159 | random_start: bool 160 | True -> random actions in self-play before startng training 161 | cvt_string: bool 162 | True -> convert the observation into string in the replay buffer 163 | image_based: bool 164 | True -> observation is image based 165 | frame_skip: int 166 | number of frame skip 167 | stacked_observations: int 168 | number of frame stack 169 | lstm_hidden_size: int 170 | dim of lstm hidden 171 | lstm_horizon_len: int 172 | horizons of value prefix prediction, 1 <= lstm_horizon_len <= num_unroll_steps 173 | reward_loss_coeff: float 174 | coefficient of reward loss 175 | value_loss_coeff: float 176 | coefficient of value loss 177 | policy_loss_coeff: float 178 | coefficient of policy loss 179 | consistency_coeff: float 180 | coefficient of consistency loss 181 | proj_hid: int 182 | dim of projection hidden layer 183 | proj_out: int 184 | dim of projection output layer 185 | pred_hid: int 186 | dim of projection head (prediction) hidden layer 187 | pred_out: int 188 | dim of projection head (prediction) output layer 189 | value_support: DiscreteSupport 190 | support of value to represent the value scalars 191 | reward_support: DiscreteSupport 192 | support of reward to represent the reward scalars 193 | """ 194 | # Self-Play 195 | self.action_space_size = None 196 | self.num_actors = num_actors 197 | self.do_consistency = do_consistency 198 | self.use_value_prefix = use_value_prefix 199 | self.off_correction = off_correction 200 | self.gray_scale = gray_scale 201 | self.auto_td_steps_ratio = auto_td_steps_ratio 202 | self.episode_life = episode_life 203 | self.change_temperature = change_temperature 204 | self.init_zero = init_zero 205 | self.state_norm = state_norm 206 | self.clip_reward = clip_reward 207 | self.random_start = random_start 208 | self.cvt_string = cvt_string 209 | self.image_based = image_based 210 | 211 | self.max_moves = max_moves 212 | self.test_max_moves = test_max_moves 213 | self.history_length = history_length 214 | self.num_simulations = num_simulations 215 | self.discount = discount 216 | self.max_grad_norm = 5 217 | 218 | # testing arguments 219 | self.test_interval = test_interval 220 | self.test_episodes = test_episodes 221 | 222 | # Root prior exploration noise. 223 | self.value_delta_max = value_delta_max 224 | self.root_dirichlet_alpha = dirichlet_alpha 225 | self.root_exploration_fraction = 0.25 226 | 227 | # UCB formula 228 | self.pb_c_base = 19652 229 | self.pb_c_init = 1.25 230 | 231 | # Training 232 | self.training_steps = training_steps 233 | self.last_steps = last_steps 234 | self.checkpoint_interval = checkpoint_interval 235 | self.target_model_interval = target_model_interval 236 | self.save_ckpt_interval = save_ckpt_interval 237 | self.log_interval = log_interval 238 | self.vis_interval = vis_interval 239 | self.start_transitions = start_transitions 240 | self.total_transitions = total_transitions 241 | self.transition_num = transition_num 242 | self.batch_size = batch_size 243 | # unroll steps 244 | self.num_unroll_steps = 5 245 | self.td_steps = td_steps 246 | self.frame_skip = frame_skip 247 | self.stacked_observations = stacked_observations 248 | self.lstm_hidden_size = lstm_hidden_size 249 | self.lstm_horizon_len = lstm_horizon_len 250 | self.reward_loss_coeff = reward_loss_coeff 251 | self.value_loss_coeff = value_loss_coeff 252 | self.policy_loss_coeff = policy_loss_coeff 253 | self.consistency_coeff = consistency_coeff 254 | self.device = 'cuda' 255 | self.exp_path = None # experiment path 256 | self.debug = False 257 | self.model_path = None 258 | self.seed = None 259 | self.transforms = None 260 | self.value_support = value_support 261 | self.reward_support = reward_support 262 | 263 | # optimization control 264 | self.weight_decay = 1e-4 265 | self.momentum = 0.9 266 | self.lr_warm_up = lr_warm_up 267 | self.lr_warm_step = int(self.training_steps * self.lr_warm_up) 268 | self.lr_init = lr_init 269 | self.lr_decay_rate = lr_decay_rate 270 | self.lr_decay_steps = lr_decay_steps 271 | self.mini_infer_size = 64 272 | 273 | # replay buffer, priority related 274 | self.priority_prob_alpha = 0.6 275 | self.priority_prob_beta = 0.4 276 | self.prioritized_replay_eps = 1e-6 277 | 278 | # env 279 | self.image_channel = 3 280 | 281 | # contrastive arch 282 | self.proj_hid = proj_hid 283 | self.proj_out = proj_out 284 | self.pred_hid = pred_hid 285 | self.pred_out = pred_out 286 | 287 | def visit_softmax_temperature_fn(self, num_moves, trained_steps): 288 | raise NotImplementedError 289 | 290 | def set_game(self, env_name): 291 | raise NotImplementedError 292 | 293 | def new_game(self, seed=None, save_video=False, save_path=None, video_callable=None, uid=None, test=False) -> Game: 294 | """ returns a new instance of the game""" 295 | raise NotImplementedError 296 | 297 | def get_uniform_network(self): 298 | raise NotImplementedError 299 | 300 | def scalar_loss(self, prediction, target): 301 | raise NotImplementedError 302 | 303 | def scalar_transform(self, x): 304 | """ Reference from MuZerp: Appendix F => Network Architecture 305 | & Appendix A : Proposition A.2 in https://arxiv.org/pdf/1805.11593.pdf (Page-11) 306 | """ 307 | delta = self.value_support.delta 308 | assert delta == 1 309 | epsilon = 0.001 310 | sign = torch.ones(x.shape).float().to(x.device) 311 | sign[x < 0] = -1.0 312 | output = sign * (torch.sqrt(torch.abs(x / delta) + 1) - 1) + epsilon * x / delta 313 | return output 314 | 315 | def inverse_reward_transform(self, reward_logits): 316 | return self.inverse_scalar_transform(reward_logits, self.reward_support) 317 | 318 | def inverse_value_transform(self, value_logits): 319 | return self.inverse_scalar_transform(value_logits, self.value_support) 320 | 321 | def inverse_scalar_transform(self, logits, scalar_support): 322 | """ Reference from MuZerp: Appendix F => Network Architecture 323 | & Appendix A : Proposition A.2 in https://arxiv.org/pdf/1805.11593.pdf (Page-11) 324 | """ 325 | delta = self.value_support.delta 326 | value_probs = torch.softmax(logits, dim=1) 327 | value_support = torch.ones(value_probs.shape) 328 | value_support[:, :] = torch.from_numpy(np.array([x for x in scalar_support.range])) 329 | value_support = value_support.to(device=value_probs.device) 330 | value = (value_support * value_probs).sum(1, keepdim=True) / delta 331 | 332 | epsilon = 0.001 333 | sign = torch.ones(value.shape).float().to(value.device) 334 | sign[value < 0] = -1.0 335 | output = (((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1) 336 | output = sign * output * delta 337 | 338 | nan_part = torch.isnan(output) 339 | output[nan_part] = 0. 340 | output[torch.abs(output) < epsilon] = 0. 341 | return output 342 | 343 | def value_phi(self, x): 344 | return self._phi(x, self.value_support.min, self.value_support.max, self.value_support.size) 345 | 346 | def reward_phi(self, x): 347 | return self._phi(x, self.reward_support.min, self.reward_support.max, self.reward_support.size) 348 | 349 | def _phi(self, x, min, max, set_size: int): 350 | delta = self.value_support.delta 351 | 352 | x.clamp_(min, max) 353 | x_low = x.floor() 354 | x_high = x.ceil() 355 | p_high = x - x_low 356 | p_low = 1 - p_high 357 | 358 | target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device) 359 | x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta 360 | target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) 361 | target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) 362 | return target 363 | 364 | def get_hparams(self): 365 | # get all the hyper-parameters 366 | hparams = {} 367 | for k, v in self.__dict__.items(): 368 | if 'path' not in k and (v is not None): 369 | hparams[k] = v 370 | return hparams 371 | 372 | def set_config(self, args): 373 | # reset config from the args 374 | self.set_game(args.env) 375 | self.case = args.case 376 | self.seed = args.seed 377 | if not args.use_priority: 378 | self.priority_prob_alpha = 0 379 | self.amp_type = args.amp_type 380 | self.use_priority = args.use_priority 381 | self.use_max_priority = args.use_max_priority if self.use_priority else False 382 | self.debug = args.debug 383 | self.device = args.device 384 | self.cpu_actor = args.cpu_actor 385 | self.gpu_actor = args.gpu_actor 386 | self.p_mcts_num = args.p_mcts_num 387 | self.use_root_value = args.use_root_value 388 | 389 | if not self.do_consistency: 390 | self.consistency_coeff = 0 391 | self.augmentation = None 392 | self.use_augmentation = False 393 | 394 | if not self.use_value_prefix: 395 | self.lstm_horizon_len = 1 396 | 397 | if not self.off_correction: 398 | self.auto_td_steps = self.training_steps 399 | else: 400 | self.auto_td_steps = self.auto_td_steps_ratio * self.training_steps 401 | 402 | assert 0 <= self.lr_warm_up <= 0.1 403 | assert 1 <= self.lstm_horizon_len <= self.num_unroll_steps 404 | assert self.start_transitions >= self.batch_size 405 | 406 | # augmentation 407 | if self.consistency_coeff > 0 and args.use_augmentation: 408 | self.use_augmentation = True 409 | self.augmentation = args.augmentation 410 | else: 411 | self.use_augmentation = False 412 | 413 | if args.revisit_policy_search_rate is not None: 414 | self.revisit_policy_search_rate = args.revisit_policy_search_rate 415 | 416 | localtime = time.asctime(time.localtime(time.time())) 417 | seed_tag = 'seed={}'.format(self.seed) 418 | self.exp_path = os.path.join(args.result_dir, args.case, args.info, args.env, seed_tag, localtime) 419 | 420 | self.model_path = os.path.join(self.exp_path, 'model.p') 421 | self.model_dir = os.path.join(self.exp_path, 'model') 422 | return self.exp_path 423 | -------------------------------------------------------------------------------- /core/selfplay_worker.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import time 3 | import torch 4 | 5 | import numpy as np 6 | import core.ctree.cytree as cytree 7 | 8 | from torch.nn import L1Loss 9 | from torch.cuda.amp import autocast as autocast 10 | from core.mcts import MCTS 11 | from core.game import GameHistory 12 | from core.utils import select_action, prepare_observation_lst 13 | 14 | 15 | @ray.remote(num_gpus=0.125) 16 | class DataWorker(object): 17 | def __init__(self, rank, replay_buffer, storage, config): 18 | """Data Worker for collecting data through self-play 19 | Parameters 20 | ---------- 21 | rank: int 22 | id of the worker 23 | replay_buffer: Any 24 | Replay buffer 25 | storage: Any 26 | The model storage 27 | """ 28 | self.rank = rank 29 | self.config = config 30 | self.storage = storage 31 | self.replay_buffer = replay_buffer 32 | # double buffering when data is sufficient 33 | self.trajectory_pool = [] 34 | self.pool_size = 1 35 | self.device = self.config.device 36 | self.gap_step = self.config.num_unroll_steps + self.config.td_steps 37 | self.last_model_index = -1 38 | 39 | def put(self, data): 40 | # put a game history into the pool 41 | self.trajectory_pool.append(data) 42 | 43 | def len_pool(self): 44 | # current pool size 45 | return len(self.trajectory_pool) 46 | 47 | def free(self): 48 | # save the game histories and clear the pool 49 | if self.len_pool() >= self.pool_size: 50 | self.replay_buffer.save_pools.remote(self.trajectory_pool, self.gap_step) 51 | del self.trajectory_pool[:] 52 | 53 | def put_last_trajectory(self, i, last_game_histories, last_game_priorities, game_histories): 54 | """put the last game history into the pool if the current game is finished 55 | Parameters 56 | ---------- 57 | last_game_histories: list 58 | list of the last game histories 59 | last_game_priorities: list 60 | list of the last game priorities 61 | game_histories: list 62 | list of the current game histories 63 | """ 64 | # pad over last block trajectory 65 | beg_index = self.config.stacked_observations 66 | end_index = beg_index + self.config.num_unroll_steps 67 | 68 | pad_obs_lst = game_histories[i].obs_history[beg_index:end_index] 69 | pad_child_visits_lst = game_histories[i].child_visits[beg_index:end_index] 70 | 71 | beg_index = 0 72 | end_index = beg_index + self.gap_step - 1 73 | 74 | pad_reward_lst = game_histories[i].rewards[beg_index:end_index] 75 | 76 | beg_index = 0 77 | end_index = beg_index + self.gap_step 78 | 79 | pad_root_values_lst = game_histories[i].root_values[beg_index:end_index] 80 | 81 | # pad over and save 82 | last_game_histories[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) 83 | last_game_histories[i].game_over() 84 | 85 | self.put((last_game_histories[i], last_game_priorities[i])) 86 | self.free() 87 | 88 | # reset last block 89 | last_game_histories[i] = None 90 | last_game_priorities[i] = None 91 | 92 | def get_priorities(self, i, pred_values_lst, search_values_lst): 93 | # obtain the priorities at index i 94 | if self.config.use_priority and not self.config.use_max_priority: 95 | pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.device).float() 96 | search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.device).float() 97 | priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + self.config.prioritized_replay_eps 98 | else: 99 | # priorities is None -> use the max priority for all newly collected data 100 | priorities = None 101 | 102 | return priorities 103 | 104 | def run(self): 105 | # number of parallel mcts 106 | env_nums = self.config.p_mcts_num 107 | model = self.config.get_uniform_network() 108 | model.to(self.device) 109 | model.eval() 110 | 111 | start_training = False 112 | envs = [self.config.new_game(self.config.seed + (self.rank + 1) * i) for i in range(env_nums)] 113 | 114 | def _get_max_entropy(action_space): 115 | p = 1.0 / action_space 116 | ep = - action_space * p * np.log2(p) 117 | return ep 118 | max_visit_entropy = _get_max_entropy(self.config.action_space_size) 119 | # 100k benchmark 120 | total_transitions = 0 121 | # max transition to collect for this data worker 122 | max_transitions = self.config.total_transitions // self.config.num_actors 123 | with torch.no_grad(): 124 | while True: 125 | trained_steps = ray.get(self.storage.get_counter.remote()) 126 | # training finished 127 | if trained_steps >= self.config.training_steps + self.config.last_steps: 128 | time.sleep(30) 129 | break 130 | 131 | init_obses = [env.reset() for env in envs] 132 | dones = np.array([False for _ in range(env_nums)]) 133 | game_histories = [GameHistory(envs[_].env.action_space, max_length=self.config.history_length, 134 | config=self.config) for _ in range(env_nums)] 135 | last_game_histories = [None for _ in range(env_nums)] 136 | last_game_priorities = [None for _ in range(env_nums)] 137 | 138 | # stack observation windows in boundary: s398, s399, s400, current s1 -> for not init trajectory 139 | stack_obs_windows = [[] for _ in range(env_nums)] 140 | 141 | for i in range(env_nums): 142 | stack_obs_windows[i] = [init_obses[i] for _ in range(self.config.stacked_observations)] 143 | game_histories[i].init(stack_obs_windows[i]) 144 | 145 | # for priorities in self-play 146 | search_values_lst = [[] for _ in range(env_nums)] 147 | pred_values_lst = [[] for _ in range(env_nums)] 148 | 149 | # some logs 150 | eps_ori_reward_lst, eps_reward_lst, eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums), np.zeros(env_nums), np.zeros(env_nums) 151 | step_counter = 0 152 | 153 | self_play_rewards = 0. 154 | self_play_ori_rewards = 0. 155 | self_play_moves = 0. 156 | self_play_episodes = 0. 157 | 158 | self_play_rewards_max = - np.inf 159 | self_play_moves_max = 0 160 | 161 | self_play_visit_entropy = [] 162 | other_dist = {} 163 | 164 | # play games until max moves 165 | while not dones.all() and (step_counter <= self.config.max_moves): 166 | if not start_training: 167 | start_training = ray.get(self.storage.get_start_signal.remote()) 168 | 169 | # get model 170 | trained_steps = ray.get(self.storage.get_counter.remote()) 171 | if trained_steps >= self.config.training_steps + self.config.last_steps: 172 | # training is finished 173 | time.sleep(30) 174 | return 175 | if start_training and (total_transitions / max_transitions) > (trained_steps / self.config.training_steps): 176 | # self-play is faster than training speed or finished 177 | time.sleep(1) 178 | continue 179 | 180 | # set temperature for distributions 181 | _temperature = np.array( 182 | [self.config.visit_softmax_temperature_fn(num_moves=0, trained_steps=trained_steps) for env in 183 | envs]) 184 | 185 | # update the models in self-play every checkpoint_interval 186 | new_model_index = trained_steps // self.config.checkpoint_interval 187 | if new_model_index > self.last_model_index: 188 | self.last_model_index = new_model_index 189 | # update model 190 | weights = ray.get(self.storage.get_weights.remote()) 191 | model.set_weights(weights) 192 | model.to(self.device) 193 | model.eval() 194 | 195 | # log if more than 1 env in parallel because env will reset in this loop. 196 | if env_nums > 1: 197 | if len(self_play_visit_entropy) > 0: 198 | visit_entropies = np.array(self_play_visit_entropy).mean() 199 | visit_entropies /= max_visit_entropy 200 | else: 201 | visit_entropies = 0. 202 | 203 | if self_play_episodes > 0: 204 | log_self_play_moves = self_play_moves / self_play_episodes 205 | log_self_play_rewards = self_play_rewards / self_play_episodes 206 | log_self_play_ori_rewards = self_play_ori_rewards / self_play_episodes 207 | else: 208 | log_self_play_moves = 0 209 | log_self_play_rewards = 0 210 | log_self_play_ori_rewards = 0 211 | 212 | self.storage.set_data_worker_logs.remote(log_self_play_moves, self_play_moves_max, 213 | log_self_play_ori_rewards, log_self_play_rewards, 214 | self_play_rewards_max, _temperature.mean(), 215 | visit_entropies, 0, 216 | other_dist) 217 | self_play_rewards_max = - np.inf 218 | 219 | step_counter += 1 220 | for i in range(env_nums): 221 | # reset env if finished 222 | if dones[i]: 223 | 224 | # pad over last block trajectory 225 | if last_game_histories[i] is not None: 226 | self.put_last_trajectory(i, last_game_histories, last_game_priorities, game_histories) 227 | 228 | # store current block trajectory 229 | priorities = self.get_priorities(i, pred_values_lst, search_values_lst) 230 | game_histories[i].game_over() 231 | 232 | self.put((game_histories[i], priorities)) 233 | self.free() 234 | 235 | # reset the finished env and new a env 236 | envs[i].close() 237 | init_obs = envs[i].reset() 238 | game_histories[i] = GameHistory(env.env.action_space, max_length=self.config.history_length, 239 | config=self.config) 240 | last_game_histories[i] = None 241 | last_game_priorities[i] = None 242 | stack_obs_windows[i] = [init_obs for _ in range(self.config.stacked_observations)] 243 | game_histories[i].init(stack_obs_windows[i]) 244 | 245 | # log 246 | self_play_rewards_max = max(self_play_rewards_max, eps_reward_lst[i]) 247 | self_play_moves_max = max(self_play_moves_max, eps_steps_lst[i]) 248 | self_play_rewards += eps_reward_lst[i] 249 | self_play_ori_rewards += eps_ori_reward_lst[i] 250 | self_play_visit_entropy.append(visit_entropies_lst[i] / eps_steps_lst[i]) 251 | self_play_moves += eps_steps_lst[i] 252 | self_play_episodes += 1 253 | 254 | pred_values_lst[i] = [] 255 | search_values_lst[i] = [] 256 | # end_tags[i] = False 257 | eps_steps_lst[i] = 0 258 | eps_reward_lst[i] = 0 259 | eps_ori_reward_lst[i] = 0 260 | visit_entropies_lst[i] = 0 261 | 262 | # stack obs for model inference 263 | stack_obs = [game_history.step_obs() for game_history in game_histories] 264 | if self.config.image_based: 265 | stack_obs = prepare_observation_lst(stack_obs) 266 | stack_obs = torch.from_numpy(stack_obs).to(self.device).float() / 255.0 267 | else: 268 | stack_obs = [game_history.step_obs() for game_history in game_histories] 269 | stack_obs = torch.from_numpy(np.array(stack_obs)).to(self.device) 270 | 271 | if self.config.amp_type == 'torch_amp': 272 | with autocast(): 273 | network_output = model.initial_inference(stack_obs.float()) 274 | else: 275 | network_output = model.initial_inference(stack_obs.float()) 276 | hidden_state_roots = network_output.hidden_state 277 | reward_hidden_roots = network_output.reward_hidden 278 | value_prefix_pool = network_output.value_prefix 279 | policy_logits_pool = network_output.policy_logits.tolist() 280 | 281 | roots = cytree.Roots(env_nums, self.config.action_space_size, self.config.num_simulations) 282 | noises = [np.random.dirichlet([self.config.root_dirichlet_alpha] * self.config.action_space_size).astype(np.float32).tolist() for _ in range(env_nums)] 283 | roots.prepare(self.config.root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool) 284 | # do MCTS for a policy 285 | MCTS(self.config).search(roots, model, hidden_state_roots, reward_hidden_roots) 286 | 287 | roots_distributions = roots.get_distributions() 288 | roots_values = roots.get_values() 289 | for i in range(env_nums): 290 | deterministic = False 291 | if start_training: 292 | distributions, value, temperature, env = roots_distributions[i], roots_values[i], _temperature[i], envs[i] 293 | else: 294 | # before starting training, use random policy 295 | value, temperature, env = roots_values[i], _temperature[i], envs[i] 296 | distributions = np.ones(self.config.action_space_size) 297 | 298 | action, visit_entropy = select_action(distributions, temperature=temperature, deterministic=deterministic) 299 | obs, ori_reward, done, info = env.step(action) 300 | # clip the reward 301 | if self.config.clip_reward: 302 | clip_reward = np.sign(ori_reward) 303 | else: 304 | clip_reward = ori_reward 305 | 306 | # store data 307 | game_histories[i].store_search_stats(distributions, value) 308 | game_histories[i].append(action, obs, clip_reward) 309 | 310 | eps_reward_lst[i] += clip_reward 311 | eps_ori_reward_lst[i] += ori_reward 312 | dones[i] = done 313 | visit_entropies_lst[i] += visit_entropy 314 | 315 | eps_steps_lst[i] += 1 316 | total_transitions += 1 317 | 318 | if self.config.use_priority and not self.config.use_max_priority and start_training: 319 | pred_values_lst[i].append(network_output.value[i].item()) 320 | search_values_lst[i].append(roots_values[i]) 321 | 322 | # fresh stack windows 323 | del stack_obs_windows[i][0] 324 | stack_obs_windows[i].append(obs) 325 | 326 | # if game history is full; 327 | # we will save a game history if it is the end of the game or the next game history is finished. 328 | if game_histories[i].is_full(): 329 | # pad over last block trajectory 330 | if last_game_histories[i] is not None: 331 | self.put_last_trajectory(i, last_game_histories, last_game_priorities, game_histories) 332 | 333 | # calculate priority 334 | priorities = self.get_priorities(i, pred_values_lst, search_values_lst) 335 | 336 | # save block trajectory 337 | last_game_histories[i] = game_histories[i] 338 | last_game_priorities[i] = priorities 339 | 340 | # new block trajectory 341 | game_histories[i] = GameHistory(envs[i].env.action_space, max_length=self.config.history_length, 342 | config=self.config) 343 | game_histories[i].init(stack_obs_windows[i]) 344 | 345 | for i in range(env_nums): 346 | env = envs[i] 347 | env.close() 348 | 349 | if dones[i]: 350 | # pad over last block trajectory 351 | if last_game_histories[i] is not None: 352 | self.put_last_trajectory(i, last_game_histories, last_game_priorities, game_histories) 353 | 354 | # store current block trajectory 355 | priorities = self.get_priorities(i, pred_values_lst, search_values_lst) 356 | game_histories[i].game_over() 357 | 358 | self.put((game_histories[i], priorities)) 359 | self.free() 360 | 361 | self_play_rewards_max = max(self_play_rewards_max, eps_reward_lst[i]) 362 | self_play_moves_max = max(self_play_moves_max, eps_steps_lst[i]) 363 | self_play_rewards += eps_reward_lst[i] 364 | self_play_ori_rewards += eps_ori_reward_lst[i] 365 | self_play_visit_entropy.append(visit_entropies_lst[i] / eps_steps_lst[i]) 366 | self_play_moves += eps_steps_lst[i] 367 | self_play_episodes += 1 368 | else: 369 | # if the final game history is not finished, we will not save this data. 370 | total_transitions -= len(game_histories[i]) 371 | 372 | # logs 373 | visit_entropies = np.array(self_play_visit_entropy).mean() 374 | visit_entropies /= max_visit_entropy 375 | 376 | if self_play_episodes > 0: 377 | log_self_play_moves = self_play_moves / self_play_episodes 378 | log_self_play_rewards = self_play_rewards / self_play_episodes 379 | log_self_play_ori_rewards = self_play_ori_rewards / self_play_episodes 380 | else: 381 | log_self_play_moves = 0 382 | log_self_play_rewards = 0 383 | log_self_play_ori_rewards = 0 384 | 385 | other_dist = {} 386 | # send logs 387 | self.storage.set_data_worker_logs.remote(log_self_play_moves, self_play_moves_max, 388 | log_self_play_ori_rewards, log_self_play_rewards, 389 | self_play_rewards_max, _temperature.mean(), 390 | visit_entropies, 0, 391 | other_dist) 392 | -------------------------------------------------------------------------------- /config/atari/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | from core.model import BaseNet, renormalize 8 | 9 | 10 | def mlp( 11 | input_size, 12 | layer_sizes, 13 | output_size, 14 | output_activation=nn.Identity, 15 | activation=nn.ReLU, 16 | momentum=0.1, 17 | init_zero=False, 18 | ): 19 | """MLP layers 20 | Parameters 21 | ---------- 22 | input_size: int 23 | dim of inputs 24 | layer_sizes: list 25 | dim of hidden layers 26 | output_size: int 27 | dim of outputs 28 | init_zero: bool 29 | zero initialization for the last layer (including w and b). 30 | This can provide stable zero outputs in the beginning. 31 | """ 32 | sizes = [input_size] + layer_sizes + [output_size] 33 | layers = [] 34 | for i in range(len(sizes) - 1): 35 | if i < len(sizes) - 2: 36 | act = activation 37 | layers += [nn.Linear(sizes[i], sizes[i + 1]), 38 | nn.BatchNorm1d(sizes[i + 1], momentum=momentum), 39 | act()] 40 | else: 41 | act = output_activation 42 | layers += [nn.Linear(sizes[i], sizes[i + 1]), 43 | act()] 44 | 45 | if init_zero: 46 | layers[-2].weight.data.fill_(0) 47 | layers[-2].bias.data.fill_(0) 48 | 49 | return nn.Sequential(*layers) 50 | 51 | 52 | def conv3x3(in_channels, out_channels, stride=1): 53 | return nn.Conv2d( 54 | in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False 55 | ) 56 | 57 | 58 | # Residual block 59 | class ResidualBlock(nn.Module): 60 | def __init__(self, in_channels, out_channels, downsample=None, stride=1, momentum=0.1): 61 | super().__init__() 62 | self.conv1 = conv3x3(in_channels, out_channels, stride) 63 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum) 64 | self.conv2 = conv3x3(out_channels, out_channels) 65 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) 66 | self.downsample = downsample 67 | 68 | def forward(self, x): 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = nn.functional.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | identity = self.downsample(x) 80 | 81 | out += identity 82 | out = nn.functional.relu(out) 83 | return out 84 | 85 | 86 | # Downsample observations before representation network (See paper appendix Network Architecture) 87 | class DownSample(nn.Module): 88 | def __init__(self, in_channels, out_channels, momentum=0.1): 89 | super().__init__() 90 | self.conv1 = nn.Conv2d( 91 | in_channels, 92 | out_channels // 2, 93 | kernel_size=3, 94 | stride=2, 95 | padding=1, 96 | bias=False, 97 | ) 98 | self.bn1 = nn.BatchNorm2d(out_channels // 2, momentum=momentum) 99 | self.resblocks1 = nn.ModuleList( 100 | [ResidualBlock(out_channels // 2, out_channels // 2, momentum=momentum) for _ in range(1)] 101 | ) 102 | self.conv2 = nn.Conv2d( 103 | out_channels // 2, 104 | out_channels, 105 | kernel_size=3, 106 | stride=2, 107 | padding=1, 108 | bias=False, 109 | ) 110 | self.downsample_block = ResidualBlock(out_channels // 2, out_channels, momentum=momentum, stride=2, downsample=self.conv2) 111 | self.resblocks2 = nn.ModuleList( 112 | [ResidualBlock(out_channels, out_channels, momentum=momentum) for _ in range(1)] 113 | ) 114 | self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 115 | self.resblocks3 = nn.ModuleList( 116 | [ResidualBlock(out_channels, out_channels, momentum=momentum) for _ in range(1)] 117 | ) 118 | self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) 119 | 120 | def forward(self, x): 121 | x = self.conv1(x) 122 | x = self.bn1(x) 123 | x = nn.functional.relu(x) 124 | for block in self.resblocks1: 125 | x = block(x) 126 | x = self.downsample_block(x) 127 | for block in self.resblocks2: 128 | x = block(x) 129 | x = self.pooling1(x) 130 | for block in self.resblocks3: 131 | x = block(x) 132 | x = self.pooling2(x) 133 | return x 134 | 135 | 136 | # Encode the observations into hidden states 137 | class RepresentationNetwork(nn.Module): 138 | def __init__( 139 | self, 140 | observation_shape, 141 | num_blocks, 142 | num_channels, 143 | downsample, 144 | momentum=0.1, 145 | ): 146 | """Representation network 147 | Parameters 148 | ---------- 149 | observation_shape: tuple or list 150 | shape of observations: [C, W, H] 151 | num_blocks: int 152 | number of res blocks 153 | num_channels: int 154 | channels of hidden states 155 | downsample: bool 156 | True -> do downsampling for observations. (For board games, do not need) 157 | """ 158 | super().__init__() 159 | self.downsample = downsample 160 | if self.downsample: 161 | self.downsample_net = DownSample( 162 | observation_shape[0], 163 | num_channels, 164 | ) 165 | self.conv = conv3x3( 166 | observation_shape[0], 167 | num_channels, 168 | ) 169 | self.bn = nn.BatchNorm2d(num_channels, momentum=momentum) 170 | self.resblocks = nn.ModuleList( 171 | [ResidualBlock(num_channels, num_channels, momentum=momentum) for _ in range(num_blocks)] 172 | ) 173 | 174 | def forward(self, x): 175 | if self.downsample: 176 | x = self.downsample_net(x) 177 | else: 178 | x = self.conv(x) 179 | x = self.bn(x) 180 | x = nn.functional.relu(x) 181 | 182 | for block in self.resblocks: 183 | x = block(x) 184 | return x 185 | 186 | def get_param_mean(self): 187 | mean = [] 188 | for name, param in self.named_parameters(): 189 | mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist() 190 | mean = sum(mean) / len(mean) 191 | return mean 192 | 193 | 194 | # Predict next hidden states given current states and actions 195 | class DynamicsNetwork(nn.Module): 196 | def __init__( 197 | self, 198 | num_blocks, 199 | num_channels, 200 | reduced_channels_reward, 201 | fc_reward_layers, 202 | full_support_size, 203 | block_output_size_reward, 204 | lstm_hidden_size=64, 205 | momentum=0.1, 206 | init_zero=False, 207 | ): 208 | """Dynamics network 209 | Parameters 210 | ---------- 211 | num_blocks: int 212 | number of res blocks 213 | num_channels: int 214 | channels of hidden states 215 | fc_reward_layers: list 216 | hidden layers of the reward prediction head (MLP head) 217 | full_support_size: int 218 | dim of reward output 219 | block_output_size_reward: int 220 | dim of flatten hidden states 221 | lstm_hidden_size: int 222 | dim of lstm hidden 223 | init_zero: bool 224 | True -> zero initialization for the last layer of reward mlp 225 | """ 226 | super().__init__() 227 | self.num_channels = num_channels 228 | self.lstm_hidden_size = lstm_hidden_size 229 | 230 | self.conv = conv3x3(num_channels, num_channels - 1) 231 | self.bn = nn.BatchNorm2d(num_channels - 1, momentum=momentum) 232 | self.resblocks = nn.ModuleList( 233 | [ResidualBlock(num_channels - 1, num_channels - 1, momentum=momentum) for _ in range(num_blocks)] 234 | ) 235 | 236 | self.reward_resblocks = nn.ModuleList( 237 | [ResidualBlock(num_channels - 1, num_channels - 1, momentum=momentum) for _ in range(num_blocks)] 238 | ) 239 | 240 | self.conv1x1_reward = nn.Conv2d(num_channels - 1, reduced_channels_reward, 1) 241 | self.bn_reward = nn.BatchNorm2d(reduced_channels_reward, momentum=momentum) 242 | self.block_output_size_reward = block_output_size_reward 243 | self.lstm = nn.LSTM(input_size=self.block_output_size_reward, hidden_size=self.lstm_hidden_size) 244 | self.bn_value_prefix = nn.BatchNorm1d(self.lstm_hidden_size, momentum=momentum) 245 | self.fc = mlp(self.lstm_hidden_size, fc_reward_layers, full_support_size, init_zero=init_zero, momentum=momentum) 246 | 247 | def forward(self, x, reward_hidden): 248 | state = x[:,:-1,:,:] 249 | x = self.conv(x) 250 | x = self.bn(x) 251 | 252 | x += state 253 | x = nn.functional.relu(x) 254 | 255 | for block in self.resblocks: 256 | x = block(x) 257 | state = x 258 | 259 | x = self.conv1x1_reward(x) 260 | x = self.bn_reward(x) 261 | x = nn.functional.relu(x) 262 | 263 | x = x.view(-1, self.block_output_size_reward).unsqueeze(0) 264 | value_prefix, reward_hidden = self.lstm(x, reward_hidden) 265 | value_prefix = value_prefix.squeeze(0) 266 | value_prefix = self.bn_value_prefix(value_prefix) 267 | value_prefix = nn.functional.relu(value_prefix) 268 | value_prefix = self.fc(value_prefix) 269 | 270 | return state, reward_hidden, value_prefix 271 | 272 | def get_dynamic_mean(self): 273 | dynamic_mean = np.abs(self.conv.weight.detach().cpu().numpy().reshape(-1)).tolist() 274 | 275 | for block in self.resblocks: 276 | for name, param in block.named_parameters(): 277 | dynamic_mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist() 278 | dynamic_mean = sum(dynamic_mean) / len(dynamic_mean) 279 | return dynamic_mean 280 | 281 | def get_reward_mean(self): 282 | reward_w_dist = self.conv1x1_reward.weight.detach().cpu().numpy().reshape(-1) 283 | 284 | for name, param in self.fc.named_parameters(): 285 | temp_weights = param.detach().cpu().numpy().reshape(-1) 286 | reward_w_dist = np.concatenate((reward_w_dist, temp_weights)) 287 | reward_mean = np.abs(reward_w_dist).mean() 288 | return reward_w_dist, reward_mean 289 | 290 | 291 | # predict the value and policy given hidden states 292 | class PredictionNetwork(nn.Module): 293 | def __init__( 294 | self, 295 | action_space_size, 296 | num_blocks, 297 | num_channels, 298 | reduced_channels_value, 299 | reduced_channels_policy, 300 | fc_value_layers, 301 | fc_policy_layers, 302 | full_support_size, 303 | block_output_size_value, 304 | block_output_size_policy, 305 | momentum=0.1, 306 | init_zero=False, 307 | ): 308 | """Prediction network 309 | Parameters 310 | ---------- 311 | action_space_size: int 312 | action space 313 | num_blocks: int 314 | number of res blocks 315 | num_channels: int 316 | channels of hidden states 317 | reduced_channels_value: int 318 | channels of value head 319 | reduced_channels_policy: int 320 | channels of policy head 321 | fc_value_layers: list 322 | hidden layers of the value prediction head (MLP head) 323 | fc_policy_layers: list 324 | hidden layers of the policy prediction head (MLP head) 325 | full_support_size: int 326 | dim of value output 327 | block_output_size_value: int 328 | dim of flatten hidden states 329 | block_output_size_policy: int 330 | dim of flatten hidden states 331 | init_zero: bool 332 | True -> zero initialization for the last layer of value/policy mlp 333 | """ 334 | super().__init__() 335 | self.resblocks = nn.ModuleList( 336 | [ResidualBlock(num_channels, num_channels, momentum=momentum) for _ in range(num_blocks)] 337 | ) 338 | 339 | self.conv1x1_value = nn.Conv2d(num_channels, reduced_channels_value, 1) 340 | self.conv1x1_policy = nn.Conv2d(num_channels, reduced_channels_policy, 1) 341 | self.bn_value = nn.BatchNorm2d(reduced_channels_value, momentum=momentum) 342 | self.bn_policy = nn.BatchNorm2d(reduced_channels_policy, momentum=momentum) 343 | self.block_output_size_value = block_output_size_value 344 | self.block_output_size_policy = block_output_size_policy 345 | self.fc_value = mlp(self.block_output_size_value, fc_value_layers, full_support_size, init_zero=init_zero, momentum=momentum) 346 | self.fc_policy = mlp(self.block_output_size_policy, fc_policy_layers, action_space_size, init_zero=init_zero, momentum=momentum) 347 | 348 | def forward(self, x): 349 | for block in self.resblocks: 350 | x = block(x) 351 | value = self.conv1x1_value(x) 352 | value = self.bn_value(value) 353 | value = nn.functional.relu(value) 354 | 355 | policy = self.conv1x1_policy(x) 356 | policy = self.bn_policy(policy) 357 | policy = nn.functional.relu(policy) 358 | 359 | value = value.view(-1, self.block_output_size_value) 360 | policy = policy.view(-1, self.block_output_size_policy) 361 | value = self.fc_value(value) 362 | policy = self.fc_policy(policy) 363 | return policy, value 364 | 365 | 366 | class EfficientZeroNet(BaseNet): 367 | def __init__( 368 | self, 369 | observation_shape, 370 | action_space_size, 371 | num_blocks, 372 | num_channels, 373 | reduced_channels_reward, 374 | reduced_channels_value, 375 | reduced_channels_policy, 376 | fc_reward_layers, 377 | fc_value_layers, 378 | fc_policy_layers, 379 | reward_support_size, 380 | value_support_size, 381 | downsample, 382 | inverse_value_transform, 383 | inverse_reward_transform, 384 | lstm_hidden_size, 385 | bn_mt=0.1, 386 | proj_hid=256, 387 | proj_out=256, 388 | pred_hid=64, 389 | pred_out=256, 390 | init_zero=False, 391 | state_norm=False 392 | ): 393 | """EfficientZero network 394 | Parameters 395 | ---------- 396 | observation_shape: tuple or list 397 | shape of observations: [C, W, H] 398 | action_space_size: int 399 | action space 400 | num_blocks: int 401 | number of res blocks 402 | num_channels: int 403 | channels of hidden states 404 | reduced_channels_reward: int 405 | channels of reward head 406 | reduced_channels_value: int 407 | channels of value head 408 | reduced_channels_policy: int 409 | channels of policy head 410 | fc_reward_layers: list 411 | hidden layers of the reward prediction head (MLP head) 412 | fc_value_layers: list 413 | hidden layers of the value prediction head (MLP head) 414 | fc_policy_layers: list 415 | hidden layers of the policy prediction head (MLP head) 416 | reward_support_size: int 417 | dim of reward output 418 | value_support_size: int 419 | dim of value output 420 | downsample: bool 421 | True -> do downsampling for observations. (For board games, do not need) 422 | inverse_value_transform: Any 423 | A function that maps value supports into value scalars 424 | inverse_reward_transform: Any 425 | A function that maps reward supports into value scalars 426 | lstm_hidden_size: int 427 | dim of lstm hidden 428 | bn_mt: float 429 | Momentum of BN 430 | proj_hid: int 431 | dim of projection hidden layer 432 | proj_out: int 433 | dim of projection output layer 434 | pred_hid: int 435 | dim of projection head (prediction) hidden layer 436 | pred_out: int 437 | dim of projection head (prediction) output layer 438 | init_zero: bool 439 | True -> zero initialization for the last layer of value/policy mlp 440 | state_norm: bool 441 | True -> normalization for hidden states 442 | """ 443 | super(EfficientZeroNet, self).__init__(inverse_value_transform, inverse_reward_transform, lstm_hidden_size) 444 | self.proj_hid = proj_hid 445 | self.proj_out = proj_out 446 | self.pred_hid = pred_hid 447 | self.pred_out = pred_out 448 | self.init_zero = init_zero 449 | self.state_norm = state_norm 450 | 451 | self.action_space_size = action_space_size 452 | block_output_size_reward = ( 453 | ( 454 | reduced_channels_reward 455 | * math.ceil(observation_shape[1] / 16) 456 | * math.ceil(observation_shape[2] / 16) 457 | ) 458 | if downsample 459 | else (reduced_channels_reward * observation_shape[1] * observation_shape[2]) 460 | ) 461 | 462 | block_output_size_value = ( 463 | ( 464 | reduced_channels_value 465 | * math.ceil(observation_shape[1] / 16) 466 | * math.ceil(observation_shape[2] / 16) 467 | ) 468 | if downsample 469 | else (reduced_channels_value * observation_shape[1] * observation_shape[2]) 470 | ) 471 | 472 | block_output_size_policy = ( 473 | ( 474 | reduced_channels_policy 475 | * math.ceil(observation_shape[1] / 16) 476 | * math.ceil(observation_shape[2] / 16) 477 | ) 478 | if downsample 479 | else (reduced_channels_policy * observation_shape[1] * observation_shape[2]) 480 | ) 481 | 482 | self.representation_network = RepresentationNetwork( 483 | observation_shape, 484 | num_blocks, 485 | num_channels, 486 | downsample, 487 | momentum=bn_mt, 488 | ) 489 | 490 | self.dynamics_network = DynamicsNetwork( 491 | num_blocks, 492 | num_channels + 1, 493 | reduced_channels_reward, 494 | fc_reward_layers, 495 | reward_support_size, 496 | block_output_size_reward, 497 | lstm_hidden_size=lstm_hidden_size, 498 | momentum=bn_mt, 499 | init_zero=self.init_zero, 500 | ) 501 | 502 | self.prediction_network = PredictionNetwork( 503 | action_space_size, 504 | num_blocks, 505 | num_channels, 506 | reduced_channels_value, 507 | reduced_channels_policy, 508 | fc_value_layers, 509 | fc_policy_layers, 510 | value_support_size, 511 | block_output_size_value, 512 | block_output_size_policy, 513 | momentum=bn_mt, 514 | init_zero=self.init_zero, 515 | ) 516 | 517 | # projection 518 | in_dim = num_channels * math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) 519 | self.porjection_in_dim = in_dim 520 | self.projection = nn.Sequential( 521 | nn.Linear(self.porjection_in_dim, self.proj_hid), 522 | nn.BatchNorm1d(self.proj_hid), 523 | nn.ReLU(), 524 | nn.Linear(self.proj_hid, self.proj_hid), 525 | nn.BatchNorm1d(self.proj_hid), 526 | nn.ReLU(), 527 | nn.Linear(self.proj_hid, self.proj_out), 528 | nn.BatchNorm1d(self.proj_out) 529 | ) 530 | self.projection_head = nn.Sequential( 531 | nn.Linear(self.proj_out, self.pred_hid), 532 | nn.BatchNorm1d(self.pred_hid), 533 | nn.ReLU(), 534 | nn.Linear(self.pred_hid, self.pred_out), 535 | ) 536 | 537 | def prediction(self, encoded_state): 538 | policy, value = self.prediction_network(encoded_state) 539 | return policy, value 540 | 541 | def representation(self, observation): 542 | encoded_state = self.representation_network(observation) 543 | if not self.state_norm: 544 | return encoded_state 545 | else: 546 | encoded_state_normalized = renormalize(encoded_state) 547 | return encoded_state_normalized 548 | 549 | def dynamics(self, encoded_state, reward_hidden, action): 550 | # Stack encoded_state with a game specific one hot encoded action 551 | action_one_hot = ( 552 | torch.ones( 553 | ( 554 | encoded_state.shape[0], 555 | 1, 556 | encoded_state.shape[2], 557 | encoded_state.shape[3], 558 | ) 559 | ) 560 | .to(action.device) 561 | .float() 562 | ) 563 | action_one_hot = ( 564 | action[:, :, None, None] * action_one_hot / self.action_space_size 565 | ) 566 | x = torch.cat((encoded_state, action_one_hot), dim=1) 567 | next_encoded_state, reward_hidden, value_prefix = self.dynamics_network(x, reward_hidden) 568 | 569 | if not self.state_norm: 570 | return next_encoded_state, reward_hidden, value_prefix 571 | else: 572 | next_encoded_state_normalized = renormalize(next_encoded_state) 573 | return next_encoded_state_normalized, reward_hidden, value_prefix 574 | 575 | def get_params_mean(self): 576 | representation_mean = self.representation_network.get_param_mean() 577 | dynamic_mean = self.dynamics_network.get_dynamic_mean() 578 | reward_w_dist, reward_mean = self.dynamics_network.get_reward_mean() 579 | 580 | return reward_w_dist, representation_mean, dynamic_mean, reward_mean 581 | 582 | def project(self, hidden_state, with_grad=True): 583 | # only the branch of proj + pred can share the gradients 584 | hidden_state = hidden_state.view(-1, self.porjection_in_dim) 585 | proj = self.projection(hidden_state) 586 | 587 | # with grad, use proj_head 588 | if with_grad: 589 | proj = self.projection_head(proj) 590 | return proj 591 | else: 592 | return proj.detach() 593 | 594 | -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ray 3 | import time 4 | import torch 5 | 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | from torch.nn import L1Loss 11 | from torch.cuda.amp import autocast as autocast 12 | from torch.cuda.amp import GradScaler as GradScaler 13 | from core.log import _log 14 | from core.test import _test 15 | from core.replay_buffer import ReplayBuffer 16 | from core.storage import SharedStorage, QueueStorage 17 | from core.selfplay_worker import DataWorker 18 | from core.reanalyze_worker import BatchWorker_GPU, BatchWorker_CPU 19 | 20 | 21 | def consist_loss_func(f1, f2): 22 | """Consistency loss function: similarity loss 23 | Parameters 24 | """ 25 | f1 = F.normalize(f1, p=2., dim=-1, eps=1e-5) 26 | f2 = F.normalize(f2, p=2., dim=-1, eps=1e-5) 27 | return -(f1 * f2).sum(dim=1) 28 | 29 | 30 | def adjust_lr(config, optimizer, step_count): 31 | # adjust learning rate, step lr every lr_decay_steps 32 | if step_count < config.lr_warm_step: 33 | lr = config.lr_init * step_count / config.lr_warm_step 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = lr 36 | else: 37 | lr = config.lr_init * config.lr_decay_rate ** ((step_count - config.lr_warm_step) // config.lr_decay_steps) 38 | for param_group in optimizer.param_groups: 39 | param_group['lr'] = lr 40 | 41 | return lr 42 | 43 | 44 | def update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_result=False): 45 | """update models given a batch data 46 | Parameters 47 | ---------- 48 | model: Any 49 | EfficientZero models 50 | batch: Any 51 | a batch data inlcudes [inputs_batch, targets_batch] 52 | replay_buffer: Any 53 | replay buffer 54 | scaler: Any 55 | scaler for torch amp 56 | vis_result: bool 57 | True -> log some visualization data in tensorboard (some distributions, values, etc) 58 | """ 59 | inputs_batch, targets_batch = batch 60 | obs_batch_ori, action_batch, mask_batch, indices, weights_lst, make_time = inputs_batch 61 | target_value_prefix, target_value, target_policy = targets_batch 62 | 63 | # [:, 0: config.stacked_observations * 3,:,:] 64 | # obs_batch_ori is the original observations in a batch 65 | # obs_batch is the observation for hat s_t (predicted hidden states from dynamics function) 66 | # obs_target_batch is the observations for s_t (hidden states from representation function) 67 | # to save GPU memory usage, obs_batch_ori contains (stack + unroll steps) frames 68 | obs_batch_ori = torch.from_numpy(obs_batch_ori).to(config.device).float() / 255.0 69 | obs_batch = obs_batch_ori[:, 0: config.stacked_observations * config.image_channel, :, :] 70 | obs_target_batch = obs_batch_ori[:, config.image_channel:, :, :] 71 | 72 | # do augmentations 73 | if config.use_augmentation: 74 | obs_batch = config.transform(obs_batch) 75 | obs_target_batch = config.transform(obs_target_batch) 76 | 77 | # use GPU tensor 78 | action_batch = torch.from_numpy(action_batch).to(config.device).unsqueeze(-1).long() 79 | mask_batch = torch.from_numpy(mask_batch).to(config.device).float() 80 | target_value_prefix = torch.from_numpy(target_value_prefix).to(config.device).float() 81 | target_value = torch.from_numpy(target_value).to(config.device).float() 82 | target_policy = torch.from_numpy(target_policy).to(config.device).float() 83 | weights = torch.from_numpy(weights_lst).to(config.device).float() 84 | 85 | batch_size = obs_batch.size(0) 86 | assert batch_size == config.batch_size == target_value_prefix.size(0) 87 | metric_loss = torch.nn.L1Loss() 88 | 89 | # some logs preparation 90 | other_log = {} 91 | other_dist = {} 92 | 93 | other_loss = { 94 | 'l1': -1, 95 | 'l1_1': -1, 96 | 'l1_-1': -1, 97 | 'l1_0': -1, 98 | } 99 | for i in range(config.num_unroll_steps): 100 | key = 'unroll_' + str(i + 1) + '_l1' 101 | other_loss[key] = -1 102 | other_loss[key + '_1'] = -1 103 | other_loss[key + '_-1'] = -1 104 | other_loss[key + '_0'] = -1 105 | 106 | # transform targets to categorical representation 107 | transformed_target_value_prefix = config.scalar_transform(target_value_prefix) 108 | target_value_prefix_phi = config.reward_phi(transformed_target_value_prefix) 109 | 110 | transformed_target_value = config.scalar_transform(target_value) 111 | target_value_phi = config.value_phi(transformed_target_value) 112 | 113 | if config.amp_type == 'torch_amp': 114 | with autocast(): 115 | value, _, policy_logits, hidden_state, reward_hidden = model.initial_inference(obs_batch) 116 | else: 117 | value, _, policy_logits, hidden_state, reward_hidden = model.initial_inference(obs_batch) 118 | scaled_value = config.inverse_value_transform(value) 119 | 120 | if vis_result: 121 | state_lst = hidden_state.detach().cpu().numpy() 122 | 123 | predicted_value_prefixs = [] 124 | # Note: Following line is just for logging. 125 | if vis_result: 126 | predicted_values, predicted_policies = scaled_value.detach().cpu(), torch.softmax(policy_logits, dim=1).detach().cpu() 127 | 128 | # calculate the new priorities for each transition 129 | value_priority = L1Loss(reduction='none')(scaled_value.squeeze(-1), target_value[:, 0]) 130 | value_priority = value_priority.data.cpu().numpy() + config.prioritized_replay_eps 131 | 132 | # loss of the first step 133 | value_loss = config.scalar_value_loss(value, target_value_phi[:, 0]) 134 | policy_loss = -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, 0]).sum(1) 135 | value_prefix_loss = torch.zeros(batch_size, device=config.device) 136 | consistency_loss = torch.zeros(batch_size, device=config.device) 137 | 138 | target_value_prefix_cpu = target_value_prefix.detach().cpu() 139 | gradient_scale = 1 / config.num_unroll_steps 140 | # loss of the unrolled steps 141 | if config.amp_type == 'torch_amp': 142 | # use torch amp 143 | with autocast(): 144 | for step_i in range(config.num_unroll_steps): 145 | # unroll with the dynamics function 146 | value, value_prefix, policy_logits, hidden_state, reward_hidden = model.recurrent_inference(hidden_state, reward_hidden, action_batch[:, step_i]) 147 | 148 | beg_index = config.image_channel * step_i 149 | end_index = config.image_channel * (step_i + config.stacked_observations) 150 | 151 | # consistency loss 152 | if config.consistency_coeff > 0: 153 | # obtain the oracle hidden states from representation function 154 | _, _, _, presentation_state, _ = model.initial_inference(obs_target_batch[:, beg_index:end_index, :, :]) 155 | # no grad for the presentation_state branch 156 | dynamic_proj = model.project(hidden_state, with_grad=True) 157 | observation_proj = model.project(presentation_state, with_grad=False) 158 | temp_loss = consist_loss_func(dynamic_proj, observation_proj) * mask_batch[:, step_i] 159 | 160 | other_loss['consist_' + str(step_i + 1)] = temp_loss.mean().item() 161 | consistency_loss += temp_loss 162 | 163 | policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) * mask_batch[:, step_i] 164 | value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) * mask_batch[:, step_i] 165 | value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) * mask_batch[:, step_i] 166 | # Follow MuZero, set half gradient 167 | hidden_state.register_hook(lambda grad: grad * 0.5) 168 | 169 | # reset hidden states 170 | if (step_i + 1) % config.lstm_horizon_len == 0: 171 | reward_hidden = (torch.zeros(1, config.batch_size, config.lstm_hidden_size).to(config.device), 172 | torch.zeros(1, config.batch_size, config.lstm_hidden_size).to(config.device)) 173 | 174 | if vis_result: 175 | scaled_value_prefixs = config.inverse_reward_transform(value_prefix.detach()) 176 | scaled_value_prefixs_cpu = scaled_value_prefixs.detach().cpu() 177 | 178 | predicted_values = torch.cat((predicted_values, config.inverse_value_transform(value).detach().cpu())) 179 | predicted_value_prefixs.append(scaled_value_prefixs_cpu) 180 | predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) 181 | state_lst = np.concatenate((state_lst, hidden_state.detach().cpu().numpy())) 182 | 183 | key = 'unroll_' + str(step_i + 1) + '_l1' 184 | 185 | value_prefix_indices_0 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == 0) 186 | value_prefix_indices_n1 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == -1) 187 | value_prefix_indices_1 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == 1) 188 | 189 | target_value_prefix_base = target_value_prefix_cpu[:, step_i].reshape(-1).unsqueeze(-1) 190 | 191 | other_loss[key] = metric_loss(scaled_value_prefixs_cpu, target_value_prefix_base) 192 | if value_prefix_indices_1.any(): 193 | other_loss[key + '_1'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_1], target_value_prefix_base[value_prefix_indices_1]) 194 | if value_prefix_indices_n1.any(): 195 | other_loss[key + '_-1'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_n1], target_value_prefix_base[value_prefix_indices_n1]) 196 | if value_prefix_indices_0.any(): 197 | other_loss[key + '_0'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_0], target_value_prefix_base[value_prefix_indices_0]) 198 | else: 199 | for step_i in range(config.num_unroll_steps): 200 | # unroll with the dynamics function 201 | value, value_prefix, policy_logits, hidden_state, reward_hidden = model.recurrent_inference(hidden_state, reward_hidden, action_batch[:, step_i]) 202 | 203 | beg_index = config.image_channel * step_i 204 | end_index = config.image_channel * (step_i + config.stacked_observations) 205 | 206 | # consistency loss 207 | if config.consistency_coeff > 0: 208 | # obtain the oracle hidden states from representation function 209 | _, _, _, presentation_state, _ = model.initial_inference(obs_target_batch[:, beg_index:end_index, :, :]) 210 | # no grad for the presentation_state branch 211 | dynamic_proj = model.project(hidden_state, with_grad=True) 212 | observation_proj = model.project(presentation_state, with_grad=False) 213 | temp_loss = consist_loss_func(dynamic_proj, observation_proj) * mask_batch[:, step_i] 214 | 215 | other_loss['consist_' + str(step_i + 1)] = temp_loss.mean().item() 216 | consistency_loss += temp_loss 217 | 218 | policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) * mask_batch[:, step_i] 219 | value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) * mask_batch[:, step_i] 220 | value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) * mask_batch[:, step_i] 221 | # Follow MuZero, set half gradient 222 | hidden_state.register_hook(lambda grad: grad * 0.5) 223 | 224 | # reset hidden states 225 | if (step_i + 1) % config.lstm_horizon_len == 0: 226 | reward_hidden = (torch.zeros(1, config.batch_size, config.lstm_hidden_size).to(config.device), 227 | torch.zeros(1, config.batch_size, config.lstm_hidden_size).to(config.device)) 228 | 229 | if vis_result: 230 | scaled_value_prefixs = config.inverse_reward_transform(value_prefix.detach()) 231 | scaled_value_prefixs_cpu = scaled_value_prefixs.detach().cpu() 232 | 233 | predicted_values = torch.cat((predicted_values, config.inverse_value_transform(value).detach().cpu())) 234 | predicted_value_prefixs.append(scaled_value_prefixs_cpu) 235 | predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) 236 | state_lst = np.concatenate((state_lst, hidden_state.detach().cpu().numpy())) 237 | 238 | key = 'unroll_' + str(step_i + 1) + '_l1' 239 | 240 | value_prefix_indices_0 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == 0) 241 | value_prefix_indices_n1 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == -1) 242 | value_prefix_indices_1 = (target_value_prefix_cpu[:, step_i].unsqueeze(-1) == 1) 243 | 244 | target_value_prefix_base = target_value_prefix_cpu[:, step_i].reshape(-1).unsqueeze(-1) 245 | 246 | other_loss[key] = metric_loss(scaled_value_prefixs_cpu, target_value_prefix_base) 247 | if value_prefix_indices_1.any(): 248 | other_loss[key + '_1'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_1], target_value_prefix_base[value_prefix_indices_1]) 249 | if value_prefix_indices_n1.any(): 250 | other_loss[key + '_-1'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_n1], target_value_prefix_base[value_prefix_indices_n1]) 251 | if value_prefix_indices_0.any(): 252 | other_loss[key + '_0'] = metric_loss(scaled_value_prefixs_cpu[value_prefix_indices_0], target_value_prefix_base[value_prefix_indices_0]) 253 | # ---------------------------------------------------------------------------------- 254 | # weighted loss with masks (some invalid states which are out of trajectory.) 255 | loss = (config.consistency_coeff * consistency_loss + config.policy_loss_coeff * policy_loss + 256 | config.value_loss_coeff * value_loss + config.reward_loss_coeff * value_prefix_loss) 257 | weighted_loss = (weights * loss).mean() 258 | 259 | # backward 260 | parameters = model.parameters() 261 | if config.amp_type == 'torch_amp': 262 | with autocast(): 263 | total_loss = weighted_loss 264 | total_loss.register_hook(lambda grad: grad * gradient_scale) 265 | else: 266 | total_loss = weighted_loss 267 | total_loss.register_hook(lambda grad: grad * gradient_scale) 268 | optimizer.zero_grad() 269 | 270 | if config.amp_type == 'none': 271 | total_loss.backward() 272 | elif config.amp_type == 'torch_amp': 273 | scaler.scale(total_loss).backward() 274 | scaler.unscale_(optimizer) 275 | 276 | torch.nn.utils.clip_grad_norm_(parameters, config.max_grad_norm) 277 | if config.amp_type == 'torch_amp': 278 | scaler.step(optimizer) 279 | scaler.update() 280 | else: 281 | optimizer.step() 282 | # ---------------------------------------------------------------------------------- 283 | # update priority 284 | new_priority = value_priority 285 | replay_buffer.update_priorities.remote(indices, new_priority, make_time) 286 | 287 | # packing data for logging 288 | loss_data = (total_loss.item(), weighted_loss.item(), loss.mean().item(), 0, policy_loss.mean().item(), 289 | value_prefix_loss.mean().item(), value_loss.mean().item(), consistency_loss.mean()) 290 | if vis_result: 291 | reward_w_dist, representation_mean, dynamic_mean, reward_mean = model.get_params_mean() 292 | other_dist['reward_weights_dist'] = reward_w_dist 293 | other_log['representation_weight'] = representation_mean 294 | other_log['dynamic_weight'] = dynamic_mean 295 | other_log['reward_weight'] = reward_mean 296 | 297 | # reward l1 loss 298 | value_prefix_indices_0 = (target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) == 0) 299 | value_prefix_indices_n1 = (target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) == -1) 300 | value_prefix_indices_1 = (target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) == 1) 301 | 302 | target_value_prefix_base = target_value_prefix_cpu[:, :config.num_unroll_steps].reshape(-1).unsqueeze(-1) 303 | 304 | predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) 305 | predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) 306 | other_loss['l1'] = metric_loss(predicted_value_prefixs, target_value_prefix_base) 307 | if value_prefix_indices_1.any(): 308 | other_loss['l1_1'] = metric_loss(predicted_value_prefixs[value_prefix_indices_1], target_value_prefix_base[value_prefix_indices_1]) 309 | if value_prefix_indices_n1.any(): 310 | other_loss['l1_-1'] = metric_loss(predicted_value_prefixs[value_prefix_indices_n1], target_value_prefix_base[value_prefix_indices_n1]) 311 | if value_prefix_indices_0.any(): 312 | other_loss['l1_0'] = metric_loss(predicted_value_prefixs[value_prefix_indices_0], target_value_prefix_base[value_prefix_indices_0]) 313 | 314 | td_data = (new_priority, target_value_prefix.detach().cpu().numpy(), target_value.detach().cpu().numpy(), 315 | transformed_target_value_prefix.detach().cpu().numpy(), transformed_target_value.detach().cpu().numpy(), 316 | target_value_prefix_phi.detach().cpu().numpy(), target_value_phi.detach().cpu().numpy(), 317 | predicted_value_prefixs.detach().cpu().numpy(), predicted_values.detach().cpu().numpy(), 318 | target_policy.detach().cpu().numpy(), predicted_policies.detach().cpu().numpy(), state_lst, 319 | other_loss, other_log, other_dist) 320 | priority_data = (weights, indices) 321 | else: 322 | td_data, priority_data = None, None 323 | 324 | return loss_data, td_data, priority_data, scaler 325 | 326 | 327 | def _train(model, target_model, replay_buffer, shared_storage, batch_storage, config, summary_writer): 328 | """training loop 329 | Parameters 330 | ---------- 331 | model: Any 332 | EfficientZero models 333 | target_model: Any 334 | EfficientZero models for reanalyzing 335 | replay_buffer: Any 336 | replay buffer 337 | shared_storage: Any 338 | model storage 339 | batch_storage: Any 340 | batch storage (queue) 341 | summary_writer: Any 342 | logging for tensorboard 343 | """ 344 | # ---------------------------------------------------------------------------------- 345 | model = model.to(config.device) 346 | target_model = target_model.to(config.device) 347 | 348 | optimizer = optim.SGD(model.parameters(), lr=config.lr_init, momentum=config.momentum, 349 | weight_decay=config.weight_decay) 350 | 351 | scaler = GradScaler() 352 | 353 | model.train() 354 | target_model.eval() 355 | # ---------------------------------------------------------------------------------- 356 | # set augmentation tools 357 | if config.use_augmentation: 358 | config.set_transforms() 359 | 360 | # wait until collecting enough data to start 361 | while not (ray.get(replay_buffer.get_total_len.remote()) >= config.start_transitions): 362 | time.sleep(1) 363 | pass 364 | print('Begin training...') 365 | # set signals for other workers 366 | shared_storage.set_start_signal.remote() 367 | 368 | step_count = 0 369 | # Note: the interval of the current model and the target model is between x and 2x. (x = target_model_interval) 370 | # recent_weights is the param of the target model 371 | recent_weights = model.get_weights() 372 | 373 | # while loop 374 | while step_count < config.training_steps + config.last_steps: 375 | # remove data if the replay buffer is full. (more data settings) 376 | if step_count % 1000 == 0: 377 | replay_buffer.remove_to_fit.remote() 378 | 379 | # obtain a batch 380 | batch = batch_storage.pop() 381 | if batch is None: 382 | time.sleep(0.3) 383 | continue 384 | shared_storage.incr_counter.remote() 385 | lr = adjust_lr(config, optimizer, step_count) 386 | 387 | # update model for self-play 388 | if step_count % config.checkpoint_interval == 0: 389 | shared_storage.set_weights.remote(model.get_weights()) 390 | 391 | # update model for reanalyzing 392 | if step_count % config.target_model_interval == 0: 393 | shared_storage.set_target_weights.remote(recent_weights) 394 | recent_weights = model.get_weights() 395 | 396 | if step_count % config.vis_interval == 0: 397 | vis_result = True 398 | else: 399 | vis_result = False 400 | 401 | if config.amp_type == 'torch_amp': 402 | log_data = update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_result) 403 | scaler = log_data[3] 404 | else: 405 | log_data = update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_result) 406 | 407 | if step_count % config.log_interval == 0: 408 | _log(config, step_count, log_data[0:3], model, replay_buffer, lr, shared_storage, summary_writer, vis_result) 409 | 410 | # The queue is empty. 411 | if step_count >= 100 and step_count % 50 == 0 and batch_storage.get_len() == 0: 412 | print('Warning: Batch Queue is empty (Require more batch actors Or batch actor fails).') 413 | 414 | step_count += 1 415 | 416 | # save models 417 | if step_count % config.save_ckpt_interval == 0: 418 | model_path = os.path.join(config.model_dir, 'model_{}.p'.format(step_count)) 419 | torch.save(model.state_dict(), model_path) 420 | 421 | shared_storage.set_weights.remote(model.get_weights()) 422 | time.sleep(30) 423 | return model.get_weights() 424 | 425 | 426 | def train(config, summary_writer, model_path=None): 427 | """training process 428 | Parameters 429 | ---------- 430 | summary_writer: Any 431 | logging for tensorboard 432 | model_path: str 433 | model path for resuming 434 | default: train from scratch 435 | """ 436 | model = config.get_uniform_network() 437 | target_model = config.get_uniform_network() 438 | if model_path: 439 | print('resume model from path: ', model_path) 440 | weights = torch.load(model_path) 441 | 442 | model.load_state_dict(weights) 443 | target_model.load_state_dict(weights) 444 | 445 | storage = SharedStorage.remote(model, target_model) 446 | 447 | # prepare the batch and mctc context storage 448 | batch_storage = QueueStorage(15, 20) 449 | mcts_storage = QueueStorage(18, 25) 450 | replay_buffer = ReplayBuffer.remote(config=config) 451 | 452 | # other workers 453 | workers = [] 454 | 455 | # reanalyze workers 456 | cpu_workers = [BatchWorker_CPU.remote(idx, replay_buffer, storage, batch_storage, mcts_storage, config) for idx in range(config.cpu_actor)] 457 | workers += [cpu_worker.run.remote() for cpu_worker in cpu_workers] 458 | gpu_workers = [BatchWorker_GPU.remote(idx, replay_buffer, storage, batch_storage, mcts_storage, config) for idx in range(config.gpu_actor)] 459 | workers += [gpu_worker.run.remote() for gpu_worker in gpu_workers] 460 | 461 | # self-play workers 462 | data_workers = [DataWorker.remote(rank, replay_buffer, storage, config) for rank in range(0, config.num_actors)] 463 | workers += [worker.run.remote() for worker in data_workers] 464 | # test workers 465 | workers += [_test.remote(config, storage)] 466 | 467 | # training loop 468 | final_weights = _train(model, target_model, replay_buffer, storage, batch_storage, config, summary_writer) 469 | 470 | ray.wait(workers) 471 | print('Training over...') 472 | 473 | return model, final_weights 474 | -------------------------------------------------------------------------------- /core/reanalyze_worker.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import time 3 | import torch 4 | 5 | import numpy as np 6 | import core.ctree.cytree as cytree 7 | 8 | from torch.cuda.amp import autocast as autocast 9 | from core.mcts import MCTS 10 | from core.model import concat_output, concat_output_value 11 | from core.utils import prepare_observation_lst, LinearSchedule 12 | 13 | 14 | @ray.remote 15 | class BatchWorker_CPU(object): 16 | def __init__(self, worker_id, replay_buffer, storage, batch_storage, mcts_storage, config): 17 | """CPU Batch Worker for reanalyzing targets, see Appendix. 18 | Prepare the context concerning CPU overhead 19 | Parameters 20 | ---------- 21 | worker_id: int 22 | id of the worker 23 | replay_buffer: Any 24 | Replay buffer 25 | storage: Any 26 | The model storage 27 | batch_storage: Any 28 | The batch storage (batch queue) 29 | mcts_storage: Ant 30 | The mcts-related contexts storage 31 | """ 32 | self.worker_id = worker_id 33 | self.replay_buffer = replay_buffer 34 | self.storage = storage 35 | self.batch_storage = batch_storage 36 | self.mcts_storage = mcts_storage 37 | self.config = config 38 | 39 | self.last_model_index = -1 40 | self.batch_max_num = 20 41 | self.beta_schedule = LinearSchedule(config.training_steps + config.last_steps, initial_p=config.priority_prob_beta, final_p=1.0) 42 | 43 | def _prepare_reward_value_context(self, indices, games, state_index_lst, total_transitions): 44 | """prepare the context of rewards and values for reanalyzing part 45 | Parameters 46 | ---------- 47 | indices: list 48 | transition index in replay buffer 49 | games: list 50 | list of game histories 51 | state_index_lst: list 52 | transition index in game 53 | total_transitions: int 54 | number of collected transitions 55 | """ 56 | zero_obs = games[0].zero_obs() 57 | config = self.config 58 | value_obs_lst = [] 59 | # the value is valid or not (out of trajectory) 60 | value_mask = [] 61 | rewards_lst = [] 62 | traj_lens = [] 63 | 64 | td_steps_lst = [] 65 | for game, state_index, idx in zip(games, state_index_lst, indices): 66 | traj_len = len(game) 67 | traj_lens.append(traj_len) 68 | 69 | # off-policy correction: shorter horizon of td steps 70 | delta_td = (total_transitions - idx) // config.auto_td_steps 71 | td_steps = config.td_steps - delta_td 72 | td_steps = np.clip(td_steps, 1, 5).astype(np.int) 73 | 74 | # prepare the corresponding observations for bootstrapped values o_{t+k} 75 | game_obs = game.obs(state_index + td_steps, config.num_unroll_steps) 76 | rewards_lst.append(game.rewards) 77 | for current_index in range(state_index, state_index + config.num_unroll_steps + 1): 78 | td_steps_lst.append(td_steps) 79 | bootstrap_index = current_index + td_steps 80 | 81 | if bootstrap_index < traj_len: 82 | value_mask.append(1) 83 | beg_index = bootstrap_index - (state_index + td_steps) 84 | end_index = beg_index + config.stacked_observations 85 | obs = game_obs[beg_index:end_index] 86 | else: 87 | value_mask.append(0) 88 | obs = zero_obs 89 | 90 | value_obs_lst.append(obs) 91 | 92 | value_obs_lst = ray.put(value_obs_lst) 93 | reward_value_context = [value_obs_lst, value_mask, state_index_lst, rewards_lst, traj_lens, td_steps_lst] 94 | return reward_value_context 95 | 96 | def _prepare_policy_non_re_context(self, indices, games, state_index_lst): 97 | """prepare the context of policies for non-reanalyzing part, just return the policy in self-play 98 | Parameters 99 | ---------- 100 | indices: list 101 | transition index in replay buffer 102 | games: list 103 | list of game histories 104 | state_index_lst: list 105 | transition index in game 106 | """ 107 | child_visits = [] 108 | traj_lens = [] 109 | 110 | for game, state_index, idx in zip(games, state_index_lst, indices): 111 | traj_len = len(game) 112 | traj_lens.append(traj_len) 113 | 114 | child_visits.append(game.child_visits) 115 | 116 | policy_non_re_context = [state_index_lst, child_visits, traj_lens] 117 | return policy_non_re_context 118 | 119 | def _prepare_policy_re_context(self, indices, games, state_index_lst): 120 | """prepare the context of policies for reanalyzing part 121 | Parameters 122 | ---------- 123 | indices: list 124 | transition index in replay buffer 125 | games: list 126 | list of game histories 127 | state_index_lst: list 128 | transition index in game 129 | """ 130 | zero_obs = games[0].zero_obs() 131 | config = self.config 132 | 133 | with torch.no_grad(): 134 | # for policy 135 | policy_obs_lst = [] 136 | policy_mask = [] # 0 -> out of traj, 1 -> new policy 137 | rewards, child_visits, traj_lens = [], [], [] 138 | for game, state_index in zip(games, state_index_lst): 139 | traj_len = len(game) 140 | traj_lens.append(traj_len) 141 | rewards.append(game.rewards) 142 | child_visits.append(game.child_visits) 143 | # prepare the corresponding observations 144 | game_obs = game.obs(state_index, config.num_unroll_steps) 145 | for current_index in range(state_index, state_index + config.num_unroll_steps + 1): 146 | 147 | if current_index < traj_len: 148 | policy_mask.append(1) 149 | beg_index = current_index - state_index 150 | end_index = beg_index + config.stacked_observations 151 | obs = game_obs[beg_index:end_index] 152 | else: 153 | policy_mask.append(0) 154 | obs = zero_obs 155 | policy_obs_lst.append(obs) 156 | 157 | policy_obs_lst = ray.put(policy_obs_lst) 158 | policy_re_context = [policy_obs_lst, policy_mask, state_index_lst, indices, child_visits, traj_lens] 159 | return policy_re_context 160 | 161 | def make_batch(self, batch_context, ratio, weights=None): 162 | """prepare the context of a batch 163 | reward_value_context: the context of reanalyzed value targets 164 | policy_re_context: the context of reanalyzed policy targets 165 | policy_non_re_context: the context of non-reanalyzed policy targets 166 | inputs_batch: the inputs of batch 167 | weights: the target model weights 168 | Parameters 169 | ---------- 170 | batch_context: Any 171 | batch context from replay buffer 172 | ratio: float 173 | ratio of reanalyzed policy (value is 100% reanalyzed) 174 | weights: Any 175 | the target model weights 176 | """ 177 | # obtain the batch context from replay buffer 178 | game_lst, game_pos_lst, indices_lst, weights_lst, make_time_lst = batch_context 179 | batch_size = len(indices_lst) 180 | obs_lst, action_lst, mask_lst = [], [], [] 181 | # prepare the inputs of a batch 182 | for i in range(batch_size): 183 | game = game_lst[i] 184 | game_pos = game_pos_lst[i] 185 | 186 | _actions = game.actions[game_pos:game_pos + self.config.num_unroll_steps].tolist() 187 | # add mask for invalid actions (out of trajectory) 188 | _mask = [1. for i in range(len(_actions))] 189 | _mask += [0. for _ in range(self.config.num_unroll_steps - len(_mask))] 190 | 191 | _actions += [np.random.randint(0, game.action_space_size) for _ in range(self.config.num_unroll_steps - len(_actions))] 192 | 193 | # obtain the input observations 194 | obs_lst.append(game_lst[i].obs(game_pos_lst[i], extra_len=self.config.num_unroll_steps, padding=True)) 195 | action_lst.append(_actions) 196 | mask_lst.append(_mask) 197 | 198 | re_num = int(batch_size * ratio) 199 | # formalize the input observations 200 | obs_lst = prepare_observation_lst(obs_lst) 201 | 202 | # formalize the inputs of a batch 203 | inputs_batch = [obs_lst, action_lst, mask_lst, indices_lst, weights_lst, make_time_lst] 204 | for i in range(len(inputs_batch)): 205 | inputs_batch[i] = np.asarray(inputs_batch[i]) 206 | 207 | total_transitions = ray.get(self.replay_buffer.get_total_len.remote()) 208 | 209 | # obtain the context of value targets 210 | reward_value_context = self._prepare_reward_value_context(indices_lst, game_lst, game_pos_lst, total_transitions) 211 | 212 | # 0:re_num -> reanalyzed policy, re_num:end -> non reanalyzed policy 213 | # reanalyzed policy 214 | if re_num > 0: 215 | # obtain the context of reanalyzed policy targets 216 | policy_re_context = self._prepare_policy_re_context(indices_lst[:re_num], game_lst[:re_num], game_pos_lst[:re_num]) 217 | else: 218 | policy_re_context = None 219 | 220 | # non reanalyzed policy 221 | if re_num < batch_size: 222 | # obtain the context of non-reanalyzed policy targets 223 | policy_non_re_context = self._prepare_policy_non_re_context(indices_lst[re_num:], game_lst[re_num:], game_pos_lst[re_num:]) 224 | else: 225 | policy_non_re_context = None 226 | 227 | countext = reward_value_context, policy_re_context, policy_non_re_context, inputs_batch, weights 228 | self.mcts_storage.push(countext) 229 | 230 | def run(self): 231 | # start making mcts contexts to feed the GPU batch maker 232 | start = False 233 | while True: 234 | # wait for starting 235 | if not start: 236 | start = ray.get(self.storage.get_start_signal.remote()) 237 | time.sleep(1) 238 | continue 239 | 240 | ray_data_lst = [self.storage.get_counter.remote(), self.storage.get_target_weights.remote()] 241 | trained_steps, target_weights = ray.get(ray_data_lst) 242 | 243 | beta = self.beta_schedule.value(trained_steps) 244 | # obtain the batch context from replay buffer 245 | batch_context = ray.get(self.replay_buffer.prepare_batch_context.remote(self.config.batch_size, beta)) 246 | # break 247 | if trained_steps >= self.config.training_steps + self.config.last_steps: 248 | time.sleep(30) 249 | break 250 | 251 | new_model_index = trained_steps // self.config.target_model_interval 252 | if new_model_index > self.last_model_index: 253 | self.last_model_index = new_model_index 254 | else: 255 | target_weights = None 256 | 257 | if self.mcts_storage.get_len() < 20: 258 | # Observation will be deleted if replay buffer is full. (They are stored in the ray object store) 259 | try: 260 | self.make_batch(batch_context, self.config.revisit_policy_search_rate, weights=target_weights) 261 | except: 262 | print('Data is deleted...') 263 | time.sleep(0.1) 264 | 265 | 266 | @ray.remote(num_gpus=0.125) 267 | class BatchWorker_GPU(object): 268 | def __init__(self, worker_id, replay_buffer, storage, batch_storage, mcts_storage, config): 269 | """GPU Batch Worker for reanalyzing targets, see Appendix. 270 | receive the context from CPU maker and deal with GPU overheads 271 | Parameters 272 | ---------- 273 | worker_id: int 274 | id of the worker 275 | replay_buffer: Any 276 | Replay buffer 277 | storage: Any 278 | The model storage 279 | batch_storage: Any 280 | The batch storage (batch queue) 281 | mcts_storage: Ant 282 | The mcts-related contexts storage 283 | """ 284 | self.replay_buffer = replay_buffer 285 | self.config = config 286 | self.worker_id = worker_id 287 | 288 | self.model = config.get_uniform_network() 289 | self.model.to(config.device) 290 | self.model.eval() 291 | 292 | self.mcts_storage = mcts_storage 293 | self.storage = storage 294 | self.batch_storage = batch_storage 295 | 296 | self.last_model_index = 0 297 | 298 | def _prepare_reward_value(self, reward_value_context): 299 | """prepare reward and value targets from the context of rewards and values 300 | """ 301 | value_obs_lst, value_mask, state_index_lst, rewards_lst, traj_lens, td_steps_lst = reward_value_context 302 | value_obs_lst = ray.get(value_obs_lst) 303 | device = self.config.device 304 | batch_size = len(value_obs_lst) 305 | 306 | batch_values, batch_value_prefixs = [], [] 307 | with torch.no_grad(): 308 | value_obs_lst = prepare_observation_lst(value_obs_lst) 309 | # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors 310 | m_batch = self.config.mini_infer_size 311 | slices = np.ceil(batch_size / m_batch).astype(np.int_) 312 | network_output = [] 313 | for i in range(slices): 314 | beg_index = m_batch * i 315 | end_index = m_batch * (i + 1) 316 | m_obs = torch.from_numpy(value_obs_lst[beg_index:end_index]).to(device).float() / 255.0 317 | if self.config.amp_type == 'torch_amp': 318 | with autocast(): 319 | m_output = self.model.initial_inference(m_obs) 320 | else: 321 | m_output = self.model.initial_inference(m_obs) 322 | network_output.append(m_output) 323 | 324 | # concat the output slices after model inference 325 | if self.config.use_root_value: 326 | # use the root values from MCTS 327 | # the root values have limited improvement but require much more GPU actors; 328 | _, value_prefix_pool, policy_logits_pool, hidden_state_roots, reward_hidden_roots = concat_output(network_output) 329 | value_prefix_pool = value_prefix_pool.squeeze().tolist() 330 | policy_logits_pool = policy_logits_pool.tolist() 331 | roots = cytree.Roots(batch_size, self.config.action_space_size, self.config.num_simulations) 332 | noises = [np.random.dirichlet([self.config.root_dirichlet_alpha] * self.config.action_space_size).astype(np.float32).tolist() for _ in range(batch_size)] 333 | roots.prepare(self.config.root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool) 334 | MCTS(self.config).search(roots, self.model, hidden_state_roots, reward_hidden_roots) 335 | 336 | roots_values = roots.get_values() 337 | value_lst = np.array(roots_values) 338 | else: 339 | # use the predicted values 340 | value_lst = concat_output_value(network_output) 341 | 342 | # get last state value 343 | value_lst = value_lst.reshape(-1) * (np.array([self.config.discount for _ in range(batch_size)]) ** td_steps_lst) 344 | value_lst = value_lst * np.array(value_mask) 345 | value_lst = value_lst.tolist() 346 | 347 | value_index = 0 348 | for traj_len_non_re, reward_lst, state_index in zip(traj_lens, rewards_lst, state_index_lst): 349 | # traj_len = len(game) 350 | target_values = [] 351 | target_value_prefixs = [] 352 | 353 | horizon_id = 0 354 | value_prefix = 0.0 355 | base_index = state_index 356 | for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1): 357 | bootstrap_index = current_index + td_steps_lst[value_index] 358 | # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): 359 | for i, reward in enumerate(reward_lst[current_index:bootstrap_index]): 360 | value_lst[value_index] += reward * self.config.discount ** i 361 | 362 | # reset every lstm_horizon_len 363 | if horizon_id % self.config.lstm_horizon_len == 0: 364 | value_prefix = 0.0 365 | base_index = current_index 366 | horizon_id += 1 367 | 368 | if current_index < traj_len_non_re: 369 | target_values.append(value_lst[value_index]) 370 | # Since the horizon is small and the discount is close to 1. 371 | # Compute the reward sum to approximate the value prefix for simplification 372 | value_prefix += reward_lst[current_index] # * config.discount ** (current_index - base_index) 373 | target_value_prefixs.append(value_prefix) 374 | else: 375 | target_values.append(0) 376 | target_value_prefixs.append(value_prefix) 377 | value_index += 1 378 | 379 | batch_value_prefixs.append(target_value_prefixs) 380 | batch_values.append(target_values) 381 | 382 | batch_value_prefixs = np.asarray(batch_value_prefixs) 383 | batch_values = np.asarray(batch_values) 384 | return batch_value_prefixs, batch_values 385 | 386 | def _prepare_policy_re(self, policy_re_context): 387 | """prepare policy targets from the reanalyzed context of policies 388 | """ 389 | batch_policies_re = [] 390 | if policy_re_context is None: 391 | return batch_policies_re 392 | 393 | policy_obs_lst, policy_mask, state_index_lst, indices, child_visits, traj_lens = policy_re_context 394 | policy_obs_lst = ray.get(policy_obs_lst) 395 | batch_size = len(policy_obs_lst) 396 | device = self.config.device 397 | 398 | with torch.no_grad(): 399 | policy_obs_lst = prepare_observation_lst(policy_obs_lst) 400 | # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors 401 | m_batch = self.config.mini_infer_size 402 | slices = np.ceil(batch_size / m_batch).astype(np.int_) 403 | network_output = [] 404 | for i in range(slices): 405 | beg_index = m_batch * i 406 | end_index = m_batch * (i + 1) 407 | 408 | m_obs = torch.from_numpy(policy_obs_lst[beg_index:end_index]).to(device).float() / 255.0 409 | if self.config.amp_type == 'torch_amp': 410 | with autocast(): 411 | m_output = self.model.initial_inference(m_obs) 412 | else: 413 | m_output = self.model.initial_inference(m_obs) 414 | network_output.append(m_output) 415 | 416 | _, value_prefix_pool, policy_logits_pool, hidden_state_roots, reward_hidden_roots = concat_output(network_output) 417 | value_prefix_pool = value_prefix_pool.squeeze().tolist() 418 | policy_logits_pool = policy_logits_pool.tolist() 419 | 420 | roots = cytree.Roots(batch_size, self.config.action_space_size, self.config.num_simulations) 421 | noises = [np.random.dirichlet([self.config.root_dirichlet_alpha] * self.config.action_space_size).astype(np.float32).tolist() for _ in range(batch_size)] 422 | roots.prepare(self.config.root_exploration_fraction, noises, value_prefix_pool, policy_logits_pool) 423 | # do MCTS for a new policy with the recent target model 424 | MCTS(self.config).search(roots, self.model, hidden_state_roots, reward_hidden_roots) 425 | 426 | roots_distributions = roots.get_distributions() 427 | policy_index = 0 428 | for state_index, game_idx in zip(state_index_lst, indices): 429 | target_policies = [] 430 | 431 | for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1): 432 | distributions = roots_distributions[policy_index] 433 | 434 | if policy_mask[policy_index] == 0: 435 | target_policies.append([0 for _ in range(self.config.action_space_size)]) 436 | else: 437 | # game.store_search_stats(distributions, value, current_index) 438 | sum_visits = sum(distributions) 439 | policy = [visit_count / sum_visits for visit_count in distributions] 440 | target_policies.append(policy) 441 | 442 | policy_index += 1 443 | 444 | batch_policies_re.append(target_policies) 445 | 446 | batch_policies_re = np.asarray(batch_policies_re) 447 | return batch_policies_re 448 | 449 | def _prepare_policy_non_re(self, policy_non_re_context): 450 | """prepare policy targets from the non-reanalyzed context of policies 451 | """ 452 | batch_policies_non_re = [] 453 | if policy_non_re_context is None: 454 | return batch_policies_non_re 455 | 456 | state_index_lst, child_visits, traj_lens = policy_non_re_context 457 | with torch.no_grad(): 458 | # for policy 459 | policy_mask = [] # 0 -> out of traj, 1 -> old policy 460 | # for game, state_index in zip(games, state_index_lst): 461 | for traj_len, child_visit, state_index in zip(traj_lens, child_visits, state_index_lst): 462 | # traj_len = len(game) 463 | target_policies = [] 464 | 465 | for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1): 466 | if current_index < traj_len: 467 | target_policies.append(child_visit[current_index]) 468 | policy_mask.append(1) 469 | else: 470 | target_policies.append([0 for _ in range(self.config.action_space_size)]) 471 | policy_mask.append(0) 472 | 473 | batch_policies_non_re.append(target_policies) 474 | batch_policies_non_re = np.asarray(batch_policies_non_re) 475 | return batch_policies_non_re 476 | 477 | def _prepare_target_gpu(self): 478 | input_countext = self.mcts_storage.pop() 479 | if input_countext is None: 480 | time.sleep(1) 481 | else: 482 | reward_value_context, policy_re_context, policy_non_re_context, inputs_batch, target_weights = input_countext 483 | if target_weights is not None: 484 | self.model.load_state_dict(target_weights) 485 | self.model.to(self.config.device) 486 | self.model.eval() 487 | 488 | # target reward, value 489 | batch_value_prefixs, batch_values = self._prepare_reward_value(reward_value_context) 490 | # target policy 491 | batch_policies_re = self._prepare_policy_re(policy_re_context) 492 | batch_policies_non_re = self._prepare_policy_non_re(policy_non_re_context) 493 | batch_policies = np.concatenate([batch_policies_re, batch_policies_non_re]) 494 | 495 | targets_batch = [batch_value_prefixs, batch_values, batch_policies] 496 | # a batch contains the inputs and the targets; inputs is prepared in CPU workers 497 | self.batch_storage.push([inputs_batch, targets_batch]) 498 | 499 | def run(self): 500 | start = False 501 | while True: 502 | # waiting for start signal 503 | if not start: 504 | start = ray.get(self.storage.get_start_signal.remote()) 505 | time.sleep(0.1) 506 | continue 507 | 508 | trained_steps = ray.get(self.storage.get_counter.remote()) 509 | if trained_steps >= self.config.training_steps + self.config.last_steps: 510 | time.sleep(30) 511 | break 512 | 513 | self._prepare_target_gpu() 514 | --------------------------------------------------------------------------------