├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── base ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-37.pyc │ ├── policy.cpython-35.pyc │ ├── policy.cpython-37.pyc │ ├── replay_buffer.cpython-35.pyc │ ├── replay_buffer.cpython-37.pyc │ ├── rl.cpython-35.pyc │ └── rl.cpython-37.pyc ├── policy.py ├── replay_buffer.py └── rl.py ├── bench ├── __init__.py └── monitor.py ├── common ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_class.cpython-35.pyc │ ├── base_class.cpython-36.pyc │ ├── console_util.cpython-35.pyc │ ├── console_util.cpython-36.pyc │ ├── console_util.cpython-37.pyc │ ├── dataset.cpython-35.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── distributions.cpython-35.pyc │ ├── distributions.cpython-36.pyc │ ├── distributions.cpython-37.pyc │ ├── math_util.cpython-35.pyc │ ├── math_util.cpython-36.pyc │ ├── math_util.cpython-37.pyc │ ├── misc_util.cpython-35.pyc │ ├── misc_util.cpython-36.pyc │ ├── misc_util.cpython-37.pyc │ ├── policies.cpython-35.pyc │ ├── policies.cpython-36.pyc │ ├── replay_buffer.cpython-35.pyc │ ├── replay_buffer.cpython-36.pyc │ ├── running_mean_std.cpython-35.pyc │ ├── running_mean_std.cpython-36.pyc │ ├── running_mean_std.cpython-37.pyc │ ├── save_util.cpython-35.pyc │ ├── save_util.cpython-36.pyc │ ├── save_util.cpython-37.pyc │ ├── schedules.cpython-35.pyc │ ├── schedules.cpython-36.pyc │ ├── schedules.cpython-37.pyc │ ├── segment_tree.cpython-35.pyc │ ├── segment_tree.cpython-36.pyc │ ├── segment_tree.cpython-37.pyc │ ├── tf_util.cpython-35.pyc │ ├── tf_util.cpython-36.pyc │ ├── tf_util.cpython-37.pyc │ ├── tile_images.cpython-35.pyc │ ├── tile_images.cpython-36.pyc │ └── tile_images.cpython-37.pyc ├── atari_wrappers.py ├── console_util.py ├── dataset.py ├── distributions.py ├── math_util.py ├── misc_util.py ├── running_mean_std.py ├── save_util.py ├── schedules.py ├── segment_tree.py ├── tf_util.py ├── tile_images.py └── vec_env │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_vec_env.cpython-35.pyc │ ├── base_vec_env.cpython-36.pyc │ ├── base_vec_env.cpython-37.pyc │ ├── dummy_vec_env.cpython-35.pyc │ ├── dummy_vec_env.cpython-36.pyc │ ├── dummy_vec_env.cpython-37.pyc │ ├── subproc_vec_env.cpython-35.pyc │ ├── subproc_vec_env.cpython-36.pyc │ ├── subproc_vec_env.cpython-37.pyc │ ├── util.cpython-35.pyc │ ├── util.cpython-36.pyc │ ├── util.cpython-37.pyc │ ├── vec_check_nan.cpython-35.pyc │ ├── vec_check_nan.cpython-36.pyc │ ├── vec_check_nan.cpython-37.pyc │ ├── vec_frame_stack.cpython-35.pyc │ ├── vec_frame_stack.cpython-36.pyc │ ├── vec_frame_stack.cpython-37.pyc │ ├── vec_normalize.cpython-35.pyc │ ├── vec_normalize.cpython-36.pyc │ ├── vec_normalize.cpython-37.pyc │ ├── vec_video_recorder.cpython-35.pyc │ ├── vec_video_recorder.cpython-36.pyc │ └── vec_video_recorder.cpython-37.pyc │ ├── base_vec_env.py │ ├── dummy_vec_env.py │ ├── subproc_vec_env.py │ ├── util.py │ ├── vec_check_nan.py │ ├── vec_frame_stack.py │ ├── vec_normalize.py │ └── vec_video_recorder.py ├── ddpg └── ddpg.py ├── dqn ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── dqn.cpython-35.pyc │ ├── dqn.cpython-37.pyc │ ├── policy.cpython-35.pyc │ └── policy.cpython-37.pyc ├── dqn.py ├── enjoy_example.py ├── policy.py ├── run_Break_double.py ├── run_Break_duel.py ├── train_atari_Breakout.py └── train_example.py ├── environment.yml ├── environment_gpu.yml ├── logger.py ├── sac ├── __pycache__ │ ├── sac.cpython-35.pyc │ └── sac.cpython-37.pyc ├── plot.ipynb ├── sac.py └── train_example.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | .vscode 3 | __pycache__/* 4 | __pycache__ 5 | *.pyc 6 | *.zip 7 | .rsync.sh 8 | results/* 9 | results_final/* 10 | train.sh 11 | *.pkl 12 | /dqn/run_Break_double.py 13 | /dqn/run_Break_duel.py 14 | /dqn/train_atari_Breakout.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | Copyright (c) 2018-2019 Stable-Baselines Team 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Baselines with TF 2.0 2 | 3 | This repository is based on the original implementations of Stable Baselines. (https://github.com/hill-a/stable-baselines) 4 | 5 | In this version, we pursuit following properties: 6 | 1. Easy to debug using Eager-execution and Tensorflow 2.0. 7 | 2. Easy to read, as simple as possible. 8 | 9 | ## Quick start 10 | We recommend to use Anaconda virtual environment. 11 | With using `environment.yml` (or `environment_gpu.yml` if you plan to use GPUs), you can easily install what you need to run! 12 | (To run experiments in Mujoco simulation tasks, make sure that you have the license for it.) 13 | 14 | To start, enter the following commands in a terminal: 15 | ``` 16 | git clone https://github.com/tzs930/stable_baselines_tf2.git 17 | cd stable_baselines_tf2 18 | conda env create -f environment.yml 19 | ``` 20 | 21 | ## Progress 22 | - Abstract classes : 23 | - `base/rl.py` : Includes abstract classes for RL algorihtm 24 | - `base/policy.py` : Includes abstract classes for RL policy 25 | - `base/replay_buffer.py` : Includes replay buffer implementation (which is for off-policy RL algorithms) 26 | - TBU : Tensorboard Writer 27 | 28 | - Algorithms : (Not working yet / Working but not verified / Verified (reproduced) ) 29 | - DQN : Working but not verified 30 | - SAC : Working but not verified 31 | - DDPG : Not working yet 32 | - Algorithms planning to add : TRPO/PPO, GAIL 33 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/__init__.py -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from base.rl import BaseRLAlgorithm, ActorCriticRLAlgorithm, ValueBasedRLAlgorithm, TensorboardWriter 2 | -------------------------------------------------------------------------------- /base/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/policy.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/policy.cpython-35.pyc -------------------------------------------------------------------------------- /base/__pycache__/policy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/policy.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/replay_buffer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/replay_buffer.cpython-35.pyc -------------------------------------------------------------------------------- /base/__pycache__/replay_buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/replay_buffer.cpython-37.pyc -------------------------------------------------------------------------------- /base/__pycache__/rl.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/rl.cpython-35.pyc -------------------------------------------------------------------------------- /base/__pycache__/rl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/base/__pycache__/rl.cpython-37.pyc -------------------------------------------------------------------------------- /base/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | from common.segment_tree import SumSegmentTree, MinSegmentTree 6 | 7 | 8 | class ReplayBuffer(object): 9 | def __init__(self, size): 10 | """ 11 | Implements a ring buffer (FIFO). 12 | 13 | :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old 14 | memories are dropped. 15 | """ 16 | self._storage = [] 17 | self._maxsize = size 18 | self._next_idx = 0 19 | 20 | def __len__(self): 21 | return len(self._storage) 22 | 23 | @property 24 | def storage(self): 25 | """[(np.ndarray, float, float, np.ndarray, bool)]: content of the replay buffer""" 26 | return self._storage 27 | 28 | @property 29 | def buffer_size(self): 30 | """float: Max capacity of the buffer""" 31 | return self._maxsize 32 | 33 | def can_sample(self, n_samples): 34 | """ 35 | Check if n_samples samples can be sampled 36 | from the buffer. 37 | 38 | :param n_samples: (int) 39 | :return: (bool) 40 | """ 41 | return len(self) >= n_samples 42 | 43 | def is_full(self): 44 | """ 45 | Check whether the replay buffer is full or not. 46 | 47 | :return: (bool) 48 | """ 49 | return len(self) == self.buffer_size 50 | 51 | def add(self, obs_t, action, reward, obs_tp1, done): 52 | """ 53 | add a new transition to the buffer 54 | 55 | :param obs_t: (Any) the last observation 56 | :param action: ([float]) the action 57 | :param reward: (float) the reward of the transition 58 | :param obs_tp1: (Any) the current observation 59 | :param done: (bool) is the episode done 60 | """ 61 | data = (obs_t, action, reward, obs_tp1, done) 62 | 63 | if self._next_idx >= len(self._storage): 64 | self._storage.append(data) 65 | else: 66 | self._storage[self._next_idx] = data 67 | self._next_idx = (self._next_idx + 1) % self._maxsize 68 | 69 | def _encode_sample(self, idxes): 70 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 71 | for i in idxes: 72 | data = self._storage[i] 73 | obs_t, action, reward, obs_tp1, done = data 74 | obses_t.append(np.array(obs_t, copy=False)) 75 | actions.append(np.array(action, copy=False)) 76 | rewards.append(reward) 77 | obses_tp1.append(np.array(obs_tp1, copy=False)) 78 | dones.append(done) 79 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 80 | 81 | def sample(self, batch_size, **_kwargs): 82 | """ 83 | Sample a batch of experiences. 84 | 85 | :param batch_size: (int) How many transitions to sample. 86 | :return: 87 | - obs_batch: (np.ndarray) batch of observations 88 | - act_batch: (numpy float) batch of actions executed given obs_batch 89 | - rew_batch: (numpy float) rewards received as results of executing act_batch 90 | - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch 91 | - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode 92 | and 0 otherwise. 93 | """ 94 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 95 | return self._encode_sample(idxes) 96 | 97 | 98 | class PrioritizedReplayBuffer(ReplayBuffer): 99 | def __init__(self, size, alpha): 100 | """ 101 | Create Prioritized Replay buffer. 102 | 103 | See Also ReplayBuffer.__init__ 104 | 105 | :param size: (int) Max number of transitions to store in the buffer. When the buffer overflows the old memories 106 | are dropped. 107 | :param alpha: (float) how much prioritization is used (0 - no prioritization, 1 - full prioritization) 108 | """ 109 | super(PrioritizedReplayBuffer, self).__init__(size) 110 | assert alpha >= 0 111 | self._alpha = alpha 112 | 113 | it_capacity = 1 114 | while it_capacity < size: 115 | it_capacity *= 2 116 | 117 | self._it_sum = SumSegmentTree(it_capacity) 118 | self._it_min = MinSegmentTree(it_capacity) 119 | self._max_priority = 1.0 120 | 121 | def add(self, obs_t, action, reward, obs_tp1, done): 122 | """ 123 | add a new transition to the buffer 124 | 125 | :param obs_t: (Any) the last observation 126 | :param action: ([float]) the action 127 | :param reward: (float) the reward of the transition 128 | :param obs_tp1: (Any) the current observation 129 | :param done: (bool) is the episode done 130 | """ 131 | idx = self._next_idx 132 | super().add(obs_t, action, reward, obs_tp1, done) 133 | self._it_sum[idx] = self._max_priority ** self._alpha 134 | self._it_min[idx] = self._max_priority ** self._alpha 135 | 136 | def _sample_proportional(self, batch_size): 137 | res = [] 138 | for _ in range(batch_size): 139 | # TODO(szymon): should we ensure no repeats? 140 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 141 | idx = self._it_sum.find_prefixsum_idx(mass) 142 | res.append(idx) 143 | return res 144 | 145 | def sample(self, batch_size, beta=0): 146 | """ 147 | Sample a batch of experiences. 148 | 149 | compared to ReplayBuffer.sample 150 | it also returns importance weights and idxes 151 | of sampled experiences. 152 | 153 | :param batch_size: (int) How many transitions to sample. 154 | :param beta: (float) To what degree to use importance weights (0 - no corrections, 1 - full correction) 155 | :return: 156 | - obs_batch: (np.ndarray) batch of observations 157 | - act_batch: (numpy float) batch of actions executed given obs_batch 158 | - rew_batch: (numpy float) rewards received as results of executing act_batch 159 | - next_obs_batch: (np.ndarray) next set of observations seen after executing act_batch 160 | - done_mask: (numpy bool) done_mask[i] = 1 if executing act_batch[i] resulted in the end of an episode 161 | and 0 otherwise. 162 | - weights: (numpy float) Array of shape (batch_size,) and dtype np.float32 denoting importance weight of 163 | each sampled transition 164 | - idxes: (numpy int) Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences 165 | """ 166 | assert beta > 0 167 | 168 | idxes = self._sample_proportional(batch_size) 169 | 170 | weights = [] 171 | p_min = self._it_min.min() / self._it_sum.sum() 172 | max_weight = (p_min * len(self._storage)) ** (-beta) 173 | 174 | for idx in idxes: 175 | p_sample = self._it_sum[idx] / self._it_sum.sum() 176 | weight = (p_sample * len(self._storage)) ** (-beta) 177 | weights.append(weight / max_weight) 178 | weights = np.array(weights) 179 | encoded_sample = self._encode_sample(idxes) 180 | return tuple(list(encoded_sample) + [weights, idxes]) 181 | 182 | def update_priorities(self, idxes, priorities): 183 | """ 184 | Update priorities of sampled transitions. 185 | 186 | sets priority of transition at index idxes[i] in buffer 187 | to priorities[i]. 188 | 189 | :param idxes: ([int]) List of idxes of sampled transitions 190 | :param priorities: ([float]) List of updated priorities corresponding to transitions at the sampled idxes 191 | denoted by variable `idxes`. 192 | """ 193 | assert len(idxes) == len(priorities) 194 | for idx, priority in zip(idxes, priorities): 195 | assert priority > 0 196 | assert 0 <= idx < len(self._storage) 197 | self._it_sum[idx] = priority ** self._alpha 198 | self._it_min[idx] = priority ** self._alpha 199 | 200 | self._max_priority = max(self._max_priority, priority) 201 | -------------------------------------------------------------------------------- /base/rl.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import os 3 | import glob 4 | import warnings 5 | from collections import OrderedDict 6 | import json 7 | import zipfile 8 | 9 | import pickle 10 | import numpy as np 11 | import gym 12 | import tensorflow as tf 13 | 14 | from common import set_global_seeds 15 | from common.save_util import ( 16 | is_json_serializable, data_to_json, json_to_data, params_to_bytes, bytes_to_params 17 | ) 18 | from base.policy import ActorCriticPolicy 19 | from common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv 20 | 21 | 22 | class BaseRLAlgorithm(ABC): 23 | """ 24 | The base RL model 25 | 26 | :param policy: (BasePolicy) Policy object 27 | :param env: (Gym environment) The environment to learn from 28 | (if registered in Gym, can be str. Can be None for loading trained models) 29 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug 30 | :param requires_vec_env: (bool) Does this model require a vectorized environment 31 | :param policy_base: (BasePolicy) the base policy used by this method 32 | """ 33 | def __init__(self, policy_class, env, test_env): 34 | self.policy_class = policy_class 35 | self.env = env 36 | self.test_env = test_env 37 | self.observation_space = None 38 | self.action_space = None 39 | self.num_timesteps = 0 40 | self.params = None 41 | 42 | if env is not None: 43 | self.observation_space = env.observation_space 44 | self.action_space = env.action_space 45 | 46 | @abstractmethod 47 | def get_parameters(self): 48 | """ 49 | Get current model parameters as dictionary of variable name -> ndarray. 50 | 51 | :return: (OrderedDict) Dictionary of variable name -> ndarray of model's parameters. 52 | """ 53 | 54 | return NotImplementedError 55 | 56 | @abstractmethod 57 | def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run", 58 | reset_num_timesteps=True): 59 | """ 60 | Return a trained model. 61 | 62 | :param total_timesteps: (int) The total number of samples to train on 63 | :param seed: (int) The initial seed for training, if None: keep current seed 64 | :param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. 65 | It takes the local and global variables. If it returns False, training is aborted. 66 | :param log_interval: (int) The number of timesteps before logging. 67 | :param tb_log_name: (str) the name of the run for tensorboard log 68 | :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) 69 | :return: (BaseRLModel) the trained model 70 | """ 71 | raise NotImplementedError 72 | 73 | @abstractmethod 74 | def predict(self, observation, state=None, mask=None, deterministic=False): 75 | """ 76 | Get the model's action from an observation 77 | 78 | :param observation: (np.ndarray) the input observation 79 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies) 80 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) 81 | :param deterministic: (bool) Whether or not to return deterministic actions. 82 | :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) 83 | """ 84 | raise NotImplementedError 85 | 86 | def evaluate(self, num_epsiodes=5, deterministic=True): 87 | """ 88 | Test the learned model in the given environment 89 | 90 | :param observation: (np.ndarray) the input observation 91 | :return: (np.ndarray) episode returns of num_epsiodes 92 | """ 93 | 94 | episode_returns = [] 95 | print("* Evaluating...") 96 | for i in range(num_epsiodes): 97 | done = False 98 | obs = self.env.reset() 99 | ret = 0 100 | while not done: 101 | action = self.predict(np.array([obs]), deterministic=deterministic)[0] 102 | obs, rew, done, _ = self.env.step(action) 103 | ret += rew 104 | print("- Episode %3d : %6.3f" % (i+1, ret)) 105 | episode_returns.append(ret) 106 | 107 | episode_returns = np.array(episode_returns) 108 | print("\n* Evaluation Result :") 109 | print("- Average of %d epsiode returns : %6.3f" % (num_epsiodes, np.mean(episode_returns))) 110 | print("- Stddev. of %d epsiode returns : %6.3f\n" % (num_epsiodes, np.std(episode_returns))) 111 | 112 | return episode_returns 113 | 114 | @abstractmethod 115 | def action_probability(self, observation, state=None, mask=None, actions=None, logp=False): 116 | """ 117 | If ``actions`` is ``None``, then get the model's action probability distribution from a given observation. 118 | 119 | Depending on the action space the output is: 120 | - Discrete: probability for each possible action 121 | - Box: mean and standard deviation of the action output 122 | 123 | However if ``actions`` is not ``None``, this function will return the probability that the given actions are 124 | taken with the given parameters (observation, state, ...) on this model. For discrete action spaces, it 125 | returns the probability mass; for continuous action spaces, the probability density. This is since the 126 | probability mass will always be zero in continuous spaces, see http://blog.christianperone.com/2019/01/ 127 | for a good explanation 128 | 129 | :param observation: (np.ndarray) the input observation 130 | :param state: (np.ndarray) The last states (can be None, used in recurrent policies) 131 | :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) 132 | :param actions: (np.ndarray) (OPTIONAL) For calculating the likelihood that the given actions are chosen by 133 | the model for each of the given parameters. Must have the same number of actions and observations. 134 | (set to None to return the complete action probability distribution) 135 | :param logp: (bool) (OPTIONAL) When specified with actions, returns probability in log-space. 136 | This has no effect if actions is None. 137 | :return: (np.ndarray) the model's (log) action probability 138 | """ 139 | raise NotImplementedError 140 | 141 | def load_parameters(self, parameters): 142 | """ 143 | Load model parameters from a file or a dictionary 144 | 145 | Dictionary keys should be tensorflow variable names, which can be obtained 146 | with ``get_parameters`` function. If ``exact_match`` is True, dictionary 147 | should contain keys for all model's parameters, otherwise RunTimeError 148 | is raised. If False, only variables included in the dictionary will be updated. 149 | 150 | This does not load agent's hyper-parameters. 151 | 152 | :param parameters: (list) A list containing parameter values 153 | """ 154 | raise NotImplementedError 155 | 156 | @abstractmethod 157 | def save(self, save_path): 158 | """ 159 | Save the current parameters to file 160 | 161 | :param save_path: (str or file-like) The save location 162 | """ 163 | raise NotImplementedError 164 | 165 | @abstractmethod 166 | def load(self, load_path): 167 | """ 168 | Load the model from file 169 | 170 | :param load_path: (str or file-like) the saved parameter location 171 | """ 172 | raise NotImplementedError() 173 | 174 | 175 | class ActorCriticRLAlgorithm(tf.keras.layers.Layer, BaseRLAlgorithm): 176 | """ 177 | The base class for Actor critic model 178 | 179 | :param policy: (BasePolicy) Policy object 180 | :param env: (Gym environment) The environment to learn from 181 | (if registered in Gym, can be str. Can be None for loading trained models) 182 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug 183 | :param policy_base: (BasePolicy) the base policy used by this method (default=ActorCriticPolicy) 184 | :param requires_vec_env: (bool) Does this model require a vectorized environment 185 | """ 186 | 187 | def __init__(self, policy_class, env, test_env): 188 | 189 | BaseRLAlgorithm.__init__(self, policy_class=policy_class, env=env, test_env=test_env) 190 | tf.keras.layers.Layer.__init__(self) 191 | 192 | self.initial_state = None 193 | self.step = None 194 | self.proba_step = None 195 | self.params = None 196 | 197 | # Actor Network 198 | self.actor = None 199 | 200 | # Critic Network 201 | self.v = None 202 | self.q = None 203 | 204 | @abstractmethod 205 | def learn(self, total_timesteps, callback=None, seed=None, 206 | log_interval=100, tb_log_name="run", reset_num_timesteps=True): 207 | pass 208 | 209 | @abstractmethod 210 | def predict(self, observation, state=None, mask=None, deterministic=False): 211 | pass 212 | 213 | def action_probability(self, observation, state=None, mask=None, actions=None, logp=False): 214 | pass 215 | 216 | def get_parameters(self): 217 | parameters = [] 218 | weights = self.get_weights() 219 | for idx, variable in enumerate(self.trainable_variables): 220 | weight = weights[idx] 221 | parameters.append((variable.name, weight)) 222 | return parameters 223 | 224 | def load_parameters(self, parameters, exact_match=False): 225 | assert len(parameters) == len(self.weights) 226 | weights = [] 227 | for variable, parameter in zip(self.weights, parameters): 228 | name, value = parameter 229 | if exact_match: 230 | assert name == variable.name 231 | weights.append(value) 232 | self.set_weights(weights) 233 | 234 | def save(self, filepath): 235 | parameters = self.get_parameters() 236 | with open(filepath, 'wb') as f: 237 | pickle.dump(parameters, f, protocol=pickle.HIGHEST_PROTOCOL) 238 | 239 | def load(self, filepath): 240 | with open(filepath, 'rb') as f: 241 | parameters = pickle.load(f) 242 | self.load_parameters(parameters) 243 | 244 | 245 | class ValueBasedRLAlgorithm(tf.keras.layers.Layer, BaseRLAlgorithm): 246 | """ 247 | The base class for off policy RL model 248 | 249 | :param policy: (BasePolicy) Policy object 250 | :param env: (Gym environment) The environment to learn from 251 | (if registered in Gym, can be str. Can be None for loading trained models) 252 | :param replay_buffer: (ReplayBuffer) the type of replay buffer 253 | :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug 254 | :param requires_vec_env: (bool) Does this model require a vectorized environment 255 | :param policy_base: (BasePolicy) the base policy used by this method 256 | """ 257 | 258 | def __init__(self, policy_class, env, replay_buffer=None): 259 | super(ValueBasedRLAlgorithm, self).__init__(policy_class, env) 260 | 261 | @abstractmethod 262 | def learn(self, total_timesteps, callback=None, seed=None, 263 | log_interval=100, tb_log_name="run", reset_num_timesteps=True, replay_wrapper=None): 264 | pass 265 | 266 | @abstractmethod 267 | def predict(self, observation, state=None, mask=None, deterministic=False): 268 | pass 269 | 270 | @abstractmethod 271 | def action_probability(self, observation, state=None, mask=None, actions=None, logp=False): 272 | pass 273 | 274 | @abstractmethod 275 | def save(self, save_path): 276 | pass 277 | 278 | @abstractmethod 279 | def load(self, load_path): 280 | """ 281 | Load the model from file 282 | 283 | :param load_path: (str or file-like) the saved parameter location 284 | :param env: (Gym Envrionment) the new environment to run the loaded model on 285 | (can be None if you only need prediction from a trained model) 286 | :param custom_objects: (dict) Dictionary of objects to replace 287 | upon loading. If a variable is present in this dictionary as a 288 | key, it will not be deserialized and the corresponding item 289 | will be used instead. Similar to custom_objects in 290 | `keras.models.load_model`. Useful when you have an object in 291 | file that can not be deserialized. 292 | :param kwargs: extra arguments to change the model when loading 293 | """ 294 | pass 295 | 296 | 297 | class TensorboardWriter: 298 | def __init__(self, graph, tensorboard_log_path, tb_log_name, new_tb_log=True): 299 | """ 300 | Create a Tensorboard writer for a code segment, and saves it to the log directory as its own run 301 | 302 | :param graph: (Tensorflow Graph) the model graph 303 | :param tensorboard_log_path: (str) the save path for the log (can be None for no logging) 304 | :param tb_log_name: (str) the name of the run for tensorboard log 305 | :param new_tb_log: (bool) whether or not to create a new logging folder for tensorbaord 306 | """ 307 | self.graph = graph 308 | self.tensorboard_log_path = tensorboard_log_path 309 | self.tb_log_name = tb_log_name 310 | self.writer = None 311 | self.new_tb_log = new_tb_log 312 | 313 | def __enter__(self): 314 | if self.tensorboard_log_path is not None: 315 | latest_run_id = self._get_latest_run_id() 316 | if self.new_tb_log: 317 | latest_run_id = latest_run_id + 1 318 | save_path = os.path.join(self.tensorboard_log_path, "{}_{}".format(self.tb_log_name, latest_run_id)) 319 | self.writer = tf.compat.v1.summary.FileWriter(save_path, graph=self.graph) 320 | return self.writer 321 | 322 | def _get_latest_run_id(self): 323 | """ 324 | returns the latest run number for the given log name and log path, 325 | by finding the greatest number in the directories. 326 | 327 | :return: (int) latest run number 328 | """ 329 | max_run_id = 0 330 | for path in glob.glob("{}/{}_[0-9]*".format(self.tensorboard_log_path, self.tb_log_name)): 331 | file_name = path.split(os.sep)[-1] 332 | ext = file_name.split("_")[-1] 333 | if self.tb_log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: 334 | max_run_id = int(ext) 335 | return max_run_id 336 | 337 | def __exit__(self, exc_type, exc_val, exc_tb): 338 | if self.writer is not None: 339 | self.writer.add_graph(self.graph) 340 | self.writer.flush() 341 | -------------------------------------------------------------------------------- /bench/__init__.py: -------------------------------------------------------------------------------- 1 | from bench.monitor import Monitor, load_results -------------------------------------------------------------------------------- /bench/monitor.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Monitor', 'get_monitor_files', 'load_results'] 2 | 3 | import csv 4 | import json 5 | import os 6 | import time 7 | from glob import glob 8 | from typing import Tuple, Dict, Any, List, Optional 9 | 10 | import gym 11 | import pandas 12 | import numpy as np 13 | 14 | 15 | class Monitor(gym.Wrapper): 16 | EXT = "monitor.csv" 17 | file_handler = None 18 | 19 | def __init__(self, 20 | env: gym.Env, 21 | filename: Optional[str], 22 | allow_early_resets: bool = True, 23 | reset_keywords=(), 24 | info_keywords=()): 25 | """ 26 | A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data. 27 | :param env: (gym.Env) The environment 28 | :param filename: (Optional[str]) the location to save a log file, can be None for no log 29 | :param allow_early_resets: (bool) allows the reset of the environment before it is done 30 | :param reset_keywords: (tuple) extra keywords for the reset call, if extra parameters are needed at reset 31 | :param info_keywords: (tuple) extra information to log, from the information return of environment.step 32 | """ 33 | super(Monitor, self).__init__(env=env) 34 | self.t_start = time.time() 35 | if filename is None: 36 | self.file_handler = None 37 | self.logger = None 38 | else: 39 | if not filename.endswith(Monitor.EXT): 40 | if os.path.isdir(filename): 41 | filename = os.path.join(filename, Monitor.EXT) 42 | else: 43 | filename = filename + "." + Monitor.EXT 44 | self.file_handler = open(filename, "wt") 45 | self.file_handler.write('#%s\n' % json.dumps({"t_start": self.t_start, 'env_id': env.spec and env.spec.id})) 46 | self.logger = csv.DictWriter(self.file_handler, 47 | fieldnames=('r', 'l', 't') + reset_keywords + info_keywords) 48 | self.logger.writeheader() 49 | self.file_handler.flush() 50 | 51 | self.reset_keywords = reset_keywords 52 | self.info_keywords = info_keywords 53 | self.allow_early_resets = allow_early_resets 54 | self.rewards = None 55 | self.needs_reset = True 56 | self.episode_rewards = [] 57 | self.episode_lengths = [] 58 | self.episode_times = [] 59 | self.total_steps = 0 60 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() 61 | 62 | def reset(self, **kwargs) -> np.ndarray: 63 | """ 64 | Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True 65 | :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords 66 | :return: (np.ndarray) the first observation of the environment 67 | """ 68 | if not self.allow_early_resets and not self.needs_reset: 69 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, " 70 | "wrap your env with Monitor(env, path, allow_early_resets=True)") 71 | self.rewards = [] 72 | self.needs_reset = False 73 | for key in self.reset_keywords: 74 | value = kwargs.get(key) 75 | if value is None: 76 | raise ValueError('Expected you to pass kwarg {} into reset'.format(key)) 77 | self.current_reset_info[key] = value 78 | return self.env.reset(**kwargs) 79 | 80 | def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]: 81 | """ 82 | Step the environment with the given action 83 | :param action: (np.ndarray) the action 84 | :return: (Tuple[np.ndarray, float, bool, Dict[Any, Any]]) observation, reward, done, information 85 | """ 86 | if self.needs_reset: 87 | raise RuntimeError("Tried to step environment that needs reset") 88 | observation, reward, done, info = self.env.step(action) 89 | self.rewards.append(reward) 90 | if done: 91 | self.needs_reset = True 92 | ep_rew = sum(self.rewards) 93 | eplen = len(self.rewards) 94 | ep_info = {"r": round(ep_rew, 6), "l": eplen, "t": round(time.time() - self.t_start, 6)} 95 | for key in self.info_keywords: 96 | ep_info[key] = info[key] 97 | self.episode_rewards.append(ep_rew) 98 | self.episode_lengths.append(eplen) 99 | self.episode_times.append(time.time() - self.t_start) 100 | ep_info.update(self.current_reset_info) 101 | if self.logger: 102 | self.logger.writerow(ep_info) 103 | self.file_handler.flush() 104 | info['episode'] = ep_info 105 | self.total_steps += 1 106 | return observation, reward, done, info 107 | 108 | def close(self): 109 | """ 110 | Closes the environment 111 | """ 112 | super(Monitor, self).close() 113 | if self.file_handler is not None: 114 | self.file_handler.close() 115 | 116 | def get_total_steps(self) -> int: 117 | """ 118 | Returns the total number of timesteps 119 | :return: (int) 120 | """ 121 | return self.total_steps 122 | 123 | def get_episode_rewards(self) -> List[float]: 124 | """ 125 | Returns the rewards of all the episodes 126 | :return: ([float]) 127 | """ 128 | return self.episode_rewards 129 | 130 | def get_episode_lengths(self) -> List[int]: 131 | """ 132 | Returns the number of timesteps of all the episodes 133 | :return: ([int]) 134 | """ 135 | return self.episode_lengths 136 | 137 | def get_episode_times(self) -> List[float]: 138 | """ 139 | Returns the runtime in seconds of all the episodes 140 | :return: ([float]) 141 | """ 142 | return self.episode_times 143 | 144 | 145 | class LoadMonitorResultsError(Exception): 146 | """ 147 | Raised when loading the monitor log fails. 148 | """ 149 | pass 150 | 151 | 152 | def get_monitor_files(path: str) -> List[str]: 153 | """ 154 | get all the monitor files in the given path 155 | :param path: (str) the logging folder 156 | :return: ([str]) the log files 157 | """ 158 | return glob(os.path.join(path, "*" + Monitor.EXT)) 159 | 160 | 161 | def load_results(path: str) -> pandas.DataFrame: 162 | """ 163 | Load all Monitor logs from a given directory path matching ``*monitor.csv`` and ``*monitor.json`` 164 | :param path: (str) the directory path containing the log file(s) 165 | :return: (pandas.DataFrame) the logged data 166 | """ 167 | # get both csv and (old) json files 168 | monitor_files = (glob(os.path.join(path, "*monitor.json")) + get_monitor_files(path)) 169 | if not monitor_files: 170 | raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, path)) 171 | data_frames = [] 172 | headers = [] 173 | for file_name in monitor_files: 174 | with open(file_name, 'rt') as file_handler: 175 | if file_name.endswith('csv'): 176 | first_line = file_handler.readline() 177 | assert first_line[0] == '#' 178 | header = json.loads(first_line[1:]) 179 | data_frame = pandas.read_csv(file_handler, index_col=None) 180 | headers.append(header) 181 | elif file_name.endswith('json'): # Deprecated json format 182 | episodes = [] 183 | lines = file_handler.readlines() 184 | header = json.loads(lines[0]) 185 | headers.append(header) 186 | for line in lines[1:]: 187 | episode = json.loads(line) 188 | episodes.append(episode) 189 | data_frame = pandas.DataFrame(episodes) 190 | else: 191 | assert 0, 'unreachable' 192 | data_frame['t'] += header['t_start'] 193 | data_frames.append(data_frame) 194 | data_frame = pandas.concat(data_frames) 195 | data_frame.sort_values('t', inplace=True) 196 | data_frame.reset_index(inplace=True) 197 | data_frame['t'] -= min(header['t_start'] for header in headers) 198 | # data_frame.headers = headers # HACK to preserve backwards compatibility 199 | return data_frame -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F403 2 | from common.console_util import fmt_row, fmt_item, colorize 3 | from common.dataset import Dataset 4 | from common.math_util import discount, discount_with_boundaries, explained_variance, \ 5 | explained_variance_2d, flatten_arrays, unflatten_vector 6 | from common.misc_util import zipsame, set_global_seeds, boolean_flag 7 | from common.schedules import LinearSchedule -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/base_class.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/base_class.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/base_class.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/base_class.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/console_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/console_util.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/console_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/console_util.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/console_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/console_util.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/dataset.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/distributions.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/distributions.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/distributions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/distributions.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/math_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/math_util.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/math_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/math_util.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/math_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/math_util.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/misc_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/misc_util.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/misc_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/misc_util.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/misc_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/misc_util.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/policies.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/policies.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/policies.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/policies.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/replay_buffer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/replay_buffer.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/replay_buffer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/replay_buffer.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/running_mean_std.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/running_mean_std.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/running_mean_std.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/running_mean_std.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/running_mean_std.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/running_mean_std.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/save_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/save_util.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/save_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/save_util.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/save_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/save_util.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/schedules.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/schedules.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/schedules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/schedules.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/schedules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/schedules.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/segment_tree.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/segment_tree.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/segment_tree.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/segment_tree.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/segment_tree.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/segment_tree.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/tf_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/tf_util.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/tf_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/tf_util.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/tf_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/tf_util.cpython-37.pyc -------------------------------------------------------------------------------- /common/__pycache__/tile_images.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/tile_images.cpython-35.pyc -------------------------------------------------------------------------------- /common/__pycache__/tile_images.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/tile_images.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/tile_images.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/__pycache__/tile_images.cpython-37.pyc -------------------------------------------------------------------------------- /common/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | import gym 5 | from gym import spaces 6 | import cv2 # pytype:disable=import-error 7 | cv2.ocl.setUseOpenCL(False) 8 | 9 | 10 | class NoopResetEnv(gym.Wrapper): 11 | def __init__(self, env, noop_max=30): 12 | """ 13 | Sample initial states by taking random number of no-ops on reset. 14 | No-op is assumed to be action 0. 15 | :param env: (Gym Environment) the environment to wrap 16 | :param noop_max: (int) the maximum value of no-ops to run 17 | """ 18 | gym.Wrapper.__init__(self, env) 19 | self.noop_max = noop_max 20 | self.override_num_noops = None 21 | self.noop_action = 0 22 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 23 | 24 | def reset(self, **kwargs): 25 | self.env.reset(**kwargs) 26 | if self.override_num_noops is not None: 27 | noops = self.override_num_noops 28 | else: 29 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) 30 | assert noops > 0 31 | obs = None 32 | for _ in range(noops): 33 | obs, _, done, _ = self.env.step(self.noop_action) 34 | if done: 35 | obs = self.env.reset(**kwargs) 36 | return obs 37 | 38 | def step(self, action): 39 | return self.env.step(action) 40 | 41 | 42 | class FireResetEnv(gym.Wrapper): 43 | def __init__(self, env): 44 | """ 45 | Take action on reset for environments that are fixed until firing. 46 | :param env: (Gym Environment) the environment to wrap 47 | """ 48 | gym.Wrapper.__init__(self, env) 49 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 50 | assert len(env.unwrapped.get_action_meanings()) >= 3 51 | 52 | def reset(self, **kwargs): 53 | self.env.reset(**kwargs) 54 | obs, _, done, _ = self.env.step(1) 55 | if done: 56 | self.env.reset(**kwargs) 57 | obs, _, done, _ = self.env.step(2) 58 | if done: 59 | self.env.reset(**kwargs) 60 | return obs 61 | 62 | def step(self, action): 63 | return self.env.step(action) 64 | 65 | 66 | class EpisodicLifeEnv(gym.Wrapper): 67 | def __init__(self, env): 68 | """ 69 | Make end-of-life == end-of-episode, but only reset on true game over. 70 | Done by DeepMind for the DQN and co. since it helps value estimation. 71 | :param env: (Gym Environment) the environment to wrap 72 | """ 73 | gym.Wrapper.__init__(self, env) 74 | self.lives = 0 75 | self.was_real_done = True 76 | 77 | def step(self, action): 78 | obs, reward, done, info = self.env.step(action) 79 | self.was_real_done = done 80 | # check current lives, make loss of life terminal, 81 | # then update lives to handle bonus lives 82 | lives = self.env.unwrapped.ale.lives() 83 | if 0 < lives < self.lives: 84 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames 85 | # so its important to keep lives > 0, so that we only reset once 86 | # the environment advertises done. 87 | done = True 88 | self.lives = lives 89 | return obs, reward, done, info 90 | 91 | def reset(self, **kwargs): 92 | """ 93 | Calls the Gym environment reset, only when lives are exhausted. 94 | This way all states are still reachable even though lives are episodic, 95 | and the learner need not know about any of this behind-the-scenes. 96 | :param kwargs: Extra keywords passed to env.reset() call 97 | :return: ([int] or [float]) the first observation of the environment 98 | """ 99 | if self.was_real_done: 100 | obs = self.env.reset(**kwargs) 101 | else: 102 | # no-op step to advance from terminal/lost life state 103 | obs, _, _, _ = self.env.step(0) 104 | self.lives = self.env.unwrapped.ale.lives() 105 | return obs 106 | 107 | 108 | class MaxAndSkipEnv(gym.Wrapper): 109 | def __init__(self, env, skip=4): 110 | """ 111 | Return only every `skip`-th frame (frameskipping) 112 | :param env: (Gym Environment) the environment 113 | :param skip: (int) number of `skip`-th frame 114 | """ 115 | gym.Wrapper.__init__(self, env) 116 | # most recent raw observations (for max pooling across time steps) 117 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=env.observation_space.dtype) 118 | self._skip = skip 119 | 120 | def step(self, action): 121 | """ 122 | Step the environment with the given action 123 | Repeat action, sum reward, and max over last observations. 124 | :param action: ([int] or [float]) the action 125 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 126 | """ 127 | total_reward = 0.0 128 | done = None 129 | for i in range(self._skip): 130 | obs, reward, done, info = self.env.step(action) 131 | if i == self._skip - 2: 132 | self._obs_buffer[0] = obs 133 | if i == self._skip - 1: 134 | self._obs_buffer[1] = obs 135 | total_reward += reward 136 | if done: 137 | break 138 | # Note that the observation on the done=True frame 139 | # doesn't matter 140 | max_frame = self._obs_buffer.max(axis=0) 141 | 142 | return max_frame, total_reward, done, info 143 | 144 | def reset(self, **kwargs): 145 | return self.env.reset(**kwargs) 146 | 147 | 148 | class ClipRewardEnv(gym.RewardWrapper): 149 | def __init__(self, env): 150 | """ 151 | clips the reward to {+1, 0, -1} by its sign. 152 | :param env: (Gym Environment) the environment 153 | """ 154 | gym.RewardWrapper.__init__(self, env) 155 | 156 | def reward(self, reward): 157 | """ 158 | Bin reward to {+1, 0, -1} by its sign. 159 | :param reward: (float) 160 | """ 161 | return np.sign(reward) 162 | 163 | 164 | class WarpFrame(gym.ObservationWrapper): 165 | def __init__(self, env): 166 | """ 167 | Warp frames to 84x84 as done in the Nature paper and later work. 168 | :param env: (Gym Environment) the environment 169 | """ 170 | gym.ObservationWrapper.__init__(self, env) 171 | self.width = 84 172 | self.height = 84 173 | self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width, 1), 174 | dtype=env.observation_space.dtype) 175 | 176 | def observation(self, frame): 177 | """ 178 | returns the current observation from a frame 179 | :param frame: ([int] or [float]) environment frame 180 | :return: ([int] or [float]) the observation 181 | """ 182 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 183 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 184 | return frame[:, :, None] 185 | 186 | 187 | class FrameStack(gym.Wrapper): 188 | def __init__(self, env, n_frames): 189 | """Stack n_frames last frames. 190 | Returns lazy array, which is much more memory efficient. 191 | See Also 192 | -------- 193 | stable_baselines.common.atari_wrappers.LazyFrames 194 | :param env: (Gym Environment) the environment 195 | :param n_frames: (int) the number of frames to stack 196 | """ 197 | gym.Wrapper.__init__(self, env) 198 | self.n_frames = n_frames 199 | self.frames = deque([], maxlen=n_frames) 200 | shp = env.observation_space.shape 201 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * n_frames), 202 | dtype=env.observation_space.dtype) 203 | 204 | def reset(self): 205 | obs = self.env.reset() 206 | for _ in range(self.n_frames): 207 | self.frames.append(obs) 208 | return self._get_ob() 209 | 210 | def step(self, action): 211 | obs, reward, done, info = self.env.step(action) 212 | self.frames.append(obs) 213 | return self._get_ob(), reward, done, info 214 | 215 | def _get_ob(self): 216 | assert len(self.frames) == self.n_frames 217 | return LazyFrames(list(self.frames)) 218 | 219 | 220 | class ScaledFloatFrame(gym.ObservationWrapper): 221 | def __init__(self, env): 222 | gym.ObservationWrapper.__init__(self, env) 223 | self.observation_space = spaces.Box(low=0, high=1.0, shape=env.observation_space.shape, dtype=np.float32) 224 | 225 | def observation(self, observation): 226 | # careful! This undoes the memory optimization, use 227 | # with smaller replay buffers only. 228 | return np.array(observation).astype(np.float32) / 255.0 229 | 230 | 231 | class LazyFrames(object): 232 | def __init__(self, frames): 233 | """ 234 | This object ensures that common frames between the observations are only stored once. 235 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 236 | buffers. 237 | This object should only be converted to np.ndarray before being passed to the model. 238 | :param frames: ([int] or [float]) environment frames 239 | """ 240 | self._frames = frames 241 | self._out = None 242 | 243 | def _force(self): 244 | if self._out is None: 245 | self._out = np.concatenate(self._frames, axis=2) 246 | self._frames = None 247 | return self._out 248 | 249 | def __array__(self, dtype=None): 250 | out = self._force() 251 | if dtype is not None: 252 | out = out.astype(dtype) 253 | return out 254 | 255 | def __len__(self): 256 | return len(self._force()) 257 | 258 | def __getitem__(self, i): 259 | return self._force()[i] 260 | 261 | 262 | def make_atari(env_id): 263 | """ 264 | Create a wrapped atari Environment 265 | :param env_id: (str) the environment ID 266 | :return: (Gym Environment) the wrapped atari environment 267 | """ 268 | env = gym.make(env_id) 269 | assert 'NoFrameskip' in env.spec.id 270 | env = NoopResetEnv(env, noop_max=30) 271 | env = MaxAndSkipEnv(env, skip=4) 272 | return env 273 | 274 | 275 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 276 | """ 277 | Configure environment for DeepMind-style Atari. 278 | :param env: (Gym Environment) the atari environment 279 | :param episode_life: (bool) wrap the episode life wrapper 280 | :param clip_rewards: (bool) wrap the reward clipping wrapper 281 | :param frame_stack: (bool) wrap the frame stacking wrapper 282 | :param scale: (bool) wrap the scaling observation wrapper 283 | :return: (Gym Environment) the wrapped atari environment 284 | """ 285 | if episode_life: 286 | env = EpisodicLifeEnv(env) 287 | if 'FIRE' in env.unwrapped.get_action_meanings(): 288 | env = FireResetEnv(env) 289 | env = WarpFrame(env) 290 | if scale: 291 | env = ScaledFloatFrame(env) 292 | if clip_rewards: 293 | env = ClipRewardEnv(env) 294 | if frame_stack: 295 | env = FrameStack(env, 4) 296 | return env -------------------------------------------------------------------------------- /common/console_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | 5 | 6 | # ================================================================ 7 | # Misc 8 | # ================================================================ 9 | 10 | 11 | def fmt_row(width, row, header=False): 12 | """ 13 | fits a list of items to at least a certain length 14 | 15 | :param width: (int) the minimum width of the string 16 | :param row: ([Any]) a list of object you wish to get the string representation 17 | :param header: (bool) whether or not to return the string as a header 18 | :return: (str) the string representation of all the elements in 'row', of length >= 'width' 19 | """ 20 | out = " | ".join(fmt_item(x, width) for x in row) 21 | if header: 22 | out = out + "\n" + "-" * len(out) 23 | return out 24 | 25 | 26 | def fmt_item(item, min_width): 27 | """ 28 | fits items to a given string length 29 | 30 | :param item: (Any) the item you wish to get the string representation 31 | :param min_width: (int) the minimum width of the string 32 | :return: (str) the string representation of 'x' of length >= 'l' 33 | """ 34 | if isinstance(item, np.ndarray): 35 | assert item.ndim == 0 36 | item = item.item() 37 | if isinstance(item, (float, np.float32, np.float64)): 38 | value = abs(item) 39 | if (value < 1e-4 or value > 1e+4) and value > 0: 40 | rep = "%7.2e" % item 41 | else: 42 | rep = "%7.5f" % item 43 | else: 44 | rep = str(item) 45 | return " " * (min_width - len(rep)) + rep 46 | 47 | 48 | COLOR_TO_NUM = dict( 49 | gray=30, 50 | red=31, 51 | green=32, 52 | yellow=33, 53 | blue=34, 54 | magenta=35, 55 | cyan=36, 56 | white=37, 57 | crimson=38 58 | ) 59 | 60 | 61 | def colorize(string, color, bold=False, highlight=False): 62 | """ 63 | Colorize, bold and/or highlight a string for terminal print 64 | 65 | :param string: (str) input string 66 | :param color: (str) the color, the lookup table is the dict at console_util.color2num 67 | :param bold: (bool) if the string should be bold or not 68 | :param highlight: (bool) if the string should be highlighted or not 69 | :return: (str) the stylized output string 70 | """ 71 | attr = [] 72 | num = COLOR_TO_NUM[color] 73 | if highlight: 74 | num += 10 75 | attr.append(str(num)) 76 | if bold: 77 | attr.append('1') 78 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 79 | -------------------------------------------------------------------------------- /common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Dataset(object): 5 | def __init__(self, data_map, shuffle=True): 6 | """ 7 | Data loader that handles batches and shuffling. 8 | WARNING: this will alter the given data_map ordering, as dicts are mutable 9 | 10 | :param data_map: (dict) the input data, where every column is a key 11 | :param shuffle: (bool) Whether to shuffle or not the dataset 12 | Important: this should be disabled for recurrent policies 13 | """ 14 | self.data_map = data_map 15 | self.shuffle = shuffle 16 | self.n_samples = next(iter(data_map.values())).shape[0] 17 | self._next_id = 0 18 | if self.shuffle: 19 | self.shuffle_dataset() 20 | 21 | def shuffle_dataset(self): 22 | """ 23 | Shuffles the data_map 24 | """ 25 | perm = np.arange(self.n_samples) 26 | np.random.shuffle(perm) 27 | 28 | for key in self.data_map: 29 | self.data_map[key] = self.data_map[key][perm] 30 | 31 | def next_batch(self, batch_size): 32 | """ 33 | returns a batch of data of a given size 34 | 35 | :param batch_size: (int) the size of the batch 36 | :return: (dict) a batch of the input data of size 'batch_size' 37 | """ 38 | if self._next_id >= self.n_samples: 39 | self._next_id = 0 40 | if self.shuffle: 41 | self.shuffle_dataset() 42 | 43 | cur_id = self._next_id 44 | cur_batch_size = min(batch_size, self.n_samples - self._next_id) 45 | self._next_id += cur_batch_size 46 | 47 | data_map = dict() 48 | for key in self.data_map: 49 | data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size] 50 | return data_map 51 | 52 | def iterate_once(self, batch_size): 53 | """ 54 | generator that iterates over the dataset 55 | 56 | :param batch_size: (int) the size of the batch 57 | :return: (dict) a batch of the input data of size 'batch_size' 58 | """ 59 | if self.shuffle: 60 | self.shuffle_dataset() 61 | 62 | while self._next_id <= self.n_samples - batch_size: 63 | yield self.next_batch(batch_size) 64 | self._next_id = 0 65 | 66 | def subset(self, num_elements, shuffle=True): 67 | """ 68 | Return a subset of the current dataset 69 | 70 | :param num_elements: (int) the number of element you wish to have in the subset 71 | :param shuffle: (bool) Whether to shuffle or not the dataset 72 | :return: (Dataset) a new subset of the current Dataset object 73 | """ 74 | data_map = dict() 75 | for key in self.data_map: 76 | data_map[key] = self.data_map[key][:num_elements] 77 | return Dataset(data_map, shuffle) 78 | 79 | 80 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True): 81 | """ 82 | Iterates over arrays in batches, must provide either num_batches or batch_size, the other must be None. 83 | 84 | :param arrays: (tuple) a tuple of arrays 85 | :param num_batches: (int) the number of batches, must be None is batch_size is defined 86 | :param batch_size: (int) the size of the batch, must be None is num_batches is defined 87 | :param shuffle: (bool) enable auto shuffle 88 | :param include_final_partial_batch: (bool) add the last batch if not the same size as the batch_size 89 | :return: (tuples) a tuple of a batch of the arrays 90 | """ 91 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both' 92 | arrays = tuple(map(np.asarray, arrays)) 93 | n_samples = arrays[0].shape[0] 94 | assert all(a.shape[0] == n_samples for a in arrays[1:]) 95 | inds = np.arange(n_samples) 96 | if shuffle: 97 | np.random.shuffle(inds) 98 | sections = np.arange(0, n_samples, batch_size)[1:] if num_batches is None else num_batches 99 | for batch_inds in np.array_split(inds, sections): 100 | if include_final_partial_batch or len(batch_inds) == batch_size: 101 | yield tuple(a[batch_inds] for a in arrays) 102 | -------------------------------------------------------------------------------- /common/math_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | 4 | 5 | def discount(vector, gamma): 6 | """ 7 | computes discounted sums along 0th dimension of vector x. 8 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 9 | where k = len(x) - t - 1 10 | 11 | :param vector: (np.ndarray) the input vector 12 | :param gamma: (float) the discount value 13 | :return: (np.ndarray) the output vector 14 | """ 15 | assert vector.ndim >= 1 16 | return scipy.signal.lfilter([1], [1, -gamma], vector[::-1], axis=0)[::-1] 17 | 18 | 19 | def explained_variance(y_pred, y_true): 20 | """ 21 | Computes fraction of variance that ypred explains about y. 22 | Returns 1 - Var[y-ypred] / Var[y] 23 | 24 | interpretation: 25 | ev=0 => might as well have predicted zero 26 | ev=1 => perfect prediction 27 | ev<0 => worse than just predicting zero 28 | 29 | :param y_pred: (np.ndarray) the prediction 30 | :param y_true: (np.ndarray) the expected value 31 | :return: (float) explained variance of ypred and y 32 | """ 33 | assert y_true.ndim == 1 and y_pred.ndim == 1 34 | var_y = np.var(y_true) 35 | return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y 36 | 37 | 38 | def explained_variance_2d(y_pred, y_true): 39 | """ 40 | Computes fraction of variance that ypred explains about y, for 2D arrays. 41 | Returns 1 - Var[y-ypred] / Var[y] 42 | 43 | interpretation: 44 | ev=0 => might as well have predicted zero 45 | ev=1 => perfect prediction 46 | ev<0 => worse than just predicting zero 47 | 48 | :param y_pred: (np.ndarray) the prediction 49 | :param y_true: (np.ndarray) the expected value 50 | :return: (float) explained variance of ypred and y 51 | """ 52 | assert y_true.ndim == 2 and y_pred.ndim == 2 53 | var_y = np.var(y_true, axis=0) 54 | explained_var = 1 - np.var(y_true - y_pred) / var_y 55 | explained_var[var_y < 1e-10] = 0 56 | return explained_var 57 | 58 | 59 | def flatten_arrays(arrs): 60 | """ 61 | flattens a list of arrays down to 1D 62 | 63 | :param arrs: ([np.ndarray]) arrays 64 | :return: (np.ndarray) 1D flattend array 65 | """ 66 | return np.concatenate([arr.flat for arr in arrs]) 67 | 68 | 69 | def unflatten_vector(vec, shapes): 70 | """ 71 | reshape a flattened array 72 | 73 | :param vec: (np.ndarray) 1D arrays 74 | :param shapes: (tuple) 75 | :return: ([np.ndarray]) reshaped array 76 | """ 77 | i = 0 78 | arrs = [] 79 | for shape in shapes: 80 | size = np.prod(shape) 81 | arr = vec[i:i + size].reshape(shape) 82 | arrs.append(arr) 83 | i += size 84 | return arrs 85 | 86 | 87 | def discount_with_boundaries(rewards, episode_starts, gamma): 88 | """ 89 | computes discounted sums along 0th dimension of x (reward), while taking into account the start of each episode. 90 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 91 | where k = len(x) - t - 1 92 | 93 | :param rewards: (np.ndarray) the input vector (rewards) 94 | :param episode_starts: (np.ndarray) 2d array of bools, indicating when a new episode has started 95 | :param gamma: (float) the discount factor 96 | :return: (np.ndarray) the output vector (discounted rewards) 97 | """ 98 | discounted_rewards = np.zeros_like(rewards) 99 | n_samples = rewards.shape[0] 100 | discounted_rewards[n_samples - 1] = rewards[n_samples - 1] 101 | for step in range(n_samples - 2, -1, -1): 102 | discounted_rewards[step] = rewards[step] + gamma * discounted_rewards[step + 1] * (1 - episode_starts[step + 1]) 103 | return discounted_rewards 104 | -------------------------------------------------------------------------------- /common/misc_util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import gym 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | def zipsame(*seqs): 9 | """ 10 | Performes a zip function, but asserts that all zipped elements are of the same size 11 | 12 | :param seqs: a list of arrays that are zipped together 13 | :return: the zipped arguments 14 | """ 15 | length = len(seqs[0]) 16 | assert all(len(seq) == length for seq in seqs[1:]) 17 | return zip(*seqs) 18 | 19 | 20 | def set_global_seeds(seed): 21 | """ 22 | set the seed for python random, tensorflow, numpy and gym spaces 23 | 24 | :param seed: (int) the seed 25 | """ 26 | tf.compat.v1.set_random_seed(seed) 27 | np.random.seed(seed) 28 | random.seed(seed) 29 | # prng was removed in latest gym version 30 | if hasattr(gym.spaces, 'prng'): 31 | gym.spaces.prng.seed(seed) 32 | 33 | 34 | def boolean_flag(parser, name, default=False, help_msg=None): 35 | """ 36 | Add a boolean flag to argparse parser. 37 | 38 | :param parser: (argparse.Parser) parser to add the flag to 39 | :param name: (str) -- will enable the flag, while --no- will disable it 40 | :param default: (bool) default value of the flag 41 | :param help_msg: (str) help string for the flag 42 | """ 43 | dest = name.replace('-', '_') 44 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help_msg) 45 | parser.add_argument("--no-" + name, action="store_false", dest=dest) 46 | 47 | 48 | def mpi_rank_or_zero(): 49 | """ 50 | Return the MPI rank if mpi is installed. Otherwise, return 0. 51 | :return: (int) 52 | """ 53 | try: 54 | from mpi4py import MPI 55 | return MPI.COMM_WORLD.Get_rank() 56 | except ImportError: 57 | return 0 58 | -------------------------------------------------------------------------------- /common/running_mean_std.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RunningMeanStd(object): 5 | def __init__(self, epsilon=1e-4, shape=()): 6 | """ 7 | calulates the running mean and std of a data stream 8 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 9 | 10 | :param epsilon: (float) helps with arithmetic issues 11 | :param shape: (tuple) the shape of the data stream's output 12 | """ 13 | self.mean = np.zeros(shape, 'float64') 14 | self.var = np.ones(shape, 'float64') 15 | self.count = epsilon 16 | 17 | def update(self, arr): 18 | batch_mean = np.mean(arr, axis=0) 19 | batch_var = np.var(arr, axis=0) 20 | batch_count = arr.shape[0] 21 | self.update_from_moments(batch_mean, batch_var, batch_count) 22 | 23 | def update_from_moments(self, batch_mean, batch_var, batch_count): 24 | delta = batch_mean - self.mean 25 | tot_count = self.count + batch_count 26 | 27 | new_mean = self.mean + delta * batch_count / tot_count 28 | m_a = self.var * self.count 29 | m_b = batch_var * batch_count 30 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 31 | new_var = m_2 / (self.count + batch_count) 32 | 33 | new_count = batch_count + self.count 34 | 35 | self.mean = new_mean 36 | self.var = new_var 37 | self.count = new_count 38 | -------------------------------------------------------------------------------- /common/save_util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from collections import OrderedDict 3 | import io 4 | import json 5 | import pickle 6 | 7 | import cloudpickle 8 | import numpy as np 9 | 10 | 11 | def is_json_serializable(item): 12 | """ 13 | Test if an object is serializable into JSON 14 | 15 | :param item: (object) The object to be tested for JSON serialization. 16 | :return: (bool) True if object is JSON serializable, false otherwise. 17 | """ 18 | # Try with try-except struct. 19 | json_serializable = True 20 | try: 21 | _ = json.dumps(item) 22 | except TypeError: 23 | json_serializable = False 24 | return json_serializable 25 | 26 | 27 | def data_to_json(data): 28 | """ 29 | Turn data (class parameters) into a JSON string for storing 30 | 31 | :param data: (Dict) Dictionary of class parameters to be 32 | stored. Items that are not JSON serializable will be 33 | pickled with Cloudpickle and stored as bytearray in 34 | the JSON file 35 | :return: (str) JSON string of the data serialized. 36 | """ 37 | # First, check what elements can not be JSONfied, 38 | # and turn them into byte-strings 39 | serializable_data = {} 40 | for data_key, data_item in data.items(): 41 | # See if object is JSON serializable 42 | if is_json_serializable(data_item): 43 | # All good, store as it is 44 | serializable_data[data_key] = data_item 45 | else: 46 | # Not serializable, cloudpickle it into 47 | # bytes and convert to base64 string for storing. 48 | # Also store type of the class for consumption 49 | # from other languages/humans, so we have an 50 | # idea what was being stored. 51 | base64_encoded = base64.b64encode( 52 | cloudpickle.dumps(data_item) 53 | ).decode() 54 | 55 | # Use ":" to make sure we do 56 | # not override these keys 57 | # when we include variables of the object later 58 | cloudpickle_serialization = { 59 | ":type:": str(type(data_item)), 60 | ":serialized:": base64_encoded 61 | } 62 | 63 | # Add first-level JSON-serializable items of the 64 | # object for further details (but not deeper than this to 65 | # avoid deep nesting). 66 | # First we check that object has attributes (not all do, 67 | # e.g. numpy scalars) 68 | if hasattr(data_item, "__dict__") or isinstance(data_item, dict): 69 | # Take elements from __dict__ for custom classes 70 | item_generator = ( 71 | data_item.items if isinstance(data_item, dict) else data_item.__dict__.items 72 | ) 73 | for variable_name, variable_item in item_generator(): 74 | # Check if serializable. If not, just include the 75 | # string-representation of the object. 76 | if is_json_serializable(variable_item): 77 | cloudpickle_serialization[variable_name] = variable_item 78 | else: 79 | cloudpickle_serialization[variable_name] = str(variable_item) 80 | 81 | serializable_data[data_key] = cloudpickle_serialization 82 | json_string = json.dumps(serializable_data, indent=4) 83 | return json_string 84 | 85 | 86 | def json_to_data(json_string, custom_objects=None): 87 | """ 88 | Turn JSON serialization of class-parameters back into dictionary. 89 | 90 | :param json_string: (str) JSON serialization of the class-parameters 91 | that should be loaded. 92 | :param custom_objects: (dict) Dictionary of objects to replace 93 | upon loading. If a variable is present in this dictionary as a 94 | key, it will not be deserialized and the corresponding item 95 | will be used instead. Similar to custom_objects in 96 | `keras.models.load_model`. Useful when you have an object in 97 | file that can not be deserialized. 98 | :return: (dict) Loaded class parameters. 99 | """ 100 | if custom_objects is not None and not isinstance(custom_objects, dict): 101 | raise ValueError("custom_objects argument must be a dict or None") 102 | 103 | json_dict = json.loads(json_string) 104 | # This will be filled with deserialized data 105 | return_data = {} 106 | for data_key, data_item in json_dict.items(): 107 | if custom_objects is not None and data_key in custom_objects.keys(): 108 | # If item is provided in custom_objects, replace 109 | # the one from JSON with the one in custom_objects 110 | return_data[data_key] = custom_objects[data_key] 111 | elif isinstance(data_item, dict) and ":serialized:" in data_item.keys(): 112 | # If item is dictionary with ":serialized:" 113 | # key, this means it is serialized with cloudpickle. 114 | serialization = data_item[":serialized:"] 115 | # Try-except deserialization in case we run into 116 | # errors. If so, we can tell bit more information to 117 | # user. 118 | try: 119 | deserialized_object = cloudpickle.loads( 120 | base64.b64decode(serialization.encode()) 121 | ) 122 | except pickle.UnpicklingError: 123 | raise RuntimeError( 124 | "Could not deserialize object {}. ".format(data_key) + 125 | "Consider using `custom_objects` argument to replace " + 126 | "this object." 127 | ) 128 | return_data[data_key] = deserialized_object 129 | else: 130 | # Read as it is 131 | return_data[data_key] = data_item 132 | return return_data 133 | 134 | 135 | def params_to_bytes(params): 136 | """ 137 | Turn params (OrderedDict of variable name -> ndarray) into 138 | serialized bytes for storing. 139 | 140 | Note: `numpy.savez` does not save the ordering. 141 | 142 | :param params: (OrderedDict) Dictionary mapping variable 143 | names to numpy arrays of the current parameters of the 144 | model. 145 | :return: (bytes) Bytes object of the serialized content. 146 | """ 147 | # Create byte-buffer and save params with 148 | # savez function, and return the bytes. 149 | byte_file = io.BytesIO() 150 | np.savez(byte_file, **params) 151 | serialized_params = byte_file.getvalue() 152 | return serialized_params 153 | 154 | 155 | def bytes_to_params(serialized_params, param_list): 156 | """ 157 | Turn serialized parameters (bytes) back into OrderedDictionary. 158 | 159 | :param serialized_params: (byte) Serialized parameters 160 | with `numpy.savez`. 161 | :param param_list: (list) List of strings, representing 162 | the order of parameters in which they should be returned 163 | :return: (OrderedDict) Dictionary mapping variable name to 164 | numpy array of the parameters. 165 | """ 166 | byte_file = io.BytesIO(serialized_params) 167 | params = np.load(byte_file) 168 | return_dictionary = OrderedDict() 169 | # Assign parameters to return_dictionary 170 | # in the order specified by param_list 171 | for param_name in param_list: 172 | return_dictionary[param_name] = params[param_name] 173 | return return_dictionary 174 | -------------------------------------------------------------------------------- /common/schedules.py: -------------------------------------------------------------------------------- 1 | """This file is used for specifying various schedules that evolve over 2 | time throughout the execution of the algorithm, such as: 3 | 4 | - learning rate for the optimizer 5 | - exploration epsilon for the epsilon greedy exploration strategy 6 | - beta parameter for beta parameter in prioritized replay 7 | 8 | Each schedule has a function `value(t)` which returns the current value 9 | of the parameter given the timestep t of the optimization procedure. 10 | """ 11 | 12 | 13 | class Schedule(object): 14 | def value(self, step): 15 | """ 16 | Value of the schedule for a given timestep 17 | 18 | :param step: (int) the timestep 19 | :return: (float) the output value for the given timestep 20 | """ 21 | raise NotImplementedError 22 | 23 | 24 | class ConstantSchedule(Schedule): 25 | """ 26 | Value remains constant over time. 27 | 28 | :param value: (float) Constant value of the schedule 29 | """ 30 | 31 | def __init__(self, value): 32 | self._value = value 33 | 34 | def value(self, step): 35 | return self._value 36 | 37 | 38 | def linear_interpolation(left, right, alpha): 39 | """ 40 | Linear interpolation between `left` and `right`. 41 | 42 | :param left: (float) left boundary 43 | :param right: (float) right boundary 44 | :param alpha: (float) coeff in [0, 1] 45 | :return: (float) 46 | """ 47 | 48 | return left + alpha * (right - left) 49 | 50 | 51 | class PiecewiseSchedule(Schedule): 52 | """ 53 | Piecewise schedule. 54 | 55 | :param endpoints: ([(int, int)]) 56 | list of pairs `(time, value)` meanining that schedule should output 57 | `value` when `t==time`. All the values for time must be sorted in 58 | an increasing order. When t is between two times, e.g. `(time_a, value_a)` 59 | and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs 60 | `interpolation(value_a, value_b, alpha)` where alpha is a fraction of 61 | time passed between `time_a` and `time_b` for time `t`. 62 | :param interpolation: (lambda (float, float, float): float) 63 | a function that takes value to the left and to the right of t according 64 | to the `endpoints`. Alpha is the fraction of distance from left endpoint to 65 | right endpoint that t has covered. See linear_interpolation for example. 66 | :param outside_value: (float) 67 | if the value is requested outside of all the intervals sepecified in 68 | `endpoints` this value is returned. If None then AssertionError is 69 | raised when outside value is requested. 70 | """ 71 | 72 | def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None): 73 | idxes = [e[0] for e in endpoints] 74 | assert idxes == sorted(idxes) 75 | self._interpolation = interpolation 76 | self._outside_value = outside_value 77 | self._endpoints = endpoints 78 | 79 | def value(self, step): 80 | for (left_t, left), (right_t, right) in zip(self._endpoints[:-1], self._endpoints[1:]): 81 | if left_t <= step < right_t: 82 | alpha = float(step - left_t) / (right_t - left_t) 83 | return self._interpolation(left, right, alpha) 84 | 85 | # t does not belong to any of the pieces, so doom. 86 | assert self._outside_value is not None 87 | return self._outside_value 88 | 89 | 90 | class LinearSchedule(Schedule): 91 | """ 92 | Linear interpolation between initial_p and final_p over 93 | schedule_timesteps. After this many timesteps pass final_p is 94 | returned. 95 | 96 | :param schedule_timesteps: (int) Number of timesteps for which to linearly anneal initial_p to final_p 97 | :param initial_p: (float) initial output value 98 | :param final_p: (float) final output value 99 | """ 100 | 101 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0): 102 | self.schedule_timesteps = schedule_timesteps 103 | self.final_p = final_p 104 | self.initial_p = initial_p 105 | 106 | def value(self, step): 107 | fraction = min(float(step) / self.schedule_timesteps, 1.0) 108 | return self.initial_p + fraction * (self.final_p - self.initial_p) 109 | -------------------------------------------------------------------------------- /common/segment_tree.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | 4 | class SegmentTree(object): 5 | def __init__(self, capacity, operation, neutral_element): 6 | """ 7 | Build a Segment Tree data structure. 8 | 9 | https://en.wikipedia.org/wiki/Segment_tree 10 | 11 | Can be used as regular array, but with two 12 | important differences: 13 | 14 | a) setting item's value is slightly slower. 15 | It is O(lg capacity) instead of O(1). 16 | b) user has access to an efficient ( O(log segment size) ) 17 | `reduce` operation which reduces `operation` over 18 | a contiguous subsequence of items in the array. 19 | 20 | :param capacity: (int) Total size of the array - must be a power of two. 21 | :param operation: (lambda (Any, Any): Any) operation for combining elements (eg. sum, max) must form a 22 | mathematical group together with the set of possible values for array elements (i.e. be associative) 23 | :param neutral_element: (Any) neutral element for the operation above. eg. float('-inf') for max and 0 for sum. 24 | """ 25 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 26 | self._capacity = capacity 27 | self._value = [neutral_element for _ in range(2 * capacity)] 28 | self._operation = operation 29 | 30 | def _reduce_helper(self, start, end, node, node_start, node_end): 31 | if start == node_start and end == node_end: 32 | return self._value[node] 33 | mid = (node_start + node_end) // 2 34 | if end <= mid: 35 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 36 | else: 37 | if mid + 1 <= start: 38 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 39 | else: 40 | return self._operation( 41 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 42 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 43 | ) 44 | 45 | def reduce(self, start=0, end=None): 46 | """ 47 | Returns result of applying `self.operation` 48 | to a contiguous subsequence of the array. 49 | 50 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 51 | 52 | :param start: (int) beginning of the subsequence 53 | :param end: (int) end of the subsequences 54 | :return: (Any) result of reducing self.operation over the specified range of array elements. 55 | """ 56 | if end is None: 57 | end = self._capacity 58 | if end < 0: 59 | end += self._capacity 60 | end -= 1 61 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 62 | 63 | def __setitem__(self, idx, val): 64 | # index of the leaf 65 | idx += self._capacity 66 | self._value[idx] = val 67 | idx //= 2 68 | while idx >= 1: 69 | self._value[idx] = self._operation( 70 | self._value[2 * idx], 71 | self._value[2 * idx + 1] 72 | ) 73 | idx //= 2 74 | 75 | def __getitem__(self, idx): 76 | assert 0 <= idx < self._capacity 77 | return self._value[self._capacity + idx] 78 | 79 | 80 | class SumSegmentTree(SegmentTree): 81 | def __init__(self, capacity): 82 | super(SumSegmentTree, self).__init__( 83 | capacity=capacity, 84 | operation=operator.add, 85 | neutral_element=0.0 86 | ) 87 | 88 | def sum(self, start=0, end=None): 89 | """ 90 | Returns arr[start] + ... + arr[end] 91 | 92 | :param start: (int) start position of the reduction (must be >= 0) 93 | :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) 94 | :return: (Any) reduction of SumSegmentTree 95 | """ 96 | return super(SumSegmentTree, self).reduce(start, end) 97 | 98 | def find_prefixsum_idx(self, prefixsum): 99 | """ 100 | Find the highest index `i` in the array such that 101 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 102 | 103 | if array values are probabilities, this function 104 | allows to sample indexes according to the discrete 105 | probability efficiently. 106 | 107 | :param prefixsum: (float) upperbound on the sum of array prefix 108 | :return: (int) highest index satisfying the prefixsum constraint 109 | """ 110 | assert 0 <= prefixsum <= self.sum() + 1e-5 111 | idx = 1 112 | while idx < self._capacity: # while non-leaf 113 | if self._value[2 * idx] > prefixsum: 114 | idx = 2 * idx 115 | else: 116 | prefixsum -= self._value[2 * idx] 117 | idx = 2 * idx + 1 118 | return idx - self._capacity 119 | 120 | 121 | class MinSegmentTree(SegmentTree): 122 | def __init__(self, capacity): 123 | super(MinSegmentTree, self).__init__( 124 | capacity=capacity, 125 | operation=min, 126 | neutral_element=float('inf') 127 | ) 128 | 129 | def min(self, start=0, end=None): 130 | """ 131 | Returns min(arr[start], ..., arr[end]) 132 | 133 | :param start: (int) start position of the reduction (must be >= 0) 134 | :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) 135 | :return: (Any) reduction of MinSegmentTree 136 | """ 137 | return super(MinSegmentTree, self).reduce(start, end) 138 | -------------------------------------------------------------------------------- /common/tile_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def tile_images(img_nhwc): 5 | """ 6 | Tile N images into one big PxQ image 7 | (P,Q) are chosen to be as close as possible, and if N 8 | is square, then P=Q. 9 | 10 | :param img_nhwc: (list) list or array of images, ndim=4 once turned into array. img nhwc 11 | n = batch index, h = height, w = width, c = channel 12 | :return: (numpy float) img_HWc, ndim=3 13 | """ 14 | img_nhwc = np.asarray(img_nhwc) 15 | n_images, height, width, n_channels = img_nhwc.shape 16 | # new_height was named H before 17 | new_height = int(np.ceil(np.sqrt(n_images))) 18 | # new_width was named W before 19 | new_width = int(np.ceil(float(n_images) / new_height)) 20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0] * 0 for _ in range(n_images, new_height * new_width)]) 21 | # img_HWhwc 22 | out_image = img_nhwc.reshape(new_height, new_width, height, width, n_channels) 23 | # img_HhWwc 24 | out_image = out_image.transpose(0, 2, 1, 3, 4) 25 | # img_Hh_Ww_c 26 | out_image = out_image.reshape(new_height * height, new_width * width, n_channels) 27 | return out_image 28 | 29 | -------------------------------------------------------------------------------- /common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa F401 2 | from common.vec_env.base_vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, \ 3 | CloudpickleWrapper 4 | from common.vec_env.dummy_vec_env import DummyVecEnv 5 | from common.vec_env.subproc_vec_env import SubprocVecEnv 6 | from common.vec_env.vec_frame_stack import VecFrameStack 7 | from common.vec_env.vec_normalize import VecNormalize 8 | from common.vec_env.vec_video_recorder import VecVideoRecorder 9 | from common.vec_env.vec_check_nan import VecCheckNan 10 | -------------------------------------------------------------------------------- /common/vec_env/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/base_vec_env.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/base_vec_env.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/base_vec_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/base_vec_env.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/base_vec_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/base_vec_env.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/dummy_vec_env.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/dummy_vec_env.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/dummy_vec_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/dummy_vec_env.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/dummy_vec_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/dummy_vec_env.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/subproc_vec_env.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/subproc_vec_env.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/subproc_vec_env.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/subproc_vec_env.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/subproc_vec_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/subproc_vec_env.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/util.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_check_nan.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_check_nan.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_check_nan.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_check_nan.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_check_nan.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_check_nan.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_frame_stack.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_frame_stack.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_frame_stack.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_frame_stack.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_frame_stack.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_frame_stack.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_normalize.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_normalize.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_normalize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_normalize.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_normalize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_normalize.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_video_recorder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_video_recorder.cpython-35.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_video_recorder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_video_recorder.cpython-36.pyc -------------------------------------------------------------------------------- /common/vec_env/__pycache__/vec_video_recorder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/common/vec_env/__pycache__/vec_video_recorder.cpython-37.pyc -------------------------------------------------------------------------------- /common/vec_env/base_vec_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import inspect 3 | import pickle 4 | 5 | import cloudpickle 6 | # from stable_baselines_tf2 import logger 7 | 8 | 9 | class AlreadySteppingError(Exception): 10 | """ 11 | Raised when an asynchronous step is running while 12 | step_async() is called again. 13 | """ 14 | 15 | def __init__(self): 16 | msg = 'already running an async step' 17 | Exception.__init__(self, msg) 18 | 19 | 20 | class NotSteppingError(Exception): 21 | """ 22 | Raised when an asynchronous step is not running but 23 | step_wait() is called. 24 | """ 25 | 26 | def __init__(self): 27 | msg = 'not running an async step' 28 | Exception.__init__(self, msg) 29 | 30 | 31 | class VecEnv(ABC): 32 | """ 33 | An abstract asynchronous, vectorized environment. 34 | 35 | :param num_envs: (int) the number of environments 36 | :param observation_space: (Gym Space) the observation space 37 | :param action_space: (Gym Space) the action space 38 | """ 39 | metadata = { 40 | 'render.modes': ['human', 'rgb_array'] 41 | } 42 | 43 | def __init__(self, num_envs, observation_space, action_space): 44 | self.num_envs = num_envs 45 | self.observation_space = observation_space 46 | self.action_space = action_space 47 | 48 | @abstractmethod 49 | def reset(self): 50 | """ 51 | Reset all the environments and return an array of 52 | observations, or a tuple of observation arrays. 53 | 54 | If step_async is still doing work, that work will 55 | be cancelled and step_wait() should not be called 56 | until step_async() is invoked again. 57 | 58 | :return: ([int] or [float]) observation 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def step_async(self, actions): 64 | """ 65 | Tell all the environments to start taking a step 66 | with the given actions. 67 | Call step_wait() to get the results of the step. 68 | 69 | You should not call this if a step_async run is 70 | already pending. 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def step_wait(self): 76 | """ 77 | Wait for the step taken with step_async(). 78 | 79 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 80 | """ 81 | pass 82 | 83 | @abstractmethod 84 | def close(self): 85 | """ 86 | Clean up the environment's resources. 87 | """ 88 | pass 89 | 90 | @abstractmethod 91 | def get_attr(self, attr_name, indices=None): 92 | """ 93 | Return attribute from vectorized environment. 94 | 95 | :param attr_name: (str) The name of the attribute whose value to return 96 | :param indices: (list,int) Indices of envs to get attribute from 97 | :return: (list) List of values of 'attr_name' in all environments 98 | """ 99 | pass 100 | 101 | @abstractmethod 102 | def set_attr(self, attr_name, value, indices=None): 103 | """ 104 | Set attribute inside vectorized environments. 105 | 106 | :param attr_name: (str) The name of attribute to assign new value 107 | :param value: (obj) Value to assign to `attr_name` 108 | :param indices: (list,int) Indices of envs to assign value 109 | :return: (NoneType) 110 | """ 111 | pass 112 | 113 | @abstractmethod 114 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 115 | """ 116 | Call instance methods of vectorized environments. 117 | 118 | :param method_name: (str) The name of the environment method to invoke. 119 | :param indices: (list,int) Indices of envs whose method to call 120 | :param method_args: (tuple) Any positional arguments to provide in the call 121 | :param method_kwargs: (dict) Any keyword arguments to provide in the call 122 | :return: (list) List of items returned by the environment's method call 123 | """ 124 | pass 125 | 126 | def step(self, actions): 127 | """ 128 | Step the environments with the given action 129 | 130 | :param actions: ([int] or [float]) the action 131 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 132 | """ 133 | self.step_async(actions) 134 | return self.step_wait() 135 | 136 | def get_images(self): 137 | """ 138 | Return RGB images from each environment 139 | """ 140 | raise NotImplementedError 141 | 142 | def render(self, *args, **kwargs): 143 | """ 144 | Gym environment rendering 145 | 146 | :param mode: (str) the rendering type 147 | """ 148 | # logger.warn('Render not defined for %s' % self) 149 | 150 | @property 151 | def unwrapped(self): 152 | if isinstance(self, VecEnvWrapper): 153 | return self.venv.unwrapped 154 | else: 155 | return self 156 | 157 | def getattr_depth_check(self, name, already_found): 158 | """Check if an attribute reference is being hidden in a recursive call to __getattr__ 159 | 160 | :param name: (str) name of attribute to check for 161 | :param already_found: (bool) whether this attribute has already been found in a wrapper 162 | :return: (str or None) name of module whose attribute is being shadowed, if any. 163 | """ 164 | if hasattr(self, name) and already_found: 165 | return "{0}.{1}".format(type(self).__module__, type(self).__name__) 166 | else: 167 | return None 168 | 169 | def _get_indices(self, indices): 170 | """ 171 | Convert a flexibly-typed reference to environment indices to an implied list of indices. 172 | 173 | :param indices: (None,int,Iterable) refers to indices of envs. 174 | :return: (list) the implied list of indices. 175 | """ 176 | if indices is None: 177 | indices = range(self.num_envs) 178 | elif isinstance(indices, int): 179 | indices = [indices] 180 | return indices 181 | 182 | 183 | class VecEnvWrapper(VecEnv): 184 | """ 185 | Vectorized environment base class 186 | 187 | :param venv: (VecEnv) the vectorized environment to wrap 188 | :param observation_space: (Gym Space) the observation space (can be None to load from venv) 189 | :param action_space: (Gym Space) the action space (can be None to load from venv) 190 | """ 191 | 192 | def __init__(self, venv, observation_space=None, action_space=None): 193 | self.venv = venv 194 | VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, 195 | action_space=action_space or venv.action_space) 196 | self.class_attributes = dict(inspect.getmembers(self.__class__)) 197 | 198 | def step_async(self, actions): 199 | self.venv.step_async(actions) 200 | 201 | @abstractmethod 202 | def reset(self): 203 | pass 204 | 205 | @abstractmethod 206 | def step_wait(self): 207 | pass 208 | 209 | def close(self): 210 | return self.venv.close() 211 | 212 | def render(self, *args, **kwargs): 213 | return self.venv.render(*args, **kwargs) 214 | 215 | def get_images(self): 216 | return self.venv.get_images() 217 | 218 | def get_attr(self, attr_name, indices=None): 219 | return self.venv.get_attr(attr_name, indices) 220 | 221 | def set_attr(self, attr_name, value, indices=None): 222 | return self.venv.set_attr(attr_name, value, indices) 223 | 224 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 225 | return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) 226 | 227 | def __getattr__(self, name): 228 | """Find attribute from wrapped venv(s) if this wrapper does not have it. 229 | Useful for accessing attributes from venvs which are wrapped with multiple wrappers 230 | which have unique attributes of interest. 231 | """ 232 | blocked_class = self.getattr_depth_check(name, already_found=False) 233 | if blocked_class is not None: 234 | own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) 235 | format_str = ("Error: Recursive attribute lookup for {0} from {1} is " 236 | "ambiguous and hides attribute from {2}") 237 | raise AttributeError(format_str.format(name, own_class, blocked_class)) 238 | 239 | return self.getattr_recursive(name) 240 | 241 | def _get_all_attributes(self): 242 | """Get all (inherited) instance and class attributes 243 | 244 | :return: (dict) all_attributes 245 | """ 246 | all_attributes = self.__dict__.copy() 247 | all_attributes.update(self.class_attributes) 248 | return all_attributes 249 | 250 | def getattr_recursive(self, name): 251 | """Recursively check wrappers to find attribute. 252 | 253 | :param name (str) name of attribute to look for 254 | :return: (object) attribute 255 | """ 256 | all_attributes = self._get_all_attributes() 257 | if name in all_attributes: # attribute is present in this wrapper 258 | attr = getattr(self, name) 259 | elif hasattr(self.venv, 'getattr_recursive'): 260 | # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr 261 | # to avoid a duplicate call to getattr_depth_check. 262 | attr = self.venv.getattr_recursive(name) 263 | else: # attribute not present, child is an unwrapped VecEnv 264 | attr = getattr(self.venv, name) 265 | 266 | return attr 267 | 268 | def getattr_depth_check(self, name, already_found): 269 | """See base class. 270 | 271 | :return: (str or None) name of module whose attribute is being shadowed, if any. 272 | """ 273 | all_attributes = self._get_all_attributes() 274 | if name in all_attributes and already_found: 275 | # this venv's attribute is being hidden because of a higher venv. 276 | shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) 277 | elif name in all_attributes and not already_found: 278 | # we have found the first reference to the attribute. Now check for duplicates. 279 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) 280 | else: 281 | # this wrapper does not have the attribute. Keep searching. 282 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) 283 | 284 | return shadowed_wrapper_class 285 | 286 | 287 | class CloudpickleWrapper(object): 288 | def __init__(self, var): 289 | """ 290 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 291 | 292 | :param var: (Any) the variable you wish to wrap for pickling with cloudpickle 293 | """ 294 | self.var = var 295 | 296 | def __getstate__(self): 297 | return cloudpickle.dumps(self.var) 298 | 299 | def __setstate__(self, obs): 300 | self.var = pickle.loads(obs) 301 | -------------------------------------------------------------------------------- /common/vec_env/dummy_vec_env.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | 4 | from common.vec_env import VecEnv 5 | from common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info 6 | 7 | 8 | class DummyVecEnv(VecEnv): 9 | """ 10 | Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current 11 | Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of 12 | multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that 13 | require a vectorized environment, but that you want a single environments to train with. 14 | 15 | :param env_fns: ([Gym Environment]) the list of environments to vectorize 16 | """ 17 | 18 | def __init__(self, env_fns): 19 | self.envs = [fn() for fn in env_fns] 20 | env = self.envs[0] 21 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 22 | obs_space = env.observation_space 23 | self.keys, shapes, dtypes = obs_space_info(obs_space) 24 | 25 | self.buf_obs = OrderedDict([ 26 | (k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) 27 | for k in self.keys]) 28 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) 29 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) 30 | self.buf_infos = [{} for _ in range(self.num_envs)] 31 | self.actions = None 32 | self.metadata = env.metadata 33 | 34 | def step_async(self, actions): 35 | self.actions = actions 36 | 37 | def step_wait(self): 38 | for env_idx in range(self.num_envs): 39 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\ 40 | self.envs[env_idx].step(self.actions[env_idx]) 41 | if self.buf_dones[env_idx]: 42 | # save final observation where user can get it, then reset 43 | self.buf_infos[env_idx]['terminal_observation'] = obs 44 | obs = self.envs[env_idx].reset() 45 | self._save_obs(env_idx, obs) 46 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), 47 | self.buf_infos.copy()) 48 | 49 | def reset(self): 50 | for env_idx in range(self.num_envs): 51 | obs = self.envs[env_idx].reset() 52 | self._save_obs(env_idx, obs) 53 | return self._obs_from_buf() 54 | 55 | def close(self): 56 | for env in self.envs: 57 | env.close() 58 | 59 | def get_images(self): 60 | return [env.render(mode='rgb_array') for env in self.envs] 61 | 62 | def render(self, *args, **kwargs): 63 | if self.num_envs == 1: 64 | return self.envs[0].render(*args, **kwargs) 65 | else: 66 | return super().render(*args, **kwargs) 67 | 68 | def _save_obs(self, env_idx, obs): 69 | for key in self.keys: 70 | if key is None: 71 | self.buf_obs[key][env_idx] = obs 72 | else: 73 | self.buf_obs[key][env_idx] = obs[key] 74 | 75 | def _obs_from_buf(self): 76 | return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) 77 | 78 | def get_attr(self, attr_name, indices=None): 79 | """Return attribute from vectorized environment (see base class).""" 80 | target_envs = self._get_target_envs(indices) 81 | return [getattr(env_i, attr_name) for env_i in target_envs] 82 | 83 | def set_attr(self, attr_name, value, indices=None): 84 | """Set attribute inside vectorized environments (see base class).""" 85 | target_envs = self._get_target_envs(indices) 86 | for env_i in target_envs: 87 | setattr(env_i, attr_name, value) 88 | 89 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 90 | """Call instance methods of vectorized environments.""" 91 | target_envs = self._get_target_envs(indices) 92 | return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] 93 | 94 | def _get_target_envs(self, indices): 95 | indices = self._get_indices(indices) 96 | return [self.envs[i] for i in indices] 97 | -------------------------------------------------------------------------------- /common/vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from collections import OrderedDict 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from common.vec_env import VecEnv, CloudpickleWrapper 8 | from common.tile_images import tile_images 9 | 10 | 11 | def _worker(remote, parent_remote, env_fn_wrapper): 12 | parent_remote.close() 13 | env = env_fn_wrapper.var() 14 | while True: 15 | try: 16 | cmd, data = remote.recv() 17 | if cmd == 'step': 18 | observation, reward, done, info = env.step(data) 19 | if done: 20 | # save final observation where user can get it, then reset 21 | info['terminal_observation'] = observation 22 | observation = env.reset() 23 | remote.send((observation, reward, done, info)) 24 | elif cmd == 'reset': 25 | observation = env.reset() 26 | remote.send(observation) 27 | elif cmd == 'render': 28 | remote.send(env.render(*data[0], **data[1])) 29 | elif cmd == 'close': 30 | remote.close() 31 | break 32 | elif cmd == 'get_spaces': 33 | remote.send((env.observation_space, env.action_space)) 34 | elif cmd == 'env_method': 35 | method = getattr(env, data[0]) 36 | remote.send(method(*data[1], **data[2])) 37 | elif cmd == 'get_attr': 38 | remote.send(getattr(env, data)) 39 | elif cmd == 'set_attr': 40 | remote.send(setattr(env, data[0], data[1])) 41 | else: 42 | raise NotImplementedError 43 | except EOFError: 44 | break 45 | 46 | 47 | class SubprocVecEnv(VecEnv): 48 | """ 49 | Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own 50 | process, allowing significant speed up when the environment is computationally complex. 51 | 52 | For performance reasons, if your environment is not IO bound, the number of environments should not exceed the 53 | number of logical cores on your CPU. 54 | 55 | .. warning:: 56 | 57 | Only 'forkserver' and 'spawn' start methods are thread-safe, 58 | which is important when TensorFlow sessions or other non thread-safe 59 | libraries are used in the parent (see issue #217). However, compared to 60 | 'fork' they incur a small start-up cost and have restrictions on 61 | global variables. With those methods, users must wrap the code in an 62 | ``if __name__ == "__main__":`` block. 63 | For more information, see the multiprocessing documentation. 64 | 65 | :param env_fns: ([Gym Environment]) Environments to run in subprocesses 66 | :param start_method: (str) method used to start the subprocesses. 67 | Must be one of the methods returned by multiprocessing.get_all_start_methods(). 68 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. 69 | """ 70 | 71 | def __init__(self, env_fns, start_method=None): 72 | self.waiting = False 73 | self.closed = False 74 | n_envs = len(env_fns) 75 | 76 | if start_method is None: 77 | # Fork is not a thread safe method (see issue #217) 78 | # but is more user friendly (does not require to wrap the code in 79 | # a `if __name__ == "__main__":`) 80 | forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods() 81 | start_method = 'forkserver' if forkserver_available else 'spawn' 82 | ctx = multiprocessing.get_context(start_method) 83 | 84 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) 85 | self.processes = [] 86 | for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): 87 | args = (work_remote, remote, CloudpickleWrapper(env_fn)) 88 | # daemon=True: if the main process crashes, we should not cause things to hang 89 | process = ctx.Process(target=_worker, args=args, daemon=True) 90 | process.start() 91 | self.processes.append(process) 92 | work_remote.close() 93 | 94 | self.remotes[0].send(('get_spaces', None)) 95 | observation_space, action_space = self.remotes[0].recv() 96 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 97 | 98 | def step_async(self, actions): 99 | for remote, action in zip(self.remotes, actions): 100 | remote.send(('step', action)) 101 | self.waiting = True 102 | 103 | def step_wait(self): 104 | results = [remote.recv() for remote in self.remotes] 105 | self.waiting = False 106 | obs, rews, dones, infos = zip(*results) 107 | return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos 108 | 109 | def reset(self): 110 | for remote in self.remotes: 111 | remote.send(('reset', None)) 112 | obs = [remote.recv() for remote in self.remotes] 113 | return _flatten_obs(obs, self.observation_space) 114 | 115 | def close(self): 116 | if self.closed: 117 | return 118 | if self.waiting: 119 | for remote in self.remotes: 120 | remote.recv() 121 | for remote in self.remotes: 122 | remote.send(('close', None)) 123 | for process in self.processes: 124 | process.join() 125 | self.closed = True 126 | 127 | def render(self, mode='human', *args, **kwargs): 128 | for pipe in self.remotes: 129 | # gather images from subprocesses 130 | # `mode` will be taken into account later 131 | pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs}))) 132 | imgs = [pipe.recv() for pipe in self.remotes] 133 | # Create a big image by tiling images from subprocesses 134 | bigimg = tile_images(imgs) 135 | if mode == 'human': 136 | import cv2 137 | cv2.imshow('vecenv', bigimg[:, :, ::-1]) 138 | cv2.waitKey(1) 139 | elif mode == 'rgb_array': 140 | return bigimg 141 | else: 142 | raise NotImplementedError 143 | 144 | def get_images(self): 145 | for pipe in self.remotes: 146 | pipe.send(('render', {"mode": 'rgb_array'})) 147 | imgs = [pipe.recv() for pipe in self.remotes] 148 | return imgs 149 | 150 | def get_attr(self, attr_name, indices=None): 151 | """Return attribute from vectorized environment (see base class).""" 152 | target_remotes = self._get_target_remotes(indices) 153 | for remote in target_remotes: 154 | remote.send(('get_attr', attr_name)) 155 | return [remote.recv() for remote in target_remotes] 156 | 157 | def set_attr(self, attr_name, value, indices=None): 158 | """Set attribute inside vectorized environments (see base class).""" 159 | target_remotes = self._get_target_remotes(indices) 160 | for remote in target_remotes: 161 | remote.send(('set_attr', (attr_name, value))) 162 | for remote in target_remotes: 163 | remote.recv() 164 | 165 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 166 | """Call instance methods of vectorized environments.""" 167 | target_remotes = self._get_target_remotes(indices) 168 | for remote in target_remotes: 169 | remote.send(('env_method', (method_name, method_args, method_kwargs))) 170 | return [remote.recv() for remote in target_remotes] 171 | 172 | def _get_target_remotes(self, indices): 173 | """ 174 | Get the connection object needed to communicate with the wanted 175 | envs that are in subprocesses. 176 | 177 | :param indices: (None,int,Iterable) refers to indices of envs. 178 | :return: ([multiprocessing.Connection]) Connection object to communicate between processes. 179 | """ 180 | indices = self._get_indices(indices) 181 | return [self.remotes[i] for i in indices] 182 | 183 | 184 | def _flatten_obs(obs, space): 185 | """ 186 | Flatten observations, depending on the observation space. 187 | 188 | :param obs: (list or tuple where X is dict, tuple or ndarray) observations. 189 | A list or tuple of observations, one per environment. 190 | Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. 191 | :return (OrderedDict, tuple or ndarray) flattened observations. 192 | A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. 193 | Each NumPy array has the environment index as its first axis. 194 | """ 195 | assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" 196 | assert len(obs) > 0, "need observations from at least one environment" 197 | 198 | if isinstance(space, gym.spaces.Dict): 199 | assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" 200 | assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" 201 | return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) 202 | elif isinstance(space, gym.spaces.Tuple): 203 | assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" 204 | obs_len = len(space.spaces) 205 | return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) 206 | else: 207 | return np.stack(obs) 208 | -------------------------------------------------------------------------------- /common/vec_env/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for dealing with vectorized environments. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | import gym 8 | import numpy as np 9 | 10 | 11 | def copy_obs_dict(obs): 12 | """ 13 | Deep-copy a dict of numpy arrays. 14 | 15 | :param obs: (OrderedDict): a dict of numpy arrays. 16 | :return (OrderedDict) a dict of copied numpy arrays. 17 | """ 18 | assert isinstance(obs, OrderedDict), "unexpected type for observations '{}'".format(type(obs)) 19 | return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) 20 | 21 | 22 | def dict_to_obs(space, obs_dict): 23 | """ 24 | Convert an internal representation raw_obs into the appropriate type 25 | specified by space. 26 | 27 | :param space: (gym.spaces.Space) an observation space. 28 | :param obs_dict: (OrderedDict) a dict of numpy arrays. 29 | :return (ndarray, tuple or dict): returns an observation 30 | of the same type as space. If space is Dict, function is identity; 31 | if space is Tuple, converts dict to Tuple; otherwise, space is 32 | unstructured and returns the value raw_obs[None]. 33 | """ 34 | if isinstance(space, gym.spaces.Dict): 35 | return obs_dict 36 | elif isinstance(space, gym.spaces.Tuple): 37 | assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space" 38 | return tuple((obs_dict[i] for i in range(len(space.spaces)))) 39 | else: 40 | assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" 41 | return obs_dict[None] 42 | 43 | 44 | def obs_space_info(obs_space): 45 | """ 46 | Get dict-structured information about a gym.Space. 47 | 48 | Dict spaces are represented directly by their dict of subspaces. 49 | Tuple spaces are converted into a dict with keys indexing into the tuple. 50 | Unstructured spaces are represented by {None: obs_space}. 51 | 52 | :param obs_space: (gym.spaces.Space) an observation space 53 | :return (tuple) A tuple (keys, shapes, dtypes): 54 | keys: a list of dict keys. 55 | shapes: a dict mapping keys to shapes. 56 | dtypes: a dict mapping keys to dtypes. 57 | """ 58 | if isinstance(obs_space, gym.spaces.Dict): 59 | assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" 60 | subspaces = obs_space.spaces 61 | elif isinstance(obs_space, gym.spaces.Tuple): 62 | subspaces = {i: space for i, space in enumerate(obs_space.spaces)} 63 | else: 64 | assert not hasattr(obs_space, 'spaces'), "Unsupported structured space '{}'".format(type(obs_space)) 65 | subspaces = {None: obs_space} 66 | keys = [] 67 | shapes = {} 68 | dtypes = {} 69 | for key, box in subspaces.items(): 70 | keys.append(key) 71 | shapes[key] = box.shape 72 | dtypes[key] = box.dtype 73 | return keys, shapes, dtypes 74 | -------------------------------------------------------------------------------- /common/vec_env/vec_check_nan.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from common.vec_env import VecEnvWrapper 6 | 7 | 8 | class VecCheckNan(VecEnvWrapper): 9 | """ 10 | NaN and inf checking wrapper for vectorized environment, will raise a warning by default, 11 | allowing you to know from what the NaN of inf originated from. 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param raise_exception: (bool) Whether or not to raise a ValueError, instead of a UserWarning 15 | :param warn_once: (bool) Whether or not to only warn once. 16 | :param check_inf: (bool) Whether or not to check for +inf or -inf as well 17 | """ 18 | 19 | def __init__(self, venv, raise_exception=False, warn_once=True, check_inf=True): 20 | VecEnvWrapper.__init__(self, venv) 21 | self.raise_exception = raise_exception 22 | self.warn_once = warn_once 23 | self.check_inf = check_inf 24 | self._actions = None 25 | self._observations = None 26 | self._user_warned = False 27 | 28 | def step_async(self, actions): 29 | self._check_val(async_step=True, actions=actions) 30 | 31 | self._actions = actions 32 | self.venv.step_async(actions) 33 | 34 | def step_wait(self): 35 | observations, rewards, news, infos = self.venv.step_wait() 36 | 37 | self._check_val(async_step=False, observations=observations, rewards=rewards, news=news) 38 | 39 | self._observations = observations 40 | return observations, rewards, news, infos 41 | 42 | def reset(self): 43 | observations = self.venv.reset() 44 | self._actions = None 45 | 46 | self._check_val(async_step=False, observations=observations) 47 | 48 | self._observations = observations 49 | return observations 50 | 51 | def _check_val(self, *, async_step, **kwargs): 52 | # if warn and warn once and have warned once: then stop checking 53 | if not self.raise_exception and self.warn_once and self._user_warned: 54 | return 55 | 56 | found = [] 57 | for name, val in kwargs.items(): 58 | has_nan = np.any(np.isnan(val)) 59 | has_inf = self.check_inf and np.any(np.isinf(val)) 60 | if has_inf: 61 | found.append((name, "inf")) 62 | if has_nan: 63 | found.append((name, "nan")) 64 | 65 | if found: 66 | self._user_warned = True 67 | msg = "" 68 | for i, (name, type_val) in enumerate(found): 69 | msg += "found {} in {}".format(type_val, name) 70 | if i != len(found) - 1: 71 | msg += ", " 72 | 73 | msg += ".\r\nOriginated from the " 74 | 75 | if not async_step: 76 | if self._actions is None: 77 | msg += "environment observation (at reset)" 78 | else: 79 | msg += "environment, Last given value was: \r\n\taction={}".format(self._actions) 80 | else: 81 | msg += "RL model, Last given value was: \r\n\tobservations={}".format(self._observations) 82 | 83 | if self.raise_exception: 84 | raise ValueError(msg) 85 | else: 86 | warnings.warn(msg, UserWarning) 87 | -------------------------------------------------------------------------------- /common/vec_env/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from common.vec_env import VecEnvWrapper 7 | 8 | 9 | class VecFrameStack(VecEnvWrapper): 10 | """ 11 | Frame stacking wrapper for vectorized environment 12 | 13 | :param venv: (VecEnv) the vectorized environment to wrap 14 | :param n_stack: (int) Number of frames to stack 15 | """ 16 | 17 | def __init__(self, venv, n_stack): 18 | self.venv = venv 19 | self.n_stack = n_stack 20 | wrapped_obs_space = venv.observation_space 21 | low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1) 22 | high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1) 23 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 24 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 25 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 26 | 27 | def step_wait(self): 28 | observations, rewards, dones, infos = self.venv.step_wait() 29 | last_ax_size = observations.shape[-1] 30 | self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) 31 | for i, done in enumerate(dones): 32 | if done: 33 | if 'terminal_observation' in infos[i]: 34 | old_terminal = infos[i]['terminal_observation'] 35 | new_terminal = np.concatenate( 36 | (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) 37 | infos[i]['terminal_observation'] = new_terminal 38 | else: 39 | warnings.warn( 40 | "VecFrameStack wrapping a VecEnv without terminal_observation info") 41 | self.stackedobs[i] = 0 42 | self.stackedobs[..., -observations.shape[-1]:] = observations 43 | return self.stackedobs, rewards, dones, infos 44 | 45 | def reset(self): 46 | """ 47 | Reset all environments 48 | """ 49 | obs = self.venv.reset() 50 | self.stackedobs[...] = 0 51 | self.stackedobs[..., -obs.shape[-1]:] = obs 52 | return self.stackedobs 53 | 54 | def close(self): 55 | self.venv.close() 56 | -------------------------------------------------------------------------------- /common/vec_env/vec_normalize.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | 5 | from common.vec_env import VecEnvWrapper 6 | from common.running_mean_std import RunningMeanStd 7 | 8 | 9 | class VecNormalize(VecEnvWrapper): 10 | """ 11 | A moving average, normalizing wrapper for vectorized environment. 12 | has support for saving/loading moving average, 13 | 14 | :param venv: (VecEnv) the vectorized environment to wrap 15 | :param training: (bool) Whether to update or not the moving average 16 | :param norm_obs: (bool) Whether to normalize observation or not (default: True) 17 | :param norm_reward: (bool) Whether to normalize rewards or not (default: True) 18 | :param clip_obs: (float) Max absolute value for observation 19 | :param clip_reward: (float) Max value absolute for discounted reward 20 | :param gamma: (float) discount factor 21 | :param epsilon: (float) To avoid division by zero 22 | """ 23 | 24 | def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, 25 | clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): 26 | VecEnvWrapper.__init__(self, venv) 27 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) 28 | self.ret_rms = RunningMeanStd(shape=()) 29 | self.clip_obs = clip_obs 30 | self.clip_reward = clip_reward 31 | # Returns: discounted rewards 32 | self.ret = np.zeros(self.num_envs) 33 | self.gamma = gamma 34 | self.epsilon = epsilon 35 | self.training = training 36 | self.norm_obs = norm_obs 37 | self.norm_reward = norm_reward 38 | self.old_obs = np.array([]) 39 | 40 | def step_wait(self): 41 | """ 42 | Apply sequence of actions to sequence of environments 43 | actions -> (observations, rewards, news) 44 | 45 | where 'news' is a boolean vector indicating whether each element is new. 46 | """ 47 | obs, rews, news, infos = self.venv.step_wait() 48 | self.ret = self.ret * self.gamma + rews 49 | self.old_obs = obs 50 | obs = self._normalize_observation(obs) 51 | if self.norm_reward: 52 | if self.training: 53 | self.ret_rms.update(self.ret) 54 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) 55 | self.ret[news] = 0 56 | return obs, rews, news, infos 57 | 58 | def _normalize_observation(self, obs): 59 | """ 60 | :param obs: (numpy tensor) 61 | """ 62 | if self.norm_obs: 63 | if self.training: 64 | self.obs_rms.update(obs) 65 | obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, 66 | self.clip_obs) 67 | return obs 68 | else: 69 | return obs 70 | 71 | def get_original_obs(self): 72 | """ 73 | returns the unnormalized observation 74 | 75 | :return: (numpy float) 76 | """ 77 | return self.old_obs 78 | 79 | def reset(self): 80 | """ 81 | Reset all environments 82 | """ 83 | obs = self.venv.reset() 84 | if len(np.array(obs).shape) == 1: # for when num_cpu is 1 85 | self.old_obs = [obs] 86 | else: 87 | self.old_obs = obs 88 | self.ret = np.zeros(self.num_envs) 89 | return self._normalize_observation(obs) 90 | 91 | def save_running_average(self, path): 92 | """ 93 | :param path: (str) path to log dir 94 | """ 95 | for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']): 96 | with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: 97 | pickle.dump(rms, file_handler) 98 | 99 | def load_running_average(self, path): 100 | """ 101 | :param path: (str) path to log dir 102 | """ 103 | for name in ['obs_rms', 'ret_rms']: 104 | with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: 105 | setattr(self, name, pickle.load(file_handler)) 106 | -------------------------------------------------------------------------------- /common/vec_env/vec_video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from gym.wrappers.monitoring import video_recorder 4 | 5 | # from stable_baselines_tf2 import logger 6 | from common.vec_env import VecEnvWrapper, DummyVecEnv, VecNormalize, VecFrameStack, SubprocVecEnv 7 | 8 | 9 | class VecVideoRecorder(VecEnvWrapper): 10 | """ 11 | Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. 12 | It requires ffmpeg or avconv to be installed on the machine. 13 | 14 | :param venv: (VecEnv or VecEnvWrapper) 15 | :param video_folder: (str) Where to save videos 16 | :param record_video_trigger: (func) Function that defines when to start recording. 17 | The function takes the current number of step, 18 | and returns whether we should start recording or not. 19 | :param video_length: (int) Length of recorded videos 20 | :param name_prefix: (str) Prefix to the video name 21 | """ 22 | 23 | def __init__(self, venv, video_folder, record_video_trigger, 24 | video_length=200, name_prefix='rl-video'): 25 | 26 | VecEnvWrapper.__init__(self, venv) 27 | 28 | self.env = venv 29 | # Temp variable to retrieve metadata 30 | temp_env = venv 31 | 32 | # Unwrap to retrieve metadata dict 33 | # that will be used by gym recorder 34 | while isinstance(temp_env, VecNormalize) or isinstance(temp_env, VecFrameStack): 35 | temp_env = temp_env.venv 36 | 37 | if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): 38 | metadata = temp_env.get_attr('metadata')[0] 39 | else: 40 | metadata = temp_env.metadata 41 | 42 | self.env.metadata = metadata 43 | 44 | self.record_video_trigger = record_video_trigger 45 | self.video_recorder = None 46 | 47 | self.video_folder = os.path.abspath(video_folder) 48 | # Create output folder if needed 49 | os.makedirs(self.video_folder, exist_ok=True) 50 | 51 | self.name_prefix = name_prefix 52 | self.step_id = 0 53 | self.video_length = video_length 54 | 55 | self.recording = False 56 | self.recorded_frames = 0 57 | 58 | def reset(self): 59 | obs = self.venv.reset() 60 | self.start_video_recorder() 61 | return obs 62 | 63 | def start_video_recorder(self): 64 | self.close_video_recorder() 65 | 66 | video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id, 67 | self.step_id + self.video_length) 68 | base_path = os.path.join(self.video_folder, video_name) 69 | self.video_recorder = video_recorder.VideoRecorder( 70 | env=self.env, 71 | base_path=base_path, 72 | metadata={'step_id': self.step_id} 73 | ) 74 | 75 | self.video_recorder.capture_frame() 76 | self.recorded_frames = 1 77 | self.recording = True 78 | 79 | def _video_enabled(self): 80 | return self.record_video_trigger(self.step_id) 81 | 82 | def step_wait(self): 83 | obs, rews, dones, infos = self.venv.step_wait() 84 | 85 | self.step_id += 1 86 | if self.recording: 87 | self.video_recorder.capture_frame() 88 | self.recorded_frames += 1 89 | if self.recorded_frames > self.video_length: 90 | # logger.info("Saving video to ", self.video_recorder.path) 91 | self.close_video_recorder() 92 | elif self._video_enabled(): 93 | self.start_video_recorder() 94 | 95 | return obs, rews, dones, infos 96 | 97 | def close_video_recorder(self): 98 | if self.recording: 99 | self.video_recorder.close() 100 | self.recording = False 101 | self.recorded_frames = 1 102 | 103 | def close(self): 104 | VecEnvWrapper.close(self) 105 | self.close_video_recorder() 106 | 107 | def __del__(self): 108 | self.close() 109 | -------------------------------------------------------------------------------- /ddpg/ddpg.py: -------------------------------------------------------------------------------- 1 | # This implementation is based on codes of Jongmin Lee & Byeong-jun Lee 2 | import time 3 | import random 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow_probability as tfp 7 | from tqdm import tqdm 8 | import pickle 9 | from base.rl import ActorCriticRLAlgorithm 10 | 11 | class Actor(tf.keras.layers.Layer): 12 | 13 | def __init__(self, env): 14 | super(Actor, self).__init__() 15 | assert NotImplementedError 16 | 17 | # @tf.function 18 | def call(self, inputs, **kwargs): 19 | assert NotImplementedError 20 | 21 | def step(self, obs, deterministic=False): 22 | assert NotImplementedError 23 | 24 | 25 | class VNetwork(tf.keras.layers.Layer): 26 | 27 | def __init__(self, obs_shape, output_dim=1): 28 | super(VNetwork, self).__init__() 29 | assert NotImplementedError 30 | 31 | # @tf.function 32 | def call(self, inputs, **kwargs): 33 | assert NotImplementedError 34 | 35 | 36 | class QNetwork(tf.keras.layers.Layer): 37 | 38 | def __init__(self, obs_shape): 39 | super(QNetwork, self).__init__() 40 | assert NotImplementedError 41 | 42 | # @tf.function 43 | def call(self, inputs, **kwargs): 44 | assert NotImplementedError 45 | 46 | 47 | class DDPG(ActorCriticRLAlgorithm): 48 | 49 | def __init__(self, env, ent_coef='auto', seed=0): 50 | super(DDPG, self).__init__() 51 | assert NotImplementedError 52 | 53 | def update_target(self): 54 | assert NotImplementedError 55 | # for target, source in zip(sel f.target_params, self.source_params): 56 | # target.set_weights( (1 - self.tau) * target.get_weights() + self.tau * source.get_weights() ) 57 | 58 | # @tf.function 59 | def train(self, obs, action, reward, next_obs, done): 60 | # Casting from float64 to float32 61 | assert NotImplementedError 62 | 63 | def learn(self, total_timesteps, log_interval=2, seed=0, callback=None, verbose=1): 64 | assert NotImplementedError 65 | 66 | def predict(self, obs, deterministic=False): 67 | obs_rank = len(obs.shape) 68 | if len(obs.shape) == 1: 69 | obs = np.array([obs]) 70 | assert len(obs.shape) == 2 71 | 72 | action = self.actor.step(obs) 73 | 74 | if obs_rank == 1: 75 | return action[0], None 76 | 77 | else: 78 | return action, None 79 | 80 | -------------------------------------------------------------------------------- /dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from dqn.dqn import DQN, ReplayBuffer 2 | from dqn.policy import MlpPolicy 3 | 4 | 5 | def wrap_atari_dqn(env): 6 | """ 7 | wrap the environment in atari wrappers for DQN 8 | :param env: (Gym Environment) the environment 9 | :return: (Gym Environment) the wrapped environment 10 | """ 11 | from common.atari_wrappers import wrap_deepmind 12 | return wrap_deepmind(env, frame_stack=True, scale=False) -------------------------------------------------------------------------------- /dqn/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/dqn/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /dqn/__pycache__/dqn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/dqn/__pycache__/dqn.cpython-35.pyc -------------------------------------------------------------------------------- /dqn/__pycache__/dqn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/dqn/__pycache__/dqn.cpython-37.pyc -------------------------------------------------------------------------------- /dqn/__pycache__/policy.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/dqn/__pycache__/policy.cpython-35.pyc -------------------------------------------------------------------------------- /dqn/__pycache__/policy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/dqn/__pycache__/policy.cpython-37.pyc -------------------------------------------------------------------------------- /dqn/dqn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from gym.spaces import MultiDiscrete, Box 4 | 5 | from tqdm import tqdm 6 | from functools import partial 7 | from common import tf_util, LinearSchedule 8 | from common.vec_env import VecEnv 9 | 10 | from base.rl import ValueBasedRLAlgorithm 11 | from base.replay_buffer import ReplayBuffer, PrioritizedReplayBuffer 12 | 13 | import copy 14 | import gym 15 | 16 | # For Save/Load 17 | import os 18 | import pickle 19 | import cloudpickle 20 | import json 21 | import zipfile 22 | from common.save_util import params_to_bytes 23 | from common.save_util import data_to_json 24 | 25 | 26 | class DQN(ValueBasedRLAlgorithm): 27 | def __init__(self, policy_class, env, gamma=0.99, learning_rate=5e-4, buffer_size=50000, 28 | exploration_fraction=0.1, exploration_final_eps=0.02, train_freq=1, batch_size=32, double_q=True, 29 | learning_starts=1000, target_network_update_freq=500, prioritized_replay=False, 30 | prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, 31 | prioritized_replay_eps=1e-6, _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False, 32 | dueling=True, 33 | model_path='~/params/'): 34 | 35 | #Create an instance for save and load path 36 | self.model_path = model_path 37 | # Create an instance of DQNPolicy (obs_space, act_space, n_env, n_steps, n_batch, name) 38 | self.env = env 39 | self.observation_space = self.env.observation_space 40 | self.action_space = self.env.action_space 41 | self.policy = policy_class(self.observation_space, self.action_space, 1, 1, None, 'q', dueling=dueling) 42 | self.q_function = self.policy.qnet.call # Q-Function : obs -> action-dim vector 43 | 44 | # Create another instance of DQNPolicy 45 | self.target_policy = policy_class(self.observation_space, self.action_space, 1, 1, None, 'target_q', dueling=dueling) 46 | self.target_q_function = self.target_policy.qnet.call # Q-Function : obs -> action-dim vector 47 | 48 | self.double_q = double_q 49 | if self.double_q: 50 | self.double_policy = policy_class(self.observation_space, self.action_space, 1, 1, None, 'double_q', dueling=dueling) 51 | self.double_q_function = self.double_policy.qnet.call 52 | 53 | self.buffer_size = buffer_size 54 | self.replay_buffer = None 55 | 56 | self.prioritized_replay = prioritized_replay 57 | self.prioritized_replay_eps = prioritized_replay_eps 58 | self.prioritized_replay_alpha = prioritized_replay_alpha 59 | self.prioritized_replay_beta0 = prioritized_replay_beta0 60 | self.prioritized_replay_beta_iters = prioritized_replay_beta_iters 61 | 62 | self.num_timesteps = 0 63 | self.learning_starts = learning_starts 64 | self.train_freq = train_freq 65 | self.batch_size = batch_size 66 | self.target_network_update_freq = target_network_update_freq 67 | self.exploration_final_eps = exploration_final_eps 68 | self.exploration_fraction = exploration_fraction 69 | self.learning_rate = learning_rate 70 | self.gamma = gamma 71 | self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) 72 | 73 | self.proba_step = self.policy.proba_step 74 | self.exploration = None 75 | self.episode_reward = None 76 | self.n_actions = self.action_space.nvec if isinstance(self.action_space, MultiDiscrete) else self.action_space.n 77 | 78 | self.qfunc_layers = self.policy.qnet.trainable_layers 79 | self.target_qfunc_layers = self.target_policy.qnet.trainable_layers 80 | 81 | self.double_q_function_layers = [] 82 | 83 | if self.double_q: 84 | self.double_q_function_layers = self.double_policy.qnet.trainable_layers 85 | 86 | self.trainable_layers = self.qfunc_layers + self.target_qfunc_layers + self.double_q_function_layers 87 | self.params = self.qfunc_layers.trainable_variables + self.target_qfunc_layers.trainable_variables\ 88 | + self.double_q_function_layers 89 | 90 | self.update_target() 91 | 92 | self.initialize_variables() 93 | 94 | def act(self, obs, eps=1., stochastic=True): 95 | batch_size = np.shape(obs)[0] 96 | max_actions = np.argmax(self.q_function(obs), axis=1) 97 | 98 | if stochastic: 99 | random_actions = np.random.randint(low=0, high=self.n_actions, size=batch_size) 100 | chose_random = np.random.uniform(size=np.stack([batch_size]), low=0, high=1) 101 | epsgreedy_actions = np.where(chose_random < eps, random_actions, max_actions) 102 | 103 | return epsgreedy_actions 104 | 105 | else: 106 | return max_actions 107 | 108 | @tf.function 109 | def train(self, obs_t, act_t, rew_t, obs_tp, done_mask, importance_weights): 110 | 111 | if self.double_q: 112 | q_tp1_best_using_online_net = tf.argmax(self.double_q_function(obs_tp), axis=1) 113 | q_tp1_best = tf.reduce_sum(self.target_q_function(obs_tp) 114 | * tf.one_hot(q_tp1_best_using_online_net, self.n_actions), axis=1) 115 | else: 116 | q_tp1_best = tf.reduce_max(self.target_q_function(obs_tp), axis=1) 117 | 118 | q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best 119 | q_t_selected_target = tf.cast(rew_t, tf.float32) + tf.cast(self.gamma, tf.float32) * q_tp1_best_masked 120 | 121 | with tf.GradientTape() as tape: 122 | q_t_selected = tf.reduce_sum(self.q_function(obs_t) * tf.one_hot(act_t, self.n_actions), axis=1) 123 | td_error = q_t_selected - tf.stop_gradient(q_t_selected_target) 124 | errors = tf_util.huber_loss(td_error) 125 | weighted_error = tf.reduce_mean(errors) 126 | 127 | grads = tape.gradient(weighted_error, self.qfunc_layers.trainable_variables) 128 | self.optimizer.apply_gradients(zip(grads, self.qfunc_layers.trainable_variables)) 129 | 130 | return td_error, weighted_error 131 | 132 | @tf.function 133 | def initialize_variables(self): 134 | zero_like_state = tf.zeros((1,) + self.observation_space.shape) 135 | 136 | self.q_function(zero_like_state) 137 | self.target_q_function(zero_like_state) 138 | if self.double_q: 139 | self.double_q_function(zero_like_state) 140 | 141 | def update_target(self): 142 | for var, var_target in zip(self.qfunc_layers, self.target_qfunc_layers): 143 | w = var.get_weights() 144 | var_target.set_weights(w) 145 | 146 | # def update_double(self): 147 | # for var, var_double in zip(self.qfunc_layers, self.double_q_function_layers): 148 | # w = var.get_weights() 149 | # var_double.set_weights(w) 150 | 151 | def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DQN", 152 | reset_num_timesteps=True): 153 | 154 | # Create the replay buffer 155 | if self.prioritized_replay: 156 | self.replay_buffer = PrioritizedReplayBuffer(self.buffer_size, alpha=self.prioritized_replay_alpha) 157 | if self.prioritized_replay_beta_iters is None: 158 | prioritized_replay_beta_iters = total_timesteps 159 | else: 160 | prioritized_replay_beta_iters = self.prioritized_replay_beta_iters 161 | 162 | self.beta_schedule = LinearSchedule(prioritized_replay_beta_iters, 163 | initial_p=self.prioritized_replay_beta0, 164 | final_p=1.0) 165 | 166 | else: 167 | self.replay_buffer = ReplayBuffer(self.buffer_size) 168 | 169 | # Create the schedule for exploration starting from 1. 170 | self.exploration = LinearSchedule(schedule_timesteps=int(self.exploration_fraction * total_timesteps), 171 | initial_p=1.0, 172 | final_p=self.exploration_final_eps) 173 | 174 | episode_rewards = [0.0] 175 | episode_successes = [] 176 | 177 | saved_mean_rewards = None 178 | model_saved = False 179 | 180 | obs = self.env.reset() 181 | error = 0 182 | 183 | self.episode_reward = np.zeros((1,)) 184 | 185 | for _ in tqdm(range(total_timesteps)): 186 | # self.env.render() 187 | # Take action and update exploration to the newest value 188 | eps = self.exploration.value(self.num_timesteps) 189 | env_action = self.act(np.array(obs)[None], eps=eps, stochastic=True)[0] 190 | new_obs, rew, done, info = self.env.step(env_action) 191 | 192 | # Store transition in the replay buffer. 193 | self.replay_buffer.add(obs, env_action, rew, new_obs, np.float32(done)) 194 | obs = copy.deepcopy(new_obs) 195 | episode_rewards[-1] += rew 196 | 197 | if done: 198 | maybe_is_success = info.get('is_success') 199 | if maybe_is_success is not None: 200 | episode_successes.append(float(maybe_is_success)) 201 | if not isinstance(self.env, VecEnv): 202 | obs = self.env.reset() 203 | episode_rewards.append(0.0) 204 | 205 | can_sample = self.replay_buffer.can_sample(self.batch_size) 206 | 207 | if can_sample and self.num_timesteps > self.learning_starts: 208 | if self.num_timesteps % self.train_freq == 0: 209 | # Minimize the error in Bellman's equation on a batch sampled from replay buffer. 210 | 211 | # Sample a batch from the replay buffer 212 | if self.prioritized_replay: 213 | (obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = \ 214 | self.replay_buffer.sample(self.batch_size, 215 | beta=self.beta_schedule.value(self.num_timesteps)) 216 | 217 | else: 218 | obses_t, actions, rewards, obses_tp1, dones = self.replay_buffer.sample(self.batch_size) 219 | weights = np.ones_like(rewards) 220 | batch_idxes = None 221 | 222 | # Minimize the error in Bellman's equation on the sampled batch 223 | td_errors, error = self.train(obses_t, actions, rewards, obses_tp1, dones, weights) 224 | 225 | # if self.double_q: 226 | # self.update_double() 227 | 228 | if self.prioritized_replay: 229 | new_priorities = np.abs(td_errors) + self.prioritized_replay_eps 230 | self.replay_buffer.update_priorities(batch_idxes, new_priorities) 231 | 232 | if self.num_timesteps % self.target_network_update_freq == 0: 233 | # Update target network periodically. 234 | self.update_target() 235 | 236 | if len(episode_rewards[-101:-1]) == 0: 237 | mean_100ep_reward = -np.inf 238 | 239 | else: 240 | mean_100ep_reward = round(float(np.mean(episode_rewards[-101:-1])), 1) 241 | 242 | num_episodes = len(episode_rewards) 243 | 244 | if done and log_interval is not None and len(episode_rewards) % log_interval == 0: 245 | print("- steps : ", self.num_timesteps) 246 | print("- episodes : ", num_episodes) 247 | print("- mean 100 episode reward : %.4f" % mean_100ep_reward) 248 | print("- recent mean TD error : %.4f" % error) 249 | print("- % time spent exploring : ", int(100 * self.exploration.value(self.num_timesteps))) 250 | 251 | # Save if mean_100ep_reward is lager than the past best result 252 | if saved_mean_rewards is None or saved_mean_rewards < mean_100ep_reward: 253 | self.save(self.model_path) 254 | model_saved = True 255 | saved_mean_rewards = mean_100ep_reward 256 | if model_saved: 257 | print("best case: ", saved_mean_rewards) 258 | 259 | self.num_timesteps += 1 260 | 261 | return self 262 | def play(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DQN", 263 | reset_num_timesteps=True): 264 | 265 | self.episode_reward = np.zeros((1,)) 266 | 267 | obs = self.env.reset() 268 | 269 | for _ in tqdm(range(total_timesteps)): 270 | self.env.render() 271 | env_action = self.act(np.array(obs)[None], eps=0, stochastic=True)[0] 272 | new_obs, rew, done, info = self.env.step(env_action) 273 | obs = copy.deepcopy(new_obs) 274 | 275 | if done: 276 | maybe_is_success = info.get('is_success') 277 | if maybe_is_success is not None: 278 | pass 279 | if not isinstance(self.env, VecEnv): 280 | obs = self.env.reset() 281 | return self 282 | 283 | def predict(self, observation, state=None, mask=None, deterministic=True): 284 | observation = np.array(observation) 285 | observation = observation.reshape((-1,) + self.observation_space.shape) 286 | actions, _, _ = self.policy.step(observation, deterministic=deterministic) 287 | actions = actions[0] 288 | 289 | return actions, None 290 | 291 | def action_probability(self, observation, state=None, mask=None, actions=None, logp=False): 292 | observation = np.array(observation) 293 | observation = observation.reshape((-1,) + self.observation_space.shape) 294 | actions_proba = self.proba_step(observation, state, mask) 295 | 296 | if actions is not None: # comparing the action distribution, to given actions 297 | actions = np.array([actions]) 298 | assert isinstance(self.action_space, gym.spaces.Discrete) 299 | actions = actions.reshape((-1,)) 300 | assert observation.shape[0] == actions.shape[0], "Error: batch sizes differ for actions and observations." 301 | actions_proba = actions_proba[np.arange(actions.shape[0]), actions] 302 | # normalize action proba shape 303 | actions_proba = actions_proba.reshape((-1, 1)) 304 | 305 | if logp: 306 | actions_proba = np.log(actions_proba) 307 | 308 | actions_proba = actions_proba[0] 309 | 310 | return actions_proba 311 | 312 | def get_parameters(self): 313 | parameters = [] 314 | weights = [] 315 | for layer in self.trainable_layers: 316 | # print(layer.name) 317 | weights.append(layer.get_weights()) 318 | 319 | weights = np.array(weights) 320 | weights = weights.reshape(np.shape(self.trainable_layers.trainable_variables)) 321 | 322 | for idx, variable in enumerate(self.trainable_layers.trainable_variables): 323 | weight = weights[idx] 324 | parameters.append((variable.name, weight)) 325 | return parameters 326 | 327 | def get_parameter_list(self): 328 | return self.params 329 | 330 | def save(self, save_path, cloudpickle=True): 331 | data = self.get_parameters() 332 | if isinstance(save_path, str): 333 | _, ext = os.path.splitext(save_path) 334 | if ext == "": 335 | save_path += ".pkl" 336 | with open(save_path, 'wb') as f: 337 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 338 | 339 | def load_parameters(self, parameters, exact_match=False): 340 | print(len(parameters), len(self.trainable_layers.weights)) 341 | 342 | assert len(parameters) == len(self.trainable_layers.weights) 343 | weights = [] 344 | for variable, parameter in zip(self.trainable_layers.weights, parameters): 345 | name, value = parameter 346 | if exact_match: 347 | assert name == variable.name 348 | weights.append(value) 349 | for i in range(len(self.trainable_layers)): 350 | self.trainable_layers[i].set_weights((weights[2 * i], weights[2 * i + 1])) 351 | 352 | def load(self, load_path, cloudpickle=True): 353 | # Parameter cloudpickle does not work now 354 | self.initialize_variables() 355 | if isinstance(load_path, str): 356 | _, ext = os.path.splitext(load_path) 357 | if ext == "": 358 | load_path += ".pkl" 359 | with open(load_path, 'rb') as f: 360 | data = pickle.load(f) 361 | self.load_parameters(data) 362 | -------------------------------------------------------------------------------- /dqn/enjoy_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | from dqn import DQN 6 | from policy import MlpPolicy, CnnPolicy 7 | 8 | 9 | def main(args): 10 | """ 11 | Run a trained model for the cartpole problem 12 | :param args: (ArgumentParser) the input arguments 13 | """ 14 | env = gym.make("CartPole-v0") 15 | model = DQN( 16 | env=env, 17 | policy_class=MlpPolicy, 18 | learning_rate=5e-4, 19 | buffer_size=50000, 20 | double_q=False, 21 | prioritized_replay=True, 22 | dueling=True, 23 | exploration_fraction=0.2, 24 | exploration_final_eps=0.02, 25 | model_path='cartpole_model' 26 | ) 27 | model = model.load("cartpole_model") 28 | 29 | while True: 30 | obs, done = env.reset(), False 31 | episode_rew = 0 32 | while not done: 33 | if not args.no_render: 34 | env.render() 35 | action, _ = model.predict(obs) 36 | obs, rew, done, _ = env.step(action) 37 | episode_rew += rew 38 | print("Episode reward", episode_rew) 39 | # No render is only used for automatic testing 40 | if args.no_render: 41 | break 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description="Enjoy trained DQN on cartpole") 46 | parser.add_argument('--no-render', default=False, action="store_true", help="Disable rendering") 47 | args = parser.parse_args() 48 | main(args) -------------------------------------------------------------------------------- /dqn/policy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from gym.spaces import Discrete 5 | from base.policy import BasePolicy 6 | from common.tf_util import conv, conv_to_fc, linear 7 | 8 | 9 | class DQNPolicy(BasePolicy): 10 | """ 11 | Policy object that implements a DQN policy 12 | 13 | :param ob_space: (Gym Space) The observation space of the environment 14 | :param ac_space: (Gym Space) The action space of the environment 15 | :param n_env: (int) The number of environments to run 16 | :param n_steps: (int) The number of steps to run for each environment 17 | :param n_batch: (int) The number of batch to run (n_envs * n_steps) 18 | :param reuse: (bool) If the policy is reusable or not 19 | :param scale: (bool) whether or not to scale the input 20 | :param obs_phs: (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder 21 | and the processed observation placeholder respectivly 22 | :param dueling: (bool) if true double the output MLP to compute a baseline for action scores 23 | """ 24 | 25 | def __init__(self, ob_space, ac_space, n_env, n_steps, n_batch, name='q', reuse=False, scale=False, dueling=True): 26 | # DQN policies need an override for the obs placeholder, due to the architecture of the code 27 | super(DQNPolicy, self).__init__(ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=scale) 28 | assert isinstance(ac_space, Discrete), "Error: the action space for DQN must be of type gym.spaces.Discrete" 29 | self.n_actions = ac_space.n 30 | self.value_fn = None 31 | self.q_values = None 32 | self.dueling = dueling 33 | self.policy_proba = None 34 | 35 | def step(self, obs, state=None, mask=None, deterministic=True): 36 | """ 37 | Returns the q_values for a single step 38 | 39 | :param obs: (np.ndarray float or int) The current observation of the environment 40 | :param state: (np.ndarray float) The last states (used in recurrent policies) 41 | :param mask: (np.ndarray float) The last masks (used in recurrent policies) 42 | :param deterministic: (bool) Whether or not to return deterministic actions. 43 | :return: (np.ndarray int, np.ndarray float, np.ndarray float) actions, q_values, states 44 | """ 45 | raise NotImplementedError 46 | 47 | def proba_step(self, obs, state=None, mask=None): 48 | """ 49 | Returns the action probability for a single step 50 | 51 | :param obs: (np.ndarray float or int) The current observation of the environment 52 | :param state: (np.ndarray float) The last states (used in recurrent policies) 53 | :param mask: (np.ndarray float) The last masks (used in recurrent policies) 54 | :return: (np.ndarray float) the action probability 55 | """ 56 | raise NotImplementedError 57 | 58 | 59 | class CNNetwork(tf.keras.layers.Layer): 60 | def __init__(self): 61 | super(CNNetwork, self).__init__() 62 | layer_conv1 = tf.keras.layers.Conv2D(name='c1', filters=32, kernel_size=8, strides=4, padding='valid', 63 | activation='relu', 64 | kernel_initializer=tf.keras.initializers.Orthogonal(np.sqrt(2))) 65 | 66 | layer_conv2 = tf.keras.layers.Conv2D(name='c2', filters=64, kernel_size=4, strides=2, padding='valid', 67 | activation='relu', 68 | kernel_initializer=tf.keras.initializers.Orthogonal(np.sqrt(2))) 69 | 70 | layer_conv3 = tf.keras.layers.Conv2D(name='c3', filters=64, kernel_size=3, strides=1, padding='valid', 71 | activation='relu', 72 | kernel_initializer=tf.keras.initializers.Orthogonal(np.sqrt(2))) 73 | 74 | layer_flat = tf.keras.layers.Flatten(name='fc') 75 | 76 | layer_dense = tf.keras.layers.Dense(512, name='fc1', activation='relu') 77 | 78 | # layer_dropout = tf.keras.layers.Dropout(0.5) 79 | 80 | # self.model = [layer_conv1, layer_conv2, layer_conv3, layer_flat, layer_dense, layer_dropout] 81 | self.model = [layer_conv1, layer_conv2, layer_conv3, layer_flat, layer_dense] 82 | 83 | @tf.function 84 | def call(self, input): 85 | h = tf.cast(input, tf.float32) 86 | for layer in self.model: 87 | # print(layer.name) 88 | h = layer(h) 89 | return h 90 | 91 | 92 | class QNetwork(tf.keras.layers.Layer): 93 | def __init__(self, layers, obs_shape, n_action, name='q', layer_norm=False, dueling=True, n_batch=None, activation='relu', 94 | cnn_extractor=CNNetwork, feature_extraction="cnn"): 95 | super(QNetwork, self).__init__() 96 | self.layer_norm = layer_norm 97 | self.dueling = dueling 98 | self.layers = [] 99 | self.layer_norms = [] 100 | self.activation = activation 101 | 102 | self.feature_extraction = feature_extraction 103 | 104 | if self.feature_extraction != "cnn": 105 | for i, layersize in enumerate(layers): 106 | if i == 0: 107 | layer = tf.keras.layers.Dense(layersize, name=name+'/l1', 108 | activation=activation, input_shape=(n_batch,) + obs_shape) 109 | 110 | else: 111 | layer = tf.keras.layers.Dense(layersize, name=name+'/l%d' % (i+1), 112 | activation=activation) 113 | self.layers.append(layer) 114 | 115 | if self.layer_norm: 116 | self.layer_norms_QNet.append(tf.keras.layers.LayerNormalization(epsilon=1e-4)) 117 | 118 | self.layer_out = tf.keras.layers.Dense(n_action, name=name + '/out') 119 | self.trainable_layers = self.layers + [self.layer_out] + self.layer_norms 120 | 121 | else: 122 | self.cnn_extractor = cnn_extractor 123 | 124 | self.conv_net = CNNetwork() 125 | self.conv_layers = self.conv_net.model 126 | 127 | self.layer_out = tf.keras.layers.Dense(n_action, name=name + '/out') 128 | 129 | self.trainable_layers = self.conv_layers[0:3] + [self.conv_layers[4], self.layer_out] 130 | 131 | if self.dueling: 132 | self.layer_norms_VNet = [] 133 | self.layers_VNet = [] 134 | 135 | for i, layersize in enumerate(layers): 136 | if i == 0: 137 | layer = tf.keras.layers.Dense(layersize, name=name+'/v/l1', activation=activation, 138 | input_shape=(n_batch,) + obs_shape) 139 | else: 140 | layer = tf.keras.layers.Dense(layersize, name=name + '/v/l%d' % (i+1), activation=activation) 141 | 142 | self.layers_VNet.append(layer) 143 | 144 | if self.layer_norm: 145 | self.layer_norms_VNet.append(tf.keras.layers.LayerNormalization(epsilon=1e-4)) 146 | 147 | self.layer_out_VNet = tf.keras.layers.Dense(1, name=name+'/v/out') 148 | self.trainable_layers = self.trainable_layers \ 149 | + self.layers_VNet + [self.layer_out_VNet] + self.layer_norms_VNet 150 | 151 | 152 | @tf.function 153 | def call(self, input): 154 | if self.feature_extraction == "cnn": 155 | extracted_features = self.conv_net(input) 156 | h = extracted_features 157 | else: 158 | h = input 159 | for i, layer in enumerate(self.layers): 160 | h = layer(h) 161 | if self.layer_norm: 162 | h = self.layer_norms[i](h) 163 | action_scores = self.layer_out(h) 164 | 165 | # TODO : Implement Dueling Network Here 166 | if self.dueling: 167 | # Value Network 168 | if self.feature_extraction == "cnn": 169 | h = extracted_features 170 | else: 171 | h = input 172 | for i, layer in enumerate(self.layers_VNet): 173 | h = layer(h) 174 | if self.layer_norm: 175 | h = self.layer_norms_VNet[i](h) 176 | 177 | state_scores = self.layer_out_VNet(h) 178 | 179 | action_scores_mean = tf.reduce_mean(action_scores, axis=1) 180 | action_scores_centered = action_scores - tf.expand_dims(action_scores_mean, axis=1) 181 | 182 | q_out = state_scores + action_scores_centered 183 | else: 184 | q_out = action_scores 185 | 186 | return q_out 187 | 188 | 189 | class FeedForwardPolicy(DQNPolicy): 190 | def __init__(self, ob_space, ac_space, n_env, n_steps, n_batch, name='q', reuse=False, layers=None, 191 | cnn_extractor=CNNetwork, feature_extraction="mlp", 192 | layer_norm=False, dueling=False, act_fun=tf.nn.relu, **kwargs): 193 | super(FeedForwardPolicy, self).__init__(ob_space, ac_space, n_env, n_steps, 194 | n_batch, dueling=dueling, reuse=reuse) 195 | if layers is None: 196 | layers = [64, 64] 197 | 198 | self.reuse = reuse 199 | self.kwargs = kwargs 200 | self.layer_norm = layer_norm 201 | self.activation_function = act_fun 202 | self.qnet = QNetwork(layers, self.ob_space.shape, self.n_actions, name, layer_norm, dueling, n_batch, 203 | self.activation_function, cnn_extractor, feature_extraction) 204 | 205 | @tf.function 206 | def q_value(self, obs): 207 | self.q_values = self.qnet(obs) 208 | self.policy_proba = tf.nn.softmax(self.q_values, axis=-1) 209 | return self.qnet(obs) 210 | 211 | def step(self, obs, state=None, mask=None, deterministic=True): 212 | # q_values, actions_proba = self.sess.run([self.q_values, self.policy_proba], {self.obs_ph: obs}) 213 | q_values = self.q_value(obs) 214 | # actions_proba = self.policy_proba(obs) 215 | actions_proba = tf.nn.softmax(q_values) 216 | if deterministic: 217 | actions = np.argmax(q_values, axis=1) 218 | else: 219 | # Unefficient sampling 220 | # TODO: replace the loop 221 | # maybe with Gumbel-max trick ? (http://amid.fish/humble-gumbel) 222 | actions = np.zeros((len(obs),), dtype=np.int64) 223 | 224 | for action_idx in range(len(obs)): 225 | actions[action_idx] = np.random.choice(self.n_actions, p=actions_proba[action_idx]) 226 | 227 | return actions, q_values, None 228 | 229 | def proba_step(self, obs, state=None, mask=None): 230 | return self.policy_proba(obs) 231 | 232 | 233 | class CnnPolicy(FeedForwardPolicy): 234 | """ 235 | Policy object that implements DQN policy, using a CNN (the nature CNN) 236 | :param ob_space: (Gym Space) The observation space of the environment 237 | :param ac_space: (Gym Space) The action space of the environment 238 | :param n_env: (int) The number of environments to run 239 | :param n_steps: (int) The number of steps to run for each environment 240 | :param n_batch: (int) The number of batch to run (n_envs * n_steps) 241 | :param reuse: (bool) If the policy is reusable or not 242 | :param obs_phs: (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder 243 | and the processed observation placeholder respectively 244 | :param dueling: (bool) if true double the output MLP to compute a baseline for action scores 245 | :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction 246 | """ 247 | 248 | def __init__(self, ob_space, ac_space, n_env, n_steps, n_batch, name='q', 249 | reuse=False, dueling=True, **_kwargs): 250 | super(CnnPolicy, self).__init__(ob_space, ac_space, n_env, n_steps, n_batch, name, reuse, 251 | feature_extraction="cnn", dueling=dueling, 252 | layer_norm=False, **_kwargs) 253 | 254 | 255 | class MlpPolicy(FeedForwardPolicy): 256 | """ 257 | Policy object that implements DQN policy, using a MLP (2 layers of 64) 258 | 259 | :param ob_space: (Gym Space) The observation space of the environment 260 | :param ac_space: (Gym Space) The action space of the environment 261 | :param n_env: (int) The number of environments to run 262 | :param n_steps: (int) The number of steps to run for each environment 263 | :param n_batch: (int) The number of batch to run (n_envs * n_steps) 264 | :param reuse: (bool) If the policy is reusable or not 265 | :param obs_phs: (TensorFlow Tensor, TensorFlow Tensor) a tuple containing an override for observation placeholder 266 | and the processed observation placeholder respectivly 267 | :param dueling: (bool) if true double the output MLP to compute a baseline for action scores 268 | :param _kwargs: (dict) Extra keyword arguments for the nature CNN feature extraction 269 | """ 270 | def __init__(self, ob_space, ac_space, n_env, n_steps, n_batch, name='q', 271 | reuse=False, dueling=True, **_kwargs): 272 | super(MlpPolicy, self).__init__(ob_space, ac_space, n_env, n_steps, n_batch, name, reuse, 273 | feature_extraction="mlp", dueling=dueling, 274 | layer_norm=False, **_kwargs) 275 | 276 | -------------------------------------------------------------------------------- /dqn/run_Break_double.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | 4 | import bench, logger 5 | from common.misc_util import set_global_seeds 6 | from common.atari_wrappers import make_atari 7 | from dqn import DQN 8 | # from dqn import wrap_atari_dqn 9 | from policy import CnnPolicy, MlpPolicy 10 | 11 | 12 | def wrap_atari_dqn(env): 13 | """ 14 | wrap the environment in atari wrappers for DQN 15 | :param env: (Gym Environment) the environment 16 | :return: (Gym Environment) the wrapped environment 17 | """ 18 | from common.atari_wrappers import wrap_deepmind 19 | return wrap_deepmind(env, frame_stack=True, scale=False) 20 | 21 | 22 | def main(): 23 | """ 24 | Run the atari test 25 | """ 26 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') 28 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 29 | parser.add_argument('--prioritized', type=int, default=1) 30 | parser.add_argument('--dueling', type=int, default=1) 31 | parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) 32 | parser.add_argument('--num-timesteps', type=int, default=int(1e7)) 33 | 34 | args = parser.parse_args() 35 | logger.configure() 36 | set_global_seeds(args.seed) 37 | env = make_atari(args.env) 38 | env = bench.Monitor(env, logger.get_dir()) 39 | env = wrap_atari_dqn(env) 40 | policy = partial(CnnPolicy, dueling=args.dueling == 1) 41 | 42 | # model = DQN( 43 | # env=env, 44 | # policy=policy, 45 | # learning_rate=1e-4, 46 | # buffer_size=10000, 47 | # exploration_fraction=0.1, 48 | # exploration_final_eps=0.01, 49 | # train_freq=4, 50 | # learning_starts=10000, 51 | # target_network_update_freq=1000, 52 | # gamma=0.99, 53 | # prioritized_replay=bool(args.prioritized), 54 | # prioritized_replay_alpha=args.prioritized_replay_alpha, 55 | # ) 56 | 57 | # policy: 'CnnPolicy' 58 | # n_timesteps: !!float 59 | # 1e7 60 | # buffer_size: 10000 61 | # learning_rate: !!float 62 | # 1e-4 63 | # learning_starts: 10000 64 | # target_network_update_freq: 1000 65 | # train_freq: 4 66 | # exploration_final_eps: 0.01 67 | # exploration_fraction: 0.1 68 | # prioritized_replay_alpha: 0.6 69 | # prioritized_replay: True 70 | model = DQN( 71 | env=env, 72 | policy_class=CnnPolicy, 73 | learning_rate=1e-4, 74 | buffer_size=10000, 75 | double_q=True, 76 | prioritized_replay=False, 77 | prioritized_replay_alpha=0.6, 78 | dueling=False, 79 | train_freq=4, 80 | learning_starts=10000, 81 | exploration_fraction=0.1, 82 | exploration_final_eps=0.01, 83 | target_network_update_freq=1000, 84 | model_path='atari_Breakout_double' 85 | ) 86 | # generate_expert_traj(model, 'expert_cartpole', n_timesteps=int(1e5), n_episodes=10) 87 | # model.learn(total_timesteps=args.num_timesteps, seed=args.seed) 88 | 89 | env.close() 90 | 91 | 92 | if __name__ == '__main__': 93 | main() -------------------------------------------------------------------------------- /dqn/run_Break_duel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | 4 | import bench, logger 5 | from common.misc_util import set_global_seeds 6 | from common.atari_wrappers import make_atari 7 | from dqn import DQN 8 | # from dqn import wrap_atari_dqn 9 | from policy import CnnPolicy, MlpPolicy 10 | 11 | 12 | def wrap_atari_dqn(env): 13 | """ 14 | wrap the environment in atari wrappers for DQN 15 | :param env: (Gym Environment) the environment 16 | :return: (Gym Environment) the wrapped environment 17 | """ 18 | from common.atari_wrappers import wrap_deepmind 19 | return wrap_deepmind(env, frame_stack=True, scale=False) 20 | 21 | 22 | def main(): 23 | """ 24 | Run the atari test 25 | """ 26 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 27 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') 28 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 29 | parser.add_argument('--prioritized', type=int, default=1) 30 | parser.add_argument('--dueling', type=int, default=1) 31 | parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) 32 | parser.add_argument('--num-timesteps', type=int, default=int(1e7)) 33 | 34 | args = parser.parse_args() 35 | logger.configure() 36 | set_global_seeds(args.seed) 37 | env = make_atari(args.env) 38 | env = bench.Monitor(env, logger.get_dir()) 39 | env = wrap_atari_dqn(env) 40 | policy = partial(CnnPolicy, dueling=args.dueling == 1) 41 | 42 | # model = DQN( 43 | # env=env, 44 | # policy=policy, 45 | # learning_rate=1e-4, 46 | # buffer_size=10000, 47 | # exploration_fraction=0.1, 48 | # exploration_final_eps=0.01, 49 | # train_freq=4, 50 | # learning_starts=10000, 51 | # target_network_update_freq=1000, 52 | # gamma=0.99, 53 | # prioritized_replay=bool(args.prioritized), 54 | # prioritized_replay_alpha=args.prioritized_replay_alpha, 55 | # ) 56 | model = DQN( 57 | env=env, 58 | policy_class=CnnPolicy, 59 | learning_rate=1e-4, 60 | buffer_size=10000, 61 | double_q=False, 62 | prioritized_replay=True, 63 | prioritized_replay_alpha=0.6, 64 | dueling=True, 65 | train_freq=4, 66 | learning_starts=10000, 67 | exploration_fraction=0.1, 68 | exploration_final_eps=0.01, 69 | target_network_update_freq=1000, 70 | model_path='atari_Breakout_duel' 71 | ) 72 | # model.learn(total_timesteps=args.num_timesteps, seed=args.seed) 73 | model.load('atari_Breakout_duel') 74 | model.evaluate(100) 75 | env.close() 76 | 77 | 78 | if __name__ == '__main__': 79 | main() -------------------------------------------------------------------------------- /dqn/train_atari_Breakout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import argparse 4 | from functools import partial 5 | 6 | import bench, logger 7 | from common.misc_util import set_global_seeds 8 | from common.atari_wrappers import make_atari 9 | from dqn import DQN 10 | # from dqn import wrap_atari_dqn 11 | from policy import CnnPolicy, MlpPolicy 12 | 13 | from common.vec_env import SubprocVecEnv, VecFrameStack, VecNormalize, VecEnvWrapper, VecVideoRecorder 14 | 15 | 16 | def wrap_atari_dqn(env): 17 | """ 18 | wrap the environment in atari wrappers for DQN 19 | :param env: (Gym Environment) the environment 20 | :return: (Gym Environment) the wrapped environment 21 | """ 22 | from common.atari_wrappers import wrap_deepmind 23 | return wrap_deepmind(env, frame_stack=True, scale=False) 24 | 25 | 26 | def main(): 27 | """ 28 | Run the atari test 29 | """ 30 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') 32 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 33 | parser.add_argument('--prioritized', type=int, default=1) 34 | parser.add_argument('--dueling', type=int, default=1) 35 | parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) 36 | parser.add_argument('--num-timesteps', type=int, default=int(1e7)) 37 | 38 | args = parser.parse_args() 39 | logger.configure() 40 | set_global_seeds(args.seed) 41 | 42 | env = make_atari(args.env) 43 | env.action_space.seed(args.seed) 44 | env = bench.Monitor(env, logger.get_dir()) 45 | env = wrap_atari_dqn(env) 46 | 47 | model = DQN( 48 | env=env, 49 | policy_class=CnnPolicy, 50 | buffer_size=10000, 51 | learning_rate=1e-4, 52 | learning_starts=10000, 53 | target_network_update_freq=1000, 54 | train_freq=4, 55 | exploration_final_eps=0.01, 56 | exploration_fraction=0.1, 57 | prioritized_replay=True, 58 | model_path='atari_test_Breakout' 59 | ) 60 | model.learn(total_timesteps=args.num_timesteps) 61 | env.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | main() -------------------------------------------------------------------------------- /dqn/train_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | from dqn import DQN 6 | from policy import MlpPolicy, CnnPolicy 7 | 8 | def main(args): 9 | """ 10 | Train and save the DQN model, for the cartpole problem 11 | 12 | :param args: (ArgumentParser) the input arguments 13 | """ 14 | env = gym.make("CartPole-v0") 15 | 16 | model = DQN( 17 | env=env, 18 | policy_class=MlpPolicy, 19 | learning_rate=1e-3, 20 | buffer_size=50000, 21 | double_q=False, 22 | prioritized_replay=True, 23 | dueling=True, 24 | exploration_fraction=0.2, 25 | exploration_final_eps=0.02, 26 | model_path='cartpole_model_test' 27 | ) 28 | model.learn(total_timesteps=args.max_timesteps) 29 | model.evaluate(num_epsiodes=50) 30 | print("\nTrain Finished") 31 | model = DQN( 32 | env=env, 33 | policy_class=MlpPolicy, 34 | learning_rate=1e-3, 35 | buffer_size=50000, 36 | double_q=False, 37 | prioritized_replay=True, 38 | dueling=True, 39 | exploration_fraction=0.2, 40 | exploration_final_eps=0.02, 41 | model_path='cartpole_model_test' 42 | ) 43 | print("\nBefore Loading") 44 | model.evaluate(num_epsiodes=50) 45 | model.load("cartpole_model_test") 46 | model.evaluate(num_epsiodes=50) 47 | print("Finished") 48 | 49 | 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description="Train DQN on cartpole") 54 | parser.add_argument('--max-timesteps', default=30000, type=int, help="Maximum number of timesteps") 55 | args = parser.parse_args() 56 | main(args) 57 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tf2 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _tflow_select=2.3.0=mkl 7 | - absl-py=0.8.1=py37_0 8 | - astor=0.8.0=py37_0 9 | - blas=1.0=mkl 10 | - c-ares=1.15.0=h7b6447c_1001 11 | - ca-certificates=2019.11.27=0 12 | - certifi=2019.11.28=py37_0 13 | - decorator=4.4.1=py_0 14 | - gast=0.2.2=py37_0 15 | - google-pasta=0.1.8=py_0 16 | - grpcio=1.16.1=py37hf8bcb03_1 17 | - h5py=2.9.0=py37h7918eee_0 18 | - hdf5=1.10.4=hb1b8bf9_0 19 | - intel-openmp=2019.4=243 20 | - keras-applications=1.0.8=py_0 21 | - keras-preprocessing=1.1.0=py_1 22 | - libedit=3.1.20181209=hc058e9b_0 23 | - libffi=3.2.1=hd88cf55_4 24 | - libgcc-ng=9.1.0=hdf63c60_0 25 | - libgfortran-ng=7.3.0=hdf63c60_0 26 | - libprotobuf=3.10.1=hd408876_0 27 | - libstdcxx-ng=9.1.0=hdf63c60_0 28 | - markdown=3.1.1=py37_0 29 | - mkl=2019.4=243 30 | - mkl-service=2.3.0=py37he904b0f_0 31 | - mkl_fft=1.0.15=py37ha843d7b_0 32 | - mkl_random=1.1.0=py37hd6b4f25_0 33 | - ncurses=6.1=he6710b0_1 34 | - numpy=1.17.4=py37hc1035e2_0 35 | - numpy-base=1.17.4=py37hde5b4d6_0 36 | - openssl=1.1.1d=h7b6447c_3 37 | - opt_einsum=3.1.0=py_0 38 | - pip=19.3.1=py37_0 39 | - protobuf=3.10.1=py37he6710b0_0 40 | - python=3.7.5=h0371630_0 41 | - readline=7.0=h7b6447c_5 42 | - scipy=1.3.2=py37h7c811a0_0 43 | - setuptools=42.0.2=py37_0 44 | - six=1.13.0=py37_0 45 | - sqlite=3.30.1=h7b6447c_0 46 | - tensorboard=2.0.0=pyhb38c66f_1 47 | - tensorflow=2.0.0=mkl_py37h66b46cc_0 48 | - tensorflow-base=2.0.0=mkl_py37h9204916_0 49 | - tensorflow-estimator=2.0.0=pyh2649769_0 50 | - tensorflow-probability=0.8.0=py_0 51 | - termcolor=1.1.0=py37_1 52 | - tk=8.6.8=hbc83047_0 53 | - tqdm=4.40.0=py_0 54 | - werkzeug=0.16.0=py_0 55 | - wheel=0.33.6=py37_0 56 | - wrapt=1.11.2=py37h7b6447c_0 57 | - xz=5.2.4=h14c3975_4 58 | - zlib=1.2.11=h7b6447c_3 59 | - pip: 60 | - astroid==2.3.3 61 | - cffi==1.13.2 62 | - cloudpickle==1.2.2 63 | - cython==0.29.14 64 | - fasteners==0.15 65 | - future==0.18.2 66 | - glfw==1.8.6 67 | - gym==0.15.4 68 | - imageio==2.6.1 69 | - isort==4.3.21 70 | - lazy-object-proxy==1.4.3 71 | - mccabe==0.6.1 72 | - monotonic==1.5 73 | - opencv-python==4.1.2.30 74 | - pillow==6.2.1 75 | - pycparser==2.19 76 | - pyglet==1.3.2 77 | - pylint==2.4.4 78 | - typed-ast==1.4.0 79 | prefix: /Users/raydr/Anaconda3/envs/tf20 80 | 81 | -------------------------------------------------------------------------------- /environment_gpu.yml: -------------------------------------------------------------------------------- 1 | name: tf2-gpu 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _tflow_select=2.3.0=mkl 7 | - absl-py=0.8.1=py37_0 8 | - astor=0.8.0=py37_0 9 | - blas=1.0=mkl 10 | - c-ares=1.15.0=h7b6447c_1001 11 | - ca-certificates=2019.11.27=0 12 | - certifi=2019.11.28=py37_0 13 | - cloudpickle=1.1.1=py_0 14 | - cudatoolkit=10.1.243=h6bb024c_0 15 | - cudnn=7.6.4=cuda10.1_0 16 | - decorator=4.4.1=py_0 17 | - gast=0.2.2=py37_0 18 | - google-pasta=0.1.8=py_0 19 | - grpcio=1.16.1=py37hf8bcb03_1 20 | - h5py=2.9.0=py37h7918eee_0 21 | - hdf5=1.10.4=hb1b8bf9_0 22 | - intel-openmp=2019.4=243 23 | - keras-applications=1.0.8=py_0 24 | - keras-preprocessing=1.1.0=py_1 25 | - libedit=3.1.20181209=hc058e9b_0 # 못깜 26 | - libffi=3.2.1=hd88cf55_4 # 못깜 27 | - libgcc-ng=9.1.0=hdf63c60_0 # 못깜 28 | - libgfortran-ng=7.3.0=hdf63c60_0 # 못깜 29 | - libprotobuf=3.10.1=hd408876_0 # 못깜 30 | - libstdcxx-ng=9.1.0=hdf63c60_0 # 못깜 31 | - markdown=3.1.1=py37_0 32 | - mkl=2019.4=243 33 | - mkl-service=2.3.0=py37he904b0f_0 34 | - mkl_fft=1.0.15=py37ha843d7b_0 35 | - mkl_random=1.1.0=py37hd6b4f25_0 36 | - ncurses=6.1=he6710b0_1 37 | - numpy=1.17.4=py37hc1035e2_0 38 | - numpy-base=1.17.4=py37hde5b4d6_0 39 | - openssl=1.1.1d=h7b6447c_3 40 | - opt_einsum=3.1.0=py_0 41 | - patchelf=0.10=he6710b0_0 42 | - pip=19.3.1=py37_0 43 | - protobuf=3.10.1=py37he6710b0_0 44 | - python=3.7.5=h0371630_0 45 | - readline=7.0=h7b6447c_5 46 | - scipy=1.3.2=py37h7c811a0_0 47 | - setuptools=42.0.2=py37_0 48 | - six=1.13.0=py37_0 49 | - sqlite=3.30.1=h7b6447c_0 50 | - tensorboard=2.0.0=pyhb38c66f_1 51 | - tensorflow=2.0.0=mkl_py37h66b46cc_0 52 | - tensorflow-base=2.0.0=mkl_py37h9204916_0 53 | - tensorflow-estimator=2.0.0=pyh2649769_0 54 | - tensorflow-probability=0.8.0=py_0 55 | - termcolor=1.1.0=py37_1 56 | - tk=8.6.8=hbc83047_0 57 | - tqdm=4.40.2=py_0 58 | - werkzeug=0.16.0=py_0 59 | - wheel=0.33.6=py37_0 60 | - wrapt=1.11.2=py37h7b6447c_0 61 | - xz=5.2.4=h14c3975_4 62 | - zlib=1.2.11=h7b6447c_3 63 | - pip: 64 | - astroid==2.3.3 65 | - blessings==1.7 66 | - cffi==1.13.2 67 | - cython==0.29.14 68 | - fasteners==0.15 69 | - future==0.18.2 70 | - glfw==1.8.6 71 | - gpustat==0.6.0 72 | - gym==0.15.4 73 | - imageio==2.6.1 74 | - isort==4.3.21 75 | - lazy-object-proxy==1.4.3 76 | - mccabe==0.6.1 77 | - monotonic==1.5 78 | - mujoco-py==2.0.2.9 # Yet 79 | - nvidia-ml-py3==7.352.0 80 | - opencv-python==4.1.2.30 81 | - pillow==6.2.1 82 | - psutil==5.6.7 83 | - pycparser==2.19 84 | - pyglet==1.3.2 85 | - pylint==2.4.4 86 | - setgpu==0.0.7 87 | - tensorflow-gpu==2.0.0 88 | - typed-ast==1.4.0 89 | prefix: /Users/raydr/Anaconda3/envs/tf20 90 | 91 | -------------------------------------------------------------------------------- /sac/__pycache__/sac.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/sac/__pycache__/sac.cpython-35.pyc -------------------------------------------------------------------------------- /sac/__pycache__/sac.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/stable-baselines-tf2/c3ddc50f9af5031706d063971267543d3980c58f/sac/__pycache__/sac.cpython-37.pyc -------------------------------------------------------------------------------- /sac/plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 2, 4 | "metadata": { 5 | "language_info": { 6 | "name": "python", 7 | "codemirror_mode": { 8 | "name": "ipython", 9 | "version": 3 10 | }, 11 | "version": "3.7.5" 12 | }, 13 | "orig_nbformat": 2, 14 | "file_extension": ".py", 15 | "mimetype": "text/x-python", 16 | "name": "python", 17 | "npconvert_exporter": "python", 18 | "pygments_lexer": "ipython3", 19 | "version": 3 20 | }, 21 | "cells": [ 22 | { 23 | "source": [ 24 | "import matplotlib.pyplot as plt\n", 25 | "import numpy as np\n", 26 | "import os\n", 27 | "%matplotlib inline" 28 | ], 29 | "cell_type": "code", 30 | "outputs": [], 31 | "metadata": {}, 32 | "execution_count": 7 33 | }, 34 | { 35 | "source": [ 36 | "envnames = ['Hopper-v2', 'Walker2d-v2', 'HalfCheetah-v2', 'Ant-v2']\n", 37 | "envname = envnames[3]\n", 38 | "path = 'results_final/%s/' % envname\n", 39 | "filenames = os.listdir(path)\n", 40 | "pathlist = [path + fname for fname in filenames]" 41 | ], 42 | "cell_type": "code", 43 | "outputs": [], 44 | "metadata": {}, 45 | "execution_count": 20 46 | }, 47 | { 48 | "source": [ 49 | "arrlist = []\n", 50 | "\n", 51 | "for path in pathlist:\n", 52 | " arr = np.load(path)\n", 53 | " arrlist.append(arr)\n", 54 | "\n", 55 | "arrlist = np.array(arrlist)\n", 56 | "arrlist = arrlist[:,:100]" 57 | ], 58 | "cell_type": "code", 59 | "outputs": [], 60 | "metadata": {}, 61 | "execution_count": 21 62 | }, 63 | { 64 | "source": [ 65 | "arrlist.shape" 66 | ], 67 | "cell_type": "code", 68 | "outputs": [ 69 | { 70 | "output_type": "execute_result", 71 | "data": { 72 | "text/plain": "(4, 100, 1)" 73 | }, 74 | "metadata": {}, 75 | "execution_count": 22 76 | } 77 | ], 78 | "metadata": {}, 79 | "execution_count": 22 80 | }, 81 | { 82 | "source": [ 83 | "arrlist = np.squeeze(arrlist)\n", 84 | "if len(arrlist.shape) == 1:\n", 85 | " arrlist = np.array([arrlist]) \n", 86 | "\n", 87 | "arrlen = len(arrlist)\n", 88 | "print(arrlen)" 89 | ], 90 | "cell_type": "code", 91 | "outputs": [ 92 | { 93 | "output_type": "stream", 94 | "name": "stdout", 95 | "text": "4\n" 96 | } 97 | ], 98 | "metadata": {}, 99 | "execution_count": 23 100 | }, 101 | { 102 | "source": [ 103 | "rewmean = arrlist.mean(axis=0)\n", 104 | "rewstderr = arrlist.std(axis=0)/np.sqrt(arrlen)\n", 105 | "arrshape = rewmean.shape\n", 106 | "stepsize = 0.005\n", 107 | "x_range = np.arange(start=stepsize, stop=stepsize + stepsize*arrshape[0], step=stepsize)" 108 | ], 109 | "cell_type": "code", 110 | "outputs": [], 111 | "metadata": {}, 112 | "execution_count": 24 113 | }, 114 | { 115 | "source": [ 116 | "plt.plot(x_range, rewmean)\n", 117 | "plt.fill_between(x_range, rewmean + rewstderr, rewmean - rewstderr, alpha=0.5)\n", 118 | "#plt.fill_between(x_range, np.min(arrlist, axis=0), np.max(arrlist, axis=0), alpha=0.5)\n", 119 | "plt.title(envname)\n", 120 | "plt.xlabel('million steps')\n", 121 | "plt.ylabel('average return')" 122 | ], 123 | "cell_type": "code", 124 | "outputs": [ 125 | { 126 | "output_type": "execute_result", 127 | "data": { 128 | "text/plain": "Text(0, 0.5, 'average return')" 129 | }, 130 | "metadata": {}, 131 | "execution_count": 25 132 | } 133 | ], 134 | "metadata": {}, 135 | "execution_count": 25 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [] 143 | } 144 | ] 145 | } -------------------------------------------------------------------------------- /sac/sac.py: -------------------------------------------------------------------------------- 1 | # This implementation is based on codes of Jongmin Lee & Byeong-jun Lee 2 | import time 3 | import random 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow_probability as tfp 7 | from tqdm import tqdm 8 | import pickle 9 | from base.rl import ActorCriticRLAlgorithm 10 | from base.replay_buffer import ReplayBuffer 11 | 12 | @tf.function 13 | def clip_with_gradient(x, low=-1, high=1): 14 | clip_high = tf.cast(x > high, tf.float32) 15 | clip_low = tf.cast(x < low, tf.float32) 16 | return x + tf.stop_gradient((high - x) * clip_high + (low - x) * clip_low) 17 | 18 | @tf.function 19 | def apply_squashing_func(sample, logp): 20 | """ 21 | Squash the ouput of the gaussian distribution and account for that in the log probability. 22 | :param sample: (tf.Tensor) Action sampled from Gaussian distribution 23 | :param logp: (tf.Tensor) Log probability before squashing 24 | """ 25 | # Squash the output 26 | squashed_action = tf.tanh(sample) 27 | squashed_action_logp = \ 28 | logp - tf.reduce_sum(tf.math.log( 29 | clip_with_gradient(1 - squashed_action ** 2, low=0, high=1) + 1e-6), axis=1) 30 | # incurred by change of variable 31 | return squashed_action, squashed_action_logp 32 | 33 | 34 | class SquashedGaussianActor(tf.keras.layers.Layer): 35 | 36 | def __init__(self, env): 37 | super(SquashedGaussianActor, self).__init__() 38 | # obs_shape, action_dim, 39 | self.obs_shape = env.observation_space.shape 40 | self.action_dim = env.action_space.shape[0] 41 | self.max_action = env.action_space.high[0] 42 | 43 | # Actor parameters 44 | self.l1 = tf.keras.layers.Dense(64, activation='relu', name='f0', input_shape=(None,) + self.obs_shape) 45 | self.l2 = tf.keras.layers.Dense(64, activation='relu', name='f1') 46 | self.l3_mu = tf.keras.layers.Dense(self.action_dim, name='f2_mu') 47 | self.l3_log_std = tf.keras.layers.Dense(self.action_dim, name='f2_log_std') 48 | 49 | @tf.function 50 | def call(self, inputs, **kwargs): 51 | h = self.l1(inputs) 52 | h = self.l2(h) 53 | mean = self.l3_mu(h) 54 | log_std = self.l3_log_std(h) 55 | std = tf.exp(log_std) 56 | 57 | dist = tfp.distributions.MultivariateNormalDiag(mean, std) 58 | sampled_action = dist.sample() 59 | sampled_action_logp = dist.log_prob(sampled_action) 60 | squahsed_action, squahsed_action_logp = apply_squashing_func(sampled_action, sampled_action_logp) 61 | 62 | return squahsed_action, tf.reshape(squahsed_action_logp, (-1,1)) 63 | 64 | def dist(self, inputs): 65 | h = self.l1(inputs) 66 | h = self.l2(h) 67 | mean = self.l3_mu(h) 68 | log_std = self.l3_log_std(h) 69 | std = tf.exp(log_std) 70 | dist = tfp.distributions.MultivariateNormalDiag(mean, std) 71 | 72 | return dist 73 | 74 | def step(self, obs, deterministic=False): 75 | if deterministic: 76 | dist = self.dist(obs) 77 | mean_action = dist.mean().numpy() 78 | mean_action = np.nan_to_num(mean_action) 79 | squashed_action = np.tanh(mean_action) 80 | 81 | else: 82 | squashed_action, _ = self.call(obs) 83 | squashed_action = np.nan_to_num(squashed_action) 84 | # squashed_action = squashed_action.numpy() 85 | 86 | return squashed_action * self.max_action 87 | 88 | 89 | class VNetwork(tf.keras.layers.Layer): 90 | 91 | def __init__(self, obs_shape, output_dim=1): 92 | super(VNetwork, self).__init__() 93 | 94 | self.v_l0 = tf.keras.layers.Dense(64, activation='relu', name='v/f0', input_shape=(None,) + obs_shape) 95 | self.v_l1 = tf.keras.layers.Dense(64, activation='relu', name='v/f1') 96 | self.v_l2 = tf.keras.layers.Dense(output_dim, name='v/f2') 97 | 98 | @tf.function 99 | def call(self, inputs, **kwargs): 100 | h = self.v_l0(inputs) 101 | h = self.v_l1(h) 102 | v = self.v_l2(h) 103 | return v 104 | 105 | 106 | class QNetwork(tf.keras.layers.Layer): 107 | 108 | def __init__(self, obs_shape, num_critics=2): 109 | super(QNetwork, self).__init__() 110 | self.num_critics = num_critics 111 | 112 | self.qs_l0, self.qs_l1, self.qs_l2 = [], [], [] 113 | for i in range(self.num_critics): 114 | self.qs_l0.append(tf.keras.layers.Dense(64, activation='relu', name='q%d/f0' % i, input_shape=(None,) + obs_shape)) 115 | self.qs_l1.append(tf.keras.layers.Dense(64, activation='relu', name='q%d/f1' % i)) 116 | self.qs_l2.append(tf.keras.layers.Dense(1, name='q%d/f2' % i)) 117 | 118 | @tf.function 119 | def call(self, inputs, **kwargs): 120 | obs, action = inputs 121 | obs_action = tf.concat([obs, action], axis=1) 122 | qs = [] 123 | for i in range(self.num_critics): 124 | h = self.qs_l0[i](obs_action) 125 | h = self.qs_l1[i](h) 126 | q = self.qs_l2[i](h) 127 | qs.append(q) 128 | 129 | return qs 130 | 131 | 132 | class SAC(ActorCriticRLAlgorithm): 133 | 134 | def __init__(self, env, test_env, policy_class=SquashedGaussianActor, 135 | ent_coef='auto', reward_scale=1, seed=0): 136 | super(SAC, self).__init__(policy_class=policy_class, env=env, test_env=test_env) 137 | 138 | self.seed = seed 139 | tf.random.set_seed(seed) 140 | np.random.seed(seed) 141 | random.seed(seed) 142 | 143 | self.env = env 144 | self.test_env = test_env 145 | self.max_action = self.env.action_space.high[0] 146 | self.reward_scale = reward_scale 147 | self.obs_shape = self.env.observation_space.shape 148 | self.state_dim = self.env.observation_space.shape[0] 149 | self.action_dim = self.env.action_space.shape[0] 150 | self.replay_buffer = ReplayBuffer(size=64000) 151 | 152 | self.num_critics = 2 153 | self.gamma = 0.99 154 | self.tau = 0.05 155 | self.learning_rate = 3e-4 156 | self.batch_size = 256 157 | self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) 158 | self.ent_coef = ent_coef 159 | 160 | # self.optimizer_variables = [] 161 | self.info_labels = ['actor_loss', 'v_loss', 'q_loss', 'mean(v)', 162 | 'mean(qs)', 'ent_coef', 'entropy', 'logp_pi'] 163 | 164 | # Entropy coefficient (auto or fixed) 165 | if isinstance(self.ent_coef, str) and self.ent_coef == 'auto': 166 | # Default initial value of ent_coef when learned 167 | init_value = 1.0 168 | self.log_ent_coef = tf.keras.backend.variable(init_value, dtype=tf.float32, name='log_ent_coef') 169 | self.ent_coefficient = tf.exp(self.log_ent_coef) 170 | self.entropy_variables = [self.log_ent_coef] 171 | 172 | else: 173 | self.log_ent_coef = tf.math.log(self.ent_coef) 174 | self.ent_coefficient = tf.constant(self.ent_coef) 175 | 176 | # Actor, Critic Networks 177 | self.actor = policy_class(self.env) 178 | self.v = VNetwork(self.obs_shape) 179 | self.q = QNetwork(self.obs_shape, num_critics=self.num_critics) 180 | self.v_target = VNetwork(self.obs_shape) 181 | 182 | self.actor_variables = self.actor.trainable_variables 183 | self.critic_variables = self.v.trainable_variables + self.q.trainable_variables 184 | 185 | self.actor_optimizer = tf.keras.optimizers.Adam(self.learning_rate) 186 | self.critic_optimizer = tf.keras.optimizers.Adam(self.learning_rate) 187 | 188 | if isinstance(ent_coef, str) and ent_coef == 'auto': 189 | self.entropy_optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate) 190 | 191 | self.optimizer_variables = self.actor.trainable_variables + self.v.trainable_variables + \ 192 | self.q.trainable_variables + self.v_target.trainable_variables 193 | 194 | # @tf.function 195 | def update_target(self, target_params, source_params): 196 | for target, source in zip(target_params, source_params): 197 | tf.keras.backend.set_value(target, (1 - self.tau) * target + self.tau * source) 198 | 199 | @tf.function 200 | def initialize_variables(self): 201 | zero_like_state = tf.zeros((1,) + self.obs_shape) 202 | zero_like_action = tf.zeros((1,self.action_dim)) 203 | self.actor(zero_like_state) 204 | self.v(zero_like_state) 205 | self.v_target(zero_like_state) 206 | self.q(inputs=(zero_like_state, zero_like_action)) 207 | 208 | @tf.function 209 | def train(self, obs, action, reward, next_obs, done): 210 | # Casting from float64 to float32 211 | obs = tf.cast(obs, tf.float32) 212 | action = tf.cast(action, tf.float32) / self.max_action 213 | reward = tf.cast(reward, tf.float32)[:, None] * self.reward_scale 214 | next_obs = tf.cast(next_obs, tf.float32) 215 | done = tf.cast(done, tf.float32)[:, None] 216 | 217 | dist = self.actor.dist(obs) 218 | 219 | with tf.GradientTape() as tape_actor: 220 | # Actor training (pi) 221 | action_pi, logp_pi = self.actor.call(obs) 222 | qs_pi = self.q.call(inputs=(obs, action_pi)) 223 | # min_q_target = tf.reduce_min(qs_pi, axis=0) 224 | actor_loss = tf.reduce_mean(tf.math.exp(self.log_ent_coef) * logp_pi - qs_pi[0]) 225 | 226 | actor_variables = self.actor.trainable_variables 227 | grads_actor = tape_actor.gradient(actor_loss, actor_variables) 228 | actor_op = self.actor_optimizer.apply_gradients(zip(grads_actor, actor_variables)) 229 | 230 | with tf.control_dependencies([actor_op]): 231 | v_target = self.v_target(next_obs) 232 | min_q_pi = tf.reduce_min(qs_pi, axis=0) # (batch, 1) 233 | v_backup = tf.stop_gradient(min_q_pi - tf.math.exp(self.log_ent_coef) * logp_pi) # (batch, 1) 234 | q_backup = tf.stop_gradient(reward + (1 - done) * self.gamma * v_target) # (batch, 1) 235 | 236 | with tf.GradientTape() as tape_critic: 237 | # Critic training (V, Q) 238 | v = self.v(obs) 239 | v_loss = 0.5 * tf.reduce_mean((v_backup - v) ** 2) # MSE, scalar 240 | 241 | qs = self.q(inputs=(obs, action)) 242 | q_losses = [0.5 * tf.reduce_mean((q_backup - qs[k]) ** 2) for k in range(self.num_critics)] # (2, batch) 243 | q_loss = tf.reduce_sum(q_losses, axis=0) # scalar 244 | 245 | value_loss = v_loss + q_loss 246 | 247 | critic_variables = self.v.trainable_variables + self.q.trainable_variables 248 | grads_critic = tape_critic.gradient(value_loss, critic_variables) 249 | self.critic_optimizer.apply_gradients(zip(grads_critic, critic_variables)) 250 | 251 | if isinstance(self.ent_coef, str) and self.ent_coef == 'auto': 252 | with tf.GradientTape() as tape_ent: 253 | ent_coef_loss = -tf.reduce_mean(self.log_ent_coef * tf.stop_gradient(logp_pi + self.target_entropy)) 254 | 255 | entropy_variables = [self.log_ent_coef] 256 | grads_ent = tape_ent.gradient(ent_coef_loss, entropy_variables) 257 | self.entropy_optimizer.apply_gradients(zip(grads_ent, entropy_variables)) 258 | 259 | return actor_loss, tf.reduce_mean(v_loss), tf.reduce_mean(q_loss), tf.reduce_mean(v), tf.reduce_mean(qs), \ 260 | tf.math.exp(self.log_ent_coef), tf.reduce_mean(dist.entropy()), tf.reduce_mean(logp_pi) 261 | 262 | 263 | def learn(self, total_timesteps, log_interval=640, callback=None, verbose=1, 264 | eval_interval=5000, eval_rollout=True, save_path=None, save_interval=500000): 265 | 266 | self.initialize_variables() 267 | for target, source in zip(self.v_target.trainable_variables, self.v.trainable_variables): 268 | tf.keras.backend.set_value(target, source.numpy()) 269 | 270 | start_time = time.time() 271 | episode_rewards = [] 272 | eval_rewards = [] 273 | 274 | obs = self.env.reset() 275 | current_episode_reward = 0 276 | 277 | for step in tqdm(range(total_timesteps), desc='SAC', ncols=70): 278 | if callback is not None: 279 | if callback(locals(), globals()) is False: 280 | break 281 | 282 | # Take an action 283 | action = np.reshape(self.predict(np.array([obs]), deterministic=False)[0], -1) 284 | next_obs, reward, done, _ = self.env.step(action) 285 | 286 | # Store transition in the replay buffer. 287 | self.replay_buffer.add(obs, action, reward, next_obs, float(done)) 288 | obs = next_obs 289 | current_episode_reward += reward 290 | 291 | if done: 292 | obs = self.env.reset() 293 | episode_rewards.append(current_episode_reward) 294 | current_episode_reward = 0.0 295 | 296 | if self.replay_buffer.can_sample(self.batch_size): 297 | obss, actions, rewards, next_obss, dones = self.replay_buffer.sample(self.batch_size) # action is normalize 298 | 299 | step_info = self.train(obss, actions, rewards, next_obss, dones) 300 | 301 | if verbose >= 1 and step % log_interval == 0: 302 | print('\n============================') 303 | print('%15s: %10.6f' % ('10ep_rewmean', np.mean(episode_rewards[-10:]))) 304 | for i, label in enumerate(self.info_labels): 305 | print('%15s: %10.6f' %(label, step_info[i].numpy())) 306 | print('============================\n') 307 | 308 | self.update_target(self.v_target.trainable_variables, self.v.trainable_variables) 309 | 310 | if step % eval_interval == 0: 311 | if eval_rollout: 312 | eval_rewards.append(self.evaluate(1)) 313 | else: 314 | eval_rewards.append(episode_rewards[-1]) 315 | 316 | if step % save_interval == 0 and save_path is not None: 317 | print('** Saving models and evaluation returns..') 318 | np.save(save_path + "/%s_rews_seed%d_iter%d.npy"%(self.env.spec.id, self.seed, step), 319 | np.array(eval_rewards)) 320 | self.save(save_path + "/%s_model_seed%d.zip" % (self.env.spec.id, self.seed) ) 321 | 322 | return eval_rewards 323 | 324 | def predict(self, obs, deterministic=False): 325 | obs_rank = len(obs.shape) 326 | if len(obs.shape) == 1: 327 | obs = np.array([obs]) 328 | assert len(obs.shape) == 2 329 | 330 | action = self.actor.step(obs, deterministic=deterministic) 331 | # action = np.clip(action, self.action_space.low, self.action_space.high) 332 | 333 | if obs_rank == 1: 334 | return action[0], None 335 | else: 336 | return action, None 337 | 338 | def load(self, filepath): 339 | self.initialize_variables() 340 | 341 | with open(filepath, 'rb') as f: 342 | parameters = pickle.load(f) 343 | self.load_parameters(parameters) 344 | -------------------------------------------------------------------------------- /sac/train_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import setGPU 3 | import gym 4 | import numpy as np 5 | 6 | from sac import SAC 7 | 8 | def main(args): 9 | """ 10 | Train and save the SAC model, for the halfcheetah problem 11 | 12 | :param args: (ArgumentParser) the input arguments 13 | """ 14 | env = gym.make(args.env) 15 | test_env = gym.make(args.env) 16 | 17 | if args.ent_coef is None: 18 | args.ent_coef = 'auto' 19 | 20 | model = SAC(env=env, 21 | test_env=test_env, 22 | seed=int(args.seed), 23 | ent_coef=args.ent_coef, 24 | reward_scale=5. 25 | ) 26 | ep_rewards = model.learn(total_timesteps=int(args.max_timesteps), 27 | save_path=args.save_path) 28 | 29 | model.save(args.save_path + "/%s_model_seed%d_fin_auto.zip"%(args.env, int(args.seed))) 30 | np.save(args.save_path + "/%s_rews_seed%d_fin_auto.npy"%(args.env, int(args.seed)), np.array(ep_rewards)) 31 | 32 | # print("Saving model to halfcheetah_model.zip") 33 | # model.learn(total_timesteps=100) 34 | # model.load("halfcheetah_model.zip") 35 | 36 | model.evaluate(10) 37 | 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser(description="Train SAC") 41 | parser.add_argument('--max-timesteps', default=2000000, type=int, help="Maximum number of timesteps") 42 | parser.add_argument('--seed', default=1, type=int, help="Random seed for training") 43 | parser.add_argument('--env', default="HalfCheetah-v2") 44 | parser.add_argument('--ent_coef', default='auto') 45 | parser.add_argument('--reward_scale', default=1.) 46 | parser.add_argument('--save_path', default="results_final") 47 | 48 | args = parser.parse_args() 49 | print("- Environment : %s" % (args.env) ) 50 | print("- Seed : %d" % (args.seed) ) 51 | print("- Ent_coef : %s" % str(args.ent_coef) ) 52 | 53 | main(args) 54 | 55 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for SEED in 1 2 3 4 5 3 | do 4 | PYTHONPATH=. python sac/train_example.py --seed=$SEED; 5 | done --------------------------------------------------------------------------------