├── DQN_network.py ├── README.md ├── demo_result ├── assault.png ├── mspacman.png └── robotank_100M.png ├── env.py ├── main.py ├── replay_memory.py ├── runIt.sh └── utils.py /DQN_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import init 5 | import numpy as np 6 | import math 7 | 8 | import sys 9 | import datetime 10 | def print_now(cmd): 11 | time_now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 12 | print('%s %s' % (time_now, cmd)) 13 | sys.stdout.flush() 14 | 15 | 16 | class NoisyLinear(nn.Module): 17 | def __init__(self, in_features, out_features, std_init=0.1): 18 | super(NoisyLinear, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | # Uniform Distribution bounds: 22 | # U(-1/sqrt(p), 1/sqrt(p)) 23 | self.lowerU = -1.0 / math.sqrt(in_features) # 24 | self.upperU = 1.0 / math.sqrt(in_features) # 25 | self.sigma_0 = std_init 26 | self.sigma_ij_in = self.sigma_0 / math.sqrt(self.in_features) 27 | self.sigma_ij_out = self.sigma_0 / math.sqrt(self.out_features) 28 | 29 | """ 30 | Registre_Buffer: Adds a persistent buffer to the module. 31 | A buffer that is not to be considered as a model parameter -- like "running_mean" in BatchNorm 32 | It is a "persistent state" and can be accessed as attributes --> self.weight_epsilon 33 | """ 34 | self.weight_mu = nn.Parameter(torch.empty(out_features, in_features)) 35 | self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features)) 36 | self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) 37 | 38 | self.bias_mu = nn.Parameter(torch.empty(out_features)) 39 | self.bias_sigma = nn.Parameter(torch.empty(out_features)) 40 | self.register_buffer('bias_epsilon', torch.empty(out_features)) 41 | 42 | self.reset_parameters() 43 | self.sample_noise() 44 | 45 | def reset_parameters(self): 46 | self.weight_mu.data.uniform_(self.lowerU, self.upperU) 47 | self.weight_sigma.data.fill_(self.sigma_ij_in) 48 | 49 | self.bias_mu.data.uniform_(self.lowerU, self.upperU) 50 | self.bias_sigma.data.fill_(self.sigma_ij_out) 51 | 52 | def sample_noise(self): 53 | eps_in = self.func_f(self.in_features) 54 | eps_out = self.func_f(self.out_features) 55 | # Take the outter product 56 | """ 57 | >>> v1 = torch.arange(1., 5.) [1, 2, 3, 4] 58 | >>> v2 = torch.arange(1., 4.) [1, 2, 3] 59 | >>> torch.ger(v1, v2) 60 | tensor([[ 1., 2., 3.], 61 | [ 2., 4., 6.], 62 | [ 3., 6., 9.], 63 | [ 4., 8., 12.]]) 64 | """ 65 | eps_ij = eps_out.ger(eps_in) 66 | self.weight_epsilon.copy_(eps_ij) 67 | self.bias_epsilon.copy_(eps_out) 68 | 69 | def func_f(self, n): # size 70 | # sign(x) * sqrt(|x|) as in paper 71 | x = torch.rand(n) 72 | return x.sign().mul_(x.abs().sqrt_()) 73 | 74 | def forward(self, x): 75 | if self.training: 76 | return F.linear(x, self.weight_mu + self.weight_sigma*self.weight_epsilon, 77 | self.bias_mu + self.bias_sigma *self.bias_epsilon) 78 | 79 | else: 80 | return F.linear(x, self.weight_mu, 81 | self.bias_mu) 82 | 83 | class DQN(nn.Module): 84 | def __init__(self, num_inputs, hidden_size=512, num_actions=1, use_duel=False, use_noisy_net=False): 85 | super(DQN, self).__init__() 86 | init_ = lambda m: init(m, 87 | nn.init.orthogonal_, 88 | lambda x: nn.init.constant_(x, 0), 89 | nn.init.calculate_gain('relu')) 90 | init2_ = lambda m: init(m, 91 | nn.init.orthogonal_, 92 | lambda x: nn.init.constant_(x, 0)) 93 | self.use_duel = use_duel 94 | self.use_noisy_net = use_noisy_net 95 | 96 | self.conv1 = init_(nn.Conv2d(num_inputs, 32, 8, stride=4)) 97 | self.conv2 = init_(nn.Conv2d(32, 64, 4, stride=2)) 98 | self.conv3 = init_(nn.Conv2d(64, 32, 3, stride=1)) 99 | 100 | 101 | 102 | if use_noisy_net: 103 | Linear = NoisyLinear 104 | else: 105 | Linear = nn.Linear 106 | 107 | if self.use_duel: 108 | self.val_fc = Linear(32*7*7, hidden_size) 109 | self.val = Linear(hidden_size, 1) 110 | self.adv_fc = Linear(32*7*7, hidden_size) 111 | self.adv = Linear(hidden_size, num_actions) 112 | if not use_noisy_net: 113 | self.val_fc = init_(self.val_fc) 114 | self.adv_fc = init_(self.adv_fc) 115 | self.val = init2_(self.val) 116 | self.adv = init2_(self.adv) 117 | 118 | else: 119 | self.fc = Linear(32*7*7, hidden_size) 120 | self.critic_linear = Linear(hidden_size, num_actions) 121 | if not use_noisy_net: 122 | self.fc = init_(self.fc) 123 | self.critic_linear = init2_(self.critic_linear) 124 | 125 | self.train() 126 | 127 | 128 | def forward(self, x): 129 | x = x / 255.0 130 | x = F.relu(self.conv1(x)) 131 | x = F.relu(self.conv2(x)) 132 | x = F.relu(self.conv3(x)) 133 | x = x.view(x.size(0), -1) 134 | if self.use_duel: 135 | val = self.val(F.relu(self.val_fc(x))) 136 | adv = self.adv(F.relu(self.adv_fc(x))) 137 | y = val + adv - adv.mean() 138 | else: 139 | x = F.relu(self.fc(x)) 140 | y = self.critic_linear(x) 141 | return y 142 | def sample_noise(self): 143 | if self.use_noisy_net: 144 | if self.use_duel: 145 | self.val_fc.sample_noise() 146 | self.val.sample_noise() 147 | self.adv_fc.sample_noise() 148 | self.adv.sample_noise() 149 | else: 150 | self.fc.sample_noise() 151 | self.critic_linear.sample_noise() 152 | 153 | 154 | class C51(nn.Module): 155 | def __init__(self, num_inputs, hidden_size=512, num_actions=4, 156 | use_duel=False, use_noisy_net=False, atoms=51, vmin=-10, vmax=10, use_qr_c51=False): 157 | super(C51, self).__init__() 158 | self.atoms = atoms 159 | self.vmin = vmin 160 | self.vmax = vmax 161 | self.num_actions = num_actions 162 | self.use_duel = use_duel 163 | self.use_noisy_net = use_noisy_net 164 | self.use_qr_c51 = use_qr_c51 165 | 166 | 167 | init_ = lambda m: init(m, 168 | nn.init.kaiming_uniform_, 169 | lambda x: nn.init.constant_(x, 0), 170 | nonlinearity='relu', 171 | mode='fan_in') 172 | init2_ = lambda m: init(m, 173 | nn.init.kaiming_uniform_, 174 | lambda x: nn.init.constant_(x, 0), 175 | nonlinearity='relu', 176 | mode='fan_in') 177 | 178 | 179 | self.conv1 = init_(nn.Conv2d(num_inputs, 32, 8, stride=4)) 180 | self.conv2 = init_(nn.Conv2d(32, 64, 4, stride=2)) 181 | self.conv3 = init_(nn.Conv2d(64, 32, 3, stride=1)) 182 | 183 | if use_noisy_net: 184 | Linear = NoisyLinear 185 | else: 186 | Linear = nn.Linear 187 | 188 | self.fc1 = Linear(32*7*7, hidden_size) 189 | self.fc2 = Linear(hidden_size, num_actions*atoms) 190 | 191 | if self.use_duel: 192 | self.val_fc = Linear(32*7*7, hidden_size) 193 | self.val = Linear(hidden_size, atoms) 194 | 195 | # Param init 196 | if not use_noisy_net: 197 | self.fc1 = init_(self.fc1) 198 | self.fc2 = init2_(self.fc2) 199 | if self.use_duel: 200 | self.val_fc = init_(self.val_fc) 201 | self.val = init2_(self.val) 202 | 203 | 204 | 205 | def forward(self, x): 206 | x = x / 255.0 207 | x = F.relu(self.conv1(x)) 208 | x = F.relu(self.conv2(x)) 209 | x = F.relu(self.conv3(x)) 210 | x = x.view(x.size(0), -1) 211 | 212 | if self.use_duel: 213 | val_x = F.relu(self.val_fc(x)) 214 | values = self.val(val_x).unsqueeze(1) # from batch x atoms to batch x 1 x atoms 215 | 216 | x = F.relu(self.fc1(x)) 217 | x = self.fc2(x) 218 | x_batch = x.view(-1, self.num_actions, self.atoms) 219 | 220 | duel = values + x_batch - x_batch.mean(1, keepdim=True) 221 | if self.use_qr_c51: 222 | y = duel 223 | else: 224 | y = F.log_softmax(duel, dim = 2).exp() # y is of shape [batch x action x atoms] 225 | else: 226 | # A Tensor of shape [batch x actions x atoms]. 227 | x = F.relu(self.fc1(x)) 228 | x = self.fc2(x) 229 | x_batch = x.view(-1, self.num_actions, self.atoms) 230 | if self.use_qr_c51: 231 | y = x_batch 232 | else: 233 | y = F.log_softmax(x_batch, dim=2).exp() # y is of shape [batch x action x atoms] 234 | 235 | return y 236 | 237 | def sample_noise(self): 238 | if self.use_noisy_net: 239 | if self.use_duel: 240 | self.fc1.sample_noise() 241 | self.fc2.sample_noise() 242 | self.val_fc.sample_noise() 243 | self.val.sample_noise() 244 | else: 245 | self.fc1.sample_noise() 246 | self.fc2.sample_noise() 247 | 248 | class IQN_C51(nn.Module): 249 | def __init__(self, num_inputs, hidden_size=512, num_actions=4, 250 | use_duel=False, use_noisy_net=False): 251 | super(IQN_C51, self).__init__() 252 | self.num_actions = num_actions 253 | self.use_duel = use_duel 254 | self.use_noisy_net = use_noisy_net 255 | self.quantile_embedding_dim = 64 256 | self.pi = np.pi 257 | 258 | 259 | init_ = lambda m: init(m, 260 | nn.init.kaiming_uniform_, 261 | lambda x: nn.init.constant_(x, 0), 262 | gain=nn.init.calculate_gain('relu'), 263 | mode='fan_in') 264 | init2_ = lambda m: init(m, 265 | nn.init.kaiming_uniform_, 266 | lambda x: nn.init.constant_(x, 0), 267 | gain=nn.init.calculate_gain('relu'), 268 | mode='fan_in') 269 | 270 | 271 | self.conv1 = init_(nn.Conv2d(num_inputs, 32, 8, stride=4)) 272 | self.conv2 = init_(nn.Conv2d(32, 64, 4, stride=2)) 273 | self.conv3 = init_(nn.Conv2d(64, 32, 3, stride=1)) 274 | 275 | if use_noisy_net: 276 | Linear = NoisyLinear 277 | else: 278 | Linear = nn.Linear 279 | # ---------------------------------------------------------------------------- 280 | # self.fc1 = Linear(32*7*7, hidden_size) 281 | self.fc2 = Linear(hidden_size, num_actions*1) 282 | # ---------------------------------------------------------------------------- 283 | Atari_Input = torch.FloatTensor(1, num_inputs, 84, 84) 284 | temp_fea = self.conv3(self.conv2(self.conv1(Atari_Input))) 285 | temp_fea = temp_fea.view(temp_fea.size(0), -1) 286 | state_net_size = temp_fea.size(1) 287 | del Atari_Input 288 | del temp_fea 289 | 290 | self.quantile_fc0 = nn.Linear(self.quantile_embedding_dim, state_net_size) 291 | self.quantile_fc1 = nn.Linear(state_net_size, hidden_size) 292 | # ---------------------------------------------------------------------------- 293 | if self.use_duel: 294 | self.quantile_fc_value = Linear(hidden_size, 1) 295 | 296 | # Param init 297 | if not use_noisy_net: 298 | self.quantile_fc0 = init2_(self.quantile_fc0) 299 | self.quantile_fc1 = init2_(self.quantile_fc1) 300 | self.fc2 = init2_(self.fc2) 301 | if self.use_duel: 302 | self.quantile_fc_value = init2_(self.quantile_fc_value) 303 | 304 | 305 | def forward(self, x, num_quantiles): 306 | x = x / 255.0 307 | x = F.relu(self.conv1(x)) 308 | x = F.relu(self.conv2(x)) 309 | x = F.relu(self.conv3(x)) 310 | x = x.view(x.size(0), -1) 311 | 312 | BATCH_SIZE = x.size(0) 313 | state_net_size = x.size(1) 314 | 315 | tau = torch.FloatTensor(BATCH_SIZE * num_quantiles, 1).to(x) 316 | tau.uniform_(0, 1) 317 | # ---------------------------------------------------------------------------------------------- 318 | quantile_net = torch.FloatTensor([i for i in range(1, 1+self.quantile_embedding_dim)]).to(x) 319 | # ------------------------------------------------------------------------------------------------- 320 | tau_expand = tau.unsqueeze(-1).expand(-1, -1, self.quantile_embedding_dim) # [Batch*Np x 1 x 64] 321 | quantile_net = quantile_net.view(1, 1, -1) # [1 x 1 x 64] --> [Batch*Np x 1 x 64] 322 | quantile_net = quantile_net.expand(BATCH_SIZE*num_quantiles, 1, self.quantile_embedding_dim) 323 | cos_tau = torch.cos(quantile_net * self.pi * tau_expand) # [Batch*Np x 1 x 64] 324 | cos_tau = cos_tau.squeeze(1) # [Batch*Np x 64] 325 | # ------------------------------------------------------------------------------------------------- 326 | out = F.relu(self.quantile_fc0(cos_tau)) # [Batch*Np x feaSize] 327 | # fea_tile = torch.cat([x]*num_quantiles, dim=0) 328 | fea_tile = x.unsqueeze(1).expand(-1, num_quantiles, -1) # [Batch x Np x feaSize] 329 | out = out.view(BATCH_SIZE, num_quantiles, -1) # [Batch x Np x feaSize] 330 | product = (fea_tile * out).view(BATCH_SIZE*num_quantiles, -1) 331 | combined_fea = F.relu(self.quantile_fc1(product)) # (Batch*atoms, 512) 332 | 333 | if self.use_duel: 334 | values = self.quantile_fc_value(combined_fea) # from [batch*atoms x 1] to [Batch x 1 x Atoms] 335 | values = values.view(-1, num_quantiles).unsqueeze(1) 336 | 337 | x = self.fc2(combined_fea) 338 | x_batch = x.view(BATCH_SIZE, num_quantiles, self.num_actions) 339 | # After transpose, x_batch becomes [batch x actions x atoms] 340 | x_batch = x_batch.transpose(1, 2).contiguous() 341 | action_component = x_batch - x_batch.mean(1, keepdim=True) 342 | 343 | duel_y = values + action_component 344 | y = duel_y 345 | else: 346 | x = self.fc2(combined_fea) 347 | # [batch x atoms x actions]. 348 | y = x.view(BATCH_SIZE, num_quantiles, self.num_actions) 349 | # output should be # A Tensor of shape [batch x actions x atoms]. 350 | y = y.transpose(1, 2).contiguous() 351 | # ------------------------------------------------------------------------------------------------ # 352 | return y, tau # [batch x actions x atoms] 353 | 354 | def sample_noise(self): 355 | if self.use_noisy_net: 356 | if self.use_duel: 357 | self.fc2.sample_noise() 358 | self.quantile_fc_value.sample_noise() 359 | else: 360 | self.fc2.sample_noise() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dqn-pytorch 2 | Towards learning Rainbow-DQN and all that good stuffs in pytorch. 3 | 4 | Run the codes with 5 | 6 | time ./runIt.sh 7 | 8 | A few notes (Bug might exist): 9 | 10 | 1. Testing results for 10M is shown in Figure 1. 11 | 12 | The comparisons are for "DDQN + C51", "DDQN + QR-C200" and "DDQN + IQN-64-64-32". 13 | 14 | 2. Rainbow in this repo runs kinda slow on my machine (_TITAN Xp with Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz_). It only reached 75 FPS when using a prioritized memory of size 50,000 on the game _Spaceinvaders_, so it seems that it may not be able to finish 200M within 10 days. (Could be because of the sub-optimal way that is used for pushing new transitions into the PER buffer.) 15 | 16 | 3. The result inconsistency (shown by the following figures) w.r.t. the Google Dopamine implementation mainly comes from the fact that we use V4 environments while the reported results by Google Dopamine utilize V0 environments with "sticky" actions. 17 | 18 | [Figure 1] 19 | ![alt text](https://raw.githubusercontent.com/dannysdeng/dqn-pytorch/master/demo_result/assault.png) 20 | ![alt text](https://raw.githubusercontent.com/dannysdeng/dqn-pytorch/master/demo_result/mspacman.png) 21 | ![alt text](https://raw.githubusercontent.com/dannysdeng/dqn-pytorch/master/demo_result/robotank_100M.png) 22 | 23 | Useful references: 24 | 25 | [0] IQN implementation reference: https://github.com/google/dopamine/tree/master/dopamine 26 | 27 | [1] Very helpful pytorch code base: https://github.com/qfettes/DeepRL-Tutorials 28 | 29 | [2] Tutorial on C51 https://mtomassoli.github.io/2017/12/08/distributional_rl/ 30 | 31 | [3] Hyperparam of target interval: lower might be better: https://www.noob-programmer.com/openai-retro-contest/how-to-score-6k-in-leaderboard/ 32 | 33 | [4] APX-DPG: something better than rainbow: https://arxiv.org/pdf/1803.00933.pdf 34 | 35 | -------------------------------------------------------------------------------- /demo_result/assault.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannysdeng/dqn-pytorch/88c4aaf350c554c7d5a5caf63ca432f1ec6945e4/demo_result/assault.png -------------------------------------------------------------------------------- /demo_result/mspacman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannysdeng/dqn-pytorch/88c4aaf350c554c7d5a5caf63ca432f1ec6945e4/demo_result/mspacman.png -------------------------------------------------------------------------------- /demo_result/robotank_100M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannysdeng/dqn-pytorch/88c4aaf350c554c7d5a5caf63ca432f1ec6945e4/demo_result/robotank_100M.png -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Be really carefully when constructing the memory replay buffer as the environment has been wrapped to product 3 | "state" and "next_state" using self.stacked_obs. 4 | """ 5 | 6 | import os 7 | 8 | import gym 9 | import numpy as np 10 | import torch 11 | from gym.spaces.box import Box 12 | 13 | from baselines import bench 14 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 15 | from baselines.common.vec_env import VecEnvWrapper 16 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 17 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 18 | from baselines.common.vec_env.vec_normalize import VecNormalize as VecNormalize_ 19 | 20 | import sys 21 | import datetime 22 | def print_now(cmd): 23 | time_now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 24 | print('%s %s' % (time_now, cmd)) 25 | sys.stdout.flush() 26 | 27 | # Make sure this is an atari environment (best if it is MuJoCo Compatible) 28 | def make_env(env_id, seed, rank, log_dir, add_timestep, allow_early_resets): 29 | assert(log_dir is not None) 30 | def _thunk(): 31 | env = gym.make(env_id) 32 | is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 33 | if is_atari: 34 | env = make_atari(env_id) 35 | env.seed(seed + rank) 36 | 37 | obs_shape = env.observation_space.shape 38 | if add_timestep: 39 | if len(obs_shape) == 1 and str(env).find(TimeLimt) > -1: 40 | print_now('Adding timestep wrapper to env') 41 | env = AddTimestep(env) 42 | 43 | env = bench.Monitor(env, os.path.join(log_dir, str(rank)), allow_early_resets=allow_early_resets) 44 | if is_atari: 45 | env = wrap_deepmind(env) 46 | # If the input is of shape (W, H, 3), wrap for PyTorch (N, 3, W, H) 47 | obs_shape = env.observation_space.shape 48 | if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: 49 | env = TransposeImage(env) 50 | return env 51 | return _thunk 52 | 53 | ### 54 | # Vectorizer to give [4x84x84 x num_processes] 55 | ### 56 | def make_vec_envs(env_name, seed, num_processes, gamma, log_dir, 57 | add_timestep, device, allow_early_resets, num_frame_stack=None): 58 | envs = [make_env(env_name, seed, i, log_dir, add_timestep, allow_early_resets) for i in range(num_processes)] 59 | if len(envs) > 1: 60 | envs = SubprocVecEnv(envs) 61 | else: 62 | envs = DummyVecEnv(envs) 63 | # 64 | # This is for MuJoCo Maybe? 65 | if len(envs.observation_space.shape) == 1: 66 | print_now('Performning VecNormalize as observation_space is of shape 1') 67 | if gamma is None: 68 | envs = VecNormalize(envs, ret=False) 69 | else: 70 | envs = VecNormalize(envs, gamma=gamma) 71 | # 72 | envs = VecPyTorch(envs, device) 73 | if num_frame_stack is not None: 74 | # If there is some pre-defined framestack: 75 | envs = VecPyTorchFrameStack(envs, num_frame_stack, device) 76 | elif len(envs.observation_space.shape) == 3: 77 | print_now('Using default 4-frame stack for image-based envs') 78 | envs = VecPyTorchFrameStack(envs, 4, device) 79 | # 80 | return envs 81 | 82 | class VecPyTorch(VecEnvWrapper): 83 | def __init__(self, venv, device): 84 | """ Return only every 'skip'-th frame """ 85 | super(VecPyTorch, self).__init__(venv) 86 | self.device = device 87 | def reset(self): 88 | obs = self.venv.reset() 89 | obs = torch.from_numpy(obs).float().to(self.device) 90 | return obs 91 | def step_async(self, actions): 92 | actions = actions.squeeze(1).cpu().numpy() 93 | self.venv.step_async(actions) 94 | def step_wait(self): 95 | obs, reward, done, info = self.venv.step_wait() 96 | obs = torch.from_numpy(obs).float().to(self.device) 97 | reward = torch.from_numpy(reward).unsqueeze(dim=1).float() # N --> N x 1 98 | return obs, reward, done, info 99 | 100 | class VecPyTorchFrameStack(VecEnvWrapper): 101 | """ OpenAI-baseline style framestack """ 102 | def __init__(self, venv, nstack, device=None): 103 | self.venv = venv 104 | self.nstack = nstack 105 | wrapped_ob_space = venv.observation_space # should be 1 x 84 x 84 106 | self.shape_dim0 = wrapped_ob_space.shape[0] # shape_dim0 is 1 107 | 108 | # wrapped_ob_space.low is ZERO matrix of size 1 x 84 x 84, we make it 4 x 84 x 84 now 109 | # wrapped_ob_space.high is 255-matrix of size 1 x 84 x 84, we make it 4 x 84 x 84 now 110 | low = np.repeat(wrapped_ob_space.low, self.nstack, axis=0) 111 | high = np.repeat(wrapped_ob_space.high, self.nstack, axis=0) 112 | 113 | if device is None: 114 | device = torch.device('cpu') 115 | new_shape_tuple = (venv.num_envs, ) + low.shape # num_processes x 4 x 84 x 84 116 | self.stacked_obs = torch.zeros(new_shape_tuple).to(device) 117 | 118 | observation_space = gym.spaces.Box( 119 | low=low, high=high, dtype=venv.observation_space.dtype) 120 | 121 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 122 | 123 | def step_wait(self): 124 | obs, rewards, dones, infos = self.venv.step_wait() 125 | # This is stacking 4 frames together 126 | # self.stacked_obs[:, :-1] is everything (first 3) except the last one, 127 | # self.stacked_obs[:, -1:] is everything (last 3) except the first one 128 | self.stacked_obs[:, :-self.shape_dim0] = self.stacked_obs[:, self.shape_dim0:] # essentially pops the first 1 out 129 | for i, done in enumerate(dones): 130 | if done: 131 | self.stacked_obs[i] = 0 132 | # 133 | # self.stacked_obs[:, -1:] = obs 134 | self.stacked_obs[:, -self.shape_dim0:] = obs # put the new observation at the last 1 position 135 | return self.stacked_obs, rewards, dones, infos 136 | 137 | def reset(self): 138 | obs = self.venv.reset() 139 | # Zero-out everything in the stacked env 140 | self.stacked_obs.zero_() 141 | self.stacked_obs[:, -self.shape_dim0:] = obs # put the first state (new observation) at the last 1 position 142 | return self.stacked_obs 143 | 144 | def close(self): 145 | self.venv.close() 146 | 147 | 148 | ## Helper Wraper: 149 | class AddTimestep(gym.ObservationWrapper): 150 | def __init__(self, env=None): 151 | super(AddTimestep, self).__init__(env) 152 | self.observation_space = Box( 153 | self.observation_space.low[0], 154 | self.observation_space.high[0], 155 | [self.observation_space.shape[0] + 1], 156 | dtype=self.observation_space.dtype) 157 | 158 | def observation(self, observation): 159 | return np.concatenate((observation, [self.env._elapsed_steps])) 160 | 161 | class TransposeImage(gym.ObservationWrapper): 162 | def __init__(self, env=None): 163 | super(TransposeImage, self).__init__(env) 164 | obs_shape = self.observation_space.shape 165 | self.observation_space = Box( 166 | self.observation_space.low[0, 0, 0], 167 | self.observation_space.high[0, 0, 0], 168 | [obs_shape[2], obs_shape[1], obs_shape[0]], 169 | dtype=self.observation_space.dtype) 170 | 171 | def observation(self, observation): 172 | # Observation is of type Tensor 173 | return observation.transpose(2, 0, 1) 174 | 175 | 176 | class VecNormalize(VecNormalize_): 177 | 178 | def __init__(self, *args, **kwargs): 179 | super(VecNormalize, self).__init__(*args, **kwargs) 180 | self.training = True 181 | 182 | def _obfilt(self, obs): 183 | if self.ob_rms: 184 | if self.training: 185 | self.ob_rms.update(obs) 186 | obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) 187 | return obs 188 | else: 189 | return obs 190 | 191 | def train(self): 192 | self.training = True 193 | 194 | def eval(self): 195 | self.training = False 196 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | The design architecture follows https://github.com/ikostrikov/pytorch-a2c-ppo-acktr 3 | Each components follow closely with the great tutorial: https://github.com/qfettes/DeepRL-Tutorials 4 | """ 5 | import copy 6 | import glob 7 | import os 8 | import time 9 | from collections import deque 10 | import random 11 | import argparse 12 | 13 | import gym 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.optim as optim 19 | import gc 20 | 21 | 22 | """ A2C specific arguments """ 23 | #import algo 24 | #from arguments import get_args 25 | 26 | # from envs import make_vec_envs 27 | # from model import Policy 28 | # from storage import RolloutStorage 29 | # from utils import get_vec_normalize 30 | 31 | # DQN specific arguments 32 | from DQN_network import DQN, C51, IQN_C51 33 | from replay_memory import ReplayMemory, PrioritizedReplayBuffer 34 | from utils import init 35 | from env import make_vec_envs 36 | from baselines.common.schedules import LinearSchedule 37 | from collections import namedtuple 38 | 39 | import sys 40 | import datetime 41 | def print_now(cmd): 42 | time_now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 43 | print('%s %s' % (time_now, cmd)) 44 | sys.stdout.flush() 45 | 46 | # Arguments 47 | parser = argparse.ArgumentParser(description='DQN Pytorch') 48 | parser.add_argument('--env-name', default='PongNoFrameskip-v4', 49 | help='environment to train on (default: PongNoFrameskip-v4)') 50 | parser.add_argument('--log-dir', default='./agentLog', 51 | help='directory to save agent logs (default: ./agentLog)') 52 | parser.add_argument('--save-dir', default='./saved_model', 53 | help='directory to save agent logs (default: ./saved_model)') 54 | parser.add_argument('--seed', type=int, default=1234, 55 | help='random seed (default: 1234)') 56 | parser.add_argument('--save-interval', type=int, default=100, 57 | help='save interval, one save per n updates (default: 100)') 58 | parser.add_argument('--total-timestep', type=float, default=1e8, 59 | help='total timestep (default: 1e8)') 60 | parser.add_argument('--num-processes', type=int, default=1, 61 | help='num processes (default: 1)') 62 | parser.add_argument('--gamma', type=int, default=0.99, 63 | help='discount factor gamma (default 0.99)') 64 | parser.add_argument('--kappa', type=float, default=1.0, 65 | help='discount factor gamma (default 0.99)') 66 | parser.add_argument('--add-timestep', action='store_true', default=False, 67 | help='add timestep to observations') 68 | parser.add_argument('--no-cuda', action='store_true', default=False, 69 | help='disables CUDA training (default to use CUDA)') 70 | parser.add_argument('--batch-size', type=int, default=32, 71 | help='batch size in DQN (default: 32)') 72 | parser.add_argument('--train-freq', type=int, default=4, 73 | help='frequency in DQN training. Every 4 frames') 74 | parser.add_argument('--target-update', type=int, default=32000, 75 | help='frequency in target-network update. Every 1000 steps') 76 | parser.add_argument('--memory-size', type=int, default=1000000, 77 | help='memory size - 10,000 transitions') 78 | parser.add_argument('--learning-starts', type=int, default=80000, 79 | help='learning starts after - 80,000 transitions') 80 | parser.add_argument('--num-lookahead', type=int, default=3, 81 | help='look ahead step - 3 transitions') 82 | 83 | parser.add_argument('--use-double-dqn', action='store_true', default=False, 84 | help='use-double-dqn') 85 | 86 | parser.add_argument('--use-prioritized-buffer', action='store_true', default=False, 87 | help='use-prioritized replay buffer') 88 | 89 | parser.add_argument('--use-n-step', action='store_true', default=False, 90 | help='use-prioritized replay buffer') 91 | 92 | parser.add_argument('--use-duel', action='store_true', default=False, 93 | help='use dueling architecture') 94 | 95 | parser.add_argument('--use-noisy-net', action='store_true', default=False, 96 | help='use dueling architecture') 97 | 98 | parser.add_argument('--use-C51', action='store_true', default=False, 99 | help='use categorical value distribution C51') 100 | 101 | parser.add_argument('--use-QR-C51', action='store_true', default=False, 102 | help='use categorical value distribution C51') 103 | 104 | parser.add_argument('--use-IQN-C51', action='store_true', default=False, 105 | help='use Inverse Quantile Network') 106 | 107 | parser.add_argument('--use_low_footprint', action='store_true', default=False, 108 | help='use Inverse Quantile Network') 109 | 110 | parser.add_argument('--N_tau', type=int, default=64, 111 | help='Paper N') 112 | parser.add_argument('--Np_tau', type=int, default=64, 113 | help="Paper N'") 114 | parser.add_argument('--K_quantile', type=int, default=32, 115 | help="Paper K") 116 | 117 | parser.add_argument('--adam_lr', type=float, default=-1, 118 | help="QR-C51") 119 | parser.add_argument('--adam_eps', type=float, default=-1, 120 | help="QR-C51") 121 | 122 | 123 | 124 | args = parser.parse_args() 125 | args.cuda = not args.no_cuda and torch.cuda.is_available() 126 | GAMMA = args.gamma 127 | BATCH_SIZE = args.batch_size 128 | TRAIN_FREQ = args.train_freq 129 | TARGET_UPDATE = args.target_update # 130 | 131 | LOW_FOOTPRINT = args.use_low_footprint 132 | # Q-Learning Parameters 133 | DOUBLE_Q_LEARNING = args.use_double_dqn #False 134 | PRIORITIZED_MEMORY = args.use_prioritized_buffer #False 135 | USE_N_STEP = args.use_n_step 136 | NUM_LOOKAHEAD = args.num_lookahead 137 | USE_DUEL = args.use_duel 138 | USE_NOISY_NET = args.use_noisy_net 139 | USE_C51 = args.use_C51 140 | USE_QR_C51 = args.use_QR_C51 141 | USE_IQN_C51 = args.use_IQN_C51 142 | 143 | 144 | if USE_IQN_C51: 145 | assert(USE_QR_C51 is True) 146 | 147 | if USE_QR_C51: 148 | assert(USE_C51 is True) 149 | 150 | if not USE_N_STEP: 151 | NUM_LOOKAHEAD = 1 152 | # --------------------------------------------------- # 153 | exploration_fraction = 0.1 154 | exploration_final_eps_1 = 0.1 155 | exploration_final_eps_2 = 0.01 156 | if args.adam_lr == -1: 157 | adam_lr = 5e-5 if USE_IQN_C51 or USE_QR_C51 or USE_C51 else 6.25e-4 # 158 | adam_eps = 3.125e-4 if USE_IQN_C51 or USE_QR_C51 or USE_C51 else 1.5e-4 # 159 | else: 160 | adam_lr = args.adam_lr 161 | adam_eps = args.adam_eps 162 | # --------------------------------------------------- # 163 | # Booking Keeping 164 | print_now('------- Begin DQN with --------') 165 | print_now('Using Low Footprint memory: {}'.format(LOW_FOOTPRINT)) 166 | print_now('Using Double DQN: {}'.format(DOUBLE_Q_LEARNING)) 167 | print_now('Using Prioritized buffer: {}'.format(PRIORITIZED_MEMORY)) 168 | print_now('Using N-step reward with N = {}: {}'.format(NUM_LOOKAHEAD, USE_N_STEP)) 169 | print_now('Using Duel (advantage): {}'.format(USE_DUEL)) 170 | print_now('Using Noisy Net: {}'.format(USE_NOISY_NET)) 171 | print_now('Using C51 {}'.format(USE_C51)) 172 | print_now('Using Quantile Regression C51: {}'.format(USE_QR_C51)) 173 | print_now('Using Implicit Quantile Net C51: {}'.format(USE_IQN_C51)) 174 | print_now('Adam learning rate: {}, eps: {}'.format(adam_lr, adam_eps)) 175 | print_now('Seed: {}'.format(args.seed)) 176 | print_now('------- -------------- --------') 177 | print_now('Task: {}'.format(args.env_name)) 178 | time.sleep(0.1) 179 | # -------------------------------------------------------------------###### 180 | device = torch.device("cuda" if args.cuda else "cpu") 181 | # -------------------------------------------------------------------###### 182 | if USE_C51: 183 | C51_atoms = 51 184 | C51_vmax = 10.0 185 | C51_vmin = -10.0 186 | C51_support = torch.linspace(C51_vmin, C51_vmax, C51_atoms).view(1, 1, C51_atoms).to(device) # Shape 1 x 1 x 51 187 | C51_delta = (C51_vmax - C51_vmin) / (C51_atoms - 1) 188 | 189 | if USE_QR_C51: 190 | C51_atoms = 200 191 | QR_C51_atoms = 200 #C51_atoms 192 | QR_C51_quantile_weight = 1.0 / QR_C51_atoms 193 | # tau 194 | QR_C51_cum_density = (2 * np.arange(QR_C51_atoms) + 1) / (2.0 * QR_C51_atoms) 195 | QR_C51_cum_density = torch.tensor(QR_C51_cum_density, device=device, dtype=torch.float).view(1, 1, -1, 1) 196 | QR_C51_cum_density = QR_C51_cum_density.expand(args.batch_size, QR_C51_atoms, QR_C51_atoms, -1) 197 | if USE_IQN_C51: 198 | C51_atoms = None 199 | QR_C51_atoms = None 200 | QR_C51_quantile_weight = None 201 | QR_C51_cum_density = None 202 | 203 | 204 | 205 | """ 2(i-1) + 1 206 | tau_i = --------------- for i = 1, 2, ..., N 207 | 2N 208 | """ 209 | 210 | 211 | # Seeds 212 | random.seed(args.seed) 213 | np.random.seed(args.seed) 214 | torch.manual_seed(args.seed) 215 | torch.cuda.manual_seed_all(args.seed) 216 | 217 | import torch.backends.cudnn as cudnn 218 | cudnn.deterministic = True 219 | cudnn.benchmark = False # False should be fully deterministic 220 | 221 | # Importand - logging 222 | try: 223 | print_now('Creating log directory at: %s' % (args.log_dir)) 224 | os.makedirs(args.log_dir) 225 | except OSError: 226 | files = glob.glob(os.path.join(args.log_dir, '*.monitor.csv')) 227 | for f in files: 228 | os.remove(f) 229 | print_now('Reset log directory contents at: %s' % (args.log_dir)) 230 | 231 | eval_log_dir = args.log_dir + "_eval" 232 | 233 | try: 234 | os.makedirs(eval_log_dir) 235 | except OSError: 236 | files = glob.glob(os.path.join(eval_log_dir, '*.monitor.csv')) 237 | for f in files: 238 | os.remove(f) 239 | 240 | # Env following https://github.com/ikostrikov/pytorch-a2c-ppo-acktr 241 | print_now('Using device: {}'.format(device)) 242 | envs = make_vec_envs(args.env_name, args.seed, args.num_processes, 243 | args.gamma, args.log_dir, args.add_timestep, device, False) 244 | 245 | action_space = envs.action_space.n 246 | if USE_IQN_C51: 247 | policy_net = IQN_C51(num_inputs=4, num_actions=action_space, 248 | use_duel=USE_DUEL, use_noisy_net=USE_NOISY_NET).to(device) 249 | target_net = IQN_C51(num_inputs=4, num_actions=action_space, 250 | use_duel=USE_DUEL, use_noisy_net=USE_NOISY_NET).to(device) 251 | elif USE_C51: 252 | policy_net = C51(num_inputs=4, num_actions=action_space, atoms=C51_atoms, 253 | use_duel=USE_DUEL, use_noisy_net=USE_NOISY_NET, use_qr_c51=USE_QR_C51).to(device) 254 | target_net = C51(num_inputs=4, num_actions=action_space, atoms=C51_atoms, 255 | use_duel=USE_DUEL, use_noisy_net=USE_NOISY_NET, use_qr_c51=USE_QR_C51).to(device) 256 | if USE_QR_C51: 257 | C51_atoms = None 258 | else: 259 | policy_net = DQN(num_inputs=4, num_actions=action_space, use_duel=USE_DUEL, use_noisy_net=USE_NOISY_NET).to(device) 260 | target_net = DQN(num_inputs=4, num_actions=action_space, use_duel=USE_DUEL, use_noisy_net=USE_NOISY_NET).to(device) 261 | target_net.load_state_dict(policy_net.state_dict()) 262 | policy_net.train() 263 | target_net.eval() 264 | Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) 265 | optimizer = optim.Adam(policy_net.parameters(), lr=adam_lr, eps=adam_eps) 266 | # -------------------------------------------------------------------###### 267 | if PRIORITIZED_MEMORY: 268 | memory = PrioritizedReplayBuffer(args.memory_size, args.total_timestep, args.learning_starts) 269 | else: 270 | memory = ReplayMemory(args.memory_size, low_footprint=LOW_FOOTPRINT) 271 | 272 | nstep_buffer = [] 273 | def n_step_preprocess(st_0, action, st_1, reward, done): 274 | transition = Transition(st_0, action, st_1, reward) 275 | if done: 276 | # Clear out the buffer 277 | while len(nstep_buffer) > 1: 278 | n_step_reward = sum([nstep_buffer[i].reward.item()*(GAMMA**i) for i in range(len(nstep_buffer))]) 279 | prev_transition = nstep_buffer.pop(0) 280 | temp_st0 = prev_transition.state 281 | temp_action = prev_transition.action 282 | temp_reward = torch.tensor([[n_step_reward]], dtype=torch.float) 283 | memory.push(temp_st0, temp_action, None, temp_reward) 284 | # 285 | n_step_reward = sum([nstep_buffer[i].reward.item()*(GAMMA**i) for i in range(len(nstep_buffer))]) 286 | prev_transition = nstep_buffer.pop(0) 287 | assert(len(nstep_buffer) == 0) 288 | return prev_transition.state, prev_transition.action, None, torch.tensor([[n_step_reward]], dtype=torch.float) 289 | 290 | elif len(nstep_buffer) < NUM_LOOKAHEAD - 1: 291 | nstep_buffer.append(transition) 292 | return None, None, None, None #st_0, action, st_1, reward 293 | else: 294 | nstep_buffer.append(transition) 295 | n_step_reward = sum([nstep_buffer[i].reward.item()*(GAMMA**i) for i in range(NUM_LOOKAHEAD)]) 296 | prev_transition = nstep_buffer.pop(0) 297 | # return prev_st0, prev_action, st_1, torch.tensor([[n_step_reward]], dtype=torch.float).to(device) 298 | assert(len(nstep_buffer) < NUM_LOOKAHEAD) 299 | return prev_transition.state, prev_transition.action, st_1, torch.tensor([[n_step_reward]], dtype=torch.float) 300 | # 301 | 302 | def IQN_next_distribution(args, non_final_next_states, batch_reward, non_final_mask): 303 | """ 304 | This is for Inverse Quantile Network 305 | """ 306 | def get_action_argmax_next_Q_sa_IQN(args, next_states): 307 | if DOUBLE_Q_LEARNING: 308 | next_dist, _ = policy_net(next_states, args.K_quantile) 309 | #next_dist = next_dist * 1 / next_dist.size(1) 310 | else: 311 | next_dist, _ = target_net(next_states, args.K_quantile) 312 | #next_dist = next_dist * 1 / next_dist.size(1) 313 | # combined = next_dist.sum(dim=2) 314 | combined = next_dist.mean(dim=2) 315 | next_Q_sa = combined.max(1)[1] # next_Q_sa is of size: [batch ] of action index 316 | next_Q_sa = next_Q_sa.view(next_states.size(0), 1, 1) # Make it to be size of [32 x 1 x 1] 317 | next_Q_sa = next_Q_sa.expand(-1, -1, args.Np_tau) # Expand to be [32 x 1 x 51], one action, expand to support 318 | return next_Q_sa 319 | 320 | with torch.no_grad(): 321 | quantiles_next = torch.zeros((BATCH_SIZE, args.Np_tau), device=device, dtype=torch.float) 322 | max_next_action = get_action_argmax_next_Q_sa_IQN(args, non_final_next_states) 323 | 324 | if USE_NOISY_NET: 325 | target_net.sample_noise() 326 | 327 | next_y, _ = target_net(non_final_next_states, args.Np_tau) 328 | quantiles_next[non_final_mask] = next_y.gather(1, max_next_action).squeeze(1) 329 | # output should change from [32 x 1 x 51] --> [32 x 51] 330 | # batch_reward should be of size [32 x 1] 331 | quantiles_next = batch_reward.expand(-1, quantiles_next.size(1)) + (GAMMA**NUM_LOOKAHEAD) * quantiles_next 332 | return quantiles_next.detach() 333 | 334 | def next_distribution(non_final_next_states, batch_reward, non_final_mask): 335 | """ 336 | This is for Quantile Regression C51 337 | """ 338 | def get_action_argmax_next_Q_sa_QRC51(next_states): 339 | if DOUBLE_Q_LEARNING: 340 | next_dist = policy_net(next_states) 341 | #next_dist = next_dist * 1 / next_dist.size(1) 342 | else: 343 | next_dist = target_net(next_states) 344 | #next_dist = next_dist * 1 / next_dist.size(1) 345 | 346 | #next_Q_sa = next_dist.sum(dim=2).max(1)[1] # next_Q_sa is of size: [batch ] of action index 347 | next_Q_sa = next_dist.mean(dim=2).max(1)[1] # next_Q_sa is of size: [batch ] of action index 348 | next_Q_sa = next_Q_sa.view(next_states.size(0), 1, 1) # Make it to be size of [32 x 1 x 1] 349 | next_Q_sa = next_Q_sa.expand(-1, -1, QR_C51_atoms) # Expand to be [32 x 1 x 51], one action, expand to support 350 | return next_Q_sa 351 | 352 | with torch.no_grad(): 353 | quantiles_next = torch.zeros((BATCH_SIZE, QR_C51_atoms), device=device, dtype=torch.float) 354 | max_next_action = get_action_argmax_next_Q_sa_QRC51(non_final_next_states) 355 | if USE_NOISY_NET: 356 | target_net.sample_noise() 357 | quantiles_next[non_final_mask] = target_net(non_final_next_states).gather(1, max_next_action).squeeze(1) 358 | # output should change from [32 x 1 x 51] --> [32 x 51] 359 | # batch_reward should be of size [32 x 1] 360 | quantiles_next = batch_reward + (GAMMA**NUM_LOOKAHEAD) * quantiles_next 361 | 362 | return quantiles_next.detach() 363 | 364 | 365 | 366 | def project_distribution(batch_state, batch_action, non_final_next_states, batch_reward, non_final_mask): 367 | """ 368 | This is for orignal C51, with KL-divergence. 369 | """ 370 | def get_action_argmax_next_Q_sa(next_states): 371 | if DOUBLE_Q_LEARNING: 372 | next_dist = policy_net(next_states) * C51_support # Next_Distribution is of size: [batch x action x atoms] 373 | else: 374 | next_dist = target_net(next_states) * C51_support # Next_Distribution is of size: [batch x action x atoms] 375 | next_Q_sa = next_dist.sum(dim=2).max(1)[1] # next_Q_sa is of size: [batch ] of action index 376 | next_Q_sa = next_Q_sa.view(next_states.size(0), 1, 1) # Make it to be size of [32 x 1 x 1] 377 | next_Q_sa = next_Q_sa.expand(-1, -1, C51_atoms) # Expand to be [32 x 1 x 51], one action, expand to support 378 | return next_Q_sa 379 | 380 | with torch.no_grad(): 381 | max_next_dist = torch.zeros((BATCH_SIZE, 1, C51_atoms), device=device, dtype=torch.float) 382 | max_next_dist += 1.0 / C51_atoms 383 | # 384 | max_next_action = get_action_argmax_next_Q_sa(non_final_next_states) 385 | if USE_NOISY_NET: 386 | target_net.sample_noise() 387 | max_next_dist[non_final_mask] = target_net(non_final_next_states).gather(1, max_next_action) 388 | max_next_dist = max_next_dist.squeeze() 389 | # 390 | # Mapping 391 | Tz = batch_reward.view(-1, 1) + (GAMMA**NUM_LOOKAHEAD) * C51_support.view(1, -1) * non_final_mask.to(torch.float).view(-1, 1) 392 | Tz = Tz.clamp(C51_vmin, C51_vmax) 393 | C51_b = (Tz - C51_vmin) / C51_delta 394 | C51_L = C51_b.floor().to(torch.int64) 395 | C51_U = C51_b.ceil().to(torch.int64) 396 | C51_L[ (C51_U > 0) * (C51_L == C51_U)] -= 1 397 | C51_U[ (C51_L < (C51_atoms - 1)) * (C51_L == C51_U)] += 1 398 | offset = torch.linspace(0, (BATCH_SIZE - 1) * C51_atoms, BATCH_SIZE) 399 | offset = offset.unsqueeze(dim=1) 400 | offset = offset.expand(BATCH_SIZE, C51_atoms).to(batch_action) # I believe this is to(device) 401 | 402 | # I believe this is analogous to torch.zeros(), but "new_zeros" keeps the type as the original tensor? 403 | m = batch_state.new_zeros(BATCH_SIZE, C51_atoms) # Returns a Tensor of size size filled with 0. same dtype 404 | m.view(-1).index_add_(0, (C51_L + offset).view(-1), (max_next_dist * (C51_U.float() - C51_b)).view(-1)) 405 | m.view(-1).index_add_(0, (C51_U + offset).view(-1), (max_next_dist * (C51_b - C51_L.float())).view(-1)) 406 | return m 407 | # -------------------------------------------------------------------###### 408 | 409 | 410 | 411 | # -------------------------------------------------------------------###### 412 | # Two stage epsilon decay following https://blog.openai.com/openai-baselines-dqn/ 413 | # But this is similar to the curve of expoenntial decay 414 | eps_schedule1 = LinearSchedule(schedule_timesteps=int(1e6), # first 1 million 415 | initial_p=1.0, 416 | final_p =exploration_final_eps_1) 417 | 418 | eps_schedule2 = LinearSchedule(schedule_timesteps=int(25e6), # next 24 million 419 | initial_p=exploration_final_eps_1, 420 | final_p =exploration_final_eps_2) 421 | 422 | steps_done = 0 423 | def select_action(state, action_space): 424 | global steps_done 425 | # eps_threshold = EPS_END + (EPS_STRAT-EPS_END) * math.exp(-1*steps_done / EPS_DECAY) 426 | eps_threshold = eps_schedule1.value(steps_done) if steps_done <= 1e6 else eps_schedule2.value(steps_done) 427 | steps_done += 1 428 | if USE_NOISY_NET or random.random() > eps_threshold: 429 | with torch.no_grad(): 430 | if USE_IQN_C51: 431 | if USE_NOISY_NET: 432 | policy_net.sample_noise() 433 | y, _ = policy_net(state, args.K_quantile) 434 | # y = y * 1.0 / y.size(1) 435 | y = y.mean(dim=2).max(1) 436 | action = y[1].view(1, 1) 437 | 438 | elif USE_QR_C51: 439 | if USE_NOISY_NET: 440 | policy_net.sample_noise() 441 | y = policy_net(state) 442 | # y = y * QR_C51_quantile_weight 443 | y = y.mean(dim=2).max(1) 444 | action = y[1].view(1, 1) 445 | 446 | elif USE_C51: 447 | if USE_NOISY_NET: 448 | policy_net.sample_noise() 449 | y = policy_net(state) 450 | y = y * C51_support 451 | y = y.sum(dim=2).max(1) 452 | action = y[1].view(1, 1) 453 | else: 454 | if USE_NOISY_NET: 455 | policy_net.sample_noise() 456 | y = policy_net(state) 457 | y = y.max(1) # (tensor([0.2177], grad_fn=), tensor([0])) 458 | action = y[1].view(1, 1) 459 | else: 460 | action = torch.tensor([[random.randrange(action_space)]], device=device, dtype=torch.long) 461 | return action 462 | 463 | # optimize 464 | if USE_IQN_C51: 465 | X_ZERO_IQN_C51 = torch.zeros((args.Np_tau, BATCH_SIZE, args.N_tau), dtype=torch.float).to(device) 466 | elif USE_QR_C51: 467 | X_ZERO_QR_C51 = torch.zeros((QR_C51_atoms, BATCH_SIZE, QR_C51_atoms), dtype=torch.float).to(device) 468 | X_ZERO = torch.zeros((BATCH_SIZE, 1), dtype=torch.float).to(device) 469 | def optimize_model(): 470 | # 471 | def huber(x, k=1.0): 472 | return torch.where(x.abs() < k, 0.5 * x.pow(2), k * (x.abs() - 0.5 * k)) 473 | def huber_loss_fast(x, xzero): 474 | # cond = (x.abs() < 1.0).float().detach() 475 | # return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1.0 - cond) 476 | return F.smooth_l1_loss(x, xzero, reduction='none') 477 | 478 | # print_now('in optimize_model, device = {}'.format(device)) 479 | if PRIORITIZED_MEMORY: 480 | transitions, batch_index, batch_weight_IS = memory.sample(BATCH_SIZE) 481 | batch_weight_IS = torch.tensor(batch_weight_IS).to(device) # [32,] 482 | else: 483 | transitions = memory.sample(BATCH_SIZE) 484 | 485 | batch = Transition(*zip(*transitions)) 486 | non_final = tuple(map(lambda s: s is not None, batch.next_state)) 487 | non_final_mask = torch.tensor(non_final, device=device, dtype=torch.uint8) 488 | sanity_check = [s for s in batch.next_state if s is not None] 489 | if len(sanity_check) == 0: 490 | return None, None, None 491 | non_final_next_states = torch.cat(sanity_check).to(device) 492 | # 493 | state_batch = torch.cat(batch.state).to(device) 494 | action_batch = torch.cat(batch.action).to(device) # this is of shape [32 x 1] 495 | reward_batch = torch.cat(batch.reward).to(device) 496 | 497 | # 498 | if USE_IQN_C51: 499 | IQN_C51_action = action_batch.unsqueeze(dim=-1).expand(-1, -1, args.N_tau) 500 | IQN_C51_reward = reward_batch.view(-1, 1) # [32 x 1] 501 | if USE_NOISY_NET: 502 | policy_net.sample_noise() 503 | y, my_tau = policy_net(state_batch, args.N_tau) 504 | quantiles = y.gather(1, IQN_C51_action).squeeze(1) # from [32 x 1 x 51] to [32 x 51] 505 | quantiles_next = IQN_next_distribution(args, non_final_next_states, IQN_C51_reward, non_final_mask) # [32, 51] 506 | # 507 | # -----------Google Implementation ----------- 508 | # (1) Make target quantile to be [Batch x Np_tau x 1] 509 | quantiles_next = quantiles_next.unsqueeze(-1) 510 | # (2) Make current quantile to be [Batch x N_tau x 1] 511 | quantiles = quantiles.unsqueeze(-1) 512 | # (3) Shape of bellman_erors and huber_loss: [Batch x Np_tau x N_tau x 1] 513 | # 514 | # [Batch x Np_tau x None x 1] - [Batch x None x N_tau x 1] 515 | diff = quantiles_next.unsqueeze(2) - quantiles.unsqueeze(1) 516 | # (4) Huber Loss 517 | huber_diff = huber(diff, args.kappa) 518 | # (5) !!! # Reshape replay_quantiles to [Batch x N_tau x 1] 519 | my_tau = my_tau.view(y.shape[0], args.N_tau, 1) # [N_tau x Batch x 1] 520 | # my_tau = my_tau.transpose(0, 1).contiguous() # [Batch x N_tau x 1] 521 | my_tau = my_tau.unsqueeze(1) # [Batch x 1 x N_tau x 1] 522 | my_tau = my_tau.expand(-1, args.Np_tau, -1, -1) # [Batch x Np_tau x N_tau x 1] 523 | # ----------- ----------- 524 | # (6) # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. 525 | loss = (huber_diff * (my_tau - (diff<0).float()).abs()) / args.kappa # Divided by kappa 526 | # (7) 527 | # Sum over current quantile value (num_tau_samples) dimension, 528 | # average over target quantile value (num_tau_prime_samples) dimension. 529 | # [batch_size x Np_tau x N_tau x 1.] 530 | loss = loss.squeeze(3).sum(-1).mean(-1) 531 | 532 | if PRIORITIZED_MEMORY: 533 | loss_PER = loss.detach().abs().cpu().numpy() 534 | if len(loss.shape) == 2: 535 | batch_weight_IS = batch_weight_IS.view(BATCH_SIZE, 1) 536 | assert(len(loss.shape) == len(batch_weight_IS.shape)) 537 | loss = loss * batch_weight_IS 538 | loss = loss.mean() 539 | ds = y.detach() * 1.0 / y.size(1) 540 | Q_sa = ds.sum(dim=2).gather(1, action_batch) 541 | elif USE_QR_C51: 542 | QR_C51_action = action_batch.unsqueeze(dim=-1).expand(-1, -1, QR_C51_atoms) 543 | QR_C51_reward = reward_batch.view(-1, 1) # [32 x 1] 544 | # 545 | if USE_NOISY_NET: 546 | policy_net.sample_noise() 547 | y = policy_net(state_batch) 548 | quantiles = y.gather(1, QR_C51_action) # [32 x 1 x 51] 549 | quantiles = quantiles.squeeze(1) # [32 x 51] 550 | # 551 | quantiles_next = next_distribution( non_final_next_states, QR_C51_reward, non_final_mask) # [32, 51] 552 | # 553 | # -----------Google Implementation ----------- 554 | # (1) Make target quantile to be [Batch x Np_tau x 1] 555 | quantiles_next = quantiles_next.unsqueeze(-1) 556 | # (2) Make current quantile to be [Batch x N_tau x 1] 557 | quantiles = quantiles.unsqueeze(-1) 558 | # (3) Shape of bellman_erors and huber_loss: [Batch x Np_tau x N_tau x 1] 559 | # 560 | # [Batch x Np_tau x None x 1] - [Batch x None x N_tau x 1] 561 | diff = quantiles_next.unsqueeze(2) - quantiles.unsqueeze(1) 562 | # (4) Huber Loss 563 | huber_diff = huber(diff) 564 | # ----------- ----------- 565 | # (6) # Shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. 566 | loss = (huber_diff * (QR_C51_cum_density - (diff<0).float()).abs()) / 1.0 # Divided by kappa 567 | # (7) 568 | # Sum over current quantile value (num_tau_samples) dimension, 569 | # average over target quantile value (num_tau_prime_samples) dimension. 570 | # [batch_size x Np_tau x N_tau x 1.] 571 | loss = loss.squeeze(3).sum(-1).mean(-1) 572 | # # [51 x 32 x 1 ] [1, 32, 51] 573 | # diff = quantiles_next.t().unsqueeze(-1) - quantiles.unsqueeze(0) # diff is of shape [51, 32 51] 574 | 575 | if PRIORITIZED_MEMORY: 576 | loss_PER = loss.detach().abs().cpu().numpy() 577 | if len(loss.shape) == 2: 578 | batch_weight_IS = batch_weight_IS.view(BATCH_SIZE, 1) 579 | assert(len(loss.shape) == len(batch_weight_IS.shape)) 580 | loss = loss * batch_weight_IS 581 | loss = loss.mean() 582 | # 583 | ds = y.detach() * QR_C51_quantile_weight 584 | Q_sa = ds.sum(dim=2).gather(1, action_batch) 585 | elif USE_C51: 586 | # [32 x 1 x 1] [32 x 1 x 51] 587 | C51_action = action_batch.unsqueeze(dim=-1).expand(-1, -1, C51_atoms) 588 | C51_reward = reward_batch.view(-1, 1, 1) # [32 x 1 x 1] 589 | # [32 x 1 x 51] ---> [32 x 51] 590 | if USE_NOISY_NET: 591 | policy_net.sample_noise() 592 | y = policy_net(state_batch) 593 | current_dist = y.gather(1, C51_action).squeeze() 594 | target_prob = project_distribution(state_batch, C51_action, non_final_next_states, C51_reward, non_final_mask) # torch.Size([32, 51]) 595 | loss = -(target_prob * current_dist.log()).sum(-1) # KL Divergence 596 | if PRIORITIZED_MEMORY: 597 | loss_PER = loss.detach().squeeze().abs().cpu().numpy() 598 | if len(loss.shape) == 2: 599 | batch_weight_IS = batch_weight_IS.view(BATCH_SIZE, 1) 600 | assert(len(loss.shape) == len(batch_weight_IS.shape)) 601 | loss = loss * batch_weight_IS # .view(BATCH_SIZE, 1) 602 | loss = loss.mean() 603 | # 604 | ds = y.detach() * C51_support 605 | Q_sa = ds.sum(dim=2).gather(1, action_batch) 606 | else: 607 | # # Normal DQN. Minimize expected TD error ------------------------###### 608 | if USE_NOISY_NET: 609 | policy_net.sample_noise() 610 | Q_sa = policy_net(state_batch).gather(1, action_batch) 611 | next_Q_sa = torch.zeros((BATCH_SIZE, 1), device=device) 612 | if DOUBLE_Q_LEARNING: 613 | # Double DQN, getting action from policy net. 614 | # See https://medium.freecodecamp.org/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682 615 | with torch.no_grad(): 616 | # Get action, no noisy in policy net 617 | if USE_NOISY_NET: 618 | target_net.sample_noise() 619 | target_Q_sa = target_net(non_final_next_states) 620 | action_from_policy_Q_sa = policy_net(non_final_next_states).max(1)[1].unsqueeze(1) # max of the first dimension --> tuple(val, index). 621 | Q_sa_double_DQN = target_Q_sa.gather(1, action_from_policy_Q_sa) # We use the action index from policy net 622 | next_Q_sa[non_final_mask] = Q_sa_double_DQN 623 | else: 624 | # Vanilla DQN, getting action from target_net 625 | with torch.no_grad(): 626 | # Get action, no noisy in policy net 627 | if USE_NOISY_NET: 628 | target_net.sample_noise() 629 | target_Q_sa = target_net(non_final_next_states) 630 | Q_sa_DQN = target_Q_sa.max(1)[0].unsqueeze(1) 631 | next_Q_sa[non_final_mask] = Q_sa_DQN 632 | 633 | Expected_Q_sa = reward_batch + ((GAMMA**NUM_LOOKAHEAD) * next_Q_sa) 634 | # 635 | if PRIORITIZED_MEMORY: 636 | diff = Q_sa - Expected_Q_sa 637 | loss = huber_loss_fast(diff, X_ZERO) 638 | if len(loss.shape) == 2: 639 | batch_weight_IS = batch_weight_IS.view(BATCH_SIZE, 1) 640 | assert(len(loss.shape) == len(batch_weight_IS.shape)) 641 | loss = loss * batch_weight_IS 642 | loss = loss.mean() 643 | # 644 | TD_error = Q_sa.detach() - Expected_Q_sa.detach() 645 | TD_error = TD_error.cpu().numpy().squeeze() 646 | abs_TD_error = abs(TD_error) 647 | else: 648 | loss = F.smooth_l1_loss(Q_sa, Expected_Q_sa) 649 | # -------------------------------------------------------------------###### 650 | if PRIORITIZED_MEMORY: 651 | if USE_C51 or USE_QR_C51 or USE_IQN_C51: 652 | memory.update_priority_on_tree(batch_index, loss_PER) 653 | else: 654 | memory.update_priority_on_tree(batch_index, abs_TD_error) 655 | # -------------------------------------------------------------------###### 656 | optimizer.zero_grad() 657 | loss.backward() 658 | # for param in policy_net.parameters(): 659 | # param.grad.data.clamp_(-1, 1) 660 | optimizer.step() 661 | # -------------------------------------------------------------------###### 662 | Qval = Q_sa.cpu().detach().numpy().squeeze() 663 | return loss, np.mean(Qval), np.mean(reward_batch.cpu().numpy().squeeze()) 664 | # 665 | 666 | def save_model(): 667 | save_path = os.path.join(args.save_dir) 668 | try: 669 | os.makedirs(save_path) 670 | except OSError: 671 | pass 672 | # 673 | # Convert model to CPU 674 | save_model = policy_net 675 | if args.cuda: 676 | save_model = copy.deepcopy(policy_net).cpu() 677 | # save_model = [save_model, getattr(get_vec_normalize(envs), 'ob_rms', None)] 678 | torch.save(save_model, 679 | os.path.join(save_path, "%s.pt"%(args.env_name))) 680 | gc.collect() 681 | 682 | # main 683 | def main(): 684 | global steps_done 685 | torch.set_num_threads(1) 686 | loss = None 687 | Q_sa = None 688 | batch_reward_mean = None 689 | update_count = 0 690 | action_history = deque(maxlen=1000) 691 | episode_rewards = deque(maxlen=100) 692 | # -------------------------------------------------------------------###### 693 | # if PRIORITIZED_MEMORY: 694 | # state = PER_pre_fill_memory(envs) # reset would be called inside 695 | # else: 696 | # state = envs.reset() 697 | state = envs.reset() 698 | # -------------------------------------------------------------------###### 699 | start = time.time() 700 | for t in range(int(args.total_timestep)): 701 | action = select_action(state, action_space) 702 | action_history.append(action.item()) 703 | #st_0 = copy.deepcopy(state) # IMPORTANT. Make a deep copy as state will be come next_state AUTOMATICALLY 704 | st_0 = state.clone() # IMPORTANT. Make a deep copy as state will be come next_state AUTOMATICALLY 705 | 706 | next_state, reward, done, info = envs.step(action) 707 | #st_1 = copy.deepcopy(next_state) # Just to re-iterate the importance, that's all 708 | st_1 = next_state.clone() # Just to re-iterate the importance, that's all 709 | if 'episode' in info[0].keys(): 710 | episode_rewards.append(info[0]['episode']['r']) 711 | # We only ensure one environment here 712 | # -------------------------------------------------------------------###### 713 | if USE_N_STEP: 714 | st_0, action, st_1, reward = n_step_preprocess(st_0, action, st_1, reward, done[0]) 715 | # -------------------------------------------------------------------###### 716 | if done[0]: 717 | memory.push(st_0, action, None, reward) 718 | elif st_0 is not None: 719 | memory.push(st_0, action, st_1, reward) 720 | state = next_state 721 | # 722 | if t > args.learning_starts and t % TRAIN_FREQ == 0: 723 | update_count += 1 724 | loss, Q_sa, batch_reward_mean = optimize_model() 725 | if t % args.save_interval == 0: 726 | save_model() 727 | # 728 | if t > args.learning_starts and t % TARGET_UPDATE == 0: 729 | print_now('Updated target network at %d' % (t)) 730 | target_net.load_state_dict(policy_net.state_dict()) 731 | 732 | # Book Keeping 733 | end = time.time() 734 | eps_threshold = eps_schedule1.value(steps_done) if steps_done <= 1e6 else eps_schedule2.value(steps_done) 735 | 736 | if t%500 == 0 and len(episode_rewards) > 0: 737 | print_now('Upd {} timestep {} FPS {} - last {} ep rew: mean : {:.1f} min/max: {:.1f}/{:.1f} action_std: {:.3f} eps_val: {:.4f} loss: {:.4f} Qval {:.2f} Nrew: {:.2f}'.format( 738 | update_count, t, 739 | int(t / (end-start)), 740 | len(episode_rewards), np.mean(episode_rewards), np.min(episode_rewards), np.max(episode_rewards), 741 | np.std(action_history), eps_threshold, 742 | loss.item() if loss else -9999, 743 | Q_sa if Q_sa else 0, 744 | batch_reward_mean if batch_reward_mean else 0 745 | )) 746 | elif len(episode_rewards) == 0: 747 | print_now('Upd {}, timestep {}, FPS {}'.format( 748 | update_count, t, 749 | int(t / (end-start)), 750 | len(episode_rewards), -1, -1, -1 751 | )) 752 | # 753 | # 754 | # 755 | if __name__ == "__main__": 756 | main() 757 | 758 | -------------------------------------------------------------------------------- /replay_memory.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import random 3 | import numpy as np 4 | from utils import SumTree 5 | from utils import SumSegmentTree, MinSegmentTree 6 | from multiprocessing import Pool, Manager 7 | 8 | import torch 9 | 10 | Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward')) 11 | class ReplayMemory(object): 12 | def __init__(self, capacity, low_footprint=False, num_workers=1): 13 | self.capacity = capacity 14 | self.memory = [] 15 | self.position = 0 16 | self.low_footprint = low_footprint 17 | self.index_list = list(range(capacity)) 18 | self.num_workers = num_workers 19 | if self.low_footprint and self.num_workers > 1: 20 | raise NotImplementedError('Multi processing for replay not implemented') 21 | 22 | def _get_transition(self, index): 23 | """ 24 | For low-footprint. Need to check whether a transition is valid. 25 | e.g., if next_state is None, it means that it is an end. 26 | So it could be: 27 | [Zero, Zero, s_{t-1}, s_{t}] 28 | or: 29 | [Zero, s_{t-2}, s_{t-1}, s_{t}] 30 | 31 | Case 1: 32 | A memory item is not valid because it is overridden. 33 | [s1, s2, s3, s4, ..., sN-1, sN] --> s1 is not valid 34 | [sN+1, s2, s3, s4, ..., sN-1, sN] --> s2 to s4 is not valid, cuz s1 is removed 35 | [sN+1, sN+2, s3, s4, ..., sN-1, sN] --> s3 to s6 is not valid, cuz s2 is removed 36 | 37 | Case 2: 38 | A memory item needs to be concated zeros, because it is an early met terminal state 39 | [s1, s2_T, s3, s4, ..., sN-1, sN] --> s2_T is valid, but needs to be cat [Zero, Zero, s1, s2_T] 40 | 41 | [sN+1, s2_T, s3, s4, ..., sN-1, sN] --> s2_T is not valid, cuz s1 is removed 42 | 43 | [sN+1, sN+2, s3, s4_T, ..., sN-1, sN] --> s4_T should be valid. As we can do [Zero, Zero, s3, s4_T] 44 | 45 | """ 46 | tran = self.memory[index] 47 | s0, a, s1, r = tran.state, tran.action, tran.next_state, tran.reward 48 | # 49 | state_list = [s0] 50 | next_state_list = [s1] # s1 will be None if it is a terminal state 51 | # 52 | for i in range(1, 3+1): 53 | prev_tran = self.memory[index-i] 54 | s0_prev, s1_prev = prev_tran.state, prev_tran.next_state 55 | if s1_prev is None: 56 | break 57 | state_list.append(s0_prev) 58 | next_state_list.append(s1_prev) 59 | # 60 | state_list = state_list[::-1] 61 | next_state_list = next_state_list[::-1] 62 | # ----------------------------------------- 63 | if s1 is not None and len(next_state_list) < 4: 64 | next_state_list = [state_list[0]] + next_state_list 65 | if len(next_state_list) < 4: 66 | next_state_list = [torch.zeros_like(s0)] * (4-len(next_state_list)) + next_state_list 67 | 68 | if len(state_list) < 4: 69 | state_list = [torch.zeros_like(s0)] * (4 - len(state_list)) + state_list 70 | 71 | # ----------------------------------------- 72 | state = torch.cat(state_list, dim=1) # from [1 x 1 x 84 x 84] to [1 x 4 x 84 x 84] 73 | next_state = torch.cat(next_state_list, dim=1) if s1 is not None else None # from [1 x 1 x 84 x 84] to [1 x 4 x 84 x 84] 74 | return Transition(state, a, next_state, r) 75 | 76 | # args is like def push(self, state, action, next_state, reward), 77 | # So Transition(state, action, next_state, reward) becomes what we need to store 78 | def push(self, *args): 79 | # ------------------------------------------------------ 80 | if len(self.memory) < self.capacity: 81 | self.memory.append(None) 82 | # ------------------------------------------------------ 83 | if self.low_footprint: 84 | # Only store the latest frame 85 | state, action, next_state, reward = args 86 | if next_state is not None: 87 | self.memory[self.position] = Transition(state[:, -1:, :, :].cpu(), 88 | action.cpu(), 89 | next_state[:, -1:, :, :].cpu(), 90 | reward.cpu()) 91 | else: 92 | self.memory[self.position] = Transition(state[:, -1:, :, :].cpu(), 93 | action.cpu(), 94 | None, 95 | reward.cpu()) 96 | else: 97 | self.memory[self.position] = Transition(*args) 98 | # ------------------------------------------------------ 99 | self.position = (self.position + 1) % self.capacity 100 | # ------------------------------------------------------ 101 | def sample(self, batch_size): 102 | if self.low_footprint: 103 | if len(self.memory) < self.capacity: 104 | valid_range = self.index_list[3:len(self.memory)] 105 | else: 106 | if self.position <= 3: 107 | valid_range = self.index_list[self.position+3:] 108 | else: 109 | valid_range = self.index_list[3:self.position] + self.index_list[self.position+3:] 110 | output_batch = [] 111 | while len(output_batch) < batch_size: 112 | out_index = random.sample(valid_range, batch_size) 113 | for index in out_index: 114 | this_one = self._get_transition(index) 115 | if this_one is not None: 116 | output_batch.append(this_one) 117 | if len(output_batch) == batch_size: 118 | break 119 | return output_batch 120 | else: 121 | return random.sample(self.memory, batch_size) 122 | 123 | def __len__(self): 124 | return len(self.memory) 125 | 126 | class PrioritizedReplayBuffer(): 127 | """ 128 | PrioritizedReplayBuffer From OpenAI Baseline 129 | """ 130 | def __init__(self, size, T_max, learn_start): 131 | self._storage = [] 132 | self._maxsize = size 133 | self._next_idx = 0 134 | # 135 | it_capacity = 1 136 | while it_capacity < size: 137 | it_capacity *= 2 138 | self._sumTree = SumSegmentTree(it_capacity) 139 | self._minTree = MinSegmentTree(it_capacity) 140 | self._max_priority = 1.0 141 | # 142 | self.e = 0.01 143 | self.alpha = 0.5 # tradeoff between taking only experience with high-priority samples 144 | self.beta = 0.4 # Importance Sampling, from 0.4 -> 1.0 over the course of training 145 | self.beta_increment = (1 - self.beta) / (T_max - learn_start) 146 | self.abs_error_clipUpper = 1.0 147 | self.NORMALIZE_BY_BATCH = False # In openAI baseline, normalize by whole 148 | 149 | def __len__(self): 150 | return len(self._storage) 151 | 152 | def push(self, state, action, next_state, reward): 153 | idx = self._next_idx 154 | # 155 | # Setting maximum priority for new transitions. Total priority will be updated 156 | if next_state is not None: 157 | data = Transition(state.cpu(), action.cpu(), next_state.cpu(), reward.cpu()) 158 | else: 159 | data = Transition(state.cpu(), action.cpu(), None, reward.cpu()) 160 | # 161 | if self._next_idx >= len(self._storage): 162 | self._storage += data, 163 | else: 164 | self._storage[self._next_idx] = data 165 | self._next_idx = (self._next_idx + 1) % self._maxsize 166 | # 167 | self._sumTree[idx] = self._max_priority ** self.alpha 168 | self._minTree[idx] = self._max_priority ** self.alpha 169 | 170 | def sample(self, batch_size): 171 | # indices = self._sample_proportional(batch_size) 172 | indices = [] 173 | batch_sample = [] 174 | weights = [] 175 | # Increase the beta each time we sample a new mini-batch until it reaches 1.0 176 | self.beta = min(self.beta + self.beta_increment, 1.0) 177 | # 178 | total_priority = self._sumTree.sum(0, len(self._storage) - 1) 179 | priority_segment = total_priority / batch_size 180 | # 181 | min_priority = self._minTree.min() / self._sumTree.sum() 182 | max_weight_ALL_memory = (min_priority * len(self._storage)) ** (-self.beta) 183 | # 184 | for i in range(batch_size): 185 | mass = (i + random.random()) * priority_segment 186 | index = self._sumTree.find_prefixsum_idx(mass) 187 | # P(j) --> stochastic priority 188 | stochastic_p = self._sumTree[index] / total_priority 189 | this_weight_IS = (stochastic_p * len(self._storage)) ** (-self.beta) 190 | """ 191 | Importance Sampling Weight: 192 | [ 1 1 ]^(beta) 193 | | --- * -----------| 194 | [ N prob_min ] 195 | """ 196 | this_weight_IS /= max_weight_ALL_memory 197 | # Append to list 198 | weights += this_weight_IS, 199 | batch_sample += self._storage[index], 200 | indices += index, 201 | # 202 | return batch_sample, indices, weights 203 | def update_priority_on_tree(self, tree_idx, abs_TD_errors): 204 | assert(len(tree_idx) == len(abs_TD_errors)) 205 | abs_TD_errors = np.nan_to_num(abs_TD_errors) + self.e 206 | abs_TD_errors = abs_TD_errors.tolist() 207 | # 208 | for index, priority in zip(tree_idx, abs_TD_errors): 209 | assert(priority > 0) 210 | assert(0<=index<=len(self._storage)) 211 | self._sumTree[index] = priority ** self.alpha 212 | self._minTree[index] = priority ** self.alpha 213 | # 214 | self._max_priority = max(self._max_priority, priority) 215 | # 216 | 217 | 218 | class PrioritizedReplayBuffer_slow(): 219 | # Deprecated 220 | def __init__(self, capacity, T_max, learn_start): 221 | self.capacity = capacity 222 | self.count = 0 223 | # We may want better data structure: https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/ 224 | self.memory = SumTree(capacity) 225 | 226 | self.e = 0.01 227 | self.alpha = 0.5 # tradeoff between taking only experience with high-priority samples 228 | self.beta = 0.4 # Importance Sampling, from 0.4 -> 1.0 over the course of training 229 | self.beta_increment = (1 - self.beta) / (T_max - learn_start) 230 | self.abs_error_clipUpper = 1.0 231 | self.NORMALIZE_BY_BATCH = False # In openAI baseline, normalize by whole 232 | 233 | def push(self, state, action, next_state, reward): 234 | # Find the max priority. Recall that treeArr is of size 2*capacity - 1. 235 | # And all the priorioties lie on the leaves of the tree 236 | self.count += 1 237 | self.count = max(self.count, self.capacity) 238 | all_priority = self.memory.treeArr[-self.memory.capacity:][:self.count] 239 | max_priority = np.max(all_priority) 240 | if max_priority == 0: 241 | max_priority = self.abs_error_clipUpper 242 | # Setting maximum priority for new transitions. Total priority will be updated 243 | if next_state is not None: 244 | transition = Transition(state.cpu(), action.cpu(), next_state.cpu(), reward.cpu()) 245 | else: 246 | transition = Transition(state.cpu(), action.cpu(), None, reward.cpu()) 247 | self.memory.push(max_priority, transition) 248 | 249 | def sample(self, batch_size): 250 | """ 251 | Let N = batch_size 252 | 253 | 1. First, divide the range [0, priority_total] into N ranges 254 | 2. Next, uniformly sample one value per range (out of N ranges) 255 | 3. Then, go and search the SumTree, 256 | the transitions with (priority score == sampled values) are retrieved 257 | 4. Finally, calculate Importance Sampling weight, W_is, for each of the element in the minibatch 258 | """ 259 | n = batch_size 260 | this_batch = [] 261 | batch_index = [] # np.empty((n, ), dtype=np.int32) 262 | batch_weight_IS = [] # np.empty((n, 1), dtype=np.float32) 263 | 264 | # Calculate the priority segment by dividing the ranges 265 | total_priority = self.memory.get_total_priority() 266 | priority_segment = total_priority / batch_size 267 | 268 | # Increase the beta each time we sample a new mini-batch until it reaches 1.0 269 | self.beta = min(self.beta + self.beta_increment, 1.0) 270 | 271 | # Calculate the max_weight 272 | all_priority = self.memory.treeArr[-self.memory.capacity:][:self.count] 273 | prob_min = min(all_priority) / total_priority 274 | assert(prob_min > 0) 275 | """ 276 | N is the batch size 277 | 278 | [ prob_min * N ]^(-beta) [ 1 1 ]^(beta) 279 | |------------------| ---> | --- * -----------| 280 | [ 1 ] [ N prob_min ] 281 | 282 | """ 283 | # Getting the MAX of importance sampling weight for nomalization 284 | max_weight_ALL_memory = (prob_min * n)**(-self.beta) 285 | max_weight_THIS_BATCH = -1 286 | # 287 | for i in range(batch_size): 288 | # A value is sample from each range 289 | A = A_the_ith_range = priority_segment * i 290 | B = B_the_ith_range = priority_segment * (i + 1) 291 | sampled_value = np.random.uniform(A, B) 292 | 293 | # The transition that corresponds to the "sampled_value" is retrieved 294 | index, priority, data = self.memory.get_leaf(sampled_value) 295 | transition = data 296 | 297 | # P(j) --> stochastic priority 298 | stochastic_p = priority / total_priority 299 | 300 | """ 301 | Importance Sampling Weight: 302 | [ 1 1 ]^(beta) 303 | | --- * -----------| 304 | [ N prob_min ] 305 | """ 306 | this_weight_IS = (stochastic_p * n) ** (-self.beta) 307 | 308 | if self.NORMALIZE_BY_BATCH and max_weight_THIS_BATCH <= this_weight_IS: 309 | max_weight_THIS_BATCH = this_weight_IS 310 | 311 | # List append 312 | batch_weight_IS += this_weight_IS, # batch_weight_IS[i, 0] = this_weight_IS 313 | batch_index += index, #batch_index[i] = index 314 | this_batch += transition, 315 | # 316 | batch_weight_IS = np.asarray(batch_weight_IS).T # --> make it 32 x 1 317 | batch_index = np.asarray(batch_index) 318 | # ------------------------------------------------------------------- # 319 | if self.NORMALIZE_BY_BATCH: 320 | batch_weight_IS /= max_weight_THIS_BATCH # Kaixin from Berkeley 321 | else: 322 | batch_weight_IS /= max_weight_ALL_memory # OpenAI Baseline 323 | # ------------------------------------------------------------------- # 324 | return this_batch, batch_index, batch_weight_IS 325 | 326 | def update_priority_on_tree(self, tree_idx, abs_TD_errors): 327 | """ 328 | A bunch of tree indices and a bunch of TD_errors 329 | """ 330 | abs_TD_errors = np.nan_to_num(abs_TD_errors) 331 | abs_TD_errors += self.e # p_t = |delta_t| + e 332 | clipped_errors = np.minimum(abs_TD_errors, self.abs_error_clipUpper) 333 | pt_alpha = np.power(clipped_errors, self.alpha) 334 | for index, prob in zip(tree_idx, pt_alpha): 335 | self.memory.update(index, prob) 336 | # Remember to deal with EMPTY MEMORY PROBLEM 337 | -------------------------------------------------------------------------------- /runIt.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | SEED=1; 3 | GAME="Assault"; 4 | NAME="DQN-C51-"$GAME"-SEED-"$SEED; 5 | LOG_NAME="./TXT_LOGS/myLog_"$NAME".txt"; 6 | 7 | if [ ! -f $LOG_NAME ]; then 8 | 9 | OMP_NUM_THREADS=1 time python main.py \ 10 | --env-name $GAME"NoFrameskip-v4" \ 11 | --log-dir "./agentLog_"$SEED \ 12 | --save-dir "./saved_model_"$SEED \ 13 | --total-timestep 1e8 \ 14 | --memory-size 50000 \ 15 | --learning-starts 20000 \ 16 | --target-update 8192 \ 17 | --seed $SEED \ 18 | --use-double-dqn \ 19 | --use-C51 \ 20 | 2>&1 | tee $LOG_NAME; 21 | 22 | else 23 | echo "Danger close. The log file at -- $LOG_NAME -- exists!!" 24 | fi 25 | 26 | # --use-QR-C51 \ 27 | # --use-prioritized-buffer \ 28 | # --use-n-step \ 29 | # --use-duel \ 30 | # --use-noisy-net \ 31 | # --use-C51 \ 32 | # --use-QR-C51 \ 33 | 34 | # Default to use Vanilla DQN 35 | # --use-double-dqn 36 | # --use-prioritized-buffer 37 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import torch 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | import math 7 | 8 | def init(module, weight_init, bias_init, gain=1, mode=None, nonlinearity='relu'): 9 | if mode is not None: 10 | weight_init(module.weight.data, mode=mode, nonlinearity=nonlinearity) 11 | else: 12 | weight_init(module.weight.data, gain=gain) 13 | bias_init(module.bias.data) 14 | return module 15 | 16 | # https://github.com/openai/baselines/blob/master/baselines/common/tf_util.py#L87 17 | def init_normc_(weight, gain=1): 18 | weight.normal_(0, 1) 19 | weight *= gain / torch.sqrt(weight.pow(2).sum(1, keepdim=True)) 20 | 21 | class SegmentTree(object): 22 | def __init__(self, capacity, operation, neutral_element): 23 | """Build a Segment Tree data structure. 24 | 25 | https://en.wikipedia.org/wiki/Segment_tree 26 | 27 | Can be used as regular array, but with two 28 | important differences: 29 | 30 | a) setting item's value is slightly slower. 31 | It is O(lg capacity) instead of O(1). 32 | b) user has access to an efficient ( O(log segment size) ) 33 | `reduce` operation which reduces `operation` over 34 | a contiguous subsequence of items in the array. 35 | 36 | Paramters 37 | --------- 38 | capacity: int 39 | Total size of the array - must be a power of two. 40 | operation: lambda obj, obj -> obj 41 | and operation for combining elements (eg. sum, max) 42 | must form a mathematical group together with the set of 43 | possible values for array elements (i.e. be associative) 44 | neutral_element: obj 45 | neutral element for the operation above. eg. float('-inf') 46 | for max and 0 for sum. 47 | """ 48 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 49 | self._capacity = capacity 50 | self._value = [neutral_element for _ in range(2 * capacity)] 51 | self._operation = operation 52 | 53 | def _reduce_helper(self, start, end, node, node_start, node_end): 54 | if start == node_start and end == node_end: 55 | return self._value[node] 56 | mid = (node_start + node_end) // 2 57 | if end <= mid: 58 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 59 | else: 60 | if mid + 1 <= start: 61 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 62 | else: 63 | return self._operation( 64 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 65 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 66 | ) 67 | 68 | def reduce(self, start=0, end=None): 69 | """Returns result of applying `self.operation` 70 | to a contiguous subsequence of the array. 71 | 72 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 73 | 74 | Parameters 75 | ---------- 76 | start: int 77 | beginning of the subsequence 78 | end: int 79 | end of the subsequences 80 | 81 | Returns 82 | ------- 83 | reduced: obj 84 | result of reducing self.operation over the specified range of array elements. 85 | """ 86 | if end is None: 87 | end = self._capacity 88 | if end < 0: 89 | end += self._capacity 90 | end -= 1 91 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 92 | 93 | def __setitem__(self, idx, val): 94 | # index of the leaf 95 | idx += self._capacity 96 | self._value[idx] = val 97 | idx //= 2 98 | while idx >= 1: 99 | self._value[idx] = self._operation( 100 | self._value[2 * idx], 101 | self._value[2 * idx + 1] 102 | ) 103 | idx //= 2 104 | 105 | def __getitem__(self, idx): 106 | assert 0 <= idx < self._capacity 107 | return self._value[self._capacity + idx] 108 | 109 | 110 | class SumSegmentTree(SegmentTree): 111 | def __init__(self, capacity): 112 | super(SumSegmentTree, self).__init__( 113 | capacity=capacity, 114 | operation=operator.add, 115 | neutral_element=0.0 116 | ) 117 | 118 | def sum(self, start=0, end=None): 119 | """Returns arr[start] + ... + arr[end]""" 120 | return super(SumSegmentTree, self).reduce(start, end) 121 | 122 | def find_prefixsum_idx(self, prefixsum): 123 | """Find the highest index `i` in the array such that 124 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 125 | 126 | if array values are probabilities, this function 127 | allows to sample indexes according to the discrete 128 | probability efficiently. 129 | 130 | Parameters 131 | ---------- 132 | perfixsum: float 133 | upperbound on the sum of array prefix 134 | 135 | Returns 136 | ------- 137 | idx: int 138 | highest index satisfying the prefixsum constraint 139 | """ 140 | assert 0 <= prefixsum <= self.sum() + 1e-5 141 | idx = 1 142 | while idx < self._capacity: # while non-leaf 143 | if self._value[2 * idx] > prefixsum: 144 | idx = 2 * idx 145 | else: 146 | prefixsum -= self._value[2 * idx] 147 | idx = 2 * idx + 1 148 | return idx - self._capacity 149 | 150 | 151 | class MinSegmentTree(SegmentTree): 152 | def __init__(self, capacity): 153 | super(MinSegmentTree, self).__init__( 154 | capacity=capacity, 155 | operation=min, 156 | neutral_element=float('inf') 157 | ) 158 | 159 | def min(self, start=0, end=None): 160 | """Returns min(arr[start], ..., arr[end])""" 161 | 162 | return super(MinSegmentTree, self).reduce(start, end) 163 | 164 | # 165 | # Utils for prioritized replay buffer and sampling 166 | # Segment tree data structure where parent node values are sum/max of children node values 167 | # https://github.com/simoninithomas/Deep_reinforcement_learning_Course/blob/master/Dueling%20Double%20DQN%20with%20PER%20and%20fixed-q%20targets/Dueling%20Deep%20Q%20Learning%20with%20Doom%20%28%2B%20double%20DQNs%20and%20Prioritized%20Experience%20Replay%29.ipynb 168 | class SumTree(object): 169 | """ Deprecated """ 170 | def __init__(self, capacity): 171 | """ 172 | Initialize the tree with all nodes = 0 173 | Initialize the data with all values = 0 174 | """ 175 | self.capacity = capacity 176 | self.position = 0 177 | self.dataArr = np.zeros(capacity, dtype=object) 178 | self.treeArr = np.zeros(2*capacity - 1) 179 | # Generate the tree with all nodes values = 0 180 | # To understand this calculation (2 * capacity - 1) look at the schema below 181 | # Remember we are in a binary node (each node has max 2 children) so 2x size of leaf (capacity) - 1 (root node) 182 | # Parent nodes = capacity - 1 183 | # Leaf nodes = capacity 184 | """ tree: 185 | 0 186 | / \ 187 | 0 0 188 | / \ / \ 189 | 0 0 0 0 [Size: capacity] it's at this line that there is the priorities score (aka pi) 190 | """ 191 | def push(self, priority, data): 192 | """ Look at what index we want to put the new transition at """ 193 | tree_index = self.position + self.capacity - 1 194 | """ 195 | tree: 196 | 0 197 | / \ 198 | 0 0 199 | / \ / \ 200 | tree_index 0 0 0 201 | 202 | We fill the leaves from left to right 203 | """ 204 | self.dataArr[self.position] = data # Update data frame 205 | self.update(tree_index, priority) # Update the leaf, using the function below 206 | # 207 | self.position += 1 208 | if self.position >= self.capacity: 209 | self.position = 0 210 | # 211 | 212 | def update(self, tree_index, priority): 213 | """ 214 | Change_of_Score = new priority score - former priority score 215 | """ 216 | delta_score = priority - self.treeArr[tree_index] 217 | self.treeArr[tree_index] = priority 218 | 219 | # Propagate this change through tree 220 | """ 221 | Here we want to access the line above 222 | THE NUMBERS IN THIS TREE ARE THE "INDEXES" NOT THE PRIORITY VALUES 223 | 224 | 0 225 | / \ 226 | 1 2 227 | / \ / \ 228 | 3 4 5 [6] 229 | 230 | If we are in leaf at index 6, we updated the priority score 231 | We need then to update index 2 node 232 | So tree_index = (tree_index - 1) // 2 233 | tree_index = (6-1)//2 234 | tree_index = 2 (because // round the result) 235 | """ 236 | while tree_index != 0: 237 | tree_index = (tree_index - 1) // 2 238 | self.treeArr[tree_index] += delta_score 239 | 240 | def get_leaf(self, v): 241 | """ 242 | Return the leaf_index, that is the "priority value" of the transition at that leaf. 243 | 244 | Tree structure and array storage: 245 | Tree index: 246 | 0 -> storing priority sum 247 | / \ 248 | 1 2 249 | / \ / \ 250 | 3 4 5 6 -> storing priority for experiences 251 | Array type for storing: 252 | [0,1,2,3,4,5,6] 253 | """ 254 | parent_index = 0 255 | while True: 256 | left_child_index = 2 * parent_index + 1 257 | right_child_index = left_child_index + 1 258 | # 259 | # If we reach bottom, end the search 260 | if left_child_index >= len(self.treeArr): 261 | LEAF_index = parent_index 262 | break 263 | else: # downward search, always search for a higher priority node 264 | if v <= self.treeArr[left_child_index]: 265 | parent_index = left_child_index 266 | else: 267 | v -= self.treeArr[left_child_index] 268 | parent_index = right_child_index 269 | # 270 | # The corresponding data index: 271 | data_index = LEAF_index - self.capacity + 1 272 | return LEAF_index, self.treeArr[LEAF_index], self.dataArr[data_index] 273 | 274 | def get_total_priority(self): 275 | return self.treeArr[0] # The root node contains the total priority 276 | 277 | 278 | 279 | def PER_pre_fill_memory(envs): 280 | """ 281 | Pre-filling the memory buffer if we are doing Prioritized Experience Replay 282 | """ 283 | state = envs.reset() 284 | print_now('[Warning] Begin to pre-fill [Prioritized Experience Replay Memory]') 285 | for j in range(args.memory_size): 286 | action = torch.tensor([[random.randrange(action_space)]], device=device, dtype=torch.long) 287 | st_0 = copy.deepcopy(state) # IMPORTANT. Make a deep copy as state will be come next_state AUTOMATICALLY 288 | next_state, reward, done, info = envs.step(action) 289 | st_1 = copy.deepcopy(next_state) 290 | # We only ensure one environment here 291 | # -------------------------------------------------------------------###### 292 | if USE_N_STEP: 293 | st_0, action, st_1, reward = n_step_preprocess(st_0, action, st_1, reward, done[0]) 294 | # -------------------------------------------------------------------###### 295 | if done[0]: 296 | memory.push(st_0, action, None, reward) 297 | print_now('Pre-filling Replay Memory %d / %d -- action: %d' % (j+1, args.memory_size, action.item())) 298 | elif st_0 is not None: 299 | memory.push(st_0, action, st_1, reward) 300 | print_now('Pre-filling Replay Memory %d / %d -- action: %d' % (j+1, args.memory_size, action.item())) 301 | state = next_state 302 | return state 303 | # 304 | --------------------------------------------------------------------------------