├── .gitignore ├── README.md ├── cartpole_networks.py ├── cartpole_test.py ├── imgs └── 128_graph.png ├── neural_network.py ├── train_gan_q_learning.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | Testing.ipynb 2 | .ipynb_checkpoints/ 3 | __pycache__/ 4 | *.ipynb 5 | logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This code implements the "GAN Q-Learning" algorithm found in https://arxiv.org/abs/1805.04874. 2 | 3 | ## Modifications From Paper 4 | 5 | * The published algorithm has a typo in it (in the form of the discriminator loss) 6 | 7 | * Currently, there seems to be a situation which causes the discriminator to (eventually) perfectly discriminate against the generator (even before learning the actual distribution) on the cartpole environment. I've experimented with different hyperparamters, but this is definitely there. For example, even when I update the generater 10 times per discriminator update, the training graph is still as follows 8 | 9 | ![graph](imgs/128_graph.png) 10 | 11 | ## Final Results 12 | 13 | In the end, I was unable to reproduce the results given in the paper since my computer couldn't sweep enough hyperparameters. After verifying that the algorithm is correct, I found that the classic problems of training GANs arose. In particular, the discriminator easily overfit the reward distribution, meaning that the generator got stuck and the reward function couldn't learn. Even with significant artchitecture modifications, these problems persisted. 14 | -------------------------------------------------------------------------------- /cartpole_networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import neural_network as nn 4 | 5 | class Generator(nn.Generator): 6 | """ 7 | Example OpenAI-Gym Generator architecture. 8 | """ 9 | def __init__(self, sess): 10 | """ 11 | Args 12 | ---- 13 | sess : the tensorflow session to be used 14 | """ 15 | self.sess_ = sess 16 | with tf.variable_scope('gen'): 17 | self.input_state_ = tf.placeholder(tf.float32, shape=[None, 4], name='input_state') 18 | self.input_seed_ = tf.placeholder(tf.float32, shape=[None, 1], name='input_seed') 19 | self.concat = tf.concat([self.input_state_, self.input_seed_], 1, name='concat') 20 | self.hidden = tf.layers.dense(self.concat, 8, activation=tf.nn.relu, name='hidden') 21 | self.output_ = tf.layers.dense(self.hidden, 2, name='output') 22 | self.sess.run(tf.global_variables_initializer()) 23 | 24 | @property 25 | def input_state(self): 26 | """ 27 | The input state of shape [None, 4] 28 | 29 | Returns 30 | ------- 31 | A placeholder tensor: the input state's placeholder tensor 32 | """ 33 | return self.input_state_ 34 | 35 | @property 36 | def output(self): 37 | """ 38 | The outputted action distribution of shape [None, 2] 39 | 40 | Returns 41 | ------- 42 | A tensor: the output tensor 43 | """ 44 | return self.output_ 45 | 46 | @property 47 | def sess(self): 48 | """ 49 | The session used to create the graph 50 | 51 | Returns 52 | ------- 53 | A session: the graph's session 54 | """ 55 | return self.sess_ 56 | 57 | @property 58 | def input_seed(self): 59 | """ 60 | The input random seed 61 | 62 | Returns 63 | ------- 64 | A placeholder: the input seed's placeholder tensor 65 | """ 66 | return self.input_seed_ 67 | 68 | @property 69 | def trainable_variables(self): 70 | """ 71 | A list of the trainable variables in our generator 72 | 73 | Returns 74 | ------- 75 | A list of tensors: the trainable variables in this graph 76 | """ 77 | return tf.trainable_variables('gen') 78 | 79 | class Discriminator(nn.Discriminator): 80 | """ 81 | Example OpenAI-Gym Discriminator Architecture 82 | """ 83 | def __init__(self, sess): 84 | """ 85 | Args 86 | ---- 87 | sess : the tensorflow session to be used 88 | """ 89 | self.sess_ = sess 90 | with tf.variable_scope('dis'): 91 | self.input_state_ = tf.placeholder(tf.float32, shape=[None, 4], name='input_state') 92 | self.input_reward_ = tf.placeholder(tf.float32, shape=[None], name='input_reward') 93 | self.input_action_ = tf.placeholder(tf.float32, shape=[None, 1], name='input_action') 94 | self.input_reward_exp = tf.expand_dims(self.input_reward_, axis=-1, name='input_reward_expanded') 95 | self.concat = tf.concat([self.input_state_, self.input_reward_exp, self.input_action_], axis=1, name='concat') 96 | self.hidden = tf.layers.dense(self.concat, 8, activation=tf.nn.relu, name='hidden') 97 | self.output_ = tf.layers.dense(self.hidden, 1, activation=tf.sigmoid, name='output') 98 | self.sess.run(tf.global_variables_initializer()) 99 | 100 | @property 101 | def input_state(self): 102 | """ 103 | The input state of shape [None, 4] 104 | 105 | Returns 106 | ------- 107 | A placeholder tensor: the input state's placeholder tensor 108 | """ 109 | return self.input_state_ 110 | 111 | @property 112 | def input_action(self): 113 | """ 114 | The input action of shape [None, 1] 115 | 116 | Returns 117 | ------- 118 | A placeholder tensor: the input action's placeholder tensor 119 | """ 120 | return self.input_action_ 121 | 122 | @property 123 | def output(self): 124 | """ 125 | The probability output of shape [None, 1] 126 | 127 | Returns 128 | ------- 129 | A tensor: the output's tensor 130 | """ 131 | return self.output_ 132 | 133 | @property 134 | def sess(self): 135 | """ 136 | The session used to create a graph 137 | 138 | Returns 139 | ------- 140 | A session: the graph's session 141 | """ 142 | return self.sess_ 143 | 144 | @property 145 | def input_reward(self): 146 | """ 147 | The input reward 148 | 149 | Returns 150 | ------- 151 | A placeholder tensor: the input reward's tensor 152 | """ 153 | return self.input_reward_ 154 | 155 | @property 156 | def trainable_variables(self): 157 | """ 158 | A list of the trainable variables in our generator 159 | 160 | Returns 161 | ------- 162 | A list of tensors: the trainable variables in this graph 163 | """ 164 | return tf.trainable_variables('dis') 165 | 166 | class Discriminator_copy(nn.Discriminator_copy): 167 | """ 168 | Example OpenAI-Gym Discriminator Copying method 169 | """ 170 | def __init__(self, dis, new_rew_input): 171 | """ 172 | Initializes a discriminator_copy object 173 | 174 | Args 175 | ---- 176 | dis (Discriminator) : The discriminator to copy 177 | new_rew_input (tf.placeholder) : a new reward input. 178 | """ 179 | self.sess_ = dis.sess 180 | 181 | #reuse the variables 182 | with tf.variable_scope('dis', reuse=tf.AUTO_REUSE): 183 | self.input_state_ = tf.placeholder(tf.float32, shape=[None, 4], name='input_state') 184 | self.input_reward_ = new_rew_input 185 | self.input_action_ = tf.placeholder(tf.float32, shape=[None, 1], name='input_action') 186 | self.input_reward_exp = tf.expand_dims(self.input_reward_, axis=-1, name='input_reward_expanded') 187 | self.concat = tf.concat([self.input_state_, self.input_reward_exp, self.input_action_], axis=1, name='concat_copy') 188 | self.hidden_ker = tf.get_variable('hidden/kernel') 189 | self.hidden_bias = tf.get_variable('hidden/bias') 190 | self.output_ker = tf.get_variable('output/kernel') 191 | self.output_bias = tf.get_variable('output/bias') 192 | 193 | self.hidden = tf.matmul(self.concat, self.hidden_ker) + self.hidden_bias 194 | self.output_ = tf.sigmoid(tf.matmul(self.hidden, self.output_ker) + self.output_bias) 195 | 196 | @property 197 | def input_state(self): 198 | """ 199 | The input state of shape [None, 4] 200 | 201 | Returns 202 | ------- 203 | A placeholder tensor: the input state's placeholder tensor 204 | """ 205 | return self.input_state_ 206 | 207 | @property 208 | def input_action(self): 209 | """ 210 | The input action of shape [None, 1] 211 | 212 | Returns 213 | ------- 214 | A placeholder tensor: the input action's placeholder tensor 215 | """ 216 | return self.input_action_ 217 | 218 | @property 219 | def output(self): 220 | """ 221 | The probability output of shape [None, 1] 222 | 223 | Returns 224 | ------- 225 | A tensor: the output's tensor 226 | """ 227 | return self.output_ 228 | 229 | @property 230 | def sess(self): 231 | """ 232 | The session used to create a graph 233 | 234 | Returns 235 | ------- 236 | A session: the graph's session 237 | """ 238 | return self.sess_ 239 | 240 | @property 241 | def input_reward(self): 242 | """ 243 | The input reward 244 | 245 | Returns 246 | ------- 247 | A placeholder tensor: the input reward's tensor 248 | """ 249 | return self.input_reward_ 250 | 251 | @property 252 | def trainable_variables(self): 253 | """ 254 | A list of the trainable variables in our generator 255 | 256 | Returns 257 | ------- 258 | A list of tensors: the trainable variables in this graph 259 | """ 260 | return tf.trainable_variables('dis') 261 | -------------------------------------------------------------------------------- /cartpole_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import train_gan_q_learning as train 3 | import cartpole_networks as networks 4 | import gym 5 | 6 | def main(): 7 | sess = tf.Session() 8 | gen = networks.Generator(sess) 9 | dis = networks.Discriminator(sess) 10 | dis_copy = networks.Discriminator_copy 11 | 12 | env = gym.make('CartPole-v0') 13 | train.learn(env, 14 | sess, 15 | 1000, 16 | 10000, 17 | 0.99, 18 | dis, 19 | dis_copy, 20 | gen, 21 | n_gen=5, 22 | log_dir='./logs/') 23 | 24 | if __name__ == '__main__' : main() 25 | -------------------------------------------------------------------------------- /imgs/128_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louaaron/GAN-Q-Learning/73ea79840c69df4bfeb78cfe75861d4b60a72ce7/imgs/128_graph.png -------------------------------------------------------------------------------- /neural_network.py: -------------------------------------------------------------------------------- 1 | class Generator(object): 2 | """ 3 | Interface for a generator. The generator should take in 4 | a state and random seed and outputs a reward distrbution 5 | over actions 6 | """ 7 | 8 | @property 9 | def input_state(self): 10 | """ 11 | The input state 12 | 13 | Returns 14 | ------- 15 | A placeholder tensor: the input state's placeholder tensor 16 | """ 17 | pass 18 | 19 | @property 20 | def output(self): 21 | """ 22 | The outputted action distribution 23 | 24 | Returns 25 | ------- 26 | A tensor: the output tensor 27 | """ 28 | pass 29 | 30 | @property 31 | def sess(self): 32 | """ 33 | The session used to create the graph 34 | 35 | Returns 36 | ------- 37 | A session: the graph's session 38 | """ 39 | pass 40 | 41 | @property 42 | def input_seed(self): 43 | """ 44 | The input random seed 45 | 46 | Returns 47 | ------- 48 | A placeholder: the input seed's placeholder tensor 49 | """ 50 | pass 51 | 52 | @property 53 | def trainable_variables(self): 54 | """ 55 | A list of the trainable variables in our generator 56 | 57 | Returns 58 | ------- 59 | A list of tensors: the trainable variables in this graph 60 | """ 61 | pass 62 | 63 | class Discriminator(object): 64 | """ 65 | Interface for a discriminator. The discriminator should take in 66 | a state, action, and expected reward and return a probability 67 | value 68 | """ 69 | 70 | @property 71 | def input_state(self): 72 | """ 73 | The input state 74 | 75 | Returns 76 | ------- 77 | A placeholder tensor: the input state's placeholder tensor 78 | """ 79 | pass 80 | 81 | @property 82 | def input_action(self): 83 | """ 84 | The input action 85 | 86 | Returns 87 | ------- 88 | A placeholder tensor: the input action's placeholder tensor 89 | """ 90 | pass 91 | 92 | @property 93 | def output(self): 94 | """ 95 | The probability output 96 | 97 | Returns 98 | ------- 99 | A tensor: the output's tensor 100 | """ 101 | pass 102 | 103 | @property 104 | def sess(self): 105 | """ 106 | The session used to create a graph 107 | 108 | Returns 109 | ------- 110 | A session: the graph's session 111 | """ 112 | pass 113 | 114 | @property 115 | def input_reward(self): 116 | """ 117 | The input reward 118 | 119 | Returns 120 | ------- 121 | A placeholder tensor: the input reward's tensor 122 | """ 123 | pass 124 | 125 | @property 126 | def trainable_variables(self): 127 | """ 128 | A list of the trainable variables in our generator 129 | 130 | Returns 131 | ------- 132 | A list of tensors: the trainable variables in this graph 133 | """ 134 | pass 135 | 136 | class Discriminator_copy(object): 137 | """ 138 | Interface for copying a discriminator (used for Loss function). 139 | The discriminator_copy object should be initialized by a discriminator 140 | and a new reward placeholder. This new discriminator should share weights 141 | and other variables with the original dis, but should be run on the 142 | new_rew_input. 143 | """ 144 | 145 | def __init__(self, dis, new_rew_input): 146 | """ 147 | Initializes a discriminator_copy object 148 | 149 | Args 150 | ---- 151 | dis (Discriminator) : The discriminator to copy 152 | new_rew_input (tf.placeholder) : a new reward input. 153 | """ 154 | pass 155 | 156 | @property 157 | def input_state(self): 158 | """ 159 | The input state 160 | 161 | Returns 162 | ------- 163 | A placeholder tensor: the input state's placeholder tensor 164 | """ 165 | pass 166 | 167 | @property 168 | def input_action(self): 169 | """ 170 | The input action 171 | 172 | Returns 173 | ------- 174 | A placeholder tensor: the input action's placeholder tensor 175 | """ 176 | pass 177 | 178 | @property 179 | def output(self): 180 | """ 181 | The outputted action distribution 182 | 183 | Returns 184 | ------- 185 | A tensor: the output tensor 186 | """ 187 | pass 188 | 189 | @property 190 | def sess(self): 191 | """ 192 | The session used to create a graph 193 | 194 | Returns 195 | ------- 196 | A session: the graph's session 197 | """ 198 | pass 199 | 200 | @property 201 | def trainable_variables(self): 202 | """ 203 | A list of the trainable variables in our generator 204 | 205 | Returns 206 | ------- 207 | A list of tensors: the trainable variables in this graph 208 | """ 209 | pass -------------------------------------------------------------------------------- /train_gan_q_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import utils 4 | import gym 5 | 6 | def learn(env, 7 | sess, 8 | episodes, 9 | buffer_size, 10 | reward_discount, 11 | dis, 12 | dis_copy, 13 | gen, 14 | learning_rate=0.0005, 15 | optimizer=tf.train.RMSPropOptimizer, 16 | n_dis=1, 17 | n_gen=1, 18 | lambda_=0, 19 | batch_size=64, 20 | log_dir=None): 21 | """ 22 | Code for the algorithm found in https://arxiv.org/abs/1805.04874 23 | GAN Q-Learning learns a probaility distrubtion for Z(s, a), the distributional 24 | value function (Q(s, a) is the case when the distrubtion is singular). 25 | 26 | Note that the algorithm described in figure 1 or the paper has some typos, which 27 | are corrected here. 28 | 29 | Args 30 | ---- 31 | env (gym.env) : 32 | The environment for training 33 | sess (int) : 34 | The session of both the discriminator and generator 35 | episodes (int) : 36 | The number of episodes to train the algorithm on 37 | buffer_size (int) : 38 | The size of the buffer 39 | reward_discount (float) : 40 | The amount of future reward to consider 41 | dis (neural_network.Discriminator) : 42 | The architecture of the discriminator 43 | dis_copy (neural_network.Discriminator_copy) : 44 | The architecture of the discriminator copier 45 | gen (neural_network.Generator) : 46 | The architecure of the generator 47 | learning_rate (float - 0.0005) : 48 | The learning rate 49 | optimizer (tf.train.Optimizer - tf.train.RMSPropOptimizer) : 50 | The optimization initialization function 51 | n_dis (int - 1) : 52 | The number of discriminator updates per episode 53 | n_gen (int - 1) : 54 | The number of generator updates per episode 55 | lambda_ (float - 0) : 56 | The gradient penalty coefficient (0 for WGAN optimization) 57 | batch_size (int - 64) : 58 | The batch_size for training 59 | log_dir (str - None) : 60 | writer output directory if not None 61 | """ 62 | z_shape = gen.input_seed.get_shape().as_list()[1:] 63 | 64 | #Assertion statements (make sure session remains the same across graphs) 65 | assert sess == dis.sess 66 | assert sess == gen.sess 67 | 68 | #Reset environment 69 | last_obs = env.reset() 70 | 71 | #The gradient for loss function 72 | grad_val_ph = tf.placeholder(tf.float32, shape=dis.input_reward.get_shape()) 73 | grad_dis = dis_copy(dis, grad_val_ph) 74 | 75 | #The generator-discriminator for loss function 76 | gen_dis = dis_copy(dis, tf.reduce_max(gen.output, axis=1)) 77 | 78 | #loss functions 79 | dis_loss = tf.reduce_mean(tf.squeeze(gen_dis.output)) - tf.reduce_mean(tf.squeeze(dis.output)) \ 80 | + lambda_ * tf.reduce_mean(tf.square(tf.gradients(grad_dis.output, grad_val_ph)[0] - 1)) 81 | 82 | gen_loss = tf.reduce_mean(-tf.squeeze(gen_dis.output)) 83 | 84 | #optimization 85 | optim = optimizer(learning_rate=learning_rate) 86 | dis_min_op = optim.minimize(dis_loss, var_list=dis.trainable_variables) 87 | gen_min_op = optim.minimize(gen_loss, var_list=gen.trainable_variables) 88 | 89 | #buffer 90 | buffer = utils.ReplayBuffer(buffer_size, 1) 91 | 92 | #writer (optional) 93 | if log_dir is not None: 94 | writer = tf.summary.FileWriter(log_dir) 95 | dis_summ = tf.summary.scalar('discriminator loss', dis_loss) 96 | gen_summ = tf.summary.scalar('generator loss', gen_loss) 97 | rew_ph = tf.placeholder(tf.int32, shape=()) 98 | rew_summ = tf.summary.scalar('average reward', rew_ph) 99 | else: 100 | writer = None 101 | 102 | #initialize all vars 103 | sess.run(tf.global_variables_initializer()) 104 | 105 | #training algorithm 106 | 107 | #trackers for writer 108 | rew_tracker = 0 109 | dis_tracker = 0 110 | gen_tracker = 0 111 | 112 | #debug print 113 | print(dis.trainable_variables) 114 | print(gen.trainable_variables) 115 | 116 | #number of episodes to train 117 | for _ in range(episodes): 118 | #loop through all the steps 119 | rew_agg = 0 120 | for _ in range(env._max_episode_steps): 121 | gen_seed = np.random.normal(0, 1, size=z_shape) 122 | action_results = sess.run(gen.output, feed_dict={ 123 | gen.input_state : np.array([last_obs]), 124 | gen.input_seed : np.array([gen_seed]) 125 | })[0] 126 | optimal_action = np.argmax(action_results) 127 | 128 | next_obs, reward, done, _ = env.step(optimal_action) 129 | rew_agg += reward 130 | idx = buffer.store_frame(last_obs) 131 | buffer.store_effect(idx, optimal_action, reward, done) 132 | 133 | if done: 134 | if writer is not None: 135 | rew_writer = sess.run(rew_summ, feed_dict={rew_ph : rew_agg}) 136 | writer.add_summary(rew_writer, rew_tracker) 137 | rew_tracker += 1 138 | rew_agg = 0 139 | last_obs = env.reset() 140 | else: 141 | last_obs = next_obs 142 | 143 | if not buffer.can_sample(batch_size): 144 | continue 145 | 146 | #update discriminator n_dis times 147 | for _ in range(n_dis): 148 | obs_batch, act_batch, rew_batch, next_obs_batch, done_batch = ( 149 | buffer.sample(batch_size) 150 | ) 151 | batch_z = np.random.normal(0, 1, size=[batch_size] + z_shape) 152 | batch_y = [] 153 | for i in range(batch_size): 154 | if done_batch[i]: 155 | batch_y.append(rew_batch[i]) 156 | else: 157 | expected_ar = sess.run(gen.output, feed_dict={ 158 | gen.input_state : np.array([obs_batch[i]]), 159 | gen.input_seed: np.array([batch_z[i]]) 160 | }) 161 | future_reward = np.max(expected_ar) 162 | batch_y.append(rew_batch[i] + reward_discount * future_reward) 163 | batch_y = np.array(batch_y) 164 | epsilons = np.random.uniform(0, 1, batch_size) 165 | predict_x = [] 166 | for i in range(batch_size): 167 | predict_x.append(epsilons[i] * batch_y[i] + (1 - epsilons[i]) * 168 | np.max(sess.run(gen.output, feed_dict={ 169 | gen.input_state : np.array([obs_batch[i]]), 170 | gen.input_seed : np.array([batch_z[i]])}))) 171 | predict_x = np.array(predict_x) 172 | act_batch = np.expand_dims(act_batch, -1) 173 | 174 | sess.run(dis_min_op, feed_dict={ 175 | gen.input_seed : batch_z, 176 | gen.input_state : obs_batch, 177 | gen_dis.input_state : obs_batch, 178 | gen_dis.input_action : act_batch, 179 | dis.input_reward : batch_y, 180 | dis.input_state : obs_batch, 181 | dis.input_action : act_batch, 182 | grad_dis.input_state : obs_batch, 183 | grad_dis.input_action : act_batch, 184 | grad_val_ph : predict_x 185 | }) 186 | 187 | if writer is not None: 188 | dis_writer = sess.run(dis_summ, feed_dict={ 189 | gen.input_seed : batch_z, 190 | gen.input_state : obs_batch, 191 | gen_dis.input_state : obs_batch, 192 | gen_dis.input_action : act_batch, 193 | dis.input_reward : batch_y, 194 | dis.input_state : obs_batch, 195 | dis.input_action : act_batch, 196 | grad_dis.input_state : obs_batch, 197 | grad_dis.input_action : act_batch, 198 | grad_val_ph : predict_x 199 | }) 200 | writer.add_summary(dis_writer, dis_tracker) 201 | dis_tracker += 1 202 | 203 | #update the generator n_gen times 204 | for _ in range(n_gen): 205 | obs_batch, act_batch, _, _, _ = (buffer.sample(batch_size)) 206 | batch_z = np.random.normal(0, 1, size=[batch_size] + z_shape) 207 | act_batch = np.expand_dims(act_batch, -1) 208 | sess.run(gen_min_op, feed_dict={ 209 | gen.input_seed : batch_z, 210 | gen.input_state : obs_batch, 211 | gen_dis.input_state : obs_batch, 212 | gen_dis.input_action: act_batch 213 | }) 214 | 215 | if writer is not None: 216 | gen_writer = sess.run(gen_summ, feed_dict={ 217 | gen.input_seed : batch_z, 218 | gen.input_state : obs_batch, 219 | gen_dis.input_state : obs_batch, 220 | gen_dis.input_action: act_batch 221 | }) 222 | writer.add_summary(gen_writer, gen_tracker) 223 | gen_tracker += 1 224 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | 5 | # Borrowed from Berkeley's CS 294 Hw 3 6 | # https://github.com/berkeleydeeprlcourse/homework/tree/master/ 7 | 8 | def sample_n_unique(sampling_f, n): 9 | """Helper function. Given a function `sampling_f` that returns 10 | comparable objects, sample n such unique objects. 11 | """ 12 | res = [] 13 | while len(res) < n: 14 | candidate = sampling_f() 15 | if candidate not in res: 16 | res.append(candidate) 17 | return res 18 | 19 | class ReplayBuffer(object): 20 | def __init__(self, size, frame_history_len): 21 | """This is a memory efficient implementation of the replay buffer. 22 | 23 | The sepecific memory optimizations use here are: 24 | - only store each frame once rather than k times 25 | even if every observation normally consists of k last frames 26 | - store frames as np.uint8 (actually it is most time-performance 27 | to cast them back to float32 on GPU to minimize memory transfer 28 | time) 29 | - store frame_t and frame_(t+1) in the same buffer. 30 | 31 | For the typical use case in Atari Deep RL buffer with 1M frames the total 32 | memory footprint of this buffer is 10^6 * 84 * 84 bytes ~= 7 gigabytes 33 | 34 | Warning! Assumes that returning frame of zeros at the beginning 35 | of the episode, when there is less frames than `frame_history_len`, 36 | is acceptable. 37 | 38 | Parameters 39 | ---------- 40 | size: int 41 | Max number of transitions to store in the buffer. When the buffer 42 | overflows the old memories are dropped. 43 | frame_history_len: int 44 | Number of memories to be retried for each observation. 45 | """ 46 | self.size = size 47 | self.frame_history_len = frame_history_len 48 | 49 | self.next_idx = 0 50 | self.num_in_buffer = 0 51 | 52 | self.obs = None 53 | self.action = None 54 | self.reward = None 55 | self.done = None 56 | 57 | def can_sample(self, batch_size): 58 | """Returns true if `batch_size` different transitions can be sampled from the buffer.""" 59 | return batch_size + 1 <= self.num_in_buffer 60 | 61 | def _encode_sample(self, idxes): 62 | obs_batch = np.concatenate([self._encode_observation(idx)[None] for idx in idxes], 0) 63 | act_batch = self.action[idxes] 64 | rew_batch = self.reward[idxes] 65 | next_obs_batch = np.concatenate([self._encode_observation(idx + 1)[None] for idx in idxes], 0) 66 | done_mask = np.array([1.0 if self.done[idx] else 0.0 for idx in idxes], dtype=np.float32) 67 | 68 | return obs_batch, act_batch, rew_batch, next_obs_batch, done_mask 69 | 70 | 71 | def sample(self, batch_size): 72 | """Sample `batch_size` different transitions. 73 | 74 | i-th sample transition is the following: 75 | 76 | when observing `obs_batch[i]`, action `act_batch[i]` was taken, 77 | after which reward `rew_batch[i]` was received and subsequent 78 | observation next_obs_batch[i] was observed, unless the epsiode 79 | was done which is represented by `done_mask[i]` which is equal 80 | to 1 if episode has ended as a result of that action. 81 | 82 | Parameters 83 | ---------- 84 | batch_size: int 85 | How many transitions to sample. 86 | 87 | Returns 88 | ------- 89 | obs_batch: np.array 90 | Array of shape 91 | (batch_size, img_h, img_w, img_c * frame_history_len) 92 | and dtype np.uint8 93 | act_batch: np.array 94 | Array of shape (batch_size,) and dtype np.int32 95 | rew_batch: np.array 96 | Array of shape (batch_size,) and dtype np.float32 97 | next_obs_batch: np.array 98 | Array of shape 99 | (batch_size, img_h, img_w, img_c * frame_history_len) 100 | and dtype np.uint8 101 | done_mask: np.array 102 | Array of shape (batch_size,) and dtype np.float32 103 | """ 104 | assert self.can_sample(batch_size) 105 | idxes = sample_n_unique(lambda: random.randint(0, self.num_in_buffer - 2), batch_size) 106 | return self._encode_sample(idxes) 107 | 108 | def encode_recent_observation(self): 109 | """Return the most recent `frame_history_len` frames. 110 | 111 | Returns 112 | ------- 113 | observation: np.array 114 | Array of shape (img_h, img_w, img_c * frame_history_len) 115 | and dtype np.uint8, where observation[:, :, i*img_c:(i+1)*img_c] 116 | encodes frame at time `t - frame_history_len + i` 117 | """ 118 | assert self.num_in_buffer > 0 119 | return self._encode_observation((self.next_idx - 1) % self.size) 120 | 121 | def _encode_observation(self, idx): 122 | end_idx = idx + 1 # make noninclusive 123 | start_idx = end_idx - self.frame_history_len 124 | # this checks if we are using low-dimensional observations, such as RAM 125 | # state, in which case we just directly return the latest RAM. 126 | if len(self.obs.shape) == 2: 127 | return self.obs[end_idx-1] 128 | # if there weren't enough frames ever in the buffer for context 129 | if start_idx < 0 and self.num_in_buffer != self.size: 130 | start_idx = 0 131 | for idx in range(start_idx, end_idx - 1): 132 | if self.done[idx % self.size]: 133 | start_idx = idx + 1 134 | missing_context = self.frame_history_len - (end_idx - start_idx) 135 | # if zero padding is needed for missing context 136 | # or we are on the boundry of the buffer 137 | if start_idx < 0 or missing_context > 0: 138 | frames = [np.zeros_like(self.obs[0]) for _ in range(missing_context)] 139 | for idx in range(start_idx, end_idx): 140 | frames.append(self.obs[idx % self.size]) 141 | return np.concatenate(frames, 2) 142 | else: 143 | # this optimization has potential to saves about 30% compute time \o/ 144 | img_h, img_w = self.obs.shape[1], self.obs.shape[2] 145 | return self.obs[start_idx:end_idx].transpose(1, 2, 0, 3).reshape(img_h, img_w, -1) 146 | 147 | def store_frame(self, frame): 148 | """Store a single frame in the buffer at the next available index, overwriting 149 | old frames if necessary. 150 | 151 | Parameters 152 | ---------- 153 | frame: np.array 154 | Array of shape (img_h, img_w, img_c) and dtype np.uint8 155 | the frame to be stored 156 | 157 | Returns 158 | ------- 159 | idx: int 160 | Index at which the frame is stored. To be used for `store_effect` later. 161 | """ 162 | if self.obs is None: 163 | self.obs = np.empty([self.size] + list(frame.shape), dtype=np.uint8) 164 | self.action = np.empty([self.size], dtype=np.int32) 165 | self.reward = np.empty([self.size], dtype=np.float32) 166 | self.done = np.empty([self.size], dtype=np.bool) 167 | self.obs[self.next_idx] = frame 168 | 169 | ret = self.next_idx 170 | self.next_idx = (self.next_idx + 1) % self.size 171 | self.num_in_buffer = min(self.size, self.num_in_buffer + 1) 172 | 173 | return ret 174 | 175 | def store_effect(self, idx, action, reward, done): 176 | """Store effects of action taken after obeserving frame stored 177 | at index idx. The reason `store_frame` and `store_effect` is broken 178 | up into two functions is so that once can call `encode_recent_observation` 179 | in between. 180 | 181 | Paramters 182 | --------- 183 | idx: int 184 | Index in buffer of recently observed frame (returned by `store_frame`). 185 | action: int 186 | Action that was performed upon observing this frame. 187 | reward: float 188 | Reward that was received when the actions was performed. 189 | done: bool 190 | True if episode was finished after performing that action. 191 | """ 192 | self.action[idx] = action 193 | self.reward[idx] = reward 194 | self.done[idx] = done --------------------------------------------------------------------------------