├── README.md ├── agents.py ├── environments ├── doom │ ├── README.md │ ├── config │ │ ├── basic.cfg │ │ ├── basic.wad │ │ ├── deadly_corridor.cfg │ │ ├── deadly_corridor.wad │ │ ├── defend_the_center.cfg │ │ ├── defend_the_center.wad │ │ ├── defend_the_line.cfg │ │ ├── defend_the_line.wad │ │ ├── health_gathering.cfg │ │ ├── health_gathering.wad │ │ ├── predict_position.cfg │ │ └── predict_position.wad │ ├── doom_env.py │ └── img │ │ ├── basic.png │ │ ├── corridor.png │ │ ├── def_center.png │ │ ├── def_line.png │ │ ├── health.png │ │ └── predict.png ├── snake │ └── snake_env.py └── windy_grid_world │ ├── evil_wgw_env.py │ └── wgw_env.py ├── img ├── dqn_categorical.png ├── dqn_classic.png ├── dqn_dueling.png ├── dqn_quantile.png ├── sac_p_network.png ├── sac_q_network.png └── sac_v_network.png ├── methods.py ├── train_agents.ipynb └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # rl_algorithms 2 | Implementations of different off-policy reinforcement learning algorithms. 3 | 4 | # Framework 5 | 6 | 1. Module [methods.py](methods.py) contains [TensorFlow](https://www.tensorflow.org) implementations of various neural network architectures used in value-based deep reinforcement learning. 7 | 8 | 2. Module [agents.py](agents.py) contains general **Agent** class and various wrappers around it which represent corresponding deep RL algorithms. 9 | 10 | 3. Module [utils.py](utils.py) contains **Replay Buffer** implementation together with a wrapper around **OpenAI gym Atari 2600** environment necessary for reproducing original DeepMind results. 11 | 12 | 4. Jupyter notebook [train_agents.ipynb](train_agents.ipynb) contains examples of how to use the proposed framework to train deep RL agents on various environments. 13 | 14 | # Available algorithms 15 | 16 | - Deep Q-Network [Volodymyr Mnih et al. "Human-level control through deep reinforcement learning." Nature (2015)](https://pra.open.tips/storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) 17 |

18 | 19 |

20 | 21 | - Dueling Deep Q-Network [Ziyu Wang et al. "Dueling network architectures for deep reinforcement learning." ICML (2016).](https://arxiv.org/pdf/1511.06581.pdf) 22 |

23 | 24 |

25 | 26 | - Categorical Deep Q-Network [Marc G. Bellemare, Will Dabney, and Rémi Munos. "A distributional perspective on reinforcement learning." ICML (2017).](https://arxiv.org/pdf/1707.06887) 27 |

28 | 29 |

30 | 31 | - Quantile Regression Deep Q-Network [Will Dabney, Mark Rowland, Marc G. Bellemare, and Rémi Munos. "Distributional Reinforcement Learning with Quantile Regression." AAAI (2018).](https://arxiv.org/pdf/1710.10044) 32 |

33 | 34 |

35 | 36 | - Soft Actor-Critic [Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor." ICML (2018).](https://arxiv.org/pdf/1801.01290) 37 | 38 |

39 | 40 | 41 | 42 |

43 | 44 | **Note.** Images of different neural network architectures are based on the images from the [Dueling architectures](https://arxiv.org/pdf/1511.06581.pdf) paper. The original images were copied and adapted to reflect features of particular architectures and learning algorithms. 45 | -------------------------------------------------------------------------------- /agents.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from IPython import display 5 | import matplotlib.pyplot as plt 6 | 7 | from methods import * 8 | from utils import * 9 | 10 | 11 | def softmax(x): 12 | e_x = np.exp(x - np.max(x)) 13 | return e_x / e_x.sum() 14 | 15 | 16 | class Agent: 17 | 18 | def __init__(self, env, num_actions, state_shape=[8, 8, 5], 19 | save_path="rl_models", model_name="agent"): 20 | 21 | self.train_env = env 22 | self.num_actions = num_actions 23 | 24 | self.path = save_path + "/" + model_name 25 | if not os.path.exists(self.path): 26 | os.makedirs(self.path) 27 | 28 | def init_weights(self): 29 | 30 | global_vars = tf.global_variables( 31 | scope="agent") + tf.global_variables(scope="target") 32 | self.init = tf.variables_initializer(global_vars) 33 | self.saver = tf.train.Saver() 34 | 35 | self.agent_vars = tf.trainable_variables(scope="agent") 36 | self.target_vars = tf.trainable_variables(scope="target") 37 | 38 | def set_parameters(self, 39 | replay_memory_size=50000, 40 | replay_start_size=10000, 41 | init_eps=1, 42 | final_eps=0.02, 43 | annealing_steps=100000, 44 | discount_factor=0.99, 45 | max_episode_length=2000, 46 | frame_history_len=1): 47 | 48 | # create experience replay and fill it with random policy samples 49 | self.rep_buffer = ReplayBuffer(size=replay_memory_size, 50 | frame_history_len=frame_history_len) 51 | frame_count = 0 52 | while (frame_count < replay_start_size): 53 | last_obs = self.train_env.reset() 54 | for time_step in range(max_episode_length): 55 | 56 | last_idx = self.rep_buffer.store_frame(last_obs) 57 | recent_obs = self.rep_buffer.encode_recent_observation() 58 | action = np.random.randint(self.num_actions) 59 | obs, reward, done = self.train_env.step(action)[:3] 60 | self.rep_buffer.store_effect(last_idx, action, reward, done) 61 | 62 | frame_count += 1 63 | 64 | if done: 65 | break 66 | last_obs = obs 67 | 68 | # define epsilon decrement schedule for exploration 69 | self.eps = init_eps 70 | self.final_eps = final_eps 71 | self.eps_drop = (init_eps - final_eps) / annealing_steps 72 | 73 | self.gamma = discount_factor 74 | self.max_ep_length = max_episode_length 75 | 76 | def train(self, 77 | gpu_id=0, 78 | batch_size=32, 79 | exploration="e-greedy", 80 | agent_update_freq=4, 81 | target_update_freq=5000, 82 | tau=1, 83 | max_num_epochs=50000, 84 | performance_print_freq=500, 85 | save_freq=10000, 86 | from_epoch=0): 87 | 88 | config = self.gpu_config(gpu_id) 89 | target_ops = self.update_target_graph(tau) 90 | self.batch_size = batch_size 91 | 92 | with tf.Session(config=config) as sess: 93 | 94 | if from_epoch == 0: 95 | sess.run(self.init) 96 | train_rewards = [] 97 | frame_counts = [] 98 | frame_count = 0 99 | num_epochs = 0 100 | 101 | else: 102 | self.saver.restore(sess, self.path+"/model-"+str(from_epoch)) 103 | train_rewards = list( 104 | np.load(self.path+"/learning_curve.npz")["r"]) 105 | frame_counts = list( 106 | np.load(self.path+"/learning_curve.npz")["f"]) 107 | frame_count = frame_counts[-1] 108 | num_epochs = from_epoch 109 | 110 | episode_count = 0 111 | ep_lifetimes = [] 112 | 113 | while num_epochs < max_num_epochs: 114 | 115 | train_ep_reward = 0 116 | 117 | # reset the environment / start new game 118 | last_obs = self.train_env.reset() 119 | for time_step in range(self.max_ep_length): 120 | 121 | last_idx = self.rep_buffer.store_frame(last_obs) 122 | recent_obs = self.rep_buffer.encode_recent_observation() 123 | 124 | # choose action e-greedily 125 | action = self.choose_action(sess, recent_obs, exploration) 126 | 127 | # make step in the environment 128 | obs, reward, done = self.train_env.step(action)[:3] 129 | 130 | # save transition into experience replay 131 | self.rep_buffer.store_effect( 132 | last_idx, action, reward, done) 133 | 134 | # update current state and statistics 135 | frame_count += 1 136 | train_ep_reward += reward 137 | 138 | # reduce epsilon according to schedule 139 | if self.eps > self.final_eps: 140 | self.eps -= self.eps_drop 141 | 142 | # update network weights 143 | if frame_count % agent_update_freq == 0: 144 | 145 | batch = self.rep_buffer.sample(batch_size) 146 | self.update_agent_weights(sess, batch) 147 | 148 | # update target network 149 | if tau == 1: 150 | if frame_count % target_update_freq == 0: 151 | self.update_target_weights(sess, target_ops) 152 | else: 153 | self.update_target_weights(sess, target_ops) 154 | 155 | # save network wieghts checkpoint and learning curve 156 | if frame_count % save_freq == 1: 157 | num_epochs += 1 158 | try: 159 | self.saver.save( 160 | sess, 161 | self.path+"/model", 162 | global_step=num_epochs) 163 | np.savez( 164 | self.path+"/learning_curve.npz", 165 | r=train_rewards, 166 | f=frame_counts, 167 | l=ep_lifetimes) 168 | 169 | # if game is over, reset the environment 170 | if done: 171 | break 172 | last_obs = obs 173 | 174 | episode_count += 1 175 | train_rewards.append(train_ep_reward) 176 | frame_counts.append(frame_count) 177 | ep_lifetimes.append(time_step+1) 178 | 179 | # print performance once in a while 180 | if episode_count % performance_print_freq == 0: 181 | avg_reward = np.mean( 182 | train_rewards[-performance_print_freq:]) 183 | avg_lifetime = np.mean( 184 | ep_lifetimes[-performance_print_freq:]) 185 | print("frame count:", frame_count) 186 | print("average reward:", avg_reward) 187 | print("epsilon:", round(self.eps, 3)) 188 | print("average lifetime:", avg_lifetime) 189 | print("-------------------------------") 190 | 191 | def choose_action(self, sess, s, exploration="e-greedy"): 192 | 193 | if (exploration == "greedy"): 194 | a = self.agent_net.get_q_argmax(sess, [s]) 195 | elif (exploration == "e-greedy"): 196 | if np.random.rand(1) < self.eps: 197 | a = np.random.randint(self.num_actions) 198 | else: 199 | a = self.agent_net.get_q_argmax(sess, [s]) 200 | elif (exploration == "boltzmann"): 201 | q_values = self.agent_net.get_q_values_s(sess, [s]) 202 | logits = q_values / self.eps 203 | probs = softmax(logits).ravel() 204 | a = np.random.choice(self.num_actions, p=probs) 205 | elif (exploration == "policy"): 206 | probs = self.agent_net.get_p_values_s(sess, [s]).ravel() 207 | a = np.random.choice(self.num_actions, p=probs) 208 | else: 209 | return 0 210 | return a 211 | 212 | def update_agent_weights(self, sess, batch): 213 | 214 | # estimate the right hand side of the Bellman equation 215 | agent_actions = self.agent_net.get_q_argmax(sess, batch.s_) 216 | q_double = self.target_net.get_q_values_sa( 217 | sess, batch.s_, agent_actions) 218 | targets = batch.r + (self.gamma * q_double * (1 - batch.done)) 219 | 220 | # update agent network 221 | self.agent_net.update(sess, batch.s, batch.a, targets) 222 | 223 | def update_target_graph(self, tau): 224 | op_holder = [] 225 | for agnt, trgt in zip(self.agent_vars, self.target_vars): 226 | op = trgt.assign(agnt.value()*tau + (1 - tau)*trgt.value()) 227 | op_holder.append(op) 228 | return op_holder 229 | 230 | def update_target_weights(self, sess, op_holder): 231 | for op in op_holder: 232 | sess.run(op) 233 | 234 | def play(self, 235 | gpu_id=0, 236 | max_episode_length=2000, 237 | from_epoch=0): 238 | 239 | config = self.gpu_config(gpu_id) 240 | with tf.Session(config=config) as sess: 241 | self.saver.restore(sess, self.path+"/model-"+str(from_epoch)) 242 | s = self.train_env.reset() 243 | R = 0 244 | for time_step in range(max_episode_length): 245 | a = self.agent_net.get_q_argmax(sess, [s])[0] 246 | s, r, done = self.train_env.step(a) 247 | R += r 248 | self.train_env.plot_state() 249 | display.clear_output(wait=True) 250 | display.display(plt.gcf()) 251 | if done: 252 | break 253 | return R 254 | 255 | def gpu_config(self, gpu_id): 256 | if (gpu_id == -1): 257 | config = tf.ConfigProto() 258 | else: 259 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 260 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 261 | config = tf.ConfigProto() 262 | config.gpu_options.allow_growth = True 263 | config.intra_op_parallelism_threads = 1 264 | config.inter_op_parallelism_threads = 1 265 | return config 266 | 267 | ############################ Deep Q-Network agent ############################# 268 | 269 | 270 | class DQNAgent(Agent): 271 | 272 | def __init__( 273 | self, env, num_actions, state_shape=[8, 8, 5], 274 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128], 275 | activation_fn=tf.nn.relu, 276 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 277 | gradient_clip=10.0, 278 | save_path="rl_models", model_name="DQN"): 279 | 280 | super(DQNAgent, self).__init__( 281 | env, num_actions, state_shape=state_shape, 282 | save_path=save_path, model_name=model_name) 283 | 284 | tf.reset_default_graph() 285 | self.agent_net = DeepQNetwork( 286 | self.num_actions, state_shape=state_shape, 287 | convs=convs, fully_connected=fully_connected, 288 | activation_fn=activation_fn, optimizer=optimizer, 289 | gradient_clip=gradient_clip, scope="agent") 290 | self.target_net = DeepQNetwork( 291 | self.num_actions, state_shape=state_shape, 292 | convs=convs, fully_connected=fully_connected, 293 | activation_fn=activation_fn, optimizer=optimizer, 294 | gradient_clip=gradient_clip, scope="target") 295 | self.init_weights() 296 | 297 | ######################## Dueling Deep Q-Network agent ######################### 298 | 299 | 300 | class DuelDQNAgent(Agent): 301 | 302 | def __init__( 303 | self, env, num_actions, state_shape=[8, 8, 5], 304 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[64], 305 | activation_fn=tf.nn.relu, 306 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 307 | gradient_clip=10.0, 308 | save_path="rl_models", model_name="DuelDQN"): 309 | 310 | super(DuelDQNAgent, self).__init__( 311 | env, num_actions, state_shape=state_shape, 312 | save_path=save_path, model_name=model_name) 313 | 314 | tf.reset_default_graph() 315 | self.agent_net = DuelingDeepQNetwork( 316 | self.num_actions, state_shape=state_shape, 317 | convs=convs, fully_connected=fully_connected, 318 | activation_fn=activation_fn, optimizer=optimizer, 319 | gradient_clip=gradient_clip, scope="agent") 320 | self.target_net = DuelingDeepQNetwork( 321 | self.num_actions, state_shape=state_shape, 322 | convs=convs, fully_connected=fully_connected, 323 | activation_fn=activation_fn, optimizer=optimizer, 324 | gradient_clip=gradient_clip, scope="target") 325 | self.init_weights() 326 | 327 | ###################### Categorical Deep Q-Network agent ####################### 328 | 329 | 330 | class CatDQNAgent(Agent): 331 | 332 | def __init__( 333 | self, env, num_actions, state_shape=[8, 8, 5], 334 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128], 335 | activation_fn=tf.nn.relu, num_atoms=21, v=(-10, 10), 336 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 337 | save_path="rl_models", model_name="CatDQN"): 338 | 339 | super(CatDQNAgent, self).__init__( 340 | env, num_actions, state_shape=state_shape, 341 | save_path=save_path, model_name=model_name) 342 | 343 | tf.reset_default_graph() 344 | self.agent_net = CategoricalDeepQNetwork( 345 | self.num_actions, state_shape=state_shape, 346 | convs=convs, fully_connected=fully_connected, 347 | activation_fn=tf.nn.relu, num_atoms=num_atoms, 348 | v=v, optimizer=optimizer, scope="agent") 349 | self.target_net = CategoricalDeepQNetwork( 350 | self.num_actions, state_shape=state_shape, 351 | convs=convs, fully_connected=fully_connected, 352 | activation_fn=tf.nn.relu, num_atoms=num_atoms, 353 | v=v, optimizer=optimizer, scope="target") 354 | self.init_weights() 355 | 356 | def update_agent_weights(self, sess, batch): 357 | 358 | # estimate categorical projection of the RHS of the Bellman equation 359 | agent_actions = self.agent_net.get_q_argmax(sess, batch.s_) 360 | probs_targets = self.target_net.cat_proj( 361 | sess, batch.s_, agent_actions, batch.r, 362 | batch.done, gamma=self.gamma) 363 | 364 | # update agent network 365 | self.agent_net.update(sess, batch.s, batch.a, probs_targets) 366 | 367 | ################## Quantile Regression Deep Q-Network agent ################### 368 | 369 | 370 | class QuantRegDQNAgent(Agent): 371 | 372 | def __init__( 373 | self, env, num_actions, state_shape=[8, 8, 5], 374 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128], 375 | activation_fn=tf.nn.relu, num_atoms=50, kappa=1.0, 376 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 377 | save_path="rl_models", model_name="QuantRegDQN"): 378 | 379 | super(QuantRegDQNAgent, self).__init__( 380 | env, num_actions, state_shape=state_shape, 381 | save_path=save_path, model_name=model_name) 382 | 383 | tf.reset_default_graph() 384 | self.agent_net = QuantileDeepQNetwork( 385 | self.num_actions, state_shape=state_shape, 386 | convs=convs, fully_connected=fully_connected, 387 | activation_fn=tf.nn.relu, num_atoms=num_atoms, 388 | kappa=kappa, optimizer=optimizer, scope="agent") 389 | self.target_net = QuantileDeepQNetwork( 390 | self.num_actions, state_shape=state_shape, 391 | convs=convs, fully_connected=fully_connected, 392 | activation_fn=tf.nn.relu, num_atoms=num_atoms, 393 | kappa=kappa, optimizer=optimizer, scope="target") 394 | self.init_weights() 395 | 396 | def update_agent_weights(self, sess, batch): 397 | 398 | # calculate target atoms produced by Bellman operator 399 | agent_actions = self.agent_net.get_q_argmax(sess, batch.s_) 400 | next_atoms = self.target_net.get_atoms_sa( 401 | sess, batch.s_, agent_actions) 402 | target_atoms = batch.r[:, None] + \ 403 | self.gamma * next_atoms * (1 - batch.done[:, None]) 404 | # update agent network 405 | self.agent_net.update(sess, batch.s, batch.a, target_atoms) 406 | 407 | ########################### Soft Actor-Critic agent ########################### 408 | 409 | 410 | class SACAgent(Agent): 411 | 412 | def __init__( 413 | self, env, num_actions, state_shape=[8, 8, 5], 414 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128], 415 | activation_fn=tf.nn.relu, 416 | temperature=1.0, 417 | optimizers=[tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 418 | tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 419 | tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32)], 420 | save_path="rl_models", model_name="SAC"): 421 | 422 | super(SACAgent, self).__init__( 423 | env, num_actions, state_shape=state_shape, 424 | save_path=save_path, model_name=model_name) 425 | 426 | tf.reset_default_graph() 427 | self.agent_net = SoftActorCriticNetwork( 428 | self.num_actions, state_shape=state_shape, 429 | convs=convs, fully_connected=fully_connected, 430 | activation_fn=activation_fn, 431 | optimizers=optimizers, scope="agent") 432 | self.target_net = SoftActorCriticNetwork( 433 | self.num_actions, state_shape=state_shape, 434 | convs=convs, fully_connected=fully_connected, 435 | activation_fn=activation_fn, 436 | optimizers=optimizers, scope="target") 437 | self.init_weights() 438 | self.t = temperature 439 | 440 | def update_agent_weights(self, sess, batch): 441 | 442 | probs = self.agent_net.get_p_values_s(sess, batch.s) 443 | c = probs.cumsum(axis=1) 444 | u = np.random.rand(len(c), 1) 445 | actions = (u < c).argmax(axis=1) 446 | 447 | v_values = self.agent_net.get_v_values_s(sess, batch.s).reshape(-1) 448 | v_values_next = self.target_net.get_v_values_s( 449 | sess, batch.s_).reshape(-1) 450 | q_values = self.agent_net.get_q_values_s(sess, batch.s) 451 | p_logits = self.agent_net.get_p_logits_s(sess, batch.s) 452 | 453 | x = np.arange(self.batch_size) 454 | q_values_selected = q_values[x, actions] 455 | p_logits_selected = p_logits[x, actions] 456 | 457 | q_targets = batch.r / self.t + \ 458 | self.gamma * v_values_next * (1 - batch.done) 459 | v_targets = q_values_selected - p_logits_selected 460 | p_targets = q_values_selected - v_values 461 | 462 | # update agent network 463 | self.agent_net.update_q(sess, batch.s, batch.a, q_targets) 464 | self.agent_net.update_v(sess, batch.s, v_targets) 465 | self.agent_net.update_p(sess, batch.s, actions, p_targets) 466 | -------------------------------------------------------------------------------- /environments/doom/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a gym-like wrapper over [VizDoom](http://vizdoom.cs.put.edu.pl) environment for deep reinforcement learning research. To install **vizdoom** python library, please follow the instructions from [VizDoom GitHub repository](https://github.com/mwydmuch/ViZDoom). 4 | 5 | ## How to use 6 | 7 | ```python 8 | from environments.doom.doom_env import DoomBasic 9 | 10 | # create environment instance 11 | env = DoomBasic() 12 | 13 | # reset the environment and get preprocessed 84x84x1 observation 14 | obs = env.reset() 15 | 16 | # make step in the environment 17 | next_obs, reward, done = env.step(action) 18 | 19 | # get original RGB image of current observation 20 | rgb_obs = env.get_obs_rgb() 21 | ``` 22 | 23 | # Available scenarios 24 | 25 | Here is the list of available scenarios (copied and adapted from [here](https://github.com/mwydmuch/ViZDoom/blob/master/scenarios/README.md)): 26 | 27 | |![doom1](img/basic.png) | ![doom2](img/def_center.png) | ![doom3](img/def_line.png) 28 | |:---:|:---:|:---:| 29 | |**Basic**|**Defend the center**|**Defend the line**| 30 | |![doom4](img/health.png) | ![doom5](img/predict.png) | ![doom6](img/corridor.png) 31 | |**Health gathering**|**Predict the position**|**Deadly corridor**| 32 | 33 | ## Basic 34 | 35 | Map is a rectangle with gray walls, ceiling and floor. Player is spawned along the longer wall, in the center. A red, circular monster is spawned randomly somewhere along the opposite wall. Player can only go left/right and shoot. 1 hit is enough to kill the monster. Episode finishes when monster is killed or on timeout. 36 | 37 | - **Actions:** move left, move right, shoot 38 | - **Rewards:** +100 for killing the monster; -6 for missing; -1 otherwise 39 | - **Episode termination:** monster is killed or after 300 time steps 40 | 41 | ## Defend the center 42 | 43 | Map is a large circle. Player is spawned in the exact center. 5 melee-only, monsters are spawned along the wall. Monsters are 44 | killed after a single shot. After dying each monster is respawned after some time. Ammo is limited to 26 bullets. Episode ends when the player dies or on timeout. 45 | 46 | - **Actions:** turn left, turn right, shoot 47 | - **Rewards:** +1 for killing the monster; -0.1 for missing; -0.1 for losing health; -1 for death; 0 otherwise 48 | - **Episode termination:** player is killed or after 2100 time steps 49 | 50 | ## Defend the line 51 | 52 | Map is a rectangle. Player is spawned along the longer wall, in the center. 3 melee-only and 3 shooting monsters are spawned along the oposite wall. Monsters are killed after a single shot, at first. After dying each monster is respawned after some time and can endure more damage. Episode ends when the player dies or on timeout. 53 | 54 | - **Actions:** move left, move right, turn left, turn right, shoot 55 | - **Rewards:** +1 for killing the monster; -0.1 for missing; -0.1 for losing health; -1 for death; 0 otherwise 56 | - **Episode termination:** player is killed or after 2100 time steps 57 | 58 | ## Health gathering 59 | 60 | Map is a rectangle with green, acidic floor which hurts the player periodically. Initially there are some medkits spread uniformly over the map. A new medkit falls from the skies every now and then. Medkits heal some portions of player's health - to survive agent needs to pick them up. Episode finishes after player's death or on timeout. 61 | 62 | - **Actions:** turn left, turn right, move forward 63 | - **Rewards:** -100 for death; +1 otherwise 64 | - **Episode termination:** player is killed or after 2100 time steps 65 | 66 | ## Predict position 67 | 68 | Map is a rectangle room. Player is spawned along the longer wall, in the center. A monster is spawned randomly somewhere along the opposite wall and walks between left and right corners along the wall. Player is equipped with a rocket launcher and a single rocket. Episode ends when missle hits a wall/the monster or on timeout. 69 | 70 | - **Actions:** move left, move right, shoot 71 | - **Rewards:** +1 for killing the monster; -0.0001 otherwise 72 | - **Episode termination:** monster is killed or after 300 time steps 73 | 74 | ## Deadly corridor 75 | 76 | Map is a corridor with shooting monsters on both sides (6 monsters in total). A green vest is placed at the oposite end of the corridor. Reward is proportional (negative or positive) to change of the distance between the player and the vest. If player ignores monsters on the sides and runs straight for the vest he will be killed somewhere along the way. 77 | 78 | - **Actions:** move left, move right, turn left, turn right, shoot 79 | - **Rewards:** -100 for death; +dX (-dx) for getting closer (further) to the vest 80 | - **Episode termination:** player is killed or after 4200 time steps 81 | -------------------------------------------------------------------------------- /environments/doom/config/basic.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = basic.wad 6 | doom_map = map01 7 | 8 | # Rewards 9 | living_reward = -1 10 | 11 | # Rendering options 12 | screen_resolution = RES_320X240 13 | screen_format = CRCGCB 14 | render_hud = True 15 | render_crosshair = false 16 | render_weapon = true 17 | render_decals = false 18 | render_particles = false 19 | window_visible = true 20 | 21 | # make episodes start after 20 tics (after unholstering the gun) 22 | episode_start_time = 14 23 | 24 | # make episodes finish after 300 actions (tics) 25 | episode_timeout = 300 26 | 27 | # Available buttons 28 | available_buttons = 29 | { 30 | MOVE_LEFT 31 | MOVE_RIGHT 32 | ATTACK 33 | } 34 | 35 | # Game variables that will be in the state 36 | available_game_variables = { AMMO2 HEALTH } 37 | 38 | mode = PLAYER 39 | doom_skill = 5 40 | -------------------------------------------------------------------------------- /environments/doom/config/basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/basic.wad -------------------------------------------------------------------------------- /environments/doom/config/deadly_corridor.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = deadly_corridor.wad 6 | 7 | # Skill 5 is reccomanded for the scenario to be a challenge. 8 | doom_skill = 5 9 | 10 | # Rewards 11 | death_penalty = 100 12 | #living_reward = 0 13 | 14 | # Rendering options 15 | screen_resolution = RES_320X240 16 | screen_format = CRCGCB 17 | render_hud = true 18 | render_crosshair = false 19 | render_weapon = true 20 | render_decals = false 21 | render_particles = false 22 | window_visible = true 23 | 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | MOVE_LEFT 30 | MOVE_RIGHT 31 | ATTACK 32 | MOVE_FORWARD 33 | MOVE_BACKWARD 34 | TURN_LEFT 35 | TURN_RIGHT 36 | } 37 | 38 | # Game variables that will be in the state 39 | available_game_variables = { HEALTH } 40 | 41 | mode = PLAYER 42 | 43 | 44 | -------------------------------------------------------------------------------- /environments/doom/config/deadly_corridor.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/deadly_corridor.wad -------------------------------------------------------------------------------- /environments/doom/config/defend_the_center.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = defend_the_center.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = True 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | # make episodes finish after 2100 actions (tics) 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | ATTACK 32 | } 33 | 34 | # Game variables that will be in the state 35 | available_game_variables = { AMMO2 HEALTH } 36 | 37 | mode = PLAYER 38 | doom_skill = 3 39 | -------------------------------------------------------------------------------- /environments/doom/config/defend_the_center.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/defend_the_center.wad -------------------------------------------------------------------------------- /environments/doom/config/defend_the_line.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = defend_the_line.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_320X240 12 | screen_format = CRCGCB 13 | render_hud = True 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | # make episodes finish after 2100 actions (tics) 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | MOVE_LEFT 30 | MOVE_RIGHT 31 | TURN_LEFT 32 | TURN_RIGHT 33 | ATTACK 34 | } 35 | 36 | # Game variables that will be in the state 37 | available_game_variables = { AMMO2 HEALTH} 38 | 39 | mode = PLAYER 40 | doom_skill = 3 41 | -------------------------------------------------------------------------------- /environments/doom/config/defend_the_line.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/defend_the_line.wad -------------------------------------------------------------------------------- /environments/doom/config/health_gathering.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = health_gathering.wad 6 | 7 | # Each step is good for you! 8 | living_reward = 1 9 | # And death is not! 10 | death_penalty = 100 11 | 12 | # Rendering options 13 | screen_resolution = RES_320X240 14 | screen_format = CRCGCB 15 | render_hud = false 16 | render_crosshair = false 17 | render_weapon = false 18 | render_decals = false 19 | render_particles = false 20 | window_visible = true 21 | 22 | # make episodes finish after 2100 actions (tics) 23 | episode_timeout = 2100 24 | 25 | # Available buttons 26 | available_buttons = 27 | { 28 | TURN_LEFT 29 | TURN_RIGHT 30 | MOVE_FORWARD 31 | } 32 | 33 | # Game variables that will be in the state 34 | available_game_variables = { } 35 | 36 | mode = PLAYER -------------------------------------------------------------------------------- /environments/doom/config/health_gathering.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/health_gathering.wad -------------------------------------------------------------------------------- /environments/doom/config/predict_position.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = predict_position.wad 6 | 7 | # Rewards 8 | living_reward = -0.001 9 | 10 | # Rendering options 11 | screen_resolution = RES_800X450 12 | screen_format = CRCGCB 13 | render_hud = false 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 16 tics (after producing the rocket launcher) 21 | episode_start_time = 16 22 | 23 | # make episodes finish after 300 actions (tics) 24 | episode_timeout = 300 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | ATTACK 32 | } 33 | 34 | # Empty list is allowed, in case you are lazy. 35 | available_game_variables = { HEALTH } 36 | 37 | game_args += +sv_noautoaim 1 38 | 39 | mode = PLAYER 40 | doom_skill = 1 41 | -------------------------------------------------------------------------------- /environments/doom/config/predict_position.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/predict_position.wad -------------------------------------------------------------------------------- /environments/doom/doom_env.py: -------------------------------------------------------------------------------- 1 | from vizdoom import DoomGame 2 | from PIL import Image 3 | import numpy as np 4 | 5 | from IPython import display 6 | import matplotlib.pyplot as plt 7 | 8 | ########################## Doom environment template class ########################## 9 | 10 | class DoomEnvironment: 11 | 12 | def __init__(self, scenario, path_to_config="doom/config"): 13 | self.game = DoomGame() 14 | self.game.load_config(path_to_config+"/"+scenario+".cfg") 15 | self.game.set_doom_scenario_path(path_to_config+"/"+scenario+".wad") 16 | self.game.set_window_visible(False) 17 | self.game.init() 18 | self.num_actions = len(self.game.get_available_buttons()) 19 | 20 | def reset(self): 21 | self.game.new_episode() 22 | game_state = self.game.get_state() 23 | obs = game_state.screen_buffer 24 | self.h, self.w = obs.shape[1:3] 25 | self.current_obs = self.preprocess_obs(obs) 26 | if self.game.get_available_game_variables_size() == 2: 27 | self.ammo, self.health = game_state.game_variables 28 | return self.get_obs() 29 | 30 | def get_obs(self): 31 | return self.current_obs[:, :, None] 32 | 33 | def get_obs_rgb(self): 34 | img = self.game.get_state().screen_buffer 35 | img = np.rollaxis(img, 0, 3) 36 | img = np.reshape(img, [self.h, self.w, 3]) 37 | return img.astype(np.uint8) 38 | 39 | def preprocess_obs(self, obs): 40 | img = np.rollaxis(obs, 0, 3) 41 | img = np.reshape(img, [self.h, self.w, 3]).astype(np.float32) 42 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 43 | img = Image.fromarray(img) 44 | img = img.resize((84, 84), Image.BILINEAR) 45 | img = np.array(img) 46 | return img.astype(np.uint8) 47 | 48 | def action_to_doom(self, a): 49 | action = [0 for i in range(self.num_actions)] 50 | action[int(a)] = 1 51 | return action 52 | 53 | def step(self, a): 54 | action = self.action_to_doom(a) 55 | reward = self.game.make_action(action) 56 | 57 | done = self.game.is_episode_finished() 58 | 59 | if done: 60 | new_obs = np.zeros_like(self.current_obs, dtype=np.uint8) 61 | else: 62 | game_state = self.game.get_state() 63 | new_obs = game_state.screen_buffer 64 | new_obs = self.preprocess_obs(new_obs) 65 | 66 | self.current_obs = new_obs 67 | 68 | return self.get_obs(), reward, done 69 | 70 | def watch_random_play(self, max_ep_length=1000, frame_skip=4): 71 | self.reset() 72 | for i in range(max_ep_length): 73 | a = np.random.randint(self.num_actions) 74 | obs, reward, done = self.step(a) 75 | if done: break 76 | 77 | img = self.get_obs_rgb() 78 | if i % frame_skip == 0: 79 | plt.imshow(img) 80 | display.clear_output(wait=True) 81 | display.display(plt.gcf()) 82 | 83 | ####################################### Basic ####################################### 84 | 85 | class DoomBasic(DoomEnvironment): 86 | 87 | def __init__(self, path_to_config="doom/config"): 88 | super(DoomBasic, self).__init__(scenario="basic", 89 | path_to_config=path_to_config) 90 | 91 | ################################## Defend the line ################################## 92 | 93 | class DoomDefendTheLine(DoomEnvironment): 94 | 95 | def __init__(self, path_to_config="doom/config"): 96 | super(DoomDefendTheLine, self).__init__(scenario="defend_the_line", 97 | path_to_config=path_to_config) 98 | 99 | def step(self, a): 100 | action = self.action_to_doom(a) 101 | reward = self.game.make_action(action) 102 | 103 | done = self.game.is_episode_finished() 104 | 105 | if done: 106 | new_obs = np.zeros_like(self.current_obs, dtype=np.uint8) 107 | else: 108 | game_state = self.game.get_state() 109 | new_obs = game_state.screen_buffer 110 | new_obs = self.preprocess_obs(new_obs) 111 | new_ammo, new_health = game_state.game_variables 112 | 113 | if (reward == 1.0): reward += 0.1 114 | if (new_ammo < self.ammo): reward -= 0.1 115 | if (new_health < self.health): reward -= 0.1 116 | 117 | self.ammo, self.health = new_ammo, new_health 118 | 119 | self.current_obs = new_obs 120 | 121 | return self.get_obs(), reward, done 122 | 123 | ################################# Defend the center ################################# 124 | 125 | class DoomDefendTheCenter(DoomEnvironment): 126 | 127 | def __init__(self, path_to_config="doom/config"): 128 | super(DoomDefendTheCenter, self).__init__(scenario="defend_the_center", 129 | path_to_config=path_to_config) 130 | 131 | def step(self, a): 132 | action = self.action_to_doom(a) 133 | reward = self.game.make_action(action) 134 | 135 | done = self.game.is_episode_finished() 136 | 137 | if done: 138 | new_obs = np.zeros_like(self.current_obs, dtype=np.uint8) 139 | else: 140 | game_state = self.game.get_state() 141 | new_obs = game_state.screen_buffer 142 | new_obs = self.preprocess_obs(new_obs) 143 | new_ammo, new_health = game_state.game_variables 144 | 145 | if (reward == 1.0): reward += 0.1 146 | if (new_ammo < self.ammo): reward -= 0.1 147 | if (new_health < self.health): reward -= 0.1 148 | 149 | self.ammo, self.health = new_ammo, new_health 150 | 151 | self.current_obs = new_obs 152 | 153 | return self.get_obs(), reward, done 154 | 155 | ############################### Predict the position ################################ 156 | 157 | class DoomPredictThePosition(DoomEnvironment): 158 | 159 | def __init__(self, path_to_config="doom/config"): 160 | super(DoomPredictThePosition, self).__init__(scenario="predict_position", 161 | path_to_config=path_to_config) 162 | 163 | ############################### Predict the position ################################ 164 | 165 | class DoomHealthGathering(DoomEnvironment): 166 | 167 | def __init__(self, path_to_config="doom/config"): 168 | super(DoomHealthGathering, self).__init__(scenario="health_gathering", 169 | path_to_config=path_to_config) 170 | 171 | ############################### Predict the position ################################ 172 | 173 | class DoomDeadlyCorridor(DoomEnvironment): 174 | 175 | def __init__(self, path_to_config="doom/config"): 176 | super(DoomDeadlyCorridor, self).__init__(scenario="deadly_corridor", 177 | path_to_config=path_to_config) -------------------------------------------------------------------------------- /environments/doom/img/basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/basic.png -------------------------------------------------------------------------------- /environments/doom/img/corridor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/corridor.png -------------------------------------------------------------------------------- /environments/doom/img/def_center.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/def_center.png -------------------------------------------------------------------------------- /environments/doom/img/def_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/def_line.png -------------------------------------------------------------------------------- /environments/doom/img/health.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/health.png -------------------------------------------------------------------------------- /environments/doom/img/predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/predict.png -------------------------------------------------------------------------------- /environments/snake/snake_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | class Snake: 5 | 6 | def __init__(self, grid_size=(8, 8)): 7 | """ 8 | Classic Snake game implemented as Gym environment. 9 | 10 | Parameters 11 | ---------- 12 | grid_size: tuple 13 | tuple of two parameters: (height, width) 14 | """ 15 | 16 | self.height, self.width = grid_size 17 | self.state = np.zeros(grid_size) 18 | self.x, self.y = [], [] 19 | self.dir = None 20 | self.food = None 21 | self.opt_tab = self.opt_table(grid_size) 22 | 23 | def reset(self): 24 | """ 25 | Resets the state of the environment and returns an initial observation. 26 | 27 | Returns 28 | ------- 29 | observation: numpy.array of size (width, height, 1) 30 | the initial observation of the space. 31 | """ 32 | 33 | self.state = np.zeros((self.height, self.width)) 34 | 35 | x_tail = np.random.randint(self.height) 36 | y_tail = np.random.randint(self.width) 37 | 38 | xs = [x_tail, ] 39 | ys = [y_tail, ] 40 | 41 | for i in range(2): 42 | nbrs = self.get_neighbors(xs[-1], ys[-1]) 43 | while 1: 44 | idx = np.random.randint(0, len(nbrs)) 45 | x0 = nbrs[idx][0] 46 | y0 = nbrs[idx][1] 47 | occupied = [list(pt) for pt in zip(xs, ys)] 48 | if not [x0, y0] in occupied: 49 | xs.append(x0) 50 | ys.append(y0) 51 | break 52 | 53 | for x_t, y_t in list(zip(xs, ys)): 54 | self.state[x_t, y_t] = 1 55 | 56 | self.generate_food() 57 | self.x = xs 58 | self.y = ys 59 | self.update_dir() 60 | 61 | return self.get_state() 62 | 63 | def step(self, a): 64 | """ 65 | Run one timestep of the environment's dynamics. When end of 66 | episode is reached, you are responsible for calling `reset()` 67 | to reset this environment's state. 68 | 69 | Args 70 | ---- 71 | action: int from {0, 1, 2, 3} 72 | an action provided by the environment 73 | 74 | Returns 75 | ------- 76 | observation: numpy.array of size (width, height, 1) 77 | agent's observation of the current environment 78 | reward: int from {-1, 0, 1} 79 | amount of reward returned after previous action 80 | done: boolean 81 | whether the episode has ended, in which case further step() 82 | calls will return undefined results 83 | """ 84 | 85 | self.update_dir() 86 | x_, y_ = self.next_cell(self.x[-1], self.y[-1], a) 87 | 88 | # snake dies if hitting the walls 89 | if x_ < 0 or x_ == self.height or y_ < 0 or y_ == self.width: 90 | return self.get_state(), -1, True 91 | 92 | # snake dies if hitting its tail with head 93 | if self.state[x_, y_] == 1: 94 | if (x_ == self.x[0] and y_ == self.y[0]): 95 | pass 96 | else: 97 | return self.get_state(), -1, True 98 | 99 | self.x.append(x_) 100 | self.y.append(y_) 101 | 102 | # snake elongates after eating a food 103 | if self.state[x_, y_] == 3: 104 | self.state[x_, y_] = 1 105 | done = self.generate_food() 106 | return self.get_state(), 1, done 107 | 108 | # snake moves forward if cell ahead is empty 109 | # or currently occupied by its tail 110 | self.state[self.x[0], self.y[0]] = 0 111 | self.state[x_, y_] = 1 112 | self.x = self.x[1:] 113 | self.y = self.y[1:] 114 | return self.get_state(), 0, False 115 | 116 | def get_state(self): 117 | state = np.zeros((self.height, self.width, 5)) 118 | state[self.x[1:-1], self.y[1:-1], 0] = 1 119 | state[self.x[-1], self.y[-1], 1] = 1 120 | state[self.x[-2], self.y[-2], 2] = 1 121 | state[self.x[0], self.y[0], 3] = 1 122 | state[self.food[0], self.food[1], 4] = 1 123 | return state.astype(np.uint8) 124 | 125 | def generate_food(self): 126 | free = np.where(self.state == 0) 127 | if free[0].size == 0: 128 | return True 129 | else: 130 | idx = np.random.randint(free[0].size) 131 | self.food = free[0][idx], free[1][idx] 132 | self.state[self.food] = 3 133 | return False 134 | 135 | def next_cell(self, i, j, a): 136 | if a == 0: 137 | return i+self.dir[0], j+self.dir[1] 138 | if a == 1: 139 | return i-self.dir[1], j+self.dir[0] 140 | if a == 2: 141 | return i+self.dir[1], j-self.dir[0] 142 | 143 | def plot_state(self): 144 | state = self.get_state() 145 | img = sum([state[:,:,i]*(i+1) for i in range(5)]) 146 | plt.imshow(img, vmin=0, vmax=5, interpolation='nearest') 147 | 148 | def get_neighbors(self, i, j): 149 | """ 150 | Get all the neighbors of the point (i, j) 151 | (excluding (i, j)) 152 | """ 153 | h = self.height 154 | w = self.width 155 | nbrs = [[i + k, j + m] for k in [-1, 0, 1] for m in [-1, 0, 1] 156 | if i + k >=0 and i + k < h and j + m >= 0 and j + m < w 157 | and not (k == m) and not (k == -m)] 158 | return nbrs 159 | 160 | def update_dir(self): 161 | x_dir = self.x[-1] - self.x[-2] 162 | y_dir = self.y[-1] - self.y[-2] 163 | self.dir = (x_dir, y_dir) 164 | 165 | ########################## Optimal action selection ########################## 166 | 167 | def opt_table(self, grid_size): 168 | n = grid_size[0] 169 | t = np.zeros(grid_size, dtype=np.int) 170 | t[0] = np.arange(n) 171 | for i in range(n//2): 172 | t[1:,(n-1)-2*i] = np.arange(n-1) + n+2*i*(n-1) 173 | t[1:,(n-2)-2*i][::-1] = np.arange(n-1) + 2*n-1+2*i*(n-1) 174 | return t 175 | 176 | def opt_action(self): 177 | x, y = self.x[-1], self.y[-1] 178 | self.update_dir() 179 | n = self.height 180 | mod = n ** 2 181 | tab_xy = self.opt_tab[x, y] 182 | pos_a = -1 183 | for a in range(3): 184 | x_, y_ = self.next_cell(x, y, a) 185 | if (x_=0 and y_>=0): 186 | tab_xy_ = self.opt_tab[x_, y_] 187 | if ((tab_xy+1) % mod == tab_xy_): 188 | return a 189 | if ((tab_xy-1) % mod == tab_xy_): 190 | pos_a = a 191 | return pos_a -------------------------------------------------------------------------------- /environments/windy_grid_world/evil_wgw_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .wgw_env import WindyGridWorld 3 | 4 | 5 | class EvilWindyGridWorld(WindyGridWorld): 6 | 7 | def __init__( 8 | self, 9 | grid_size=(7, 10), 10 | stochasticity=0.1, 11 | visual=False): 12 | self.w, self.h = grid_size 13 | self.stochasticity = stochasticity 14 | self.visual = visual 15 | 16 | # x position of the wall 17 | self.x_wall = self.w // 2 18 | # y position of the hole in the wall 19 | self.y_hole = self.h - 4 20 | self.y_hole2 = self.h - 7 21 | 22 | self.reset() 23 | 24 | def move(self, a): 25 | """ find valid coordinates of the agent after executing action 26 | """ 27 | x, y = self.pos 28 | self.field[x, y] = 0 29 | x, y = self.wind_shift(x, y) 30 | 31 | if a == 0: 32 | x_, y_ = x + 1, y 33 | if a == 1: 34 | x_, y_ = x, y + 1 35 | if a == 2: 36 | x_, y_ = x - 1, y 37 | if a == 3: 38 | x_, y_ = x, y - 1 39 | 40 | # check if new position does not conflict with the wall 41 | if x_ == self.x_wall and y != self.y_hole and y != self.y_hole2: 42 | x_, y_ = x, y 43 | return self.clip_xy(x_, y_) 44 | 45 | def reset(self): 46 | """ resets the environment 47 | """ 48 | self.field = np.zeros((self.w, self.h)) 49 | self.field[self.x_wall, :] = 1 50 | self.field[self.x_wall, self.y_hole] = 0 51 | self.field[self.x_wall, self.y_hole2] = 0 52 | self.field[self.x_wall + 1, self.y_hole2 + 1] = -1 53 | self.field[self.x_wall + 1, self.y_hole2 - 1] = -1 54 | self.field[0, 0] = 2 55 | self.pos = (0, 0) 56 | obs = self.get_observation() 57 | return obs 58 | 59 | def step(self, a): 60 | """ makes a step in the environment 61 | """ 62 | 63 | if np.random.rand() < self.stochasticity: 64 | a = np.random.randint(4) 65 | 66 | self.field[self.pos] = 0 67 | self.pos = self.move(a) 68 | self.field[self.pos] = 2 69 | 70 | done = False 71 | reward = 0 72 | if self.pos == (self.w - 1, 0): 73 | # episode finished successfully 74 | done = True 75 | reward = 1 76 | if (self.pos == (self.x_wall + 1, self.y_hole2 + 1) or 77 | self.pos == (self.x_wall + 1, self.y_hole2 - 1)): 78 | # episode finished unsuccessfully 79 | done = True 80 | reward = -1 81 | 82 | next_obs = self.get_observation() 83 | return next_obs, reward, done 84 | -------------------------------------------------------------------------------- /environments/windy_grid_world/wgw_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import time 4 | from IPython import display 5 | 6 | 7 | class WindyGridWorld: 8 | 9 | def __init__( 10 | self, 11 | grid_size=(11, 14), 12 | stochasticity=0.1, 13 | visual=False): 14 | self.w, self.h = grid_size 15 | self.stochasticity = stochasticity 16 | self.visual = visual 17 | 18 | # x position of the wall 19 | self.x_wall = self.w // 2 20 | # y position of the hole in the wall 21 | self.y_hole = self.h - 4 22 | 23 | self.reset() 24 | 25 | def clip_xy(self, x, y): 26 | """ clip coordinates if they go beyond the grid 27 | """ 28 | x_ = np.clip(x, 0, self.w - 1) 29 | y_ = np.clip(y, 0, self.h - 1) 30 | return x_, y_ 31 | 32 | def wind_shift(self, x, y): 33 | """ apply wind shift to areas where wind is blowing 34 | """ 35 | if x == 1: 36 | return self.clip_xy(x, y + 1) 37 | elif x > 1 and x < self.x_wall: 38 | return self.clip_xy(x, y + 2) 39 | else: 40 | return x, y 41 | 42 | def move(self, a): 43 | """ find valid coordinates of the agent after executing action 44 | """ 45 | x, y = self.pos 46 | self.field[x, y] = 0 47 | x, y = self.wind_shift(x, y) 48 | 49 | if a == 0: 50 | x_, y_ = x + 1, y 51 | if a == 1: 52 | x_, y_ = x, y + 1 53 | if a == 2: 54 | x_, y_ = x - 1, y 55 | if a == 3: 56 | x_, y_ = x, y - 1 57 | 58 | # check if new position does not conflict with the wall 59 | if x_ == self.x_wall and y != self.y_hole: 60 | x_, y_ = x, y 61 | return self.clip_xy(x_, y_) 62 | 63 | def get_observation(self): 64 | if self.visual: 65 | obs = np.rot90(self.field)[:, :, None] 66 | else: 67 | obs = self.pos 68 | return obs 69 | 70 | def reset(self): 71 | """ resets the environment 72 | """ 73 | self.field = np.zeros((self.w, self.h)) 74 | self.field[self.x_wall, :] = 1 75 | self.field[self.x_wall, self.y_hole] = 0 76 | self.field[0, 0] = 2 77 | self.pos = (0, 0) 78 | obs = self.get_observation() 79 | return obs 80 | 81 | def step(self, a): 82 | """ makes a step in the environment 83 | """ 84 | 85 | if np.random.rand() < self.stochasticity: 86 | a = np.random.randint(4) 87 | 88 | self.field[self.pos] = 0 89 | self.pos = self.move(a) 90 | self.field[self.pos] = 2 91 | 92 | done = False 93 | reward = 0 94 | if self.pos == (self.w - 1, 0): 95 | # episode finished successfully 96 | done = True 97 | reward = 1 98 | next_obs = self.get_observation() 99 | return next_obs, reward, done 100 | 101 | def play_with_policy(self, policy, max_iter=100, visualize=True): 102 | """ play with given policy 103 | returns: 104 | episode return, number of time steps 105 | """ 106 | self.reset() 107 | for i in range(max_iter): 108 | a = np.argmax(policy[self.pos]) 109 | next_obs, reward, done = self.step(a) 110 | # plot grid world state 111 | if visualize: 112 | img = np.rot90(1-self.field) 113 | plt.imshow(img, cmap="gray") 114 | display.clear_output(wait=True) 115 | display.display(plt.gcf()) 116 | time.sleep(0.01) 117 | if done: 118 | break 119 | if visualize: 120 | display.clear_output(wait=True) 121 | return reward, i+1 122 | -------------------------------------------------------------------------------- /img/dqn_categorical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_categorical.png -------------------------------------------------------------------------------- /img/dqn_classic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_classic.png -------------------------------------------------------------------------------- /img/dqn_dueling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_dueling.png -------------------------------------------------------------------------------- /img/dqn_quantile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_quantile.png -------------------------------------------------------------------------------- /img/sac_p_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/sac_p_network.png -------------------------------------------------------------------------------- /img/sac_q_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/sac_q_network.png -------------------------------------------------------------------------------- /img/sac_v_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/sac_v_network.png -------------------------------------------------------------------------------- /methods.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import tensorflow as tf 4 | import tensorflow.contrib.layers as layers 5 | from tensorflow.contrib.layers import convolution2d as conv 6 | from tensorflow.contrib.layers import fully_connected as fc 7 | from tensorflow.contrib.layers import xavier_initializer as xavier 8 | 9 | ############################################################################### 10 | ################################ Core modules ################################# 11 | ############################################################################### 12 | 13 | 14 | def conv_module(input_layer, convs, activation_fn=tf.nn.relu): 15 | """ convolutional module 16 | """ 17 | out = input_layer 18 | for num_outputs, kernel_size, stride in convs: 19 | out = conv( 20 | out, 21 | num_outputs=num_outputs, 22 | kernel_size=kernel_size, 23 | stride=stride, 24 | padding="VALID", 25 | activation_fn=activation_fn) 26 | return out 27 | 28 | 29 | def fc_module(input_layer, fully_connected, activation_fn=tf.nn.relu): 30 | """ fully connected module 31 | """ 32 | out = input_layer 33 | for num_outputs in fully_connected: 34 | out = fc( 35 | out, 36 | num_outputs=num_outputs, 37 | activation_fn=activation_fn, 38 | weights_initializer=xavier()) 39 | return out 40 | 41 | 42 | def full_module( 43 | input_layer, convs, fully_connected, 44 | num_outputs, activation_fn=tf.nn.relu): 45 | """ convolutional + fully connected + output 46 | """ 47 | out = input_layer 48 | out = conv_module(out, convs, activation_fn) 49 | out = layers.flatten(out) 50 | out = fc_module(out, fully_connected, activation_fn) 51 | out = fc_module(out, [num_outputs], None) 52 | return out 53 | 54 | ############################################################################### 55 | ############################### Deep Q-Network ################################ 56 | ############################################################################### 57 | 58 | 59 | class DeepQNetwork: 60 | 61 | def __init__( 62 | self, 63 | num_actions, 64 | state_shape=[8, 8, 5], 65 | convs=[[32, 4, 2], [64, 2, 1]], 66 | fully_connected=[128], 67 | activation_fn=tf.nn.relu, 68 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 69 | gradient_clip=10.0, 70 | scope="dqn", 71 | reuse=False): 72 | 73 | with tf.variable_scope(scope, reuse=reuse): 74 | 75 | ################### Neural network architecture ################### 76 | 77 | input_shape = [None] + state_shape 78 | self.input_states = tf.placeholder( 79 | dtype=tf.float32, shape=input_shape) 80 | 81 | self.q_values = full_module( 82 | self.input_states, convs, fully_connected, 83 | num_actions, activation_fn) 84 | 85 | ##################### Optimization procedure ###################### 86 | 87 | # convert input actions to indices for q-values selection 88 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None]) 89 | indices_range = tf.range(tf.shape(self.input_actions)[0]) 90 | action_indices = tf.stack( 91 | [indices_range, self.input_actions], axis=1) 92 | 93 | # select q-values for input actions 94 | self.q_values_selected = tf.gather_nd( 95 | self.q_values, action_indices) 96 | 97 | # select best actions (according to q-values) 98 | self.q_argmax = tf.argmax(self.q_values, axis=1) 99 | 100 | # define loss function and update rule 101 | self.q_targets = tf.placeholder(dtype=tf.float32, shape=[None]) 102 | self.loss = tf.losses.huber_loss( 103 | self.q_targets, self.q_values_selected, delta=gradient_clip) 104 | self.update_model = optimizer.minimize(self.loss) 105 | 106 | def get_q_values_s(self, sess, states): 107 | feed_dict = {self.input_states: states} 108 | q_values = sess.run(self.q_values, feed_dict) 109 | return q_values 110 | 111 | def get_q_values_sa(self, sess, states, actions): 112 | feed_dict = {self.input_states: states, self.input_actions: actions} 113 | q_values_selected = sess.run(self.q_values_selected, feed_dict) 114 | return q_values_selected 115 | 116 | def get_q_argmax(self, sess, states): 117 | feed_dict = {self.input_states: states} 118 | q_argmax = sess.run(self.q_argmax, feed_dict) 119 | return q_argmax 120 | 121 | def update(self, sess, states, actions, q_targets): 122 | feed_dict = {self.input_states: states, 123 | self.input_actions: actions, 124 | self.q_targets: q_targets} 125 | sess.run(self.update_model, feed_dict) 126 | 127 | ############################################################################### 128 | ########################### Dueling Deep Q-Network ############################ 129 | ############################################################################### 130 | 131 | 132 | class DuelingDeepQNetwork: 133 | 134 | def __init__( 135 | self, 136 | num_actions, 137 | state_shape=[8, 8, 5], 138 | convs=[[32, 4, 2], [64, 2, 1]], 139 | fully_connected=[64], 140 | activation_fn=tf.nn.relu, 141 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 142 | gradient_clip=10.0, 143 | scope="duel_dqn", 144 | reuse=False): 145 | 146 | with tf.variable_scope(scope, reuse=reuse): 147 | 148 | ################### Neural network architecture ################### 149 | 150 | input_shape = [None] + state_shape 151 | self.input_states = tf.placeholder( 152 | dtype=tf.float32, shape=input_shape) 153 | 154 | out = conv_module(self.input_states, convs, activation_fn) 155 | val, adv = tf.split(out, num_or_size_splits=2, axis=3) 156 | self.v_values = full_module( 157 | val, [], fully_connected, 1, activation_fn) 158 | self.a_values = full_module( 159 | adv, [], fully_connected, num_actions, activation_fn) 160 | 161 | a_values_mean = tf.reduce_mean( 162 | self.a_values, axis=1, keepdims=True) 163 | a_values_centered = tf.subtract(self.a_values, a_values_mean) 164 | self.q_values = self.v_values + a_values_centered 165 | 166 | ##################### Optimization procedure ###################### 167 | 168 | # convert input actions to indices for q-values selection 169 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None]) 170 | indices_range = tf.range(tf.shape(self.input_actions)[0]) 171 | action_indices = tf.stack( 172 | [indices_range, self.input_actions], axis=1) 173 | 174 | # select q-values for input actions 175 | self.q_values_selected = tf.gather_nd( 176 | self.q_values, action_indices) 177 | 178 | # select best actions (according to q-values) 179 | self.q_argmax = tf.argmax(self.q_values, axis=1) 180 | 181 | # define loss function and update rule 182 | self.q_targets = tf.placeholder(dtype=tf.float32, shape=[None]) 183 | self.loss = tf.losses.huber_loss( 184 | self.q_targets, self.q_values_selected, delta=gradient_clip) 185 | self.update_model = optimizer.minimize(self.loss) 186 | 187 | def get_q_values_s(self, sess, states): 188 | feed_dict = {self.input_states: states} 189 | q_values = sess.run(self.q_values, feed_dict) 190 | return q_values 191 | 192 | def get_q_values_sa(self, sess, states, actions): 193 | feed_dict = {self.input_states: states, self.input_actions: actions} 194 | q_values_selected = sess.run(self.q_values_selected, feed_dict) 195 | return q_values_selected 196 | 197 | def get_q_argmax(self, sess, states): 198 | feed_dict = {self.input_states: states} 199 | q_argmax = sess.run(self.q_argmax, feed_dict) 200 | return q_argmax 201 | 202 | def update(self, sess, states, actions, q_targets): 203 | feed_dict = {self.input_states: states, 204 | self.input_actions: actions, 205 | self.q_targets: q_targets} 206 | sess.run(self.update_model, feed_dict) 207 | 208 | ############################################################################### 209 | ######################### Categorical Deep Q-Network ########################## 210 | ############################################################################### 211 | 212 | 213 | class CategoricalDeepQNetwork: 214 | 215 | def __init__( 216 | self, 217 | num_actions, 218 | state_shape=[8, 8, 5], 219 | convs=[[32, 4, 2], [64, 2, 1]], 220 | fully_connected=[128], 221 | num_atoms=21, 222 | v=(-10, 10), 223 | activation_fn=tf.nn.relu, 224 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 225 | scope="cat_dqn", 226 | reuse=False): 227 | 228 | with tf.variable_scope(scope, reuse=reuse): 229 | 230 | ################### Neural network architecture ################### 231 | 232 | input_shape = [None] + state_shape 233 | self.input_states = tf.placeholder( 234 | dtype=tf.float32, shape=input_shape) 235 | 236 | # distribution parameters 237 | self.num_atoms = num_atoms 238 | self.v_min, self.v_max = v 239 | self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) 240 | self.z = np.linspace( 241 | start=self.v_min, stop=self.v_max, num=num_atoms) 242 | 243 | # main module 244 | out = full_module( 245 | self.input_states, convs, fully_connected, 246 | num_outputs=num_actions*num_atoms, activation_fn=activation_fn) 247 | 248 | self.logits = tf.reshape(out, shape=[-1, num_actions, num_atoms]) 249 | self.probs = tf.nn.softmax(self.logits, axis=2) 250 | self.q_values = tf.reduce_sum( 251 | tf.multiply(self.probs, self.z), axis=2) 252 | 253 | ##################### Optimization procedure ###################### 254 | 255 | # convert input actions to indices for probs and q-values selection 256 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None]) 257 | indices_range = tf.range(tf.shape(self.input_actions)[0]) 258 | action_indices = tf.stack( 259 | [indices_range, self.input_actions], axis=1) 260 | 261 | # select q-values and probs for input actions 262 | self.q_values_selected = tf.gather_nd( 263 | self.q_values, action_indices) 264 | self.probs_selected = tf.gather_nd(self.probs, action_indices) 265 | 266 | # select best actions (according to q-values) 267 | self.q_argmax = tf.argmax(self.q_values, axis=1) 268 | 269 | # define loss function and update rule 270 | self.probs_targets = tf.placeholder( 271 | dtype=tf.float32, shape=[None, self.num_atoms]) 272 | self.loss = -tf.reduce_sum( 273 | self.probs_targets * tf.log(self.probs_selected+1e-6)) 274 | self.update_model = optimizer.minimize(self.loss) 275 | 276 | def get_q_values_s(self, sess, states): 277 | feed_dict = {self.input_states: states} 278 | q_values = sess.run(self.q_values, feed_dict) 279 | return q_values 280 | 281 | def get_q_values_sa(self, sess, states, actions): 282 | feed_dict = {self.input_states: states, self.input_actions: actions} 283 | q_values_selected = sess.run(self.q_values_selected, feed_dict) 284 | return q_values_selected 285 | 286 | def get_q_argmax(self, sess, states): 287 | feed_dict = {self.input_states: states} 288 | q_argmax = sess.run(self.q_argmax, feed_dict) 289 | return q_argmax 290 | 291 | def get_probs_s(self, sess, states): 292 | feed_dict = {self.input_states: states} 293 | probs = sess.run(self.probs, feed_dict) 294 | return probs 295 | 296 | def get_probs_sa(self, sess, states, actions): 297 | feed_dict = {self.input_states: states, self.input_actions: actions} 298 | probs_selected = sess.run(self.probs_selected, feed_dict) 299 | return probs_selected 300 | 301 | def update(self, sess, states, actions, probs_targets): 302 | feed_dict = {self.input_states: states, 303 | self.input_actions: actions, 304 | self.probs_targets: probs_targets} 305 | sess.run(self.update_model, feed_dict) 306 | 307 | def cat_proj(self, sess, states, actions, rewards, done, gamma=0.99): 308 | """ 309 | Categorical algorithm from https://arxiv.org/abs/1707.06887 310 | """ 311 | 312 | atoms_targets = rewards[:, None] + gamma * self.z * (1 - done[:, None]) 313 | tz = np.clip(atoms_targets, self.v_min, self.v_max) 314 | tz_z = tz[:, None, :] - self.z[None, :, None] 315 | tz_z = np.clip((1.0 - (np.abs(tz_z) / self.delta_z)), 0, 1) 316 | 317 | probs = self.get_probs_sa(sess, states, actions) 318 | probs_targets = np.einsum('bij,bj->bi', tz_z, probs) 319 | 320 | return probs_targets 321 | 322 | ############################################################################### 323 | ########################### Quantile Deep Q-Network ########################### 324 | ############################################################################### 325 | 326 | 327 | class QuantileDeepQNetwork: 328 | 329 | def __init__( 330 | self, 331 | num_actions, 332 | state_shape=[8, 8, 5], 333 | convs=[[32, 4, 2], [64, 2, 1]], 334 | fully_connected=[128], 335 | num_atoms=50, 336 | kappa=1.0, 337 | activation_fn=tf.nn.relu, 338 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32), 339 | scope="qr_dqn", 340 | reuse=False): 341 | 342 | with tf.variable_scope(scope, reuse=reuse): 343 | 344 | ################### Neural network architecture ################### 345 | 346 | input_shape = [None] + state_shape 347 | self.input_states = tf.placeholder( 348 | dtype=tf.float32, shape=input_shape) 349 | 350 | # distribution parameters 351 | tau_min = 1 / (2 * num_atoms) 352 | tau_max = 1 - tau_min 353 | tau_vector = tf.lin_space( 354 | start=tau_min, stop=tau_max, num=num_atoms) 355 | 356 | # reshape tau to matrix for fast loss calculation 357 | tau_matrix = tf.tile(tau_vector, [num_atoms]) 358 | self.tau_matrix = tf.reshape( 359 | tau_matrix, shape=[num_atoms, num_atoms]) 360 | 361 | # main module 362 | out = full_module( 363 | self.input_states, convs, fully_connected, 364 | num_outputs=num_actions*num_atoms, activation_fn=activation_fn) 365 | self.atoms = tf.reshape(out, shape=[-1, num_actions, num_atoms]) 366 | self.q_values = tf.reduce_mean(self.atoms, axis=2) 367 | 368 | ##################### Optimization procedure ###################### 369 | 370 | # convert input actions to indices for atoms and q-values selection 371 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None]) 372 | indices_range = tf.range(tf.shape(self.input_actions)[0]) 373 | action_indices = tf.stack( 374 | [indices_range, self.input_actions], axis=1) 375 | 376 | # select q-values for input actions 377 | self.q_values_selected = tf.gather_nd( 378 | self.q_values, action_indices) 379 | self.atoms_selected = tf.gather_nd(self.atoms, action_indices) 380 | 381 | # select best actions (according to q-values) 382 | self.q_argmax = tf.argmax(self.q_values, axis=1) 383 | 384 | # reshape chosen atoms to matrix for fast loss calculation 385 | atoms_matrix = tf.tile(self.atoms_selected, [1, num_atoms]) 386 | self.atoms_matrix = tf.reshape( 387 | atoms_matrix, shape=[-1, num_atoms, num_atoms]) 388 | 389 | # reshape target atoms to matrix for fast loss calculation 390 | self.atoms_targets = tf.placeholder( 391 | dtype=tf.float32, shape=[None, num_atoms]) 392 | targets_matrix = tf.tile(self.atoms_targets, [1, num_atoms]) 393 | targets_matrix = tf.reshape( 394 | targets_matrix, shape=[-1, num_atoms, num_atoms]) 395 | self.targets_matrix = tf.transpose(targets_matrix, perm=[0, 2, 1]) 396 | 397 | # define loss function and update rule 398 | atoms_diff = self.targets_matrix - self.atoms_matrix 399 | delta_atoms_diff = tf.where( 400 | atoms_diff < 0, 401 | tf.ones_like(atoms_diff), 402 | tf.ones_like(atoms_diff)) 403 | huber_weights = tf.abs( 404 | self.tau_matrix - delta_atoms_diff) / num_atoms 405 | self.loss = tf.losses.huber_loss( 406 | self.targets_matrix, self.atoms_matrix, weights=huber_weights, 407 | delta=kappa, reduction=tf.losses.Reduction.SUM) 408 | self.update_model = optimizer.minimize(self.loss) 409 | 410 | def get_q_values_s(self, sess, states): 411 | feed_dict = {self.input_states: states} 412 | q_values = sess.run(self.q_values, feed_dict) 413 | return q_values 414 | 415 | def get_q_values_sa(self, sess, states, actions): 416 | feed_dict = {self.input_states: states, self.input_actions: actions} 417 | q_values_selected = sess.run(self.q_values_selected, feed_dict) 418 | return q_values_selected 419 | 420 | def get_q_argmax(self, sess, states): 421 | feed_dict = {self.input_states: states} 422 | q_argmax = sess.run(self.q_argmax, feed_dict) 423 | return q_argmax 424 | 425 | def get_atoms_s(self, sess, states): 426 | feed_dict = {self.input_states: states} 427 | atoms = sess.run(self.atoms, feed_dict) 428 | return probs 429 | 430 | def get_atoms_sa(self, sess, states, actions): 431 | feed_dict = {self.input_states: states, self.input_actions: actions} 432 | atoms_selected = sess.run(self.atoms_selected, feed_dict) 433 | return atoms_selected 434 | 435 | def update(self, sess, states, actions, atoms_targets): 436 | feed_dict = {self.input_states: states, 437 | self.input_actions: actions, 438 | self.atoms_targets: atoms_targets} 439 | sess.run(self.update_model, feed_dict) 440 | 441 | ############################################################################### 442 | ############################## Soft Actor-Critic ############################## 443 | ############################################################################### 444 | 445 | 446 | class SoftActorCriticNetwork: 447 | 448 | def __init__(self, num_actions, state_shape=[8, 8, 5], 449 | convs=[[32, 4, 2], [64, 2, 1]], 450 | fully_connected=[128], 451 | activation_fn=tf.nn.relu, 452 | optimizers=[tf.train.AdamOptimizer(2.5e-4), 453 | tf.train.AdamOptimizer(2.5e-4), 454 | tf.train.AdamOptimizer(2.5e-4)], 455 | scope="sac", reuse=False): 456 | 457 | with tf.variable_scope(scope, reuse=reuse): 458 | 459 | ################### Neural network architecture ################### 460 | 461 | input_shape = [None] + state_shape 462 | self.input_states = tf.placeholder( 463 | dtype=tf.float32, shape=input_shape) 464 | 465 | self.v_values = full_module( 466 | self.input_states, convs, fully_connected, 467 | 1, activation_fn) 468 | self.q_values = full_module( 469 | self.input_states, convs, fully_connected, 470 | num_actions, activation_fn) 471 | self.p_logits = full_module( 472 | self.input_states, convs, fully_connected, 473 | num_actions, activation_fn) 474 | self.p_values = layers.softmax(self.p_logits) 475 | 476 | ##################### Optimization procedure ###################### 477 | 478 | # convert =actions to indices for p-logits and q-values selection 479 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None]) 480 | indices_range = tf.range(tf.shape(self.input_actions)[0]) 481 | action_indices = tf.stack( 482 | [indices_range, self.input_actions], axis=1) 483 | 484 | q_values_selected = tf.gather_nd(self.q_values, action_indices) 485 | p_logits_selected = tf.gather_nd(self.p_logits, action_indices) 486 | 487 | # choose best actions (according to q-values) 488 | self.q_argmax = tf.argmax(self.q_values, axis=1) 489 | 490 | # define loss function and update rule 491 | self.q_targets = tf.placeholder(dtype=tf.float32, shape=[None]) 492 | self.v_targets = tf.placeholder(dtype=tf.float32, shape=[None]) 493 | self.p_targets = tf.placeholder(dtype=tf.float32, shape=[None]) 494 | 495 | q_loss = tf.losses.huber_loss(self.q_targets, q_values_selected) 496 | self.q_loss = tf.reduce_sum(q_loss) 497 | q_optimizer = optimizers[0] 498 | 499 | v_loss = tf.losses.huber_loss( 500 | self.v_targets[:, None], self.v_values) 501 | self.v_loss = tf.reduce_sum(v_loss) 502 | v_optimizer = optimizers[1] 503 | 504 | p_loss = tf.losses.huber_loss(self.p_targets, p_logits_selected) 505 | self.p_loss = tf.reduce_sum(p_loss) 506 | p_optimizer = optimizers[2] 507 | 508 | self.update_q_values = q_optimizer.minimize(self.q_loss) 509 | self.update_v_values = v_optimizer.minimize(self.v_loss) 510 | self.update_p_logits = p_optimizer.minimize(self.p_loss) 511 | 512 | def get_q_argmax(self, sess, states): 513 | feed_dict = {self.input_states: states} 514 | q_argmax = sess.run(self.q_argmax, feed_dict) 515 | return q_argmax 516 | 517 | def get_q_values_s(self, sess, states): 518 | feed_dict = {self.input_states: states} 519 | q_values = sess.run(self.q_values, feed_dict) 520 | return q_values 521 | 522 | def get_v_values_s(self, sess, states): 523 | feed_dict = {self.input_states: states} 524 | v_values = sess.run(self.v_values, feed_dict) 525 | return v_values 526 | 527 | def get_p_logits_s(self, sess, states): 528 | feed_dict = {self.input_states: states} 529 | p_logits = sess.run(self.p_logits, feed_dict) 530 | return p_logits 531 | 532 | def get_p_values_s(self, sess, states): 533 | feed_dict = {self.input_states: states} 534 | p_values = sess.run(self.p_values, feed_dict) 535 | return p_values 536 | 537 | def update_q(self, sess, states, actions, q_targets): 538 | 539 | feed_dict = {self.input_states: states, 540 | self.input_actions: actions, 541 | self.q_targets: q_targets} 542 | sess.run(self.update_q_values, feed_dict) 543 | 544 | def update_v(self, sess, states, v_targets): 545 | 546 | feed_dict = {self.input_states: states, 547 | self.v_targets: v_targets} 548 | sess.run(self.update_v_values, feed_dict) 549 | 550 | def update_p(self, sess, states, actions, p_targets): 551 | 552 | feed_dict = {self.input_states: states, 553 | self.input_actions: actions, 554 | self.p_targets: p_targets} 555 | sess.run(self.update_p_logits, feed_dict) 556 | -------------------------------------------------------------------------------- /train_agents.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import gym\n", 17 | "from utils import *\n", 18 | "from agents import *\n", 19 | "from environments.snake.snake_env import Snake" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Snake Environment" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "### Environment initializtion" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "env = Snake(grid_size=(8, 8))\n", 43 | "num_actions = 3" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "### Agent training" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Create basic agent which consists of two networks: agent and target.\n", 60 | "# Checkpoints of networks' weights and learning curves will be saved\n", 61 | "# in \"save_path/model_name\" folder.\n", 62 | "snake_agent = DQNAgent(env, num_actions, state_shape=[8, 8, 5],\n", 63 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n", 64 | " save_path=\"snake_models\", model_name=\"dqn_8x8\")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Set basic hyper parameters (for full list see \"set_parameters\" method).\n", 74 | "# Create replay buffer and fill it with random transitions.\n", 75 | "snake_agent.set_parameters(max_episode_length=1000, replay_memory_size=100000, replay_start_size=10000,\n", 76 | " discount_factor=0.999, final_eps=0.01, annealing_steps=100000)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "scrolled": true 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "# Set training hyper parameters (for full list see \"train\" method).\n", 88 | "# Set gpu_id = -1 to use cpu instead if gpu, otherwise set it to gpu device id.\n", 89 | "snake_agent.train(gpu_id=-1, exploration=\"boltzmann\", save_freq=500000, max_num_epochs=1000)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "### Other agents" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# Classic deep Q-network\n", 106 | "snake_agent = DQNAgent(env, num_actions, state_shape=[8, 8, 5],\n", 107 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n", 108 | " save_path=\"snake_models\", model_name=\"dqn_8x8\")\n", 109 | "\n", 110 | "# Dueling deep Q-network\n", 111 | "snake_agent = DuelDQNAgent(env, num_actions, state_shape=[8, 8, 5],\n", 112 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[64],\n", 113 | " save_path=\"snake_models\", model_name=\"dueldqn_8x8\")\n", 114 | "\n", 115 | "# Categorical deep Q-network (C51)\n", 116 | "snake_agent = CatDQNAgent(env, num_actions, state_shape=[8, 8, 5],\n", 117 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n", 118 | " v=(-5, 25), num_atoms=51,\n", 119 | " save_path=\"snake_models\", model_name=\"catdqn_8x8\")\n", 120 | "\n", 121 | "# Quantile regression deep Q-network (QR-DQN)\n", 122 | "snake_agent = QuantRegDQNAgent(env, num_actions, state_shape=[8, 8, 5],\n", 123 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n", 124 | " num_atoms=100, kappa=1.0,\n", 125 | " save_path=\"snake_models\", model_name=\"quantdqn_8x8\")\n", 126 | "\n", 127 | "# Soft Actor-Critic\n", 128 | "snake_agent = SACAgent(env, num_actions, state_shape=[8, 8, 5],\n", 129 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n", 130 | " temperature=0.1,\n", 131 | " save_path=\"snake_models\", model_name=\"sac_8x8\")" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "collapsed": true 138 | }, 139 | "source": [ 140 | "# Atari Environment" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "### Environment initializtion" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "game_id = \"PongNoFrameskip-v4\"\n", 157 | "env = wrap_deepmind(gym.make(game_id))\n", 158 | "num_actions = env.unwrapped.action_space.n" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "### Agent training" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],\n", 175 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n", 176 | " save_path=\"atari_models\", model_name=\"dqn_boi\")" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "atari_agent.set_parameters(max_episode_length=100000, discount_factor=0.99, final_eps=0.01,\n", 186 | " replay_memory_size=1000000, replay_start_size=50, annealing_steps=1000000,\n", 187 | " frame_history_len=4)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "atari_agent.train(gpu_id=-1, exploration=\"e-greedy\", save_freq=50000, \n", 197 | " max_num_epochs=1000, performance_print_freq=50)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "### Other agents" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "# Classic deep Q-network\n", 214 | "atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],\n", 215 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n", 216 | " save_path=\"atari_models\", model_name=\"dqn_boi\")\n", 217 | "\n", 218 | "# Dueling deep Q-network\n", 219 | "atari_agent = DuelDQNAgent(env, num_actions, state_shape=[84, 84, 4],\n", 220 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[256],\n", 221 | " save_path=\"atari_models\", model_name=\"dueldqn_boi\")\n", 222 | "\n", 223 | "# Categorical deep Q-network (C51)\n", 224 | "atari_agent = CatDQNAgent(env, num_actions, state_shape=[84, 84, 4],\n", 225 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n", 226 | " v=(-10, 10), num_atoms=51,\n", 227 | " save_path=\"atari_models\", model_name=\"catdqn_boi\")\n", 228 | "\n", 229 | "# Quantile regression deep Q-network (QR-DQN)\n", 230 | "atari_agent = QuantRegDQNAgent(env, num_actions, state_shape=[84, 84, 4],\n", 231 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n", 232 | " num_atoms=200, kappa=1,\n", 233 | " save_path=\"atari_models\", model_name=\"quantdqn_boi\")" 234 | ] 235 | } 236 | ], 237 | "metadata": { 238 | "kernelspec": { 239 | "display_name": "Python [default]", 240 | "language": "python", 241 | "name": "python3" 242 | }, 243 | "language_info": { 244 | "codemirror_mode": { 245 | "name": "ipython", 246 | "version": 3 247 | }, 248 | "file_extension": ".py", 249 | "mimetype": "text/x-python", 250 | "name": "python", 251 | "nbconvert_exporter": "python", 252 | "pygments_lexer": "ipython3", 253 | "version": "3.6.5" 254 | } 255 | }, 256 | "nbformat": 4, 257 | "nbformat_minor": 2 258 | } 259 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is copied/apdated from https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3 3 | """ 4 | 5 | import numpy as np 6 | import random 7 | from collections import namedtuple, deque 8 | import gym 9 | from gym import spaces 10 | from PIL import Image 11 | 12 | #################################################################################################### 13 | ######################################## Experience Replay ######################################### 14 | #################################################################################################### 15 | 16 | def sample_n_unique(sampling_f, n): 17 | """Helper function. Given a function `sampling_f` that returns 18 | comparable objects, sample n such unique objects. 19 | """ 20 | res = [] 21 | while len(res) < n: 22 | candidate = sampling_f() 23 | if candidate not in res: 24 | res.append(candidate) 25 | return res 26 | 27 | class ReplayBuffer(object): 28 | def __init__(self, size, frame_history_len): 29 | """This is a memory efficient implementation of the replay buffer. 30 | The sepecific memory optimizations use here are: 31 | - only store each frame once rather than k times 32 | even if every observation normally consists of k last frames 33 | - store frames as np.uint8 (actually it is most time-performance 34 | to cast them back to float32 on GPU to minimize memory transfer 35 | time) 36 | - store frame_t and frame_(t+1) in the same buffer. 37 | For the tipical use case in Atari Deep RL buffer with 1M frames the total 38 | memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes 39 | Warning! Assumes that returning frame of zeros at the beginning 40 | of the episode, when there is less frames than `frame_history_len`, 41 | is acceptable. 42 | Parameters 43 | ---------- 44 | size: int 45 | Max number of transitions to store in the buffer. When the buffer 46 | overflows the old memories are dropped. 47 | frame_history_len: int 48 | Number of memories to be retried for each observation. 49 | """ 50 | self.size = size 51 | self.frame_history_len = frame_history_len 52 | 53 | self.next_idx = 0 54 | self.num_in_buffer = 0 55 | 56 | self.obs = None 57 | self.action = None 58 | self.reward = None 59 | self.done = None 60 | 61 | self.transition = namedtuple('Transition', ('s', 'a', 'r', 's_', 'done')) 62 | 63 | def can_sample(self, batch_size): 64 | """Returns true if `batch_size` different transitions can be sampled from the buffer.""" 65 | return batch_size + 1 <= self.num_in_buffer 66 | 67 | def _encode_sample(self, idxes): 68 | obs_batch = np.concatenate([self._encode_observation(idx)[None] for idx in idxes], 0) 69 | act_batch = self.action[idxes] 70 | rew_batch = self.reward[idxes] 71 | next_obs_batch = np.concatenate([self._encode_observation(idx + 1)[None] for idx in idxes], 0) 72 | done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32) 73 | 74 | return self.transition(obs_batch, act_batch, rew_batch, next_obs_batch, done_mask) 75 | 76 | 77 | def sample(self, batch_size): 78 | """Sample `batch_size` different transitions. 79 | i-th sample transition is the following: 80 | when observing `obs_batch[i]`, action `act_batch[i]` was taken, 81 | after which reward `rew_batch[i]` was received and subsequent 82 | observation next_obs_batch[i] was observed, unless the epsiode 83 | was done which is represented by `done_mask[i]` which is equal 84 | to 1 if episode has ended as a result of that action. 85 | Parameters 86 | ---------- 87 | batch_size: int 88 | How many transitions to sample. 89 | Returns 90 | ------- 91 | obs_batch: np.array 92 | Array of shape 93 | (batch_size, img_h, img_w, img_c * frame_history_len) 94 | and dtype np.uint8 95 | act_batch: np.array 96 | Array of shape (batch_size,) and dtype np.int32 97 | rew_batch: np.array 98 | Array of shape (batch_size,) and dtype np.float32 99 | next_obs_batch: np.array 100 | Array of shape 101 | (batch_size, img_h, img_w, img_c * frame_history_len) 102 | and dtype np.uint8 103 | done_mask: np.array 104 | Array of shape (batch_size,) and dtype np.float32 105 | """ 106 | assert self.can_sample(batch_size) 107 | idxes = sample_n_unique(lambda: random.randint(0, self.num_in_buffer - 2), batch_size) 108 | return self._encode_sample(idxes) 109 | 110 | def encode_recent_observation(self): 111 | """Return the most recent `frame_history_len` frames. 112 | Returns 113 | ------- 114 | observation: np.array 115 | Array of shape (img_h, img_w, img_c * frame_history_len) 116 | and dtype np.uint8, where observation[:, :, i*img_c:(i+1)*img_c] 117 | encodes frame at time `t - frame_history_len + i` 118 | """ 119 | assert self.num_in_buffer > 0 120 | return self._encode_observation((self.next_idx - 1) % self.size) 121 | 122 | def _encode_observation(self, idx): 123 | end_idx = idx + 1 # make noninclusive 124 | start_idx = end_idx - self.frame_history_len 125 | # this checks if we are using low-dimensional observations, such as RAM 126 | # state, in which case we just directly return the latest RAM. 127 | if len(self.obs.shape) == 2: 128 | return self.obs[end_idx-1] 129 | # if there weren't enough frames ever in the buffer for context 130 | if start_idx < 0 and self.num_in_buffer != self.size: 131 | start_idx = 0 132 | for idx in range(start_idx, end_idx - 1): 133 | if self.done[idx % self.size]: 134 | start_idx = idx + 1 135 | missing_context = self.frame_history_len - (end_idx - start_idx) 136 | # if zero padding is needed for missing context 137 | # or we are on the boundry of the buffer 138 | if start_idx < 0 or missing_context > 0: 139 | frames = [np.zeros_like(self.obs[0]) for _ in range(missing_context)] 140 | for idx in range(start_idx, end_idx): 141 | frames.append(self.obs[idx % self.size]) 142 | return np.concatenate(frames, 2) 143 | else: 144 | # this optimization has potential to saves about 30% compute time \o/ 145 | img_h, img_w = self.obs.shape[1], self.obs.shape[2] 146 | return self.obs[start_idx:end_idx].transpose(1, 2, 0, 3).reshape(img_h, img_w, -1) 147 | 148 | def store_frame(self, frame): 149 | """Store a single frame in the buffer at the next available index, overwriting 150 | old frames if necessary. 151 | Parameters 152 | ---------- 153 | frame: np.array 154 | Array of shape (img_h, img_w, img_c) and dtype np.uint8 155 | the frame to be stored 156 | Returns 157 | ------- 158 | idx: int 159 | Index at which the frame is stored. To be used for `store_effect` later. 160 | """ 161 | if self.obs is None: 162 | self.obs = np.empty([self.size] + list(frame.shape), dtype=np.uint8) 163 | self.action = np.empty([self.size], dtype=np.int32) 164 | self.reward = np.empty([self.size], dtype=np.float32) 165 | self.done = np.empty([self.size], dtype=np.bool) 166 | self.obs[self.next_idx] = frame 167 | 168 | ret = self.next_idx 169 | self.next_idx = (self.next_idx + 1) % self.size 170 | self.num_in_buffer = min(self.size, self.num_in_buffer + 1) 171 | 172 | return ret 173 | 174 | def store_effect(self, idx, action, reward, done): 175 | """Store effects of action taken after obeserving frame stored 176 | at index idx. The reason `store_frame` and `store_effect` is broken 177 | up into two functions is so that once can call `encode_recent_observation` 178 | in between. 179 | Paramters 180 | --------- 181 | idx: int 182 | Index in buffer of recently observed frame (returned by `store_frame`). 183 | action: int 184 | Action that was performed upon observing this frame. 185 | reward: float 186 | Reward that was received when the actions was performed. 187 | done: bool 188 | True if episode was finished after performing that action. 189 | """ 190 | self.action[idx] = action 191 | self.reward[idx] = reward 192 | self.done[idx] = done 193 | 194 | #################################################################################################### 195 | ######################################## DM Atari Wrappers ######################################### 196 | #################################################################################################### 197 | 198 | class NoopResetEnv(gym.Wrapper): 199 | def __init__(self, env=None, noop_max=30): 200 | """Sample initial states by taking random number of no-ops on reset. 201 | No-op is assumed to be action 0. 202 | """ 203 | super(NoopResetEnv, self).__init__(env) 204 | self.noop_max = noop_max 205 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 206 | 207 | def _reset(self): 208 | """ Do no-op action for a number of steps in [1, noop_max].""" 209 | self.env.reset() 210 | noops = np.random.randint(1, self.noop_max + 1) 211 | for _ in range(noops): 212 | obs, _, _, _ = self.env.step(0) 213 | return obs 214 | 215 | class FireResetEnv(gym.Wrapper): 216 | def __init__(self, env=None): 217 | """Take action on reset for environments that are fixed until firing.""" 218 | super(FireResetEnv, self).__init__(env) 219 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 220 | assert len(env.unwrapped.get_action_meanings()) >= 3 221 | 222 | def _reset(self): 223 | self.env.reset() 224 | obs, _, _, _ = self.env.step(1) 225 | obs, _, _, _ = self.env.step(2) 226 | return obs 227 | 228 | class EpisodicLifeEnv(gym.Wrapper): 229 | def __init__(self, env=None): 230 | """Make end-of-life == end-of-episode, but only reset on true game over. 231 | Done by DeepMind for the DQN and co. since it helps value estimation. 232 | """ 233 | super(EpisodicLifeEnv, self).__init__(env) 234 | self.lives = 0 235 | self.was_real_done = True 236 | self.was_real_reset = False 237 | 238 | def _step(self, action): 239 | obs, reward, done, info = self.env.step(action) 240 | self.was_real_done = done 241 | # check current lives, make loss of life terminal, 242 | # then update lives to handle bonus lives 243 | lives = self.env.unwrapped.ale.lives() 244 | if lives < self.lives and lives > 0: 245 | # for Qbert somtimes we stay in lives == 0 condtion for a few frames 246 | # so its important to keep lives > 0, so that we only reset once 247 | # the environment advertises done. 248 | done = True 249 | self.lives = lives 250 | return obs, reward, done, info 251 | 252 | def _reset(self): 253 | """Reset only when lives are exhausted. 254 | This way all states are still reachable even though lives are episodic, 255 | and the learner need not know about any of this behind-the-scenes. 256 | """ 257 | if self.was_real_done: 258 | obs = self.env.reset() 259 | self.was_real_reset = True 260 | else: 261 | # no-op step to advance from terminal/lost life state 262 | obs, _, _, _ = self.env.step(0) 263 | self.was_real_reset = False 264 | self.lives = self.env.unwrapped.ale.lives() 265 | return obs 266 | 267 | class MaxAndSkipEnv(gym.Wrapper): 268 | def __init__(self, env=None, skip=4): 269 | """Return only every `skip`-th frame""" 270 | super(MaxAndSkipEnv, self).__init__(env) 271 | # most recent raw observations (for max pooling across time steps) 272 | self._obs_buffer = deque(maxlen=2) 273 | self._skip = skip 274 | 275 | def _step(self, action): 276 | total_reward = 0.0 277 | done = None 278 | for _ in range(self._skip): 279 | obs, reward, done, info = self.env.step(action) 280 | self._obs_buffer.append(obs) 281 | total_reward += reward 282 | if done: 283 | break 284 | 285 | max_frame = np.max(np.stack(self._obs_buffer), axis=0) 286 | 287 | return max_frame, total_reward, done, info 288 | 289 | def _reset(self): 290 | """Clear past frame buffer and init. to first obs. from inner env.""" 291 | self._obs_buffer.clear() 292 | obs = self.env.reset() 293 | self._obs_buffer.append(obs) 294 | return obs 295 | 296 | def _process_frame84(frame): 297 | img = np.reshape(frame, [210, 160, 3]).astype(np.float32) 298 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 299 | img = Image.fromarray(img) 300 | resized_screen = img.resize((84, 110), Image.BILINEAR) 301 | resized_screen = np.array(resized_screen) 302 | x_t = resized_screen[18:102, :] 303 | x_t = np.reshape(x_t, [84, 84, 1]) 304 | return x_t.astype(np.uint8) 305 | 306 | class ProcessFrame84(gym.Wrapper): 307 | def __init__(self, env=None): 308 | super(ProcessFrame84, self).__init__(env) 309 | self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) 310 | 311 | def _step(self, action): 312 | obs, reward, done, info = self.env.step(action) 313 | return _process_frame84(obs), reward, done, info 314 | 315 | def _reset(self): 316 | return _process_frame84(self.env.reset()) 317 | 318 | class ClippedRewardsWrapper(gym.Wrapper): 319 | def _step(self, action): 320 | obs, reward, done, info = self.env.step(action) 321 | return obs, np.sign(reward), done, info 322 | 323 | def wrap_deepmind_ram(env): 324 | env = EpisodicLifeEnv(env) 325 | env = NoopResetEnv(env, noop_max=30) 326 | env = MaxAndSkipEnv(env, skip=4) 327 | if 'FIRE' in env.unwrapped.get_action_meanings(): 328 | env = FireResetEnv(env) 329 | env = ClippedRewardsWrapper(env) 330 | return env 331 | 332 | def wrap_deepmind(env): 333 | assert 'NoFrameskip' in env.spec.id 334 | env = EpisodicLifeEnv(env) 335 | env = NoopResetEnv(env, noop_max=30) 336 | env = MaxAndSkipEnv(env, skip=4) 337 | if 'FIRE' in env.unwrapped.get_action_meanings(): 338 | env = FireResetEnv(env) 339 | env = ProcessFrame84(env) 340 | env = ClippedRewardsWrapper(env) 341 | return env --------------------------------------------------------------------------------