├── .gitignore ├── README.md ├── atari_breakout ├── agents │ ├── DQN.py │ ├── Double.py │ ├── Dueling.py │ ├── layers.py │ ├── losses.py │ └── memories.py ├── arguments.py ├── logger.py ├── main.py └── visualization.py ├── atari_breakout_run.sh └── images └── atari-breakout-D3QN.gif /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .python-version 4 | .vscode/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-only-RL 2 | 3 | Implementating Reinforcement Learning from A to Z using keras only. 4 | 5 | 14 | 15 | # How to run 16 | 17 | ## For example: Run Atari Breakout with D3QN Agent 18 | 19 | ```bash 20 | sh atari_breakout_run.sh --double=True --dueling=True 21 | ``` 22 | 23 | ![demo](./images/atari-breakout-D3QN.gif) 24 | 25 | ## Help 26 | 27 | ```bash 28 | sh atari_breakout_run.sh -h 29 | ``` 30 | ``` 31 | usage: main.py [-h] [--e E] [--double D] [--dueling B] 32 | 33 | Some hyperparameters 34 | 35 | optional arguments: 36 | -h, --help show this help message and exit 37 | --e E Total episodes 38 | --double D Enable Double DQN 39 | --dueling B Enable Dueling DQN 40 | ``` 41 | 42 | # To the Rainbow 43 | 44 | | Technique | Problem | How to solve it | 45 | | --- | --- | --- | 46 | | DQN | Non-stationary targets makes learning unstable | Fixed Q-targets | 47 | | | Correlation between samples makes W biased | Replay Memory | 48 | | Double ~ | Maximum estimator raises over-estimation | Using double estimators | 49 | | Dueling ~ | Some state may have inherently low value | *Q(s, a) = V(s) + A(s, a)* | 50 | 51 | 71 | -------------------------------------------------------------------------------- /atari_breakout/agents/DQN.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Input, Dense, Conv2D, Flatten 2 | from keras.models import Model 3 | from keras.optimizers import RMSprop 4 | import numpy as np 5 | import random 6 | 7 | from agents.memories import ReplayMemory 8 | from agents.losses import Huber_loss 9 | 10 | 11 | class DQNAgent: 12 | def __init__(self, action_size): 13 | self.render = True 14 | self.load_model_dir = False 15 | 16 | # Environment settings 17 | self.state_size = (84, 84, 4) 18 | self.action_size = action_size 19 | 20 | # Epsilon 21 | self.epsilon = 1. # Current e 22 | self.epsilon_start, self.epsilon_end = 1.0, 0.1 23 | self.exploration_steps = 1000000. 24 | self.epsilon_decay = (self.epsilon_start - self.epsilon_end) / self.exploration_steps 25 | 26 | # Training 27 | self.no_op_steps = 30 28 | self.batch_size = 32 29 | self.train_start = 50000 30 | self.update_target_rate = 10000 31 | self.discount_factor = 0.99 32 | self.memory = ReplayMemory(400000) 33 | 34 | # Build model 35 | self.optimizer = RMSprop(lr=0.00025, epsilon=0.01) 36 | self.model_for_train, self.model = self._build_model() 37 | _, self.target_model = self._build_model() 38 | self.update_target_model() 39 | if self.load_model_dir: 40 | # TODO: mkdir 41 | self.load_model("./save_model/breakout.h5") 42 | 43 | # for logging 44 | # TODO: Tensorboard 45 | self.avg_q_max, self.avg_loss = 0.0, 0.0 46 | 47 | def _build_model(self): 48 | x = Input(shape=self.state_size, name='input') 49 | h = Conv2D(32, (8, 8), strides=(4, 4), activation='relu')(x) 50 | h = Conv2D(64, (4, 4), strides=(2, 2), activation='relu')(h) 51 | h = Conv2D(64, (3, 3), strides=(1, 1), activation='relu')(h) 52 | h = Flatten()(h) 53 | h = Dense(512, activation='relu')(h) 54 | y_pred = Dense(self.action_size, activation='linear', name='y_pred')(h) 55 | 56 | # for custom loss function 57 | y_true = Input(shape=(self.action_size, ), name='y_true') 58 | model_for_train = Model( 59 | inputs=[x, y_true], 60 | outputs=y_pred, 61 | name='model_for_train' 62 | ) 63 | model_for_train.add_loss(Huber_loss(y_true, y_pred)) 64 | model_for_train.compile(loss=None, optimizer=self.optimizer) 65 | 66 | model_for_using = Model( 67 | inputs=x, 68 | outputs=y_pred, 69 | name='model_for_using' 70 | ) 71 | 72 | model_for_using.summary() 73 | return model_for_train, model_for_using 74 | 75 | def update_target_model(self): 76 | self.target_model.set_weights(self.model.get_weights()) 77 | 78 | def get_action(self, state): 79 | if np.random.rand() <= self.epsilon: 80 | return random.randrange(self.action_size) 81 | else: 82 | state = np.float32(state / 255.) 83 | q_value = self.model.predict(state)[0] 84 | return np.argmax(q_value) 85 | 86 | def append_sample(self, state, action, reward, next_state, done): 87 | self.memory.push(state, action, reward, next_state, done) 88 | 89 | def _decaying(self): 90 | if self.epsilon > self.epsilon_end: 91 | self.epsilon -= self.epsilon_decay 92 | 93 | def learn(self): 94 | if len(self.memory) < self.train_start: 95 | return 96 | 97 | self._decaying() 98 | 99 | batch_size = min(len(self.memory), self.batch_size) 100 | 101 | states = np.zeros((batch_size, self.state_size[0], self.state_size[1], self.state_size[2])) 102 | next_states = np.zeros((batch_size, self.state_size[0], self.state_size[1], self.state_size[2])) 103 | actions, rewards, dones = [], [], [] # Actually, not done but dead. 104 | 105 | experiences = self.memory.sample(batch_size) 106 | for i, experience in enumerate(experiences): 107 | # 'State', 'Action', 'Reward', 'Next_state', 'Done' 108 | states[i] = np.float32(experience[0] / 255.) 109 | actions.append(experience[1]) 110 | rewards.append(experience[2]) 111 | next_states[i] = np.float32(experience[3] / 255.) 112 | dones.append(experience[4]) 113 | 114 | target_values = self.target_model.predict(next_states) 115 | targets = np.zeros((batch_size, self.action_size, )) 116 | for i in range(batch_size): 117 | targets[i] = target_values[i] 118 | action = actions[i] 119 | if dones[i]: 120 | targets[i][action] = rewards[i] 121 | else: 122 | targets[i][action] = rewards[i] + self.discount_factor * np.amax(target_values[i]) 123 | 124 | metrics = self.model_for_train.fit( 125 | [states, targets], 126 | batch_size=batch_size, 127 | epochs=1, 128 | verbose=0 129 | ) 130 | 131 | self.avg_loss += metrics.history['loss'][0] 132 | 133 | def save_model(self, _dir): 134 | self.model.save_weights(_dir) 135 | 136 | def load_model(self, _dir): 137 | self.model.load_weights(_dir) 138 | -------------------------------------------------------------------------------- /atari_breakout/agents/Double.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from agents.DQN import DQNAgent 4 | 5 | 6 | class DDQNAgent(DQNAgent): 7 | def learn(self): 8 | if len(self.memory) < self.train_start: 9 | return 10 | 11 | self._decaying() 12 | 13 | batch_size = min(len(self.memory), self.batch_size) 14 | 15 | states = np.zeros((batch_size, self.state_size[0], self.state_size[1], self.state_size[2])) 16 | next_states = np.zeros((batch_size, self.state_size[0], self.state_size[1], self.state_size[2])) 17 | actions, rewards, dones = [], [], [] # Actually, not done but dead. 18 | 19 | experiences = self.memory.sample(batch_size) 20 | for i, experience in enumerate(experiences): 21 | # 'State', 'Action', 'Reward', 'Next_state', 'Done' 22 | states[i] = np.float32(experience[0] / 255.) 23 | actions.append(experience[1]) 24 | rewards.append(experience[2]) 25 | next_states[i] = np.float32(experience[3] / 255.) 26 | dones.append(experience[4]) 27 | 28 | values = self.model.predict(next_states) # DDQN 29 | target_values = self.target_model.predict(next_states) 30 | targets = np.zeros((batch_size, self.action_size, )) 31 | for i in range(batch_size): 32 | targets[i] = target_values[i] 33 | action = actions[i] 34 | if dones[i]: 35 | targets[i][action] = rewards[i] 36 | else: 37 | selected_action = np.argmax(values[i]) 38 | targets[i][action] = rewards[i] + self.discount_factor * targets[i][selected_action] 39 | 40 | metrics = self.model_for_train.fit( 41 | [states, targets], 42 | batch_size=batch_size, 43 | epochs=1, 44 | verbose=0 45 | ) 46 | 47 | self.avg_loss += metrics.history['loss'][0] 48 | -------------------------------------------------------------------------------- /atari_breakout/agents/Dueling.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.layers import Input, Dense, Conv2D, Flatten, Lambda, Add 3 | from keras.models import Model 4 | 5 | from agents.losses import Huber_loss 6 | 7 | from agents.DQN import DQNAgent 8 | from agents.Double import DDQNAgent 9 | 10 | 11 | class DuelingAgent(DQNAgent): 12 | def _build_model(self): 13 | x = Input(shape=self.state_size, name='input') 14 | h = Conv2D(32, (8, 8), strides=(4, 4), activation='relu')(x) 15 | h = Conv2D(64, (4, 4), strides=(2, 2), activation='relu')(h) 16 | h = Conv2D(64, (3, 3), strides=(1, 1), activation='relu')(h) 17 | flatten = Flatten()(h) 18 | 19 | """ 20 | # Directly summing V and A gives us no guarantees that the A will actually predict the V. 21 | # Instead we combine them with criterion - Max or Avg 22 | # This(Avg) loses the original semantics of V and A (c.f. Max) . 23 | # But on the other hand it increases the stability of the optimization. 24 | # Ref: Dueling Network Architectures for Deep Reinforcement Learning 25 | """ 26 | # Advantage function 27 | A = Dense(512, activation='relu')(flatten) 28 | A = Dense(self.action_size, activation='linear', name='advanced')(A) 29 | A = Lambda( 30 | lambda a: a[:, :] - K.mean(a[:, :], keepdims=True), 31 | output_shape=(self.action_size, ) 32 | )(A) 33 | 34 | # Value function 35 | V = Dense(512, activation='relu')(flatten) 36 | V = Dense(1, activation='linear', name='value')(V) 37 | V = Lambda( 38 | lambda v: K.expand_dims(v[:, 0], -1), 39 | output_shape=(self.action_size, ) 40 | )(V) 41 | 42 | y_pred = Add()([V, A]) # tensor shape broadcasting 43 | 44 | # for custom loss function 45 | y_true = Input(shape=(self.action_size, ), name='y_true') 46 | model_for_train = Model( 47 | inputs=[x, y_true], 48 | outputs=y_pred, 49 | name='model_for_train' 50 | ) 51 | model_for_train.add_loss(Huber_loss(y_true, y_pred)) 52 | model_for_train.compile(loss=None, optimizer=self.optimizer) 53 | 54 | model_for_using = Model( 55 | inputs=x, 56 | outputs=y_pred, 57 | name='model_for_using' 58 | ) 59 | 60 | model_for_using.summary() 61 | return model_for_train, model_for_using 62 | 63 | 64 | class D3QNAgent(DDQNAgent): 65 | def _build_model(self): 66 | x = Input(shape=self.state_size, name='input') 67 | h = Conv2D(32, (8, 8), strides=(4, 4), activation='relu')(x) 68 | h = Conv2D(64, (4, 4), strides=(2, 2), activation='relu')(h) 69 | h = Conv2D(64, (3, 3), strides=(1, 1), activation='relu')(h) 70 | flatten = Flatten()(h) 71 | 72 | """ 73 | # Directly summing V and A gives us no guarantees that the A will actually predict the V. 74 | # Instead we combine them with criterion - Max or Avg 75 | # This(Avg) loses the original semantics of V and A (c.f. Max) . 76 | # But on the other hand it increases the stability of the optimization. 77 | # Ref: Dueling Network Architectures for Deep Reinforcement Learning 78 | """ 79 | # Advantage function 80 | A = Dense(512, activation='relu')(flatten) 81 | A = Dense(self.action_size, activation='linear', name='advanced')(A) 82 | A = Lambda( 83 | lambda a: a[:, :] - K.mean(a[:, :], keepdims=True), 84 | output_shape=(self.action_size, ) 85 | )(A) 86 | 87 | # Value function 88 | V = Dense(512, activation='relu')(flatten) 89 | V = Dense(1, activation='linear', name='value')(V) 90 | V = Lambda( 91 | lambda v: K.expand_dims(v[:, 0], -1), 92 | output_shape=(self.action_size, ) 93 | )(V) 94 | 95 | y_pred = Add()([V, A]) # tensor shape broadcasting 96 | 97 | # for custom loss function 98 | y_true = Input(shape=(self.action_size, ), name='y_true') 99 | model_for_train = Model( 100 | inputs=[x, y_true], 101 | outputs=y_pred, 102 | name='model_for_train' 103 | ) 104 | model_for_train.add_loss(Huber_loss(y_true, y_pred)) 105 | model_for_train.compile(loss=None, optimizer=self.optimizer) 106 | 107 | model_for_using = Model( 108 | inputs=x, 109 | outputs=y_pred, 110 | name='model_for_using' 111 | ) 112 | 113 | model_for_using.summary() 114 | return model_for_train, model_for_using 115 | -------------------------------------------------------------------------------- /atari_breakout/agents/layers.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras import initializers, activations 3 | from keras.engine.topology import Layer 4 | 5 | 6 | # TODO: variables convert into private 7 | # Buile own layer 8 | # Ref: https://keras.io/ko/layers/writing-your-own-keras-layers/ 9 | class NoisyDense(Layer): 10 | def __init__(self, 11 | output_dim, 12 | activation=None, 13 | **kwargs): 14 | self.output_dim = output_dim 15 | self.activation = activations.get(activation) 16 | 17 | self.e_i = None 18 | self.e_j = None 19 | 20 | super(NoisyDense, self).__init__(**kwargs) 21 | 22 | def build(self, input_shape): 23 | # assert isinstance(input_shape, list) 24 | 25 | self.input_dim = input_shape[-1] 26 | 27 | # Factorised NoisyNet 28 | # Ref: https://github.com/jakegrigsby/keras-rl/blob/master/rl/layers.py 29 | sqrt_inputs = self.input_dim ** (1 / 2) 30 | self.sigma_initializer = initializers.Constant(value=0.5 / sqrt_inputs) 31 | self.mu_initializer = initializers.RandomUniform(minval=(-1.0 / sqrt_inputs), maxval=(1.0 / sqrt_inputs)) 32 | 33 | # Learnable parameters 34 | # TODO: constraint, regularizer 35 | self.mu_weight = self.add_weight(name='mu_weights', 36 | shape=(self.input_dim, self.output_dim), 37 | initializer=self.mu_initializer, 38 | trainable=True) 39 | 40 | self.sigma_weight = self.add_weight(name='sigma_weights', 41 | shape=(self.input_dim, self.output_dim), 42 | initializer=self.sigma_initializer, 43 | trainable=True) 44 | 45 | self.mu_bias = self.add_weight(name='mu_bias', 46 | shape=(self.output_dim,), 47 | initializer=self.mu_initializer, 48 | trainable=True) 49 | 50 | self.sigma_bias = self.add_weight(name='sigma_bias', 51 | shape=(self.output_dim,), 52 | initializer=self.sigma_initializer, 53 | trainable=True) 54 | 55 | self.reset_noise() 56 | super(NoisyDense, self).build(input_shape) 57 | 58 | def call(self, x): 59 | # assert isinstance(x, list) 60 | 61 | # Factorised Gaussian noise: 62 | def f(e): 63 | return K.sign(e) * (K.sqrt(K.abs(e))) 64 | eW = f(self.e_i) * f(self.e_j) 65 | eB = f(self.e_j) 66 | 67 | # # Independent Gaussian noise: 68 | # eW = self.e_i 69 | # eB = self.e_j 70 | 71 | noise_injected_weights = self.mu_weight + (self.sigma_weight * eW) 72 | noise_injected_bias = self.mu_bias + (self.sigma_bias * eB) 73 | output = K.bias_add(K.dot(x, noise_injected_weights), noise_injected_bias) 74 | 75 | if self.activation != None: 76 | output = self.activation(output) 77 | return output 78 | 79 | def compute_output_shape(self, input_shape): 80 | # assert isinstance(input_shape, list) 81 | 82 | output_shape = list(input_shape) 83 | output_shape[-1] = self.output_dim 84 | return tuple(output_shape) 85 | 86 | # Ref: https://github.com/LuEE-C/Noisy-A3C-Keras 87 | def reset_noise(self): 88 | # Random variables 89 | # sample from noise distribution 90 | self.e_i = K.random_normal((self.input_dim, self.output_dim)) 91 | self.e_j = K.random_normal((self.output_dim,)) 92 | -------------------------------------------------------------------------------- /atari_breakout/agents/losses.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | 3 | 4 | def _Huber(y_true, y_pred): 5 | error = K.abs(y_true - y_pred) 6 | quadratic_part = K.clip(error, 0.0, 1.0) 7 | linear_part = error - quadratic_part 8 | return 0.5 * K.square(quadratic_part) + linear_part 9 | 10 | 11 | def Huber_loss(y_true, y_pred): 12 | return K.mean(_Huber(y_true, y_pred)) 13 | 14 | 15 | def PER_MSE_loss(y_true, y_pred, importances): 16 | error = K.abs(y_true - y_pred) 17 | L = K.square(error) 18 | return K.mean(L * importances) 19 | 20 | 21 | def PER_Huber_loss(y_true, y_pred, importances): 22 | return K.mean(_Huber(y_true, y_pred) * importances) 23 | -------------------------------------------------------------------------------- /atari_breakout/agents/memories.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import namedtuple, deque 3 | import random 4 | 5 | 6 | class ReplayMemory: 7 | def __init__(self, buffer_size): 8 | self.memory = deque(maxlen=buffer_size) 9 | self.Experience = namedtuple('Experience', ['State', 'Action', 'Reward', 'Next_state', 'Done']) 10 | 11 | def __len__(self): 12 | return len(self.memory) 13 | 14 | def push(self, state, action, reward, next_state, done): 15 | e = self.Experience(state, action, reward, next_state, done) 16 | self.memory.append(e) 17 | 18 | def sample(self, batch_size): 19 | return random.sample(self.memory, k=batch_size) 20 | 21 | 22 | class PrioritizedReplayMemory: 23 | def __init__(self, buffer_size): 24 | self.memory = deque(maxlen=buffer_size) 25 | self.Experience = namedtuple('Experience', ['State', 'Action', 'Reward', 'Next_state', 'Done']) 26 | 27 | self.priority = deque(maxlen=buffer_size) 28 | self.max_priority = 1.0 29 | 30 | def __len__(self): 31 | return len(self.memory) 32 | 33 | def push(self, state, action, reward, next_state, done, p=None): 34 | if p == None: 35 | p = self.max_priority # maximal priority 36 | 37 | e = self.Experience(state, action, reward, next_state, done) 38 | self.memory.append(e) 39 | self.priority.append(p) 40 | 41 | def sample(self, batch_size, beta): 42 | p_sum = np.sum(self.priority) 43 | prob = self.priority / p_sum 44 | 45 | indices = random.choices(range(len(prob)), k=batch_size, weights=prob) 46 | samples = np.array(self.memory)[indices] 47 | 48 | w = (len(self.priority) * prob) ** (-beta) 49 | importances = np.array(w)[indices] 50 | importances /= max(importances) # normalization 51 | 52 | return indices, samples, importances 53 | 54 | def update_priority(self, idx, p): 55 | self.priority[idx] = p 56 | -------------------------------------------------------------------------------- /atari_breakout/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parser(): 5 | parser = argparse.ArgumentParser(description='Some hyperparameters') 6 | # parser.add_argument('--sum', dest='accumulate', action='store_const', 7 | # const=sum, default=max, 8 | # help='sum the integers (default: find the max)') 9 | 10 | parser.add_argument('--e', metavar='E', type=int, default=50000, 11 | help='Total episodes') 12 | parser.add_argument('--double', metavar='D', type=bool, default=False, 13 | help='Enable Double DQN') 14 | parser.add_argument('--dueling', metavar='B', type=bool, default=False, 15 | help='Enable Dueling DQN') 16 | 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | if __name__ == "__main__": 22 | args = parser() 23 | print(args) 24 | -------------------------------------------------------------------------------- /atari_breakout/logger.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukepark327/keras-only-RL/f24c968f3a4bb4531a5c07df08374a9e8e656777/atari_breakout/logger.py -------------------------------------------------------------------------------- /atari_breakout/main.py: -------------------------------------------------------------------------------- 1 | """ References 2 | https://github.com/rlcode/reinforcement-learning/blob/master/3-atari/1-breakout/breakout_dqn.py 3 | https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras/blob/master/chapter9-drl/dqn-cartpole-9.6.1.py 4 | """ 5 | 6 | import numpy as np 7 | import gym 8 | import random 9 | from skimage.color import rgb2gray 10 | from skimage.transform import resize 11 | 12 | from agents.DQN import DQNAgent 13 | from agents.Double import DDQNAgent 14 | from agents.Dueling import DuelingAgent, D3QNAgent 15 | 16 | import arguments 17 | 18 | 19 | def pre_processing(observe): 20 | processed_observe = np.uint8(resize( # float --> integer (to reduce the size of replay memory) 21 | rgb2gray(observe), # 210*160*3(color) --> 84*84(mono) 22 | (84, 84), 23 | mode='constant') * 255) 24 | return processed_observe 25 | 26 | 27 | if __name__ == "__main__": 28 | # arguments 29 | args = arguments.parser() 30 | EPISODES = args.e # Number of episodes 31 | Double = args.double # Double DQN 32 | Dueling = args.dueling # Double DQN 33 | 34 | print("> Setting:", args) 35 | 36 | env = gym.make('BreakoutDeterministic-v4') 37 | 38 | if Double is True: 39 | if Dueling is True: 40 | agent = D3QNAgent(action_size=3) 41 | else: 42 | agent = DDQNAgent(action_size=3) 43 | if Dueling is True: # Double is False 44 | agent = DuelingAgent(action_size=3) 45 | else: 46 | agent = DQNAgent(action_size=3) 47 | 48 | scores, episodes, global_step = [], [], 0 49 | for e in range(EPISODES): 50 | done, dead = False, False 51 | step, score, start_life = 0, 0, 5 # 1 episode = 5 lives 52 | 53 | observe = env.reset() 54 | 55 | # Do nothing at the start of episode to avoid sub-optimal 56 | for _ in range(random.randint(1, agent.no_op_steps)): 57 | observe, _, _, _ = env.step(1) 58 | 59 | # At start of episode, there is no preceding frame 60 | # So just copy initial states to make history 61 | state = pre_processing(observe) 62 | history = np.stack((state, state, state, state), axis=2) 63 | history = np.reshape([history], (1, 84, 84, 4)) 64 | 65 | while not done: 66 | if agent.render: 67 | env.render() 68 | 69 | global_step += 1 70 | step += 1 71 | 72 | # get action for the current history and go one step in environment 73 | action = agent.get_action(history) 74 | 75 | # change action to real_action 76 | if action == 0: 77 | real_action = 1 78 | elif action == 1: 79 | real_action = 2 80 | else: 81 | real_action = 3 82 | 83 | observe, reward, done, info = env.step(real_action) 84 | 85 | # pre-process the observation --> history 86 | next_state = pre_processing(observe) 87 | next_state = np.reshape([next_state], (1, 84, 84, 1)) 88 | next_history = np.append(next_state, history[:, :, :, :3], axis=3) 89 | 90 | agent.avg_q_max += np.amax(agent.model.predict(np.float32(history / 255.))[0]) 91 | 92 | # if the agent missed ball, agent is dead --> episode is not over 93 | if start_life > info['ale.lives']: 94 | dead = True 95 | start_life = info['ale.lives'] 96 | 97 | reward = np.clip(reward, -1., 1.) 98 | 99 | agent.append_sample(history, action, reward, next_history, dead) 100 | agent.learn() 101 | 102 | if global_step % agent.update_target_rate == 0: 103 | agent.update_target_model() 104 | 105 | score += reward 106 | 107 | # if agent is dead, then reset the history 108 | if dead: 109 | dead = False 110 | else: 111 | history = next_history 112 | 113 | # if done, plot the score over episodes 114 | if done: 115 | print("episode:", e, 116 | "\tglobal_step:", global_step, 117 | "\tmemory_len: ", len(agent.memory), 118 | "\tscore:", score, 119 | "\tepsilon:", round(agent.epsilon, 6), 120 | "\tavg_q_max:", round(agent.avg_q_max / float(step), 6), 121 | "\tavg_loss:", round(agent.avg_loss / float(step), 6)) 122 | 123 | agent.avg_q_max, agent.avg_loss = 0.0, 0.0 124 | 125 | if e % 1000 == 0: 126 | # TODO: mkdir 127 | # agent.save_model("./save_model/breakout_dqn.h5") 128 | pass 129 | -------------------------------------------------------------------------------- /atari_breakout/visualization.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukepark327/keras-only-RL/f24c968f3a4bb4531a5c07df08374a9e8e656777/atari_breakout/visualization.py -------------------------------------------------------------------------------- /atari_breakout_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | RUN_PATH=`dirname $0` 4 | # echo $RUN_PATH 5 | 6 | python $RUN_PATH/atari_breakout/main.py $@ 7 | -------------------------------------------------------------------------------- /images/atari-breakout-D3QN.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukepark327/keras-only-RL/f24c968f3a4bb4531a5c07df08374a9e8e656777/images/atari-breakout-D3QN.gif --------------------------------------------------------------------------------