├── README.md
├── agents.py
├── environments
├── doom
│ ├── README.md
│ ├── config
│ │ ├── basic.cfg
│ │ ├── basic.wad
│ │ ├── deadly_corridor.cfg
│ │ ├── deadly_corridor.wad
│ │ ├── defend_the_center.cfg
│ │ ├── defend_the_center.wad
│ │ ├── defend_the_line.cfg
│ │ ├── defend_the_line.wad
│ │ ├── health_gathering.cfg
│ │ ├── health_gathering.wad
│ │ ├── predict_position.cfg
│ │ └── predict_position.wad
│ ├── doom_env.py
│ └── img
│ │ ├── basic.png
│ │ ├── corridor.png
│ │ ├── def_center.png
│ │ ├── def_line.png
│ │ ├── health.png
│ │ └── predict.png
├── snake
│ └── snake_env.py
└── windy_grid_world
│ ├── evil_wgw_env.py
│ └── wgw_env.py
├── img
├── dqn_categorical.png
├── dqn_classic.png
├── dqn_dueling.png
├── dqn_quantile.png
├── sac_p_network.png
├── sac_q_network.png
└── sac_v_network.png
├── methods.py
├── train_agents.ipynb
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # rl_algorithms
2 | Implementations of different off-policy reinforcement learning algorithms.
3 |
4 | # Framework
5 |
6 | 1. Module [methods.py](methods.py) contains [TensorFlow](https://www.tensorflow.org) implementations of various neural network architectures used in value-based deep reinforcement learning.
7 |
8 | 2. Module [agents.py](agents.py) contains general **Agent** class and various wrappers around it which represent corresponding deep RL algorithms.
9 |
10 | 3. Module [utils.py](utils.py) contains **Replay Buffer** implementation together with a wrapper around **OpenAI gym Atari 2600** environment necessary for reproducing original DeepMind results.
11 |
12 | 4. Jupyter notebook [train_agents.ipynb](train_agents.ipynb) contains examples of how to use the proposed framework to train deep RL agents on various environments.
13 |
14 | # Available algorithms
15 |
16 | - Deep Q-Network [Volodymyr Mnih et al. "Human-level control through deep reinforcement learning." Nature (2015)](https://pra.open.tips/storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
17 |
18 |
19 |
20 |
21 | - Dueling Deep Q-Network [Ziyu Wang et al. "Dueling network architectures for deep reinforcement learning." ICML (2016).](https://arxiv.org/pdf/1511.06581.pdf)
22 |
23 |
24 |
25 |
26 | - Categorical Deep Q-Network [Marc G. Bellemare, Will Dabney, and Rémi Munos. "A distributional perspective on reinforcement learning." ICML (2017).](https://arxiv.org/pdf/1707.06887)
27 |
28 |
29 |
30 |
31 | - Quantile Regression Deep Q-Network [Will Dabney, Mark Rowland, Marc G. Bellemare, and Rémi Munos. "Distributional Reinforcement Learning with Quantile Regression." AAAI (2018).](https://arxiv.org/pdf/1710.10044)
32 |
33 |
34 |
35 |
36 | - Soft Actor-Critic [Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, and Sergey Levine. "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor." ICML (2018).](https://arxiv.org/pdf/1801.01290)
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | **Note.** Images of different neural network architectures are based on the images from the [Dueling architectures](https://arxiv.org/pdf/1511.06581.pdf) paper. The original images were copied and adapted to reflect features of particular architectures and learning algorithms.
45 |
--------------------------------------------------------------------------------
/agents.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | from IPython import display
5 | import matplotlib.pyplot as plt
6 |
7 | from methods import *
8 | from utils import *
9 |
10 |
11 | def softmax(x):
12 | e_x = np.exp(x - np.max(x))
13 | return e_x / e_x.sum()
14 |
15 |
16 | class Agent:
17 |
18 | def __init__(self, env, num_actions, state_shape=[8, 8, 5],
19 | save_path="rl_models", model_name="agent"):
20 |
21 | self.train_env = env
22 | self.num_actions = num_actions
23 |
24 | self.path = save_path + "/" + model_name
25 | if not os.path.exists(self.path):
26 | os.makedirs(self.path)
27 |
28 | def init_weights(self):
29 |
30 | global_vars = tf.global_variables(
31 | scope="agent") + tf.global_variables(scope="target")
32 | self.init = tf.variables_initializer(global_vars)
33 | self.saver = tf.train.Saver()
34 |
35 | self.agent_vars = tf.trainable_variables(scope="agent")
36 | self.target_vars = tf.trainable_variables(scope="target")
37 |
38 | def set_parameters(self,
39 | replay_memory_size=50000,
40 | replay_start_size=10000,
41 | init_eps=1,
42 | final_eps=0.02,
43 | annealing_steps=100000,
44 | discount_factor=0.99,
45 | max_episode_length=2000,
46 | frame_history_len=1):
47 |
48 | # create experience replay and fill it with random policy samples
49 | self.rep_buffer = ReplayBuffer(size=replay_memory_size,
50 | frame_history_len=frame_history_len)
51 | frame_count = 0
52 | while (frame_count < replay_start_size):
53 | last_obs = self.train_env.reset()
54 | for time_step in range(max_episode_length):
55 |
56 | last_idx = self.rep_buffer.store_frame(last_obs)
57 | recent_obs = self.rep_buffer.encode_recent_observation()
58 | action = np.random.randint(self.num_actions)
59 | obs, reward, done = self.train_env.step(action)[:3]
60 | self.rep_buffer.store_effect(last_idx, action, reward, done)
61 |
62 | frame_count += 1
63 |
64 | if done:
65 | break
66 | last_obs = obs
67 |
68 | # define epsilon decrement schedule for exploration
69 | self.eps = init_eps
70 | self.final_eps = final_eps
71 | self.eps_drop = (init_eps - final_eps) / annealing_steps
72 |
73 | self.gamma = discount_factor
74 | self.max_ep_length = max_episode_length
75 |
76 | def train(self,
77 | gpu_id=0,
78 | batch_size=32,
79 | exploration="e-greedy",
80 | agent_update_freq=4,
81 | target_update_freq=5000,
82 | tau=1,
83 | max_num_epochs=50000,
84 | performance_print_freq=500,
85 | save_freq=10000,
86 | from_epoch=0):
87 |
88 | config = self.gpu_config(gpu_id)
89 | target_ops = self.update_target_graph(tau)
90 | self.batch_size = batch_size
91 |
92 | with tf.Session(config=config) as sess:
93 |
94 | if from_epoch == 0:
95 | sess.run(self.init)
96 | train_rewards = []
97 | frame_counts = []
98 | frame_count = 0
99 | num_epochs = 0
100 |
101 | else:
102 | self.saver.restore(sess, self.path+"/model-"+str(from_epoch))
103 | train_rewards = list(
104 | np.load(self.path+"/learning_curve.npz")["r"])
105 | frame_counts = list(
106 | np.load(self.path+"/learning_curve.npz")["f"])
107 | frame_count = frame_counts[-1]
108 | num_epochs = from_epoch
109 |
110 | episode_count = 0
111 | ep_lifetimes = []
112 |
113 | while num_epochs < max_num_epochs:
114 |
115 | train_ep_reward = 0
116 |
117 | # reset the environment / start new game
118 | last_obs = self.train_env.reset()
119 | for time_step in range(self.max_ep_length):
120 |
121 | last_idx = self.rep_buffer.store_frame(last_obs)
122 | recent_obs = self.rep_buffer.encode_recent_observation()
123 |
124 | # choose action e-greedily
125 | action = self.choose_action(sess, recent_obs, exploration)
126 |
127 | # make step in the environment
128 | obs, reward, done = self.train_env.step(action)[:3]
129 |
130 | # save transition into experience replay
131 | self.rep_buffer.store_effect(
132 | last_idx, action, reward, done)
133 |
134 | # update current state and statistics
135 | frame_count += 1
136 | train_ep_reward += reward
137 |
138 | # reduce epsilon according to schedule
139 | if self.eps > self.final_eps:
140 | self.eps -= self.eps_drop
141 |
142 | # update network weights
143 | if frame_count % agent_update_freq == 0:
144 |
145 | batch = self.rep_buffer.sample(batch_size)
146 | self.update_agent_weights(sess, batch)
147 |
148 | # update target network
149 | if tau == 1:
150 | if frame_count % target_update_freq == 0:
151 | self.update_target_weights(sess, target_ops)
152 | else:
153 | self.update_target_weights(sess, target_ops)
154 |
155 | # save network wieghts checkpoint and learning curve
156 | if frame_count % save_freq == 1:
157 | num_epochs += 1
158 | try:
159 | self.saver.save(
160 | sess,
161 | self.path+"/model",
162 | global_step=num_epochs)
163 | np.savez(
164 | self.path+"/learning_curve.npz",
165 | r=train_rewards,
166 | f=frame_counts,
167 | l=ep_lifetimes)
168 |
169 | # if game is over, reset the environment
170 | if done:
171 | break
172 | last_obs = obs
173 |
174 | episode_count += 1
175 | train_rewards.append(train_ep_reward)
176 | frame_counts.append(frame_count)
177 | ep_lifetimes.append(time_step+1)
178 |
179 | # print performance once in a while
180 | if episode_count % performance_print_freq == 0:
181 | avg_reward = np.mean(
182 | train_rewards[-performance_print_freq:])
183 | avg_lifetime = np.mean(
184 | ep_lifetimes[-performance_print_freq:])
185 | print("frame count:", frame_count)
186 | print("average reward:", avg_reward)
187 | print("epsilon:", round(self.eps, 3))
188 | print("average lifetime:", avg_lifetime)
189 | print("-------------------------------")
190 |
191 | def choose_action(self, sess, s, exploration="e-greedy"):
192 |
193 | if (exploration == "greedy"):
194 | a = self.agent_net.get_q_argmax(sess, [s])
195 | elif (exploration == "e-greedy"):
196 | if np.random.rand(1) < self.eps:
197 | a = np.random.randint(self.num_actions)
198 | else:
199 | a = self.agent_net.get_q_argmax(sess, [s])
200 | elif (exploration == "boltzmann"):
201 | q_values = self.agent_net.get_q_values_s(sess, [s])
202 | logits = q_values / self.eps
203 | probs = softmax(logits).ravel()
204 | a = np.random.choice(self.num_actions, p=probs)
205 | elif (exploration == "policy"):
206 | probs = self.agent_net.get_p_values_s(sess, [s]).ravel()
207 | a = np.random.choice(self.num_actions, p=probs)
208 | else:
209 | return 0
210 | return a
211 |
212 | def update_agent_weights(self, sess, batch):
213 |
214 | # estimate the right hand side of the Bellman equation
215 | agent_actions = self.agent_net.get_q_argmax(sess, batch.s_)
216 | q_double = self.target_net.get_q_values_sa(
217 | sess, batch.s_, agent_actions)
218 | targets = batch.r + (self.gamma * q_double * (1 - batch.done))
219 |
220 | # update agent network
221 | self.agent_net.update(sess, batch.s, batch.a, targets)
222 |
223 | def update_target_graph(self, tau):
224 | op_holder = []
225 | for agnt, trgt in zip(self.agent_vars, self.target_vars):
226 | op = trgt.assign(agnt.value()*tau + (1 - tau)*trgt.value())
227 | op_holder.append(op)
228 | return op_holder
229 |
230 | def update_target_weights(self, sess, op_holder):
231 | for op in op_holder:
232 | sess.run(op)
233 |
234 | def play(self,
235 | gpu_id=0,
236 | max_episode_length=2000,
237 | from_epoch=0):
238 |
239 | config = self.gpu_config(gpu_id)
240 | with tf.Session(config=config) as sess:
241 | self.saver.restore(sess, self.path+"/model-"+str(from_epoch))
242 | s = self.train_env.reset()
243 | R = 0
244 | for time_step in range(max_episode_length):
245 | a = self.agent_net.get_q_argmax(sess, [s])[0]
246 | s, r, done = self.train_env.step(a)
247 | R += r
248 | self.train_env.plot_state()
249 | display.clear_output(wait=True)
250 | display.display(plt.gcf())
251 | if done:
252 | break
253 | return R
254 |
255 | def gpu_config(self, gpu_id):
256 | if (gpu_id == -1):
257 | config = tf.ConfigProto()
258 | else:
259 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
260 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
261 | config = tf.ConfigProto()
262 | config.gpu_options.allow_growth = True
263 | config.intra_op_parallelism_threads = 1
264 | config.inter_op_parallelism_threads = 1
265 | return config
266 |
267 | ############################ Deep Q-Network agent #############################
268 |
269 |
270 | class DQNAgent(Agent):
271 |
272 | def __init__(
273 | self, env, num_actions, state_shape=[8, 8, 5],
274 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
275 | activation_fn=tf.nn.relu,
276 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
277 | gradient_clip=10.0,
278 | save_path="rl_models", model_name="DQN"):
279 |
280 | super(DQNAgent, self).__init__(
281 | env, num_actions, state_shape=state_shape,
282 | save_path=save_path, model_name=model_name)
283 |
284 | tf.reset_default_graph()
285 | self.agent_net = DeepQNetwork(
286 | self.num_actions, state_shape=state_shape,
287 | convs=convs, fully_connected=fully_connected,
288 | activation_fn=activation_fn, optimizer=optimizer,
289 | gradient_clip=gradient_clip, scope="agent")
290 | self.target_net = DeepQNetwork(
291 | self.num_actions, state_shape=state_shape,
292 | convs=convs, fully_connected=fully_connected,
293 | activation_fn=activation_fn, optimizer=optimizer,
294 | gradient_clip=gradient_clip, scope="target")
295 | self.init_weights()
296 |
297 | ######################## Dueling Deep Q-Network agent #########################
298 |
299 |
300 | class DuelDQNAgent(Agent):
301 |
302 | def __init__(
303 | self, env, num_actions, state_shape=[8, 8, 5],
304 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[64],
305 | activation_fn=tf.nn.relu,
306 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
307 | gradient_clip=10.0,
308 | save_path="rl_models", model_name="DuelDQN"):
309 |
310 | super(DuelDQNAgent, self).__init__(
311 | env, num_actions, state_shape=state_shape,
312 | save_path=save_path, model_name=model_name)
313 |
314 | tf.reset_default_graph()
315 | self.agent_net = DuelingDeepQNetwork(
316 | self.num_actions, state_shape=state_shape,
317 | convs=convs, fully_connected=fully_connected,
318 | activation_fn=activation_fn, optimizer=optimizer,
319 | gradient_clip=gradient_clip, scope="agent")
320 | self.target_net = DuelingDeepQNetwork(
321 | self.num_actions, state_shape=state_shape,
322 | convs=convs, fully_connected=fully_connected,
323 | activation_fn=activation_fn, optimizer=optimizer,
324 | gradient_clip=gradient_clip, scope="target")
325 | self.init_weights()
326 |
327 | ###################### Categorical Deep Q-Network agent #######################
328 |
329 |
330 | class CatDQNAgent(Agent):
331 |
332 | def __init__(
333 | self, env, num_actions, state_shape=[8, 8, 5],
334 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
335 | activation_fn=tf.nn.relu, num_atoms=21, v=(-10, 10),
336 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
337 | save_path="rl_models", model_name="CatDQN"):
338 |
339 | super(CatDQNAgent, self).__init__(
340 | env, num_actions, state_shape=state_shape,
341 | save_path=save_path, model_name=model_name)
342 |
343 | tf.reset_default_graph()
344 | self.agent_net = CategoricalDeepQNetwork(
345 | self.num_actions, state_shape=state_shape,
346 | convs=convs, fully_connected=fully_connected,
347 | activation_fn=tf.nn.relu, num_atoms=num_atoms,
348 | v=v, optimizer=optimizer, scope="agent")
349 | self.target_net = CategoricalDeepQNetwork(
350 | self.num_actions, state_shape=state_shape,
351 | convs=convs, fully_connected=fully_connected,
352 | activation_fn=tf.nn.relu, num_atoms=num_atoms,
353 | v=v, optimizer=optimizer, scope="target")
354 | self.init_weights()
355 |
356 | def update_agent_weights(self, sess, batch):
357 |
358 | # estimate categorical projection of the RHS of the Bellman equation
359 | agent_actions = self.agent_net.get_q_argmax(sess, batch.s_)
360 | probs_targets = self.target_net.cat_proj(
361 | sess, batch.s_, agent_actions, batch.r,
362 | batch.done, gamma=self.gamma)
363 |
364 | # update agent network
365 | self.agent_net.update(sess, batch.s, batch.a, probs_targets)
366 |
367 | ################## Quantile Regression Deep Q-Network agent ###################
368 |
369 |
370 | class QuantRegDQNAgent(Agent):
371 |
372 | def __init__(
373 | self, env, num_actions, state_shape=[8, 8, 5],
374 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
375 | activation_fn=tf.nn.relu, num_atoms=50, kappa=1.0,
376 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
377 | save_path="rl_models", model_name="QuantRegDQN"):
378 |
379 | super(QuantRegDQNAgent, self).__init__(
380 | env, num_actions, state_shape=state_shape,
381 | save_path=save_path, model_name=model_name)
382 |
383 | tf.reset_default_graph()
384 | self.agent_net = QuantileDeepQNetwork(
385 | self.num_actions, state_shape=state_shape,
386 | convs=convs, fully_connected=fully_connected,
387 | activation_fn=tf.nn.relu, num_atoms=num_atoms,
388 | kappa=kappa, optimizer=optimizer, scope="agent")
389 | self.target_net = QuantileDeepQNetwork(
390 | self.num_actions, state_shape=state_shape,
391 | convs=convs, fully_connected=fully_connected,
392 | activation_fn=tf.nn.relu, num_atoms=num_atoms,
393 | kappa=kappa, optimizer=optimizer, scope="target")
394 | self.init_weights()
395 |
396 | def update_agent_weights(self, sess, batch):
397 |
398 | # calculate target atoms produced by Bellman operator
399 | agent_actions = self.agent_net.get_q_argmax(sess, batch.s_)
400 | next_atoms = self.target_net.get_atoms_sa(
401 | sess, batch.s_, agent_actions)
402 | target_atoms = batch.r[:, None] + \
403 | self.gamma * next_atoms * (1 - batch.done[:, None])
404 | # update agent network
405 | self.agent_net.update(sess, batch.s, batch.a, target_atoms)
406 |
407 | ########################### Soft Actor-Critic agent ###########################
408 |
409 |
410 | class SACAgent(Agent):
411 |
412 | def __init__(
413 | self, env, num_actions, state_shape=[8, 8, 5],
414 | convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
415 | activation_fn=tf.nn.relu,
416 | temperature=1.0,
417 | optimizers=[tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
418 | tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
419 | tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32)],
420 | save_path="rl_models", model_name="SAC"):
421 |
422 | super(SACAgent, self).__init__(
423 | env, num_actions, state_shape=state_shape,
424 | save_path=save_path, model_name=model_name)
425 |
426 | tf.reset_default_graph()
427 | self.agent_net = SoftActorCriticNetwork(
428 | self.num_actions, state_shape=state_shape,
429 | convs=convs, fully_connected=fully_connected,
430 | activation_fn=activation_fn,
431 | optimizers=optimizers, scope="agent")
432 | self.target_net = SoftActorCriticNetwork(
433 | self.num_actions, state_shape=state_shape,
434 | convs=convs, fully_connected=fully_connected,
435 | activation_fn=activation_fn,
436 | optimizers=optimizers, scope="target")
437 | self.init_weights()
438 | self.t = temperature
439 |
440 | def update_agent_weights(self, sess, batch):
441 |
442 | probs = self.agent_net.get_p_values_s(sess, batch.s)
443 | c = probs.cumsum(axis=1)
444 | u = np.random.rand(len(c), 1)
445 | actions = (u < c).argmax(axis=1)
446 |
447 | v_values = self.agent_net.get_v_values_s(sess, batch.s).reshape(-1)
448 | v_values_next = self.target_net.get_v_values_s(
449 | sess, batch.s_).reshape(-1)
450 | q_values = self.agent_net.get_q_values_s(sess, batch.s)
451 | p_logits = self.agent_net.get_p_logits_s(sess, batch.s)
452 |
453 | x = np.arange(self.batch_size)
454 | q_values_selected = q_values[x, actions]
455 | p_logits_selected = p_logits[x, actions]
456 |
457 | q_targets = batch.r / self.t + \
458 | self.gamma * v_values_next * (1 - batch.done)
459 | v_targets = q_values_selected - p_logits_selected
460 | p_targets = q_values_selected - v_values
461 |
462 | # update agent network
463 | self.agent_net.update_q(sess, batch.s, batch.a, q_targets)
464 | self.agent_net.update_v(sess, batch.s, v_targets)
465 | self.agent_net.update_p(sess, batch.s, actions, p_targets)
466 |
--------------------------------------------------------------------------------
/environments/doom/README.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | This is a gym-like wrapper over [VizDoom](http://vizdoom.cs.put.edu.pl) environment for deep reinforcement learning research. To install **vizdoom** python library, please follow the instructions from [VizDoom GitHub repository](https://github.com/mwydmuch/ViZDoom).
4 |
5 | ## How to use
6 |
7 | ```python
8 | from environments.doom.doom_env import DoomBasic
9 |
10 | # create environment instance
11 | env = DoomBasic()
12 |
13 | # reset the environment and get preprocessed 84x84x1 observation
14 | obs = env.reset()
15 |
16 | # make step in the environment
17 | next_obs, reward, done = env.step(action)
18 |
19 | # get original RGB image of current observation
20 | rgb_obs = env.get_obs_rgb()
21 | ```
22 |
23 | # Available scenarios
24 |
25 | Here is the list of available scenarios (copied and adapted from [here](https://github.com/mwydmuch/ViZDoom/blob/master/scenarios/README.md)):
26 |
27 | | |  | 
28 | |:---:|:---:|:---:|
29 | |**Basic**|**Defend the center**|**Defend the line**|
30 | | |  | 
31 | |**Health gathering**|**Predict the position**|**Deadly corridor**|
32 |
33 | ## Basic
34 |
35 | Map is a rectangle with gray walls, ceiling and floor. Player is spawned along the longer wall, in the center. A red, circular monster is spawned randomly somewhere along the opposite wall. Player can only go left/right and shoot. 1 hit is enough to kill the monster. Episode finishes when monster is killed or on timeout.
36 |
37 | - **Actions:** move left, move right, shoot
38 | - **Rewards:** +100 for killing the monster; -6 for missing; -1 otherwise
39 | - **Episode termination:** monster is killed or after 300 time steps
40 |
41 | ## Defend the center
42 |
43 | Map is a large circle. Player is spawned in the exact center. 5 melee-only, monsters are spawned along the wall. Monsters are
44 | killed after a single shot. After dying each monster is respawned after some time. Ammo is limited to 26 bullets. Episode ends when the player dies or on timeout.
45 |
46 | - **Actions:** turn left, turn right, shoot
47 | - **Rewards:** +1 for killing the monster; -0.1 for missing; -0.1 for losing health; -1 for death; 0 otherwise
48 | - **Episode termination:** player is killed or after 2100 time steps
49 |
50 | ## Defend the line
51 |
52 | Map is a rectangle. Player is spawned along the longer wall, in the center. 3 melee-only and 3 shooting monsters are spawned along the oposite wall. Monsters are killed after a single shot, at first. After dying each monster is respawned after some time and can endure more damage. Episode ends when the player dies or on timeout.
53 |
54 | - **Actions:** move left, move right, turn left, turn right, shoot
55 | - **Rewards:** +1 for killing the monster; -0.1 for missing; -0.1 for losing health; -1 for death; 0 otherwise
56 | - **Episode termination:** player is killed or after 2100 time steps
57 |
58 | ## Health gathering
59 |
60 | Map is a rectangle with green, acidic floor which hurts the player periodically. Initially there are some medkits spread uniformly over the map. A new medkit falls from the skies every now and then. Medkits heal some portions of player's health - to survive agent needs to pick them up. Episode finishes after player's death or on timeout.
61 |
62 | - **Actions:** turn left, turn right, move forward
63 | - **Rewards:** -100 for death; +1 otherwise
64 | - **Episode termination:** player is killed or after 2100 time steps
65 |
66 | ## Predict position
67 |
68 | Map is a rectangle room. Player is spawned along the longer wall, in the center. A monster is spawned randomly somewhere along the opposite wall and walks between left and right corners along the wall. Player is equipped with a rocket launcher and a single rocket. Episode ends when missle hits a wall/the monster or on timeout.
69 |
70 | - **Actions:** move left, move right, shoot
71 | - **Rewards:** +1 for killing the monster; -0.0001 otherwise
72 | - **Episode termination:** monster is killed or after 300 time steps
73 |
74 | ## Deadly corridor
75 |
76 | Map is a corridor with shooting monsters on both sides (6 monsters in total). A green vest is placed at the oposite end of the corridor. Reward is proportional (negative or positive) to change of the distance between the player and the vest. If player ignores monsters on the sides and runs straight for the vest he will be killed somewhere along the way.
77 |
78 | - **Actions:** move left, move right, turn left, turn right, shoot
79 | - **Rewards:** -100 for death; +dX (-dx) for getting closer (further) to the vest
80 | - **Episode termination:** player is killed or after 4200 time steps
81 |
--------------------------------------------------------------------------------
/environments/doom/config/basic.cfg:
--------------------------------------------------------------------------------
1 | # Lines starting with # are treated as comments (or with whitespaces+#).
2 | # It doesn't matter if you use capital letters or not.
3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
4 |
5 | doom_scenario_path = basic.wad
6 | doom_map = map01
7 |
8 | # Rewards
9 | living_reward = -1
10 |
11 | # Rendering options
12 | screen_resolution = RES_320X240
13 | screen_format = CRCGCB
14 | render_hud = True
15 | render_crosshair = false
16 | render_weapon = true
17 | render_decals = false
18 | render_particles = false
19 | window_visible = true
20 |
21 | # make episodes start after 20 tics (after unholstering the gun)
22 | episode_start_time = 14
23 |
24 | # make episodes finish after 300 actions (tics)
25 | episode_timeout = 300
26 |
27 | # Available buttons
28 | available_buttons =
29 | {
30 | MOVE_LEFT
31 | MOVE_RIGHT
32 | ATTACK
33 | }
34 |
35 | # Game variables that will be in the state
36 | available_game_variables = { AMMO2 HEALTH }
37 |
38 | mode = PLAYER
39 | doom_skill = 5
40 |
--------------------------------------------------------------------------------
/environments/doom/config/basic.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/basic.wad
--------------------------------------------------------------------------------
/environments/doom/config/deadly_corridor.cfg:
--------------------------------------------------------------------------------
1 | # Lines starting with # are treated as comments (or with whitespaces+#).
2 | # It doesn't matter if you use capital letters or not.
3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
4 |
5 | doom_scenario_path = deadly_corridor.wad
6 |
7 | # Skill 5 is reccomanded for the scenario to be a challenge.
8 | doom_skill = 5
9 |
10 | # Rewards
11 | death_penalty = 100
12 | #living_reward = 0
13 |
14 | # Rendering options
15 | screen_resolution = RES_320X240
16 | screen_format = CRCGCB
17 | render_hud = true
18 | render_crosshair = false
19 | render_weapon = true
20 | render_decals = false
21 | render_particles = false
22 | window_visible = true
23 |
24 | episode_timeout = 2100
25 |
26 | # Available buttons
27 | available_buttons =
28 | {
29 | MOVE_LEFT
30 | MOVE_RIGHT
31 | ATTACK
32 | MOVE_FORWARD
33 | MOVE_BACKWARD
34 | TURN_LEFT
35 | TURN_RIGHT
36 | }
37 |
38 | # Game variables that will be in the state
39 | available_game_variables = { HEALTH }
40 |
41 | mode = PLAYER
42 |
43 |
44 |
--------------------------------------------------------------------------------
/environments/doom/config/deadly_corridor.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/deadly_corridor.wad
--------------------------------------------------------------------------------
/environments/doom/config/defend_the_center.cfg:
--------------------------------------------------------------------------------
1 | # Lines starting with # are treated as comments (or with whitespaces+#).
2 | # It doesn't matter if you use capital letters or not.
3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
4 |
5 | doom_scenario_path = defend_the_center.wad
6 |
7 | # Rewards
8 | death_penalty = 1
9 |
10 | # Rendering options
11 | screen_resolution = RES_640X480
12 | screen_format = CRCGCB
13 | render_hud = True
14 | render_crosshair = false
15 | render_weapon = true
16 | render_decals = false
17 | render_particles = false
18 | window_visible = true
19 |
20 | # make episodes start after 10 tics (after unholstering the gun)
21 | episode_start_time = 10
22 |
23 | # make episodes finish after 2100 actions (tics)
24 | episode_timeout = 2100
25 |
26 | # Available buttons
27 | available_buttons =
28 | {
29 | TURN_LEFT
30 | TURN_RIGHT
31 | ATTACK
32 | }
33 |
34 | # Game variables that will be in the state
35 | available_game_variables = { AMMO2 HEALTH }
36 |
37 | mode = PLAYER
38 | doom_skill = 3
39 |
--------------------------------------------------------------------------------
/environments/doom/config/defend_the_center.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/defend_the_center.wad
--------------------------------------------------------------------------------
/environments/doom/config/defend_the_line.cfg:
--------------------------------------------------------------------------------
1 | # Lines starting with # are treated as comments (or with whitespaces+#).
2 | # It doesn't matter if you use capital letters or not.
3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
4 |
5 | doom_scenario_path = defend_the_line.wad
6 |
7 | # Rewards
8 | death_penalty = 1
9 |
10 | # Rendering options
11 | screen_resolution = RES_320X240
12 | screen_format = CRCGCB
13 | render_hud = True
14 | render_crosshair = false
15 | render_weapon = true
16 | render_decals = false
17 | render_particles = false
18 | window_visible = true
19 |
20 | # make episodes start after 10 tics (after unholstering the gun)
21 | episode_start_time = 10
22 |
23 | # make episodes finish after 2100 actions (tics)
24 | episode_timeout = 2100
25 |
26 | # Available buttons
27 | available_buttons =
28 | {
29 | MOVE_LEFT
30 | MOVE_RIGHT
31 | TURN_LEFT
32 | TURN_RIGHT
33 | ATTACK
34 | }
35 |
36 | # Game variables that will be in the state
37 | available_game_variables = { AMMO2 HEALTH}
38 |
39 | mode = PLAYER
40 | doom_skill = 3
41 |
--------------------------------------------------------------------------------
/environments/doom/config/defend_the_line.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/defend_the_line.wad
--------------------------------------------------------------------------------
/environments/doom/config/health_gathering.cfg:
--------------------------------------------------------------------------------
1 | # Lines starting with # are treated as comments (or with whitespaces+#).
2 | # It doesn't matter if you use capital letters or not.
3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
4 |
5 | doom_scenario_path = health_gathering.wad
6 |
7 | # Each step is good for you!
8 | living_reward = 1
9 | # And death is not!
10 | death_penalty = 100
11 |
12 | # Rendering options
13 | screen_resolution = RES_320X240
14 | screen_format = CRCGCB
15 | render_hud = false
16 | render_crosshair = false
17 | render_weapon = false
18 | render_decals = false
19 | render_particles = false
20 | window_visible = true
21 |
22 | # make episodes finish after 2100 actions (tics)
23 | episode_timeout = 2100
24 |
25 | # Available buttons
26 | available_buttons =
27 | {
28 | TURN_LEFT
29 | TURN_RIGHT
30 | MOVE_FORWARD
31 | }
32 |
33 | # Game variables that will be in the state
34 | available_game_variables = { }
35 |
36 | mode = PLAYER
--------------------------------------------------------------------------------
/environments/doom/config/health_gathering.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/health_gathering.wad
--------------------------------------------------------------------------------
/environments/doom/config/predict_position.cfg:
--------------------------------------------------------------------------------
1 | # Lines starting with # are treated as comments (or with whitespaces+#).
2 | # It doesn't matter if you use capital letters or not.
3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
4 |
5 | doom_scenario_path = predict_position.wad
6 |
7 | # Rewards
8 | living_reward = -0.001
9 |
10 | # Rendering options
11 | screen_resolution = RES_800X450
12 | screen_format = CRCGCB
13 | render_hud = false
14 | render_crosshair = false
15 | render_weapon = true
16 | render_decals = false
17 | render_particles = false
18 | window_visible = true
19 |
20 | # make episodes start after 16 tics (after producing the rocket launcher)
21 | episode_start_time = 16
22 |
23 | # make episodes finish after 300 actions (tics)
24 | episode_timeout = 300
25 |
26 | # Available buttons
27 | available_buttons =
28 | {
29 | TURN_LEFT
30 | TURN_RIGHT
31 | ATTACK
32 | }
33 |
34 | # Empty list is allowed, in case you are lazy.
35 | available_game_variables = { HEALTH }
36 |
37 | game_args += +sv_noautoaim 1
38 |
39 | mode = PLAYER
40 | doom_skill = 1
41 |
--------------------------------------------------------------------------------
/environments/doom/config/predict_position.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/config/predict_position.wad
--------------------------------------------------------------------------------
/environments/doom/doom_env.py:
--------------------------------------------------------------------------------
1 | from vizdoom import DoomGame
2 | from PIL import Image
3 | import numpy as np
4 |
5 | from IPython import display
6 | import matplotlib.pyplot as plt
7 |
8 | ########################## Doom environment template class ##########################
9 |
10 | class DoomEnvironment:
11 |
12 | def __init__(self, scenario, path_to_config="doom/config"):
13 | self.game = DoomGame()
14 | self.game.load_config(path_to_config+"/"+scenario+".cfg")
15 | self.game.set_doom_scenario_path(path_to_config+"/"+scenario+".wad")
16 | self.game.set_window_visible(False)
17 | self.game.init()
18 | self.num_actions = len(self.game.get_available_buttons())
19 |
20 | def reset(self):
21 | self.game.new_episode()
22 | game_state = self.game.get_state()
23 | obs = game_state.screen_buffer
24 | self.h, self.w = obs.shape[1:3]
25 | self.current_obs = self.preprocess_obs(obs)
26 | if self.game.get_available_game_variables_size() == 2:
27 | self.ammo, self.health = game_state.game_variables
28 | return self.get_obs()
29 |
30 | def get_obs(self):
31 | return self.current_obs[:, :, None]
32 |
33 | def get_obs_rgb(self):
34 | img = self.game.get_state().screen_buffer
35 | img = np.rollaxis(img, 0, 3)
36 | img = np.reshape(img, [self.h, self.w, 3])
37 | return img.astype(np.uint8)
38 |
39 | def preprocess_obs(self, obs):
40 | img = np.rollaxis(obs, 0, 3)
41 | img = np.reshape(img, [self.h, self.w, 3]).astype(np.float32)
42 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
43 | img = Image.fromarray(img)
44 | img = img.resize((84, 84), Image.BILINEAR)
45 | img = np.array(img)
46 | return img.astype(np.uint8)
47 |
48 | def action_to_doom(self, a):
49 | action = [0 for i in range(self.num_actions)]
50 | action[int(a)] = 1
51 | return action
52 |
53 | def step(self, a):
54 | action = self.action_to_doom(a)
55 | reward = self.game.make_action(action)
56 |
57 | done = self.game.is_episode_finished()
58 |
59 | if done:
60 | new_obs = np.zeros_like(self.current_obs, dtype=np.uint8)
61 | else:
62 | game_state = self.game.get_state()
63 | new_obs = game_state.screen_buffer
64 | new_obs = self.preprocess_obs(new_obs)
65 |
66 | self.current_obs = new_obs
67 |
68 | return self.get_obs(), reward, done
69 |
70 | def watch_random_play(self, max_ep_length=1000, frame_skip=4):
71 | self.reset()
72 | for i in range(max_ep_length):
73 | a = np.random.randint(self.num_actions)
74 | obs, reward, done = self.step(a)
75 | if done: break
76 |
77 | img = self.get_obs_rgb()
78 | if i % frame_skip == 0:
79 | plt.imshow(img)
80 | display.clear_output(wait=True)
81 | display.display(plt.gcf())
82 |
83 | ####################################### Basic #######################################
84 |
85 | class DoomBasic(DoomEnvironment):
86 |
87 | def __init__(self, path_to_config="doom/config"):
88 | super(DoomBasic, self).__init__(scenario="basic",
89 | path_to_config=path_to_config)
90 |
91 | ################################## Defend the line ##################################
92 |
93 | class DoomDefendTheLine(DoomEnvironment):
94 |
95 | def __init__(self, path_to_config="doom/config"):
96 | super(DoomDefendTheLine, self).__init__(scenario="defend_the_line",
97 | path_to_config=path_to_config)
98 |
99 | def step(self, a):
100 | action = self.action_to_doom(a)
101 | reward = self.game.make_action(action)
102 |
103 | done = self.game.is_episode_finished()
104 |
105 | if done:
106 | new_obs = np.zeros_like(self.current_obs, dtype=np.uint8)
107 | else:
108 | game_state = self.game.get_state()
109 | new_obs = game_state.screen_buffer
110 | new_obs = self.preprocess_obs(new_obs)
111 | new_ammo, new_health = game_state.game_variables
112 |
113 | if (reward == 1.0): reward += 0.1
114 | if (new_ammo < self.ammo): reward -= 0.1
115 | if (new_health < self.health): reward -= 0.1
116 |
117 | self.ammo, self.health = new_ammo, new_health
118 |
119 | self.current_obs = new_obs
120 |
121 | return self.get_obs(), reward, done
122 |
123 | ################################# Defend the center #################################
124 |
125 | class DoomDefendTheCenter(DoomEnvironment):
126 |
127 | def __init__(self, path_to_config="doom/config"):
128 | super(DoomDefendTheCenter, self).__init__(scenario="defend_the_center",
129 | path_to_config=path_to_config)
130 |
131 | def step(self, a):
132 | action = self.action_to_doom(a)
133 | reward = self.game.make_action(action)
134 |
135 | done = self.game.is_episode_finished()
136 |
137 | if done:
138 | new_obs = np.zeros_like(self.current_obs, dtype=np.uint8)
139 | else:
140 | game_state = self.game.get_state()
141 | new_obs = game_state.screen_buffer
142 | new_obs = self.preprocess_obs(new_obs)
143 | new_ammo, new_health = game_state.game_variables
144 |
145 | if (reward == 1.0): reward += 0.1
146 | if (new_ammo < self.ammo): reward -= 0.1
147 | if (new_health < self.health): reward -= 0.1
148 |
149 | self.ammo, self.health = new_ammo, new_health
150 |
151 | self.current_obs = new_obs
152 |
153 | return self.get_obs(), reward, done
154 |
155 | ############################### Predict the position ################################
156 |
157 | class DoomPredictThePosition(DoomEnvironment):
158 |
159 | def __init__(self, path_to_config="doom/config"):
160 | super(DoomPredictThePosition, self).__init__(scenario="predict_position",
161 | path_to_config=path_to_config)
162 |
163 | ############################### Predict the position ################################
164 |
165 | class DoomHealthGathering(DoomEnvironment):
166 |
167 | def __init__(self, path_to_config="doom/config"):
168 | super(DoomHealthGathering, self).__init__(scenario="health_gathering",
169 | path_to_config=path_to_config)
170 |
171 | ############################### Predict the position ################################
172 |
173 | class DoomDeadlyCorridor(DoomEnvironment):
174 |
175 | def __init__(self, path_to_config="doom/config"):
176 | super(DoomDeadlyCorridor, self).__init__(scenario="deadly_corridor",
177 | path_to_config=path_to_config)
--------------------------------------------------------------------------------
/environments/doom/img/basic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/basic.png
--------------------------------------------------------------------------------
/environments/doom/img/corridor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/corridor.png
--------------------------------------------------------------------------------
/environments/doom/img/def_center.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/def_center.png
--------------------------------------------------------------------------------
/environments/doom/img/def_line.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/def_line.png
--------------------------------------------------------------------------------
/environments/doom/img/health.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/health.png
--------------------------------------------------------------------------------
/environments/doom/img/predict.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/environments/doom/img/predict.png
--------------------------------------------------------------------------------
/environments/snake/snake_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | class Snake:
5 |
6 | def __init__(self, grid_size=(8, 8)):
7 | """
8 | Classic Snake game implemented as Gym environment.
9 |
10 | Parameters
11 | ----------
12 | grid_size: tuple
13 | tuple of two parameters: (height, width)
14 | """
15 |
16 | self.height, self.width = grid_size
17 | self.state = np.zeros(grid_size)
18 | self.x, self.y = [], []
19 | self.dir = None
20 | self.food = None
21 | self.opt_tab = self.opt_table(grid_size)
22 |
23 | def reset(self):
24 | """
25 | Resets the state of the environment and returns an initial observation.
26 |
27 | Returns
28 | -------
29 | observation: numpy.array of size (width, height, 1)
30 | the initial observation of the space.
31 | """
32 |
33 | self.state = np.zeros((self.height, self.width))
34 |
35 | x_tail = np.random.randint(self.height)
36 | y_tail = np.random.randint(self.width)
37 |
38 | xs = [x_tail, ]
39 | ys = [y_tail, ]
40 |
41 | for i in range(2):
42 | nbrs = self.get_neighbors(xs[-1], ys[-1])
43 | while 1:
44 | idx = np.random.randint(0, len(nbrs))
45 | x0 = nbrs[idx][0]
46 | y0 = nbrs[idx][1]
47 | occupied = [list(pt) for pt in zip(xs, ys)]
48 | if not [x0, y0] in occupied:
49 | xs.append(x0)
50 | ys.append(y0)
51 | break
52 |
53 | for x_t, y_t in list(zip(xs, ys)):
54 | self.state[x_t, y_t] = 1
55 |
56 | self.generate_food()
57 | self.x = xs
58 | self.y = ys
59 | self.update_dir()
60 |
61 | return self.get_state()
62 |
63 | def step(self, a):
64 | """
65 | Run one timestep of the environment's dynamics. When end of
66 | episode is reached, you are responsible for calling `reset()`
67 | to reset this environment's state.
68 |
69 | Args
70 | ----
71 | action: int from {0, 1, 2, 3}
72 | an action provided by the environment
73 |
74 | Returns
75 | -------
76 | observation: numpy.array of size (width, height, 1)
77 | agent's observation of the current environment
78 | reward: int from {-1, 0, 1}
79 | amount of reward returned after previous action
80 | done: boolean
81 | whether the episode has ended, in which case further step()
82 | calls will return undefined results
83 | """
84 |
85 | self.update_dir()
86 | x_, y_ = self.next_cell(self.x[-1], self.y[-1], a)
87 |
88 | # snake dies if hitting the walls
89 | if x_ < 0 or x_ == self.height or y_ < 0 or y_ == self.width:
90 | return self.get_state(), -1, True
91 |
92 | # snake dies if hitting its tail with head
93 | if self.state[x_, y_] == 1:
94 | if (x_ == self.x[0] and y_ == self.y[0]):
95 | pass
96 | else:
97 | return self.get_state(), -1, True
98 |
99 | self.x.append(x_)
100 | self.y.append(y_)
101 |
102 | # snake elongates after eating a food
103 | if self.state[x_, y_] == 3:
104 | self.state[x_, y_] = 1
105 | done = self.generate_food()
106 | return self.get_state(), 1, done
107 |
108 | # snake moves forward if cell ahead is empty
109 | # or currently occupied by its tail
110 | self.state[self.x[0], self.y[0]] = 0
111 | self.state[x_, y_] = 1
112 | self.x = self.x[1:]
113 | self.y = self.y[1:]
114 | return self.get_state(), 0, False
115 |
116 | def get_state(self):
117 | state = np.zeros((self.height, self.width, 5))
118 | state[self.x[1:-1], self.y[1:-1], 0] = 1
119 | state[self.x[-1], self.y[-1], 1] = 1
120 | state[self.x[-2], self.y[-2], 2] = 1
121 | state[self.x[0], self.y[0], 3] = 1
122 | state[self.food[0], self.food[1], 4] = 1
123 | return state.astype(np.uint8)
124 |
125 | def generate_food(self):
126 | free = np.where(self.state == 0)
127 | if free[0].size == 0:
128 | return True
129 | else:
130 | idx = np.random.randint(free[0].size)
131 | self.food = free[0][idx], free[1][idx]
132 | self.state[self.food] = 3
133 | return False
134 |
135 | def next_cell(self, i, j, a):
136 | if a == 0:
137 | return i+self.dir[0], j+self.dir[1]
138 | if a == 1:
139 | return i-self.dir[1], j+self.dir[0]
140 | if a == 2:
141 | return i+self.dir[1], j-self.dir[0]
142 |
143 | def plot_state(self):
144 | state = self.get_state()
145 | img = sum([state[:,:,i]*(i+1) for i in range(5)])
146 | plt.imshow(img, vmin=0, vmax=5, interpolation='nearest')
147 |
148 | def get_neighbors(self, i, j):
149 | """
150 | Get all the neighbors of the point (i, j)
151 | (excluding (i, j))
152 | """
153 | h = self.height
154 | w = self.width
155 | nbrs = [[i + k, j + m] for k in [-1, 0, 1] for m in [-1, 0, 1]
156 | if i + k >=0 and i + k < h and j + m >= 0 and j + m < w
157 | and not (k == m) and not (k == -m)]
158 | return nbrs
159 |
160 | def update_dir(self):
161 | x_dir = self.x[-1] - self.x[-2]
162 | y_dir = self.y[-1] - self.y[-2]
163 | self.dir = (x_dir, y_dir)
164 |
165 | ########################## Optimal action selection ##########################
166 |
167 | def opt_table(self, grid_size):
168 | n = grid_size[0]
169 | t = np.zeros(grid_size, dtype=np.int)
170 | t[0] = np.arange(n)
171 | for i in range(n//2):
172 | t[1:,(n-1)-2*i] = np.arange(n-1) + n+2*i*(n-1)
173 | t[1:,(n-2)-2*i][::-1] = np.arange(n-1) + 2*n-1+2*i*(n-1)
174 | return t
175 |
176 | def opt_action(self):
177 | x, y = self.x[-1], self.y[-1]
178 | self.update_dir()
179 | n = self.height
180 | mod = n ** 2
181 | tab_xy = self.opt_tab[x, y]
182 | pos_a = -1
183 | for a in range(3):
184 | x_, y_ = self.next_cell(x, y, a)
185 | if (x_=0 and y_>=0):
186 | tab_xy_ = self.opt_tab[x_, y_]
187 | if ((tab_xy+1) % mod == tab_xy_):
188 | return a
189 | if ((tab_xy-1) % mod == tab_xy_):
190 | pos_a = a
191 | return pos_a
--------------------------------------------------------------------------------
/environments/windy_grid_world/evil_wgw_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .wgw_env import WindyGridWorld
3 |
4 |
5 | class EvilWindyGridWorld(WindyGridWorld):
6 |
7 | def __init__(
8 | self,
9 | grid_size=(7, 10),
10 | stochasticity=0.1,
11 | visual=False):
12 | self.w, self.h = grid_size
13 | self.stochasticity = stochasticity
14 | self.visual = visual
15 |
16 | # x position of the wall
17 | self.x_wall = self.w // 2
18 | # y position of the hole in the wall
19 | self.y_hole = self.h - 4
20 | self.y_hole2 = self.h - 7
21 |
22 | self.reset()
23 |
24 | def move(self, a):
25 | """ find valid coordinates of the agent after executing action
26 | """
27 | x, y = self.pos
28 | self.field[x, y] = 0
29 | x, y = self.wind_shift(x, y)
30 |
31 | if a == 0:
32 | x_, y_ = x + 1, y
33 | if a == 1:
34 | x_, y_ = x, y + 1
35 | if a == 2:
36 | x_, y_ = x - 1, y
37 | if a == 3:
38 | x_, y_ = x, y - 1
39 |
40 | # check if new position does not conflict with the wall
41 | if x_ == self.x_wall and y != self.y_hole and y != self.y_hole2:
42 | x_, y_ = x, y
43 | return self.clip_xy(x_, y_)
44 |
45 | def reset(self):
46 | """ resets the environment
47 | """
48 | self.field = np.zeros((self.w, self.h))
49 | self.field[self.x_wall, :] = 1
50 | self.field[self.x_wall, self.y_hole] = 0
51 | self.field[self.x_wall, self.y_hole2] = 0
52 | self.field[self.x_wall + 1, self.y_hole2 + 1] = -1
53 | self.field[self.x_wall + 1, self.y_hole2 - 1] = -1
54 | self.field[0, 0] = 2
55 | self.pos = (0, 0)
56 | obs = self.get_observation()
57 | return obs
58 |
59 | def step(self, a):
60 | """ makes a step in the environment
61 | """
62 |
63 | if np.random.rand() < self.stochasticity:
64 | a = np.random.randint(4)
65 |
66 | self.field[self.pos] = 0
67 | self.pos = self.move(a)
68 | self.field[self.pos] = 2
69 |
70 | done = False
71 | reward = 0
72 | if self.pos == (self.w - 1, 0):
73 | # episode finished successfully
74 | done = True
75 | reward = 1
76 | if (self.pos == (self.x_wall + 1, self.y_hole2 + 1) or
77 | self.pos == (self.x_wall + 1, self.y_hole2 - 1)):
78 | # episode finished unsuccessfully
79 | done = True
80 | reward = -1
81 |
82 | next_obs = self.get_observation()
83 | return next_obs, reward, done
84 |
--------------------------------------------------------------------------------
/environments/windy_grid_world/wgw_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import time
4 | from IPython import display
5 |
6 |
7 | class WindyGridWorld:
8 |
9 | def __init__(
10 | self,
11 | grid_size=(11, 14),
12 | stochasticity=0.1,
13 | visual=False):
14 | self.w, self.h = grid_size
15 | self.stochasticity = stochasticity
16 | self.visual = visual
17 |
18 | # x position of the wall
19 | self.x_wall = self.w // 2
20 | # y position of the hole in the wall
21 | self.y_hole = self.h - 4
22 |
23 | self.reset()
24 |
25 | def clip_xy(self, x, y):
26 | """ clip coordinates if they go beyond the grid
27 | """
28 | x_ = np.clip(x, 0, self.w - 1)
29 | y_ = np.clip(y, 0, self.h - 1)
30 | return x_, y_
31 |
32 | def wind_shift(self, x, y):
33 | """ apply wind shift to areas where wind is blowing
34 | """
35 | if x == 1:
36 | return self.clip_xy(x, y + 1)
37 | elif x > 1 and x < self.x_wall:
38 | return self.clip_xy(x, y + 2)
39 | else:
40 | return x, y
41 |
42 | def move(self, a):
43 | """ find valid coordinates of the agent after executing action
44 | """
45 | x, y = self.pos
46 | self.field[x, y] = 0
47 | x, y = self.wind_shift(x, y)
48 |
49 | if a == 0:
50 | x_, y_ = x + 1, y
51 | if a == 1:
52 | x_, y_ = x, y + 1
53 | if a == 2:
54 | x_, y_ = x - 1, y
55 | if a == 3:
56 | x_, y_ = x, y - 1
57 |
58 | # check if new position does not conflict with the wall
59 | if x_ == self.x_wall and y != self.y_hole:
60 | x_, y_ = x, y
61 | return self.clip_xy(x_, y_)
62 |
63 | def get_observation(self):
64 | if self.visual:
65 | obs = np.rot90(self.field)[:, :, None]
66 | else:
67 | obs = self.pos
68 | return obs
69 |
70 | def reset(self):
71 | """ resets the environment
72 | """
73 | self.field = np.zeros((self.w, self.h))
74 | self.field[self.x_wall, :] = 1
75 | self.field[self.x_wall, self.y_hole] = 0
76 | self.field[0, 0] = 2
77 | self.pos = (0, 0)
78 | obs = self.get_observation()
79 | return obs
80 |
81 | def step(self, a):
82 | """ makes a step in the environment
83 | """
84 |
85 | if np.random.rand() < self.stochasticity:
86 | a = np.random.randint(4)
87 |
88 | self.field[self.pos] = 0
89 | self.pos = self.move(a)
90 | self.field[self.pos] = 2
91 |
92 | done = False
93 | reward = 0
94 | if self.pos == (self.w - 1, 0):
95 | # episode finished successfully
96 | done = True
97 | reward = 1
98 | next_obs = self.get_observation()
99 | return next_obs, reward, done
100 |
101 | def play_with_policy(self, policy, max_iter=100, visualize=True):
102 | """ play with given policy
103 | returns:
104 | episode return, number of time steps
105 | """
106 | self.reset()
107 | for i in range(max_iter):
108 | a = np.argmax(policy[self.pos])
109 | next_obs, reward, done = self.step(a)
110 | # plot grid world state
111 | if visualize:
112 | img = np.rot90(1-self.field)
113 | plt.imshow(img, cmap="gray")
114 | display.clear_output(wait=True)
115 | display.display(plt.gcf())
116 | time.sleep(0.01)
117 | if done:
118 | break
119 | if visualize:
120 | display.clear_output(wait=True)
121 | return reward, i+1
122 |
--------------------------------------------------------------------------------
/img/dqn_categorical.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_categorical.png
--------------------------------------------------------------------------------
/img/dqn_classic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_classic.png
--------------------------------------------------------------------------------
/img/dqn_dueling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_dueling.png
--------------------------------------------------------------------------------
/img/dqn_quantile.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/dqn_quantile.png
--------------------------------------------------------------------------------
/img/sac_p_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/sac_p_network.png
--------------------------------------------------------------------------------
/img/sac_q_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/sac_q_network.png
--------------------------------------------------------------------------------
/img/sac_v_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlexGrinch/rl_algorithms/e1a3d1334e9da30fcae784132f68423afcfb0fe5/img/sac_v_network.png
--------------------------------------------------------------------------------
/methods.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import tensorflow as tf
4 | import tensorflow.contrib.layers as layers
5 | from tensorflow.contrib.layers import convolution2d as conv
6 | from tensorflow.contrib.layers import fully_connected as fc
7 | from tensorflow.contrib.layers import xavier_initializer as xavier
8 |
9 | ###############################################################################
10 | ################################ Core modules #################################
11 | ###############################################################################
12 |
13 |
14 | def conv_module(input_layer, convs, activation_fn=tf.nn.relu):
15 | """ convolutional module
16 | """
17 | out = input_layer
18 | for num_outputs, kernel_size, stride in convs:
19 | out = conv(
20 | out,
21 | num_outputs=num_outputs,
22 | kernel_size=kernel_size,
23 | stride=stride,
24 | padding="VALID",
25 | activation_fn=activation_fn)
26 | return out
27 |
28 |
29 | def fc_module(input_layer, fully_connected, activation_fn=tf.nn.relu):
30 | """ fully connected module
31 | """
32 | out = input_layer
33 | for num_outputs in fully_connected:
34 | out = fc(
35 | out,
36 | num_outputs=num_outputs,
37 | activation_fn=activation_fn,
38 | weights_initializer=xavier())
39 | return out
40 |
41 |
42 | def full_module(
43 | input_layer, convs, fully_connected,
44 | num_outputs, activation_fn=tf.nn.relu):
45 | """ convolutional + fully connected + output
46 | """
47 | out = input_layer
48 | out = conv_module(out, convs, activation_fn)
49 | out = layers.flatten(out)
50 | out = fc_module(out, fully_connected, activation_fn)
51 | out = fc_module(out, [num_outputs], None)
52 | return out
53 |
54 | ###############################################################################
55 | ############################### Deep Q-Network ################################
56 | ###############################################################################
57 |
58 |
59 | class DeepQNetwork:
60 |
61 | def __init__(
62 | self,
63 | num_actions,
64 | state_shape=[8, 8, 5],
65 | convs=[[32, 4, 2], [64, 2, 1]],
66 | fully_connected=[128],
67 | activation_fn=tf.nn.relu,
68 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
69 | gradient_clip=10.0,
70 | scope="dqn",
71 | reuse=False):
72 |
73 | with tf.variable_scope(scope, reuse=reuse):
74 |
75 | ################### Neural network architecture ###################
76 |
77 | input_shape = [None] + state_shape
78 | self.input_states = tf.placeholder(
79 | dtype=tf.float32, shape=input_shape)
80 |
81 | self.q_values = full_module(
82 | self.input_states, convs, fully_connected,
83 | num_actions, activation_fn)
84 |
85 | ##################### Optimization procedure ######################
86 |
87 | # convert input actions to indices for q-values selection
88 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None])
89 | indices_range = tf.range(tf.shape(self.input_actions)[0])
90 | action_indices = tf.stack(
91 | [indices_range, self.input_actions], axis=1)
92 |
93 | # select q-values for input actions
94 | self.q_values_selected = tf.gather_nd(
95 | self.q_values, action_indices)
96 |
97 | # select best actions (according to q-values)
98 | self.q_argmax = tf.argmax(self.q_values, axis=1)
99 |
100 | # define loss function and update rule
101 | self.q_targets = tf.placeholder(dtype=tf.float32, shape=[None])
102 | self.loss = tf.losses.huber_loss(
103 | self.q_targets, self.q_values_selected, delta=gradient_clip)
104 | self.update_model = optimizer.minimize(self.loss)
105 |
106 | def get_q_values_s(self, sess, states):
107 | feed_dict = {self.input_states: states}
108 | q_values = sess.run(self.q_values, feed_dict)
109 | return q_values
110 |
111 | def get_q_values_sa(self, sess, states, actions):
112 | feed_dict = {self.input_states: states, self.input_actions: actions}
113 | q_values_selected = sess.run(self.q_values_selected, feed_dict)
114 | return q_values_selected
115 |
116 | def get_q_argmax(self, sess, states):
117 | feed_dict = {self.input_states: states}
118 | q_argmax = sess.run(self.q_argmax, feed_dict)
119 | return q_argmax
120 |
121 | def update(self, sess, states, actions, q_targets):
122 | feed_dict = {self.input_states: states,
123 | self.input_actions: actions,
124 | self.q_targets: q_targets}
125 | sess.run(self.update_model, feed_dict)
126 |
127 | ###############################################################################
128 | ########################### Dueling Deep Q-Network ############################
129 | ###############################################################################
130 |
131 |
132 | class DuelingDeepQNetwork:
133 |
134 | def __init__(
135 | self,
136 | num_actions,
137 | state_shape=[8, 8, 5],
138 | convs=[[32, 4, 2], [64, 2, 1]],
139 | fully_connected=[64],
140 | activation_fn=tf.nn.relu,
141 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
142 | gradient_clip=10.0,
143 | scope="duel_dqn",
144 | reuse=False):
145 |
146 | with tf.variable_scope(scope, reuse=reuse):
147 |
148 | ################### Neural network architecture ###################
149 |
150 | input_shape = [None] + state_shape
151 | self.input_states = tf.placeholder(
152 | dtype=tf.float32, shape=input_shape)
153 |
154 | out = conv_module(self.input_states, convs, activation_fn)
155 | val, adv = tf.split(out, num_or_size_splits=2, axis=3)
156 | self.v_values = full_module(
157 | val, [], fully_connected, 1, activation_fn)
158 | self.a_values = full_module(
159 | adv, [], fully_connected, num_actions, activation_fn)
160 |
161 | a_values_mean = tf.reduce_mean(
162 | self.a_values, axis=1, keepdims=True)
163 | a_values_centered = tf.subtract(self.a_values, a_values_mean)
164 | self.q_values = self.v_values + a_values_centered
165 |
166 | ##################### Optimization procedure ######################
167 |
168 | # convert input actions to indices for q-values selection
169 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None])
170 | indices_range = tf.range(tf.shape(self.input_actions)[0])
171 | action_indices = tf.stack(
172 | [indices_range, self.input_actions], axis=1)
173 |
174 | # select q-values for input actions
175 | self.q_values_selected = tf.gather_nd(
176 | self.q_values, action_indices)
177 |
178 | # select best actions (according to q-values)
179 | self.q_argmax = tf.argmax(self.q_values, axis=1)
180 |
181 | # define loss function and update rule
182 | self.q_targets = tf.placeholder(dtype=tf.float32, shape=[None])
183 | self.loss = tf.losses.huber_loss(
184 | self.q_targets, self.q_values_selected, delta=gradient_clip)
185 | self.update_model = optimizer.minimize(self.loss)
186 |
187 | def get_q_values_s(self, sess, states):
188 | feed_dict = {self.input_states: states}
189 | q_values = sess.run(self.q_values, feed_dict)
190 | return q_values
191 |
192 | def get_q_values_sa(self, sess, states, actions):
193 | feed_dict = {self.input_states: states, self.input_actions: actions}
194 | q_values_selected = sess.run(self.q_values_selected, feed_dict)
195 | return q_values_selected
196 |
197 | def get_q_argmax(self, sess, states):
198 | feed_dict = {self.input_states: states}
199 | q_argmax = sess.run(self.q_argmax, feed_dict)
200 | return q_argmax
201 |
202 | def update(self, sess, states, actions, q_targets):
203 | feed_dict = {self.input_states: states,
204 | self.input_actions: actions,
205 | self.q_targets: q_targets}
206 | sess.run(self.update_model, feed_dict)
207 |
208 | ###############################################################################
209 | ######################### Categorical Deep Q-Network ##########################
210 | ###############################################################################
211 |
212 |
213 | class CategoricalDeepQNetwork:
214 |
215 | def __init__(
216 | self,
217 | num_actions,
218 | state_shape=[8, 8, 5],
219 | convs=[[32, 4, 2], [64, 2, 1]],
220 | fully_connected=[128],
221 | num_atoms=21,
222 | v=(-10, 10),
223 | activation_fn=tf.nn.relu,
224 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
225 | scope="cat_dqn",
226 | reuse=False):
227 |
228 | with tf.variable_scope(scope, reuse=reuse):
229 |
230 | ################### Neural network architecture ###################
231 |
232 | input_shape = [None] + state_shape
233 | self.input_states = tf.placeholder(
234 | dtype=tf.float32, shape=input_shape)
235 |
236 | # distribution parameters
237 | self.num_atoms = num_atoms
238 | self.v_min, self.v_max = v
239 | self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
240 | self.z = np.linspace(
241 | start=self.v_min, stop=self.v_max, num=num_atoms)
242 |
243 | # main module
244 | out = full_module(
245 | self.input_states, convs, fully_connected,
246 | num_outputs=num_actions*num_atoms, activation_fn=activation_fn)
247 |
248 | self.logits = tf.reshape(out, shape=[-1, num_actions, num_atoms])
249 | self.probs = tf.nn.softmax(self.logits, axis=2)
250 | self.q_values = tf.reduce_sum(
251 | tf.multiply(self.probs, self.z), axis=2)
252 |
253 | ##################### Optimization procedure ######################
254 |
255 | # convert input actions to indices for probs and q-values selection
256 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None])
257 | indices_range = tf.range(tf.shape(self.input_actions)[0])
258 | action_indices = tf.stack(
259 | [indices_range, self.input_actions], axis=1)
260 |
261 | # select q-values and probs for input actions
262 | self.q_values_selected = tf.gather_nd(
263 | self.q_values, action_indices)
264 | self.probs_selected = tf.gather_nd(self.probs, action_indices)
265 |
266 | # select best actions (according to q-values)
267 | self.q_argmax = tf.argmax(self.q_values, axis=1)
268 |
269 | # define loss function and update rule
270 | self.probs_targets = tf.placeholder(
271 | dtype=tf.float32, shape=[None, self.num_atoms])
272 | self.loss = -tf.reduce_sum(
273 | self.probs_targets * tf.log(self.probs_selected+1e-6))
274 | self.update_model = optimizer.minimize(self.loss)
275 |
276 | def get_q_values_s(self, sess, states):
277 | feed_dict = {self.input_states: states}
278 | q_values = sess.run(self.q_values, feed_dict)
279 | return q_values
280 |
281 | def get_q_values_sa(self, sess, states, actions):
282 | feed_dict = {self.input_states: states, self.input_actions: actions}
283 | q_values_selected = sess.run(self.q_values_selected, feed_dict)
284 | return q_values_selected
285 |
286 | def get_q_argmax(self, sess, states):
287 | feed_dict = {self.input_states: states}
288 | q_argmax = sess.run(self.q_argmax, feed_dict)
289 | return q_argmax
290 |
291 | def get_probs_s(self, sess, states):
292 | feed_dict = {self.input_states: states}
293 | probs = sess.run(self.probs, feed_dict)
294 | return probs
295 |
296 | def get_probs_sa(self, sess, states, actions):
297 | feed_dict = {self.input_states: states, self.input_actions: actions}
298 | probs_selected = sess.run(self.probs_selected, feed_dict)
299 | return probs_selected
300 |
301 | def update(self, sess, states, actions, probs_targets):
302 | feed_dict = {self.input_states: states,
303 | self.input_actions: actions,
304 | self.probs_targets: probs_targets}
305 | sess.run(self.update_model, feed_dict)
306 |
307 | def cat_proj(self, sess, states, actions, rewards, done, gamma=0.99):
308 | """
309 | Categorical algorithm from https://arxiv.org/abs/1707.06887
310 | """
311 |
312 | atoms_targets = rewards[:, None] + gamma * self.z * (1 - done[:, None])
313 | tz = np.clip(atoms_targets, self.v_min, self.v_max)
314 | tz_z = tz[:, None, :] - self.z[None, :, None]
315 | tz_z = np.clip((1.0 - (np.abs(tz_z) / self.delta_z)), 0, 1)
316 |
317 | probs = self.get_probs_sa(sess, states, actions)
318 | probs_targets = np.einsum('bij,bj->bi', tz_z, probs)
319 |
320 | return probs_targets
321 |
322 | ###############################################################################
323 | ########################### Quantile Deep Q-Network ###########################
324 | ###############################################################################
325 |
326 |
327 | class QuantileDeepQNetwork:
328 |
329 | def __init__(
330 | self,
331 | num_actions,
332 | state_shape=[8, 8, 5],
333 | convs=[[32, 4, 2], [64, 2, 1]],
334 | fully_connected=[128],
335 | num_atoms=50,
336 | kappa=1.0,
337 | activation_fn=tf.nn.relu,
338 | optimizer=tf.train.AdamOptimizer(2.5e-4, epsilon=0.01/32),
339 | scope="qr_dqn",
340 | reuse=False):
341 |
342 | with tf.variable_scope(scope, reuse=reuse):
343 |
344 | ################### Neural network architecture ###################
345 |
346 | input_shape = [None] + state_shape
347 | self.input_states = tf.placeholder(
348 | dtype=tf.float32, shape=input_shape)
349 |
350 | # distribution parameters
351 | tau_min = 1 / (2 * num_atoms)
352 | tau_max = 1 - tau_min
353 | tau_vector = tf.lin_space(
354 | start=tau_min, stop=tau_max, num=num_atoms)
355 |
356 | # reshape tau to matrix for fast loss calculation
357 | tau_matrix = tf.tile(tau_vector, [num_atoms])
358 | self.tau_matrix = tf.reshape(
359 | tau_matrix, shape=[num_atoms, num_atoms])
360 |
361 | # main module
362 | out = full_module(
363 | self.input_states, convs, fully_connected,
364 | num_outputs=num_actions*num_atoms, activation_fn=activation_fn)
365 | self.atoms = tf.reshape(out, shape=[-1, num_actions, num_atoms])
366 | self.q_values = tf.reduce_mean(self.atoms, axis=2)
367 |
368 | ##################### Optimization procedure ######################
369 |
370 | # convert input actions to indices for atoms and q-values selection
371 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None])
372 | indices_range = tf.range(tf.shape(self.input_actions)[0])
373 | action_indices = tf.stack(
374 | [indices_range, self.input_actions], axis=1)
375 |
376 | # select q-values for input actions
377 | self.q_values_selected = tf.gather_nd(
378 | self.q_values, action_indices)
379 | self.atoms_selected = tf.gather_nd(self.atoms, action_indices)
380 |
381 | # select best actions (according to q-values)
382 | self.q_argmax = tf.argmax(self.q_values, axis=1)
383 |
384 | # reshape chosen atoms to matrix for fast loss calculation
385 | atoms_matrix = tf.tile(self.atoms_selected, [1, num_atoms])
386 | self.atoms_matrix = tf.reshape(
387 | atoms_matrix, shape=[-1, num_atoms, num_atoms])
388 |
389 | # reshape target atoms to matrix for fast loss calculation
390 | self.atoms_targets = tf.placeholder(
391 | dtype=tf.float32, shape=[None, num_atoms])
392 | targets_matrix = tf.tile(self.atoms_targets, [1, num_atoms])
393 | targets_matrix = tf.reshape(
394 | targets_matrix, shape=[-1, num_atoms, num_atoms])
395 | self.targets_matrix = tf.transpose(targets_matrix, perm=[0, 2, 1])
396 |
397 | # define loss function and update rule
398 | atoms_diff = self.targets_matrix - self.atoms_matrix
399 | delta_atoms_diff = tf.where(
400 | atoms_diff < 0,
401 | tf.ones_like(atoms_diff),
402 | tf.ones_like(atoms_diff))
403 | huber_weights = tf.abs(
404 | self.tau_matrix - delta_atoms_diff) / num_atoms
405 | self.loss = tf.losses.huber_loss(
406 | self.targets_matrix, self.atoms_matrix, weights=huber_weights,
407 | delta=kappa, reduction=tf.losses.Reduction.SUM)
408 | self.update_model = optimizer.minimize(self.loss)
409 |
410 | def get_q_values_s(self, sess, states):
411 | feed_dict = {self.input_states: states}
412 | q_values = sess.run(self.q_values, feed_dict)
413 | return q_values
414 |
415 | def get_q_values_sa(self, sess, states, actions):
416 | feed_dict = {self.input_states: states, self.input_actions: actions}
417 | q_values_selected = sess.run(self.q_values_selected, feed_dict)
418 | return q_values_selected
419 |
420 | def get_q_argmax(self, sess, states):
421 | feed_dict = {self.input_states: states}
422 | q_argmax = sess.run(self.q_argmax, feed_dict)
423 | return q_argmax
424 |
425 | def get_atoms_s(self, sess, states):
426 | feed_dict = {self.input_states: states}
427 | atoms = sess.run(self.atoms, feed_dict)
428 | return probs
429 |
430 | def get_atoms_sa(self, sess, states, actions):
431 | feed_dict = {self.input_states: states, self.input_actions: actions}
432 | atoms_selected = sess.run(self.atoms_selected, feed_dict)
433 | return atoms_selected
434 |
435 | def update(self, sess, states, actions, atoms_targets):
436 | feed_dict = {self.input_states: states,
437 | self.input_actions: actions,
438 | self.atoms_targets: atoms_targets}
439 | sess.run(self.update_model, feed_dict)
440 |
441 | ###############################################################################
442 | ############################## Soft Actor-Critic ##############################
443 | ###############################################################################
444 |
445 |
446 | class SoftActorCriticNetwork:
447 |
448 | def __init__(self, num_actions, state_shape=[8, 8, 5],
449 | convs=[[32, 4, 2], [64, 2, 1]],
450 | fully_connected=[128],
451 | activation_fn=tf.nn.relu,
452 | optimizers=[tf.train.AdamOptimizer(2.5e-4),
453 | tf.train.AdamOptimizer(2.5e-4),
454 | tf.train.AdamOptimizer(2.5e-4)],
455 | scope="sac", reuse=False):
456 |
457 | with tf.variable_scope(scope, reuse=reuse):
458 |
459 | ################### Neural network architecture ###################
460 |
461 | input_shape = [None] + state_shape
462 | self.input_states = tf.placeholder(
463 | dtype=tf.float32, shape=input_shape)
464 |
465 | self.v_values = full_module(
466 | self.input_states, convs, fully_connected,
467 | 1, activation_fn)
468 | self.q_values = full_module(
469 | self.input_states, convs, fully_connected,
470 | num_actions, activation_fn)
471 | self.p_logits = full_module(
472 | self.input_states, convs, fully_connected,
473 | num_actions, activation_fn)
474 | self.p_values = layers.softmax(self.p_logits)
475 |
476 | ##################### Optimization procedure ######################
477 |
478 | # convert =actions to indices for p-logits and q-values selection
479 | self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None])
480 | indices_range = tf.range(tf.shape(self.input_actions)[0])
481 | action_indices = tf.stack(
482 | [indices_range, self.input_actions], axis=1)
483 |
484 | q_values_selected = tf.gather_nd(self.q_values, action_indices)
485 | p_logits_selected = tf.gather_nd(self.p_logits, action_indices)
486 |
487 | # choose best actions (according to q-values)
488 | self.q_argmax = tf.argmax(self.q_values, axis=1)
489 |
490 | # define loss function and update rule
491 | self.q_targets = tf.placeholder(dtype=tf.float32, shape=[None])
492 | self.v_targets = tf.placeholder(dtype=tf.float32, shape=[None])
493 | self.p_targets = tf.placeholder(dtype=tf.float32, shape=[None])
494 |
495 | q_loss = tf.losses.huber_loss(self.q_targets, q_values_selected)
496 | self.q_loss = tf.reduce_sum(q_loss)
497 | q_optimizer = optimizers[0]
498 |
499 | v_loss = tf.losses.huber_loss(
500 | self.v_targets[:, None], self.v_values)
501 | self.v_loss = tf.reduce_sum(v_loss)
502 | v_optimizer = optimizers[1]
503 |
504 | p_loss = tf.losses.huber_loss(self.p_targets, p_logits_selected)
505 | self.p_loss = tf.reduce_sum(p_loss)
506 | p_optimizer = optimizers[2]
507 |
508 | self.update_q_values = q_optimizer.minimize(self.q_loss)
509 | self.update_v_values = v_optimizer.minimize(self.v_loss)
510 | self.update_p_logits = p_optimizer.minimize(self.p_loss)
511 |
512 | def get_q_argmax(self, sess, states):
513 | feed_dict = {self.input_states: states}
514 | q_argmax = sess.run(self.q_argmax, feed_dict)
515 | return q_argmax
516 |
517 | def get_q_values_s(self, sess, states):
518 | feed_dict = {self.input_states: states}
519 | q_values = sess.run(self.q_values, feed_dict)
520 | return q_values
521 |
522 | def get_v_values_s(self, sess, states):
523 | feed_dict = {self.input_states: states}
524 | v_values = sess.run(self.v_values, feed_dict)
525 | return v_values
526 |
527 | def get_p_logits_s(self, sess, states):
528 | feed_dict = {self.input_states: states}
529 | p_logits = sess.run(self.p_logits, feed_dict)
530 | return p_logits
531 |
532 | def get_p_values_s(self, sess, states):
533 | feed_dict = {self.input_states: states}
534 | p_values = sess.run(self.p_values, feed_dict)
535 | return p_values
536 |
537 | def update_q(self, sess, states, actions, q_targets):
538 |
539 | feed_dict = {self.input_states: states,
540 | self.input_actions: actions,
541 | self.q_targets: q_targets}
542 | sess.run(self.update_q_values, feed_dict)
543 |
544 | def update_v(self, sess, states, v_targets):
545 |
546 | feed_dict = {self.input_states: states,
547 | self.v_targets: v_targets}
548 | sess.run(self.update_v_values, feed_dict)
549 |
550 | def update_p(self, sess, states, actions, p_targets):
551 |
552 | feed_dict = {self.input_states: states,
553 | self.input_actions: actions,
554 | self.p_targets: p_targets}
555 | sess.run(self.update_p_logits, feed_dict)
556 |
--------------------------------------------------------------------------------
/train_agents.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Imports"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import gym\n",
17 | "from utils import *\n",
18 | "from agents import *\n",
19 | "from environments.snake.snake_env import Snake"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "# Snake Environment"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "### Environment initializtion"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "env = Snake(grid_size=(8, 8))\n",
43 | "num_actions = 3"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "### Agent training"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "# Create basic agent which consists of two networks: agent and target.\n",
60 | "# Checkpoints of networks' weights and learning curves will be saved\n",
61 | "# in \"save_path/model_name\" folder.\n",
62 | "snake_agent = DQNAgent(env, num_actions, state_shape=[8, 8, 5],\n",
63 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n",
64 | " save_path=\"snake_models\", model_name=\"dqn_8x8\")"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# Set basic hyper parameters (for full list see \"set_parameters\" method).\n",
74 | "# Create replay buffer and fill it with random transitions.\n",
75 | "snake_agent.set_parameters(max_episode_length=1000, replay_memory_size=100000, replay_start_size=10000,\n",
76 | " discount_factor=0.999, final_eps=0.01, annealing_steps=100000)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {
83 | "scrolled": true
84 | },
85 | "outputs": [],
86 | "source": [
87 | "# Set training hyper parameters (for full list see \"train\" method).\n",
88 | "# Set gpu_id = -1 to use cpu instead if gpu, otherwise set it to gpu device id.\n",
89 | "snake_agent.train(gpu_id=-1, exploration=\"boltzmann\", save_freq=500000, max_num_epochs=1000)"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "metadata": {},
95 | "source": [
96 | "### Other agents"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "# Classic deep Q-network\n",
106 | "snake_agent = DQNAgent(env, num_actions, state_shape=[8, 8, 5],\n",
107 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n",
108 | " save_path=\"snake_models\", model_name=\"dqn_8x8\")\n",
109 | "\n",
110 | "# Dueling deep Q-network\n",
111 | "snake_agent = DuelDQNAgent(env, num_actions, state_shape=[8, 8, 5],\n",
112 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[64],\n",
113 | " save_path=\"snake_models\", model_name=\"dueldqn_8x8\")\n",
114 | "\n",
115 | "# Categorical deep Q-network (C51)\n",
116 | "snake_agent = CatDQNAgent(env, num_actions, state_shape=[8, 8, 5],\n",
117 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n",
118 | " v=(-5, 25), num_atoms=51,\n",
119 | " save_path=\"snake_models\", model_name=\"catdqn_8x8\")\n",
120 | "\n",
121 | "# Quantile regression deep Q-network (QR-DQN)\n",
122 | "snake_agent = QuantRegDQNAgent(env, num_actions, state_shape=[8, 8, 5],\n",
123 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n",
124 | " num_atoms=100, kappa=1.0,\n",
125 | " save_path=\"snake_models\", model_name=\"quantdqn_8x8\")\n",
126 | "\n",
127 | "# Soft Actor-Critic\n",
128 | "snake_agent = SACAgent(env, num_actions, state_shape=[8, 8, 5],\n",
129 | " convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],\n",
130 | " temperature=0.1,\n",
131 | " save_path=\"snake_models\", model_name=\"sac_8x8\")"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {
137 | "collapsed": true
138 | },
139 | "source": [
140 | "# Atari Environment"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "### Environment initializtion"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "game_id = \"PongNoFrameskip-v4\"\n",
157 | "env = wrap_deepmind(gym.make(game_id))\n",
158 | "num_actions = env.unwrapped.action_space.n"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {},
164 | "source": [
165 | "### Agent training"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],\n",
175 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n",
176 | " save_path=\"atari_models\", model_name=\"dqn_boi\")"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "atari_agent.set_parameters(max_episode_length=100000, discount_factor=0.99, final_eps=0.01,\n",
186 | " replay_memory_size=1000000, replay_start_size=50, annealing_steps=1000000,\n",
187 | " frame_history_len=4)"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "atari_agent.train(gpu_id=-1, exploration=\"e-greedy\", save_freq=50000, \n",
197 | " max_num_epochs=1000, performance_print_freq=50)"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "### Other agents"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "# Classic deep Q-network\n",
214 | "atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],\n",
215 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n",
216 | " save_path=\"atari_models\", model_name=\"dqn_boi\")\n",
217 | "\n",
218 | "# Dueling deep Q-network\n",
219 | "atari_agent = DuelDQNAgent(env, num_actions, state_shape=[84, 84, 4],\n",
220 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[256],\n",
221 | " save_path=\"atari_models\", model_name=\"dueldqn_boi\")\n",
222 | "\n",
223 | "# Categorical deep Q-network (C51)\n",
224 | "atari_agent = CatDQNAgent(env, num_actions, state_shape=[84, 84, 4],\n",
225 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n",
226 | " v=(-10, 10), num_atoms=51,\n",
227 | " save_path=\"atari_models\", model_name=\"catdqn_boi\")\n",
228 | "\n",
229 | "# Quantile regression deep Q-network (QR-DQN)\n",
230 | "atari_agent = QuantRegDQNAgent(env, num_actions, state_shape=[84, 84, 4],\n",
231 | " convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],\n",
232 | " num_atoms=200, kappa=1,\n",
233 | " save_path=\"atari_models\", model_name=\"quantdqn_boi\")"
234 | ]
235 | }
236 | ],
237 | "metadata": {
238 | "kernelspec": {
239 | "display_name": "Python [default]",
240 | "language": "python",
241 | "name": "python3"
242 | },
243 | "language_info": {
244 | "codemirror_mode": {
245 | "name": "ipython",
246 | "version": 3
247 | },
248 | "file_extension": ".py",
249 | "mimetype": "text/x-python",
250 | "name": "python",
251 | "nbconvert_exporter": "python",
252 | "pygments_lexer": "ipython3",
253 | "version": "3.6.5"
254 | }
255 | },
256 | "nbformat": 4,
257 | "nbformat_minor": 2
258 | }
259 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is copied/apdated from https://github.com/berkeleydeeprlcourse/homework/tree/master/hw3
3 | """
4 |
5 | import numpy as np
6 | import random
7 | from collections import namedtuple, deque
8 | import gym
9 | from gym import spaces
10 | from PIL import Image
11 |
12 | ####################################################################################################
13 | ######################################## Experience Replay #########################################
14 | ####################################################################################################
15 |
16 | def sample_n_unique(sampling_f, n):
17 | """Helper function. Given a function `sampling_f` that returns
18 | comparable objects, sample n such unique objects.
19 | """
20 | res = []
21 | while len(res) < n:
22 | candidate = sampling_f()
23 | if candidate not in res:
24 | res.append(candidate)
25 | return res
26 |
27 | class ReplayBuffer(object):
28 | def __init__(self, size, frame_history_len):
29 | """This is a memory efficient implementation of the replay buffer.
30 | The sepecific memory optimizations use here are:
31 | - only store each frame once rather than k times
32 | even if every observation normally consists of k last frames
33 | - store frames as np.uint8 (actually it is most time-performance
34 | to cast them back to float32 on GPU to minimize memory transfer
35 | time)
36 | - store frame_t and frame_(t+1) in the same buffer.
37 | For the tipical use case in Atari Deep RL buffer with 1M frames the total
38 | memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes
39 | Warning! Assumes that returning frame of zeros at the beginning
40 | of the episode, when there is less frames than `frame_history_len`,
41 | is acceptable.
42 | Parameters
43 | ----------
44 | size: int
45 | Max number of transitions to store in the buffer. When the buffer
46 | overflows the old memories are dropped.
47 | frame_history_len: int
48 | Number of memories to be retried for each observation.
49 | """
50 | self.size = size
51 | self.frame_history_len = frame_history_len
52 |
53 | self.next_idx = 0
54 | self.num_in_buffer = 0
55 |
56 | self.obs = None
57 | self.action = None
58 | self.reward = None
59 | self.done = None
60 |
61 | self.transition = namedtuple('Transition', ('s', 'a', 'r', 's_', 'done'))
62 |
63 | def can_sample(self, batch_size):
64 | """Returns true if `batch_size` different transitions can be sampled from the buffer."""
65 | return batch_size + 1 <= self.num_in_buffer
66 |
67 | def _encode_sample(self, idxes):
68 | obs_batch = np.concatenate([self._encode_observation(idx)[None] for idx in idxes], 0)
69 | act_batch = self.action[idxes]
70 | rew_batch = self.reward[idxes]
71 | next_obs_batch = np.concatenate([self._encode_observation(idx + 1)[None] for idx in idxes], 0)
72 | done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32)
73 |
74 | return self.transition(obs_batch, act_batch, rew_batch, next_obs_batch, done_mask)
75 |
76 |
77 | def sample(self, batch_size):
78 | """Sample `batch_size` different transitions.
79 | i-th sample transition is the following:
80 | when observing `obs_batch[i]`, action `act_batch[i]` was taken,
81 | after which reward `rew_batch[i]` was received and subsequent
82 | observation next_obs_batch[i] was observed, unless the epsiode
83 | was done which is represented by `done_mask[i]` which is equal
84 | to 1 if episode has ended as a result of that action.
85 | Parameters
86 | ----------
87 | batch_size: int
88 | How many transitions to sample.
89 | Returns
90 | -------
91 | obs_batch: np.array
92 | Array of shape
93 | (batch_size, img_h, img_w, img_c * frame_history_len)
94 | and dtype np.uint8
95 | act_batch: np.array
96 | Array of shape (batch_size,) and dtype np.int32
97 | rew_batch: np.array
98 | Array of shape (batch_size,) and dtype np.float32
99 | next_obs_batch: np.array
100 | Array of shape
101 | (batch_size, img_h, img_w, img_c * frame_history_len)
102 | and dtype np.uint8
103 | done_mask: np.array
104 | Array of shape (batch_size,) and dtype np.float32
105 | """
106 | assert self.can_sample(batch_size)
107 | idxes = sample_n_unique(lambda: random.randint(0, self.num_in_buffer - 2), batch_size)
108 | return self._encode_sample(idxes)
109 |
110 | def encode_recent_observation(self):
111 | """Return the most recent `frame_history_len` frames.
112 | Returns
113 | -------
114 | observation: np.array
115 | Array of shape (img_h, img_w, img_c * frame_history_len)
116 | and dtype np.uint8, where observation[:, :, i*img_c:(i+1)*img_c]
117 | encodes frame at time `t - frame_history_len + i`
118 | """
119 | assert self.num_in_buffer > 0
120 | return self._encode_observation((self.next_idx - 1) % self.size)
121 |
122 | def _encode_observation(self, idx):
123 | end_idx = idx + 1 # make noninclusive
124 | start_idx = end_idx - self.frame_history_len
125 | # this checks if we are using low-dimensional observations, such as RAM
126 | # state, in which case we just directly return the latest RAM.
127 | if len(self.obs.shape) == 2:
128 | return self.obs[end_idx-1]
129 | # if there weren't enough frames ever in the buffer for context
130 | if start_idx < 0 and self.num_in_buffer != self.size:
131 | start_idx = 0
132 | for idx in range(start_idx, end_idx - 1):
133 | if self.done[idx % self.size]:
134 | start_idx = idx + 1
135 | missing_context = self.frame_history_len - (end_idx - start_idx)
136 | # if zero padding is needed for missing context
137 | # or we are on the boundry of the buffer
138 | if start_idx < 0 or missing_context > 0:
139 | frames = [np.zeros_like(self.obs[0]) for _ in range(missing_context)]
140 | for idx in range(start_idx, end_idx):
141 | frames.append(self.obs[idx % self.size])
142 | return np.concatenate(frames, 2)
143 | else:
144 | # this optimization has potential to saves about 30% compute time \o/
145 | img_h, img_w = self.obs.shape[1], self.obs.shape[2]
146 | return self.obs[start_idx:end_idx].transpose(1, 2, 0, 3).reshape(img_h, img_w, -1)
147 |
148 | def store_frame(self, frame):
149 | """Store a single frame in the buffer at the next available index, overwriting
150 | old frames if necessary.
151 | Parameters
152 | ----------
153 | frame: np.array
154 | Array of shape (img_h, img_w, img_c) and dtype np.uint8
155 | the frame to be stored
156 | Returns
157 | -------
158 | idx: int
159 | Index at which the frame is stored. To be used for `store_effect` later.
160 | """
161 | if self.obs is None:
162 | self.obs = np.empty([self.size] + list(frame.shape), dtype=np.uint8)
163 | self.action = np.empty([self.size], dtype=np.int32)
164 | self.reward = np.empty([self.size], dtype=np.float32)
165 | self.done = np.empty([self.size], dtype=np.bool)
166 | self.obs[self.next_idx] = frame
167 |
168 | ret = self.next_idx
169 | self.next_idx = (self.next_idx + 1) % self.size
170 | self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
171 |
172 | return ret
173 |
174 | def store_effect(self, idx, action, reward, done):
175 | """Store effects of action taken after obeserving frame stored
176 | at index idx. The reason `store_frame` and `store_effect` is broken
177 | up into two functions is so that once can call `encode_recent_observation`
178 | in between.
179 | Paramters
180 | ---------
181 | idx: int
182 | Index in buffer of recently observed frame (returned by `store_frame`).
183 | action: int
184 | Action that was performed upon observing this frame.
185 | reward: float
186 | Reward that was received when the actions was performed.
187 | done: bool
188 | True if episode was finished after performing that action.
189 | """
190 | self.action[idx] = action
191 | self.reward[idx] = reward
192 | self.done[idx] = done
193 |
194 | ####################################################################################################
195 | ######################################## DM Atari Wrappers #########################################
196 | ####################################################################################################
197 |
198 | class NoopResetEnv(gym.Wrapper):
199 | def __init__(self, env=None, noop_max=30):
200 | """Sample initial states by taking random number of no-ops on reset.
201 | No-op is assumed to be action 0.
202 | """
203 | super(NoopResetEnv, self).__init__(env)
204 | self.noop_max = noop_max
205 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
206 |
207 | def _reset(self):
208 | """ Do no-op action for a number of steps in [1, noop_max]."""
209 | self.env.reset()
210 | noops = np.random.randint(1, self.noop_max + 1)
211 | for _ in range(noops):
212 | obs, _, _, _ = self.env.step(0)
213 | return obs
214 |
215 | class FireResetEnv(gym.Wrapper):
216 | def __init__(self, env=None):
217 | """Take action on reset for environments that are fixed until firing."""
218 | super(FireResetEnv, self).__init__(env)
219 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
220 | assert len(env.unwrapped.get_action_meanings()) >= 3
221 |
222 | def _reset(self):
223 | self.env.reset()
224 | obs, _, _, _ = self.env.step(1)
225 | obs, _, _, _ = self.env.step(2)
226 | return obs
227 |
228 | class EpisodicLifeEnv(gym.Wrapper):
229 | def __init__(self, env=None):
230 | """Make end-of-life == end-of-episode, but only reset on true game over.
231 | Done by DeepMind for the DQN and co. since it helps value estimation.
232 | """
233 | super(EpisodicLifeEnv, self).__init__(env)
234 | self.lives = 0
235 | self.was_real_done = True
236 | self.was_real_reset = False
237 |
238 | def _step(self, action):
239 | obs, reward, done, info = self.env.step(action)
240 | self.was_real_done = done
241 | # check current lives, make loss of life terminal,
242 | # then update lives to handle bonus lives
243 | lives = self.env.unwrapped.ale.lives()
244 | if lives < self.lives and lives > 0:
245 | # for Qbert somtimes we stay in lives == 0 condtion for a few frames
246 | # so its important to keep lives > 0, so that we only reset once
247 | # the environment advertises done.
248 | done = True
249 | self.lives = lives
250 | return obs, reward, done, info
251 |
252 | def _reset(self):
253 | """Reset only when lives are exhausted.
254 | This way all states are still reachable even though lives are episodic,
255 | and the learner need not know about any of this behind-the-scenes.
256 | """
257 | if self.was_real_done:
258 | obs = self.env.reset()
259 | self.was_real_reset = True
260 | else:
261 | # no-op step to advance from terminal/lost life state
262 | obs, _, _, _ = self.env.step(0)
263 | self.was_real_reset = False
264 | self.lives = self.env.unwrapped.ale.lives()
265 | return obs
266 |
267 | class MaxAndSkipEnv(gym.Wrapper):
268 | def __init__(self, env=None, skip=4):
269 | """Return only every `skip`-th frame"""
270 | super(MaxAndSkipEnv, self).__init__(env)
271 | # most recent raw observations (for max pooling across time steps)
272 | self._obs_buffer = deque(maxlen=2)
273 | self._skip = skip
274 |
275 | def _step(self, action):
276 | total_reward = 0.0
277 | done = None
278 | for _ in range(self._skip):
279 | obs, reward, done, info = self.env.step(action)
280 | self._obs_buffer.append(obs)
281 | total_reward += reward
282 | if done:
283 | break
284 |
285 | max_frame = np.max(np.stack(self._obs_buffer), axis=0)
286 |
287 | return max_frame, total_reward, done, info
288 |
289 | def _reset(self):
290 | """Clear past frame buffer and init. to first obs. from inner env."""
291 | self._obs_buffer.clear()
292 | obs = self.env.reset()
293 | self._obs_buffer.append(obs)
294 | return obs
295 |
296 | def _process_frame84(frame):
297 | img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
298 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
299 | img = Image.fromarray(img)
300 | resized_screen = img.resize((84, 110), Image.BILINEAR)
301 | resized_screen = np.array(resized_screen)
302 | x_t = resized_screen[18:102, :]
303 | x_t = np.reshape(x_t, [84, 84, 1])
304 | return x_t.astype(np.uint8)
305 |
306 | class ProcessFrame84(gym.Wrapper):
307 | def __init__(self, env=None):
308 | super(ProcessFrame84, self).__init__(env)
309 | self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
310 |
311 | def _step(self, action):
312 | obs, reward, done, info = self.env.step(action)
313 | return _process_frame84(obs), reward, done, info
314 |
315 | def _reset(self):
316 | return _process_frame84(self.env.reset())
317 |
318 | class ClippedRewardsWrapper(gym.Wrapper):
319 | def _step(self, action):
320 | obs, reward, done, info = self.env.step(action)
321 | return obs, np.sign(reward), done, info
322 |
323 | def wrap_deepmind_ram(env):
324 | env = EpisodicLifeEnv(env)
325 | env = NoopResetEnv(env, noop_max=30)
326 | env = MaxAndSkipEnv(env, skip=4)
327 | if 'FIRE' in env.unwrapped.get_action_meanings():
328 | env = FireResetEnv(env)
329 | env = ClippedRewardsWrapper(env)
330 | return env
331 |
332 | def wrap_deepmind(env):
333 | assert 'NoFrameskip' in env.spec.id
334 | env = EpisodicLifeEnv(env)
335 | env = NoopResetEnv(env, noop_max=30)
336 | env = MaxAndSkipEnv(env, skip=4)
337 | if 'FIRE' in env.unwrapped.get_action_meanings():
338 | env = FireResetEnv(env)
339 | env = ProcessFrame84(env)
340 | env = ClippedRewardsWrapper(env)
341 | return env
--------------------------------------------------------------------------------