├── .gitignore ├── LICENSE ├── README.md ├── agents ├── __init__.py ├── _async.py ├── _n_step_q.py ├── _sarsa.py ├── agent.py ├── deep_q.py ├── experience.py ├── history.py └── statistic.py ├── assets ├── A1_A4_double_dueling.png └── corridor_result.png ├── environments ├── __init__.py ├── corridor.py └── environment.py ├── main.py ├── networks ├── __init__.py ├── cnn.py ├── layers.py ├── mlp.py └── network.py ├── test.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # misc 2 | logs 3 | 4 | # data 5 | samples 6 | *checkpoints/ 7 | *.npy 8 | *.pkl 9 | *.tgz 10 | *.zip 11 | *.tar.gz 12 | 13 | 14 | # Created by https://www.gitignore.io/api/python,vim 15 | 16 | ### IPythonNotebook ### 17 | ## Temporary data 18 | .ipynb_checkpoints/ 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | env/ 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *,cover 66 | .hypothesis/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | 82 | ### Vim ### 83 | [._]*.s[a-w][a-z] 84 | [._]s[a-w][a-z] 85 | *.un~ 86 | Session.vim 87 | .netrwhist 88 | *~ 89 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Taehoon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement Learning in TensorFlow 2 | 3 | TensorFlow implementation of Deep Reinforcement Learning papers. This implementation contains: 4 | 5 | [1] [Playing Atari with Deep Reinforcement Learning](http://arxiv.org/abs/1312.5602) 6 | [2] [Human-Level Control through Deep Reinforcement Learning](http://home.uchicago.edu/~arij/journalclub/papers/2015_Mnih_et_al.pdf) 7 | [3] [Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461) 8 | [4] [Dueling Network Architectures for Deep Reinforcement Learning](http://arxiv.org/abs/1511.06581) 9 | [5] [Prioritized Experience Replay](http://arxiv.org/pdf/1511.05952v3.pdf) (in progress) 10 | [6] [Deep Exploration via Bootstrapped DQN](http://arxiv.org/abs/1602.04621) (in progress) 11 | [7] [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783) (in progress) 12 | [8] [Continuous Deep q-Learning with Model-based Acceleration](http://arxiv.org/abs/1603.00748) (in progress) 13 | 14 | 15 | ## Requirements 16 | 17 | - Python 2.7 18 | - [gym](https://github.com/openai/gym) 19 | - [tqdm](https://github.com/tqdm/tqdm) 20 | - [OpenCV2](http://opencv.org/) or [Scipy](https://www.scipy.org/) 21 | - [TensorFlow 0.12.0](https://www.tensorflow.org/) 22 | 23 | 24 | ## Usage 25 | 26 | First, install prerequisites with: 27 | 28 | $ pip install -U 'gym[all]' tqdm scipy 29 | 30 | Don't forget to also install the latest 31 | [TensorFlow](https://www.tensorflow.org/). Also note that you need to install 32 | the dependences of [`doom-py`](https://github.com/openai/doom-py) which is 33 | required by `gym[all]` 34 | 35 | Train with DQN model described in [[1]](#deep-reinforcement-learning-in-tensorflow) without gpu: 36 | 37 | $ python main.py --network_header_type=nips --env_name=Breakout-v0 --use_gpu=False 38 | 39 | Train with DQN model described in [[2]](#deep-reinforcement-learning-in-tensorflow): 40 | 41 | $ python main.py --network_header_type=nature --env_name=Breakout-v0 42 | 43 | Train with Double DQN model described in [[3]](#deep-reinforcement-learning-in-tensorflow): 44 | 45 | $ python main.py --double_q=True --env_name=Breakout-v0 46 | 47 | Train with Deuling network with Double Q-learning described in [[4]](#deep-reinforcement-learning-in-tensorflow): 48 | 49 | $ python main.py --double_q=True --network_output_type=dueling --env_name=Breakout-v0 50 | 51 | Train with MLP model described in [[4]](#deep-reinforcement-learning-in-tensorflow) with corridor environment (useful for debugging): 52 | 53 | $ python main.py --network_header_type=mlp --network_output_type=normal --observation_dims='[16]' --env_name=CorridorSmall-v5 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=10 --display=True --learning_rate=0.025 --learning_rate_minimum=0.0025 54 | $ python main.py --network_header_type=mlp --network_output_type=normal --double_q=True --observation_dims='[16]' --env_name=CorridorSmall-v5 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=10 --display=True --learning_rate=0.025 --learning_rate_minimum=0.0025 55 | $ python main.py --network_header_type=mlp --network_output_type=dueling --observation_dims='[16]' --env_name=CorridorSmall-v5 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=10 --display=True --learning_rate=0.025 --learning_rate_minimum=0.0025 56 | $ python main.py --network_header_type=mlp --network_output_type=dueling --double_q=True --observation_dims='[16]' --env_name=CorridorSmall-v5 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=10 --display=True --learning_rate=0.025 --learning_rate_minimum=0.0025 57 | 58 | 59 | ## Results 60 | 61 | Result of `Corridor-v5` in [[4]](#deep-reinforcement-learning-in-tensorflow) for DQN (purple), DDQN (red), Dueling DQN (green), Dueling DDQN (blue). 62 | 63 | ![model](assets/corridor_result.png) 64 | 65 | Result of `Breakout-v0' for DQN without frame-skip (white-blue), DQN with frame-skip (light purple), Dueling DDQN (dark blue). 66 | 67 | ![model](assets/A1_A4_double_dueling.png) 68 | 69 | The hyperparameters and gradient clipping are not implemented as it is as [[4]](#deep-reinforcement-learning-in-tensorflow). 70 | 71 | 72 | ## References 73 | 74 | - [DQN-tensorflow](https://github.com/devsisters/DQN-tensorflow) 75 | - [DeepMind's code](https://sites.google.com/a/deepmind.com/dqn/) 76 | 77 | 78 | ## Author 79 | 80 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 81 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/deep-rl-tensorflow/95d3e2dde77d4a7a393ec418fe3537094d08c2ba/agents/__init__.py -------------------------------------------------------------------------------- /agents/_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from threading import Thread 6 | from logging import getLogger 7 | 8 | from .agent import Agent 9 | from .history import History 10 | from .experience import Experience 11 | 12 | logger = getLogger(__name__) 13 | 14 | class Async(Agent): 15 | def __init__(self, sess, pred_network, env, stat, conf, target_network=None): 16 | super(DeepQ, self).__init__(sess, pred_network, target_network, env, stat, conf) 17 | 18 | raise Exception("[!] Not fully implemented yet") 19 | 20 | # Optimizer 21 | with tf.variable_scope('optimizer'): 22 | self.targets = tf.placeholder('float32', [None], name='target_q_t') 23 | self.actions = tf.placeholder('int64', [None], name='action') 24 | 25 | actions_one_hot = tf.one_hot(self.actions, self.env.action_size, 1.0, 0.0, name='action_one_hot') 26 | pred_q = tf.reduce_sum(self.pred_network.outputs * actions_one_hot, reduction_indices=1, name='q_acted') 27 | 28 | self.delta = self.targets - pred_q 29 | if self.max_delta and self.min_delta: 30 | self.delta = tf.clip_by_value(self.delta, self.min_delta, self.max_delta, name='clipped_delta') 31 | 32 | self.loss = tf.reduce_mean(tf.square(self.delta), name='loss') 33 | 34 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum, 35 | tf.train.exponential_decay( 36 | self.learning_rate, 37 | self.stat.t_op, 38 | self.learning_rate_decay_step, 39 | self.learning_rate_decay, 40 | staircase=True)) 41 | 42 | optimizer = tf.train.RMSPropOptimizer( 43 | self.learning_rate_op, momentum=0.95, epsilon=0.01) 44 | 45 | grads_and_vars = optimizer.compute_gradients(self.loss) 46 | for idx, (grad, var) in enumerate(grads_and_vars): 47 | if grad is not None: 48 | grads_and_vars[idx] = (tf.clip_by_norm(grad, self.max_grad_norm), var) 49 | self.optim = optimizer.apply_gradients(grads_and_vars) 50 | 51 | 52 | def train(self, max_t, global_t): 53 | self.global_t = global_t 54 | 55 | # 0. Prepare training 56 | state, reward, terminal = self.env.new_random_game() 57 | self.observe(state, reward, terminal) 58 | 59 | while True: 60 | if global_t[0] > self.t_train_max: 61 | break 62 | 63 | # 1. Predict 64 | action = self.predict(state) 65 | # 2. Step 66 | state, reward, terminal = self.env.step(action, is_training=True) 67 | # 3. Observe 68 | self.observe(state, reward, terminal) 69 | 70 | if terminal: 71 | observation, reward, terminal = self.new_game() 72 | 73 | global_t[0] += 1 74 | 75 | def train_with_log(self, max_t, global_t): 76 | from tqdm import tqdm 77 | 78 | for _ in tqdm(range(max_t), ncols=70, initial=int(global_t[0])): 79 | if global_t[0] > self.t_train_max: 80 | break 81 | 82 | # 1. Predict 83 | action = self.predict(state) 84 | # 2. Step 85 | state, reward, terminal = self.env.step(-1, is_training=True) 86 | # 3. Observe 87 | self.observe(state, reward, terminal) 88 | 89 | if terminal: 90 | observation, reward, terminal = self.new_game() 91 | 92 | global_t[0] += 1 93 | 94 | if self.stat: 95 | self.stat.on_step(self.t, action, reward, terminal, 96 | ep, q, loss, is_update, self.learning_rate_op) 97 | 98 | def observe(self, s_t, r_t, terminal): 99 | self.prev_r[self.t] = max(self.min_reward, min(self.max_reward, r_t)) 100 | 101 | if (terminal and self.t_start < self.t) or self.t - self.t_start == self.t_max: 102 | r = {} 103 | 104 | lr = (self.t_train_max - self.global_t[0] + 1) / self.t_train_max * self.learning_rate 105 | 106 | if terminal: 107 | r[self.t] = 0. 108 | else: 109 | r[self.t] = self.sess.partial_run( 110 | self.partial_graph, 111 | self.networks[self.t_start - self.t].value, 112 | )[0][0] 113 | 114 | for t in range(self.t - 1, self.t_start - 1, -1): 115 | r[t] = self.prev_r[t] + self.gamma * r[t + 1] 116 | 117 | data = {} 118 | data.update({ 119 | self.networks[t].R: [r[t + self.t_start]] for t in range(len(self.prev_r) - 1) 120 | }) 121 | data.update({ 122 | self.networks[t].true_log_policy: 123 | [self.prev_log_policy[t + self.t_start]] for t in range(len(self.prev_r) - 1) 124 | }) 125 | data.update({ 126 | self.learning_rate_op: lr, 127 | }) 128 | 129 | # 1. Update accumulated gradients 130 | if not self.writer: 131 | self.sess.partial_run(self.partial_graph, 132 | [self.add_accum_grads[t] for t in range(len(self.prev_r) - 1)], data) 133 | else: 134 | results = self.sess.partial_run(self.partial_graph, 135 | [self.value_policy_summary] + [self.add_accum_grads[t] for t in range(len(self.prev_r) - 1)], data) 136 | 137 | summary_str = results[0] 138 | self.writer.add_summary(summary_str, self.global_t[0]) 139 | 140 | # 2. Update global w with accumulated gradients 141 | self.sess.run(self.apply_gradient, { 142 | self.learning_rate_op: lr, 143 | }) 144 | 145 | # 3. Reset accumulated gradients to zero 146 | self.sess.run(self.reset_accum_grad) 147 | 148 | # 4. Copy weights of global_network to local_network 149 | self.networks[0].copy_from_global() 150 | 151 | self.prev_r = {self.t: self.prev_r[self.t]} 152 | self.t_start = self.t 153 | 154 | del self.partial_graph 155 | -------------------------------------------------------------------------------- /agents/_n_step_q.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from logging import getLogger 6 | 7 | from .agent import Agent 8 | from .history import History 9 | from .experience import Experience 10 | 11 | logger = getLogger(__name__) 12 | 13 | class NStepQ(Agent): 14 | def __init__(self, sess, pred_network, env, stat, conf, target_network=None): 15 | super(DeepQ, self).__init__(sess, pred_network, target_network, env, stat, conf) 16 | 17 | raise Exception("[!] Not fully implemented yet") 18 | 19 | # Optimizer 20 | with tf.variable_scope('optimizer'): 21 | self.targets = tf.placeholder('float32', [None], name='target_q_t') 22 | self.actions = tf.placeholder('int64', [None], name='action') 23 | 24 | actions_one_hot = tf.one_hot(self.actions, self.env.action_size, 1.0, 0.0, name='action_one_hot') 25 | pred_q = tf.reduce_sum(self.pred_network.outputs * actions_one_hot, reduction_indices=1, name='q_acted') 26 | 27 | self.delta = self.targets - pred_q 28 | if self.max_delta and self.min_delta: 29 | self.delta = tf.clip_by_value(self.delta, self.min_delta, self.max_delta, name='clipped_delta') 30 | 31 | self.loss = tf.reduce_mean(tf.square(self.delta), name='loss') 32 | 33 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum, 34 | tf.train.exponential_decay( 35 | self.learning_rate, 36 | self.stat.t_op, 37 | self.learning_rate_decay_step, 38 | self.learning_rate_decay, 39 | staircase=True)) 40 | 41 | optimizer = tf.train.RMSPropOptimizer( 42 | self.learning_rate_op, momentum=0.95, epsilon=0.01) 43 | 44 | grads_and_vars = optimizer.compute_gradients(self.loss) 45 | for idx, (grad, var) in enumerate(grads_and_vars): 46 | if grad is not None: 47 | grads_and_vars[idx] = (tf.clip_by_norm(grad, self.max_grad_norm), var) 48 | self.optim = optimizer.apply_gradients(grads_and_vars) 49 | 50 | # Add accumulated gradients for n-step Q-learning 51 | def make_accumulated_gradients(self): 52 | reset_accum_grads = [] 53 | new_grads_and_vars = [] 54 | 55 | # 1. Prepare accum_grads 56 | self.accum_grads = {} 57 | self.add_accum_grads = {} 58 | 59 | for step, network in enumerate(self.networks): 60 | grads_and_vars = self.global_optim.compute_gradients(network.total_loss, network.w.values()) 61 | _add_accum_grads = [] 62 | 63 | for grad, var in tuple(grads_and_vars): 64 | if grad is not None: 65 | shape = grad.get_shape().as_list() 66 | 67 | name = 'accum/%s' % "/".join(var.name.split(':')[0].split('/')[-3:]) 68 | if step == 0: 69 | self.accum_grads[name] = tf.Variable( 70 | tf.zeros(shape), trainable=False, name=name) 71 | 72 | global_v = global_var[re.sub(r'.*\/A3C_\d+\/', '', var.name)] 73 | new_grads_and_vars.append((tf.clip_by_norm(self.accum_grads[name].ref(), self.max_grad_norm), global_v)) 74 | 75 | reset_accum_grads.append(self.accum_grads[name].assign(tf.zeros(shape))) 76 | 77 | _add_accum_grads.append(tf.assign_add(self.accum_grads[name], grad)) 78 | 79 | # 2. Add gradient to accum_grads 80 | self.add_accum_grads[step] = tf.group(*_add_accum_grads) 81 | 82 | def observe(self, observation, reward, action, terminal): 83 | reward = max(self.min_r, min(self.max_r, reward)) 84 | 85 | self.history.add(observation) 86 | self.experience.add(observation, reward, action, terminal) 87 | 88 | # q, loss, is_update 89 | result = [], 0, False 90 | 91 | if self.t > self.t_learn_start: 92 | if self.t % self.t_train_freq == 0: 93 | result = self.q_learning_minibatch() 94 | 95 | if self.t % self.t_target_q_update_freq == self.t_target_q_update_freq - 1: 96 | self.update_target_q_network() 97 | 98 | return result 99 | 100 | def q_learning_minibatch(self): 101 | if self.experience.count < self.history_length: 102 | return [], 0, False 103 | else: 104 | s_t, action, reward, s_t_plus_1, terminal = self.experience.sample() 105 | 106 | terminal = np.array(terminal) + 0. 107 | 108 | # Deep Q-learning 109 | max_q_t_plus_1 = self.target_network.calc_max_outputs(s_t_plus_1) 110 | target_q_t = (1. - terminal) * self.discount_r * max_q_t_plus_1 + reward 111 | 112 | _, q_t, loss = self.sess.run([self.optim, self.pred_network.outputs, self.loss], { 113 | self.targets: target_q_t, 114 | self.actions: action, 115 | self.pred_network.inputs: s_t, 116 | }) 117 | 118 | return q_t, loss, True 119 | -------------------------------------------------------------------------------- /agents/_sarsa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from logging import getLogger 6 | 7 | from .agent import Agent 8 | from .history import History 9 | from .experience import Experience 10 | 11 | logger = getLogger(__name__) 12 | 13 | class SARSA(Agent): 14 | def __init__(self, sess, pred_network, 15 | env, stat, conf, target_network=None, policy_network=None): 16 | super(DeepQ, self).__init__(sess, pred_network, target_network, policy_network, env, stat, conf) 17 | 18 | raise Exception("policy is not implemented yet") 19 | 20 | # Optimizer 21 | with tf.variable_scope('optimizer'): 22 | self.targets = tf.placeholder('float32', [None], name='target_q_t') 23 | self.actions = tf.placeholder('int64', [None], name='action') 24 | 25 | actions_one_hot = tf.one_hot(self.actions, self.env.action_size, 1.0, 0.0, name='action_one_hot') 26 | pred_q = tf.reduce_sum(self.pred_network.outputs * actions_one_hot, reduction_indices=1, name='q_acted') 27 | 28 | self.delta = self.targets - pred_q 29 | if self.max_delta and self.min_delta: 30 | self.delta = tf.clip_by_value(self.delta, self.min_delta, self.max_delta, name='clipped_delta') 31 | 32 | self.loss = tf.reduce_mean(tf.square(self.delta), name='loss') 33 | 34 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum, 35 | tf.train.exponential_decay( 36 | self.learning_rate, 37 | self.stat.t_op, 38 | self.learning_rate_decay_step, 39 | self.learning_rate_decay, 40 | staircase=True)) 41 | 42 | optimizer = tf.train.RMSPropOptimizer( 43 | self.learning_rate_op, momentum=0.95, epsilon=0.01) 44 | 45 | grads_and_vars = optimizer.compute_gradients(self.loss) 46 | for idx, (grad, var) in enumerate(grads_and_vars): 47 | if grad is not None: 48 | grads_and_vars[idx] = (tf.clip_by_norm(grad, self.max_grad_norm), var) 49 | self.optim = optimizer.apply_gradients(grads_and_vars) 50 | 51 | def observe(self, observation, reward, action, terminal): 52 | reward = max(self.min_r, min(self.max_r, reward)) 53 | 54 | self.history.add(observation) 55 | self.experience.add(observation, reward, action, terminal) 56 | 57 | # q, loss, is_update 58 | result = [], 0, False 59 | 60 | if self.t > self.t_learn_start: 61 | if self.t % self.t_train_freq == 0: 62 | result = self.q_learning_minibatch() 63 | 64 | if self.t % self.t_target_q_update_freq == self.t_target_q_update_freq - 1: 65 | self.update_target_q_network() 66 | 67 | return result 68 | 69 | def q_learning_minibatch(self): 70 | if self.experience.count < self.history_length: 71 | return [], 0, False 72 | else: 73 | s_t, action, reward, s_t_plus_1, terminal = self.experience.sample() 74 | 75 | terminal = np.array(terminal) + 0. 76 | 77 | # Deep Q-learning 78 | max_q_t_plus_1 = self.target_network.calc_max_outputs(s_t_plus_1) 79 | target_q_t = (1. - terminal) * self.discount_r * max_q_t_plus_1 + reward 80 | 81 | _, q_t, loss = self.sess.run([self.optim, self.pred_network.outputs, self.loss], { 82 | self.targets: target_q_t, 83 | self.actions: action, 84 | self.pred_network.inputs: s_t, 85 | }) 86 | 87 | return q_t, loss, True 88 | -------------------------------------------------------------------------------- /agents/agent.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | import tensorflow as tf 7 | from logging import getLogger 8 | 9 | from .history import History 10 | from .experience import Experience 11 | 12 | logger = getLogger(__name__) 13 | 14 | def get_time(): 15 | return time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime()) 16 | 17 | class Agent(object): 18 | def __init__(self, sess, pred_network, env, stat, conf, target_network=None): 19 | self.sess = sess 20 | self.stat = stat 21 | 22 | self.ep_start = conf.ep_start 23 | self.ep_end = conf.ep_end 24 | self.history_length = conf.history_length 25 | self.t_ep_end = conf.t_ep_end 26 | self.t_learn_start = conf.t_learn_start 27 | self.t_train_freq = conf.t_train_freq 28 | self.t_target_q_update_freq = conf.t_target_q_update_freq 29 | self.env_name = conf.env_name 30 | 31 | self.discount_r = conf.discount_r 32 | self.min_r = conf.min_r 33 | self.max_r = conf.max_r 34 | self.min_delta = conf.min_delta 35 | self.max_delta = conf.max_delta 36 | self.max_grad_norm = conf.max_grad_norm 37 | self.observation_dims = conf.observation_dims 38 | 39 | self.learning_rate = conf.learning_rate 40 | self.learning_rate_minimum = conf.learning_rate_minimum 41 | self.learning_rate_decay = conf.learning_rate_decay 42 | self.learning_rate_decay_step = conf.learning_rate_decay_step 43 | 44 | # network 45 | self.double_q = conf.double_q 46 | self.pred_network = pred_network 47 | self.target_network = target_network 48 | self.target_network.create_copy_op(self.pred_network) 49 | 50 | self.env = env 51 | self.history = History(conf.data_format, 52 | conf.batch_size, conf.history_length, conf.observation_dims) 53 | self.experience = Experience(conf.data_format, 54 | conf.batch_size, conf.history_length, conf.memory_size, conf.observation_dims) 55 | 56 | if conf.random_start: 57 | self.new_game = self.env.new_random_game 58 | else: 59 | self.new_game = self.env.new_game 60 | 61 | def train(self, t_max): 62 | tf.global_variables_initializer().run() 63 | 64 | self.stat.load_model() 65 | self.target_network.run_copy() 66 | 67 | start_t = self.stat.get_t() 68 | observation, reward, terminal = self.new_game() 69 | 70 | for _ in range(self.history_length): 71 | self.history.add(observation) 72 | 73 | for self.t in tqdm(range(start_t, t_max), ncols=70, initial=start_t): 74 | ep = (self.ep_end + 75 | max(0., (self.ep_start - self.ep_end) 76 | * (self.t_ep_end - max(0., self.t - self.t_learn_start)) / self.t_ep_end)) 77 | 78 | # 1. predict 79 | action = self.predict(self.history.get(), ep) 80 | # 2. act 81 | observation, reward, terminal, info = self.env.step(action, is_training=True) 82 | # 3. observe 83 | q, loss, is_update = self.observe(observation, reward, action, terminal) 84 | 85 | logger.debug("a: %d, r: %d, t: %d, q: %.4f, l: %.2f" % \ 86 | (action, reward, terminal, np.mean(q), loss)) 87 | 88 | if self.stat: 89 | self.stat.on_step(self.t, action, reward, terminal, 90 | ep, q, loss, is_update, self.learning_rate_op) 91 | if terminal: 92 | observation, reward, terminal = self.new_game() 93 | 94 | def play(self, test_ep, n_step=10000, n_episode=100): 95 | tf.initialize_all_variables().run() 96 | 97 | self.stat.load_model() 98 | self.target_network.run_copy() 99 | 100 | if not self.env.display: 101 | gym_dir = '/tmp/%s-%s' % (self.env_name, get_time()) 102 | env = gym.wrappers.Monitor(self.env.env, gym_dir) 103 | 104 | best_reward, best_idx, best_count = 0, 0, 0 105 | try: 106 | itr = xrange(n_episode) 107 | except NameError: 108 | itr = range(n_episode) 109 | for idx in itr: 110 | observation, reward, terminal = self.new_game() 111 | current_reward = 0 112 | 113 | for _ in range(self.history_length): 114 | self.history.add(observation) 115 | 116 | for self.t in tqdm(range(n_step), ncols=70): 117 | # 1. predict 118 | action = self.predict(self.history.get(), test_ep) 119 | # 2. act 120 | observation, reward, terminal, info = self.env.step(action, is_training=False) 121 | # 3. observe 122 | q, loss, is_update = self.observe(observation, reward, action, terminal) 123 | 124 | logger.debug("a: %d, r: %d, t: %d, q: %.4f, l: %.2f" % \ 125 | (action, reward, terminal, np.mean(q), loss)) 126 | current_reward += reward 127 | 128 | if terminal: 129 | break 130 | 131 | if current_reward > best_reward: 132 | best_reward = current_reward 133 | best_idx = idx 134 | best_count = 0 135 | elif current_reward == best_reward: 136 | best_count += 1 137 | 138 | print ("="*30) 139 | print (" [%d] Best reward : %d (dup-percent: %d/%d)" % (best_idx, best_reward, best_count, n_episode)) 140 | print ("="*30) 141 | 142 | #if not self.env.display: 143 | #gym.upload(gym_dir, writeup='https://github.com/devsisters/DQN-tensorflow', api_key='') 144 | 145 | def predict(self, s_t, ep): 146 | if random.random() < ep: 147 | action = random.randrange(self.env.action_size) 148 | else: 149 | action = self.pred_network.calc_actions([s_t])[0] 150 | return action 151 | 152 | def q_learning_minibatch_test(self): 153 | s_t = np.array([[[ 0., 0., 0., 0.], 154 | [ 0., 0., 0., 0.], 155 | [ 0., 0., 0., 0.], 156 | [ 1., 0., 0., 0.]]], dtype=np.uint8) 157 | s_t_plus_1 = np.array([[[ 0., 0., 0., 0.], 158 | [ 0., 0., 0., 0.], 159 | [ 1., 0., 0., 0.], 160 | [ 0., 0., 0., 0.]]], dtype=np.uint8) 161 | s_t = s_t.reshape([1, 1] + self.observation_dims) 162 | s_t_plus_1 = s_t_plus_1.reshape([1, 1] + self.observation_dims) 163 | 164 | action = [3] 165 | reward = [1] 166 | terminal = [0] 167 | 168 | terminal = np.array(terminal) + 0. 169 | max_q_t_plus_1 = self.target_network.calc_max_outputs(s_t_plus_1) 170 | target_q_t = (1. - terminal) * self.discount_r * max_q_t_plus_1 + reward 171 | 172 | _, q_t, a, loss = self.sess.run([ 173 | self.optim, self.pred_network.outputs, self.pred_network.actions, self.loss 174 | ], { 175 | self.targets: target_q_t, 176 | self.actions: action, 177 | self.pred_network.inputs: s_t, 178 | }) 179 | 180 | logger.info("q: %s, a: %d, l: %.2f" % (q_t, a, loss)) 181 | 182 | def update_target_q_network(self): 183 | assert self.target_network != None 184 | self.target_network.run_copy() 185 | -------------------------------------------------------------------------------- /agents/deep_q.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from logging import getLogger 6 | 7 | from .agent import Agent 8 | 9 | logger = getLogger(__name__) 10 | 11 | class DeepQ(Agent): 12 | def __init__(self, sess, pred_network, env, stat, conf, target_network=None): 13 | super(DeepQ, self).__init__(sess, pred_network, env, stat, conf, target_network=target_network) 14 | 15 | # Optimizer 16 | with tf.variable_scope('optimizer'): 17 | self.targets = tf.placeholder('float32', [None], name='target_q_t') 18 | self.actions = tf.placeholder('int64', [None], name='action') 19 | 20 | actions_one_hot = tf.one_hot(self.actions, self.env.action_size, 1.0, 0.0, name='action_one_hot') 21 | pred_q = tf.reduce_sum(self.pred_network.outputs * actions_one_hot, reduction_indices=1, name='q_acted') 22 | 23 | self.delta = self.targets - pred_q 24 | self.clipped_error = tf.where(tf.abs(self.delta) < 1.0, 25 | 0.5 * tf.square(self.delta), 26 | tf.abs(self.delta) - 0.5, name='clipped_error') 27 | 28 | self.loss = tf.reduce_mean(self.clipped_error, name='loss') 29 | 30 | self.learning_rate_op = tf.maximum(self.learning_rate_minimum, 31 | tf.train.exponential_decay( 32 | self.learning_rate, 33 | self.stat.t_op, 34 | self.learning_rate_decay_step, 35 | self.learning_rate_decay, 36 | staircase=True)) 37 | 38 | optimizer = tf.train.RMSPropOptimizer( 39 | self.learning_rate_op, momentum=0.95, epsilon=0.01) 40 | 41 | if self.max_grad_norm != None: 42 | grads_and_vars = optimizer.compute_gradients(self.loss) 43 | for idx, (grad, var) in enumerate(grads_and_vars): 44 | if grad is not None: 45 | grads_and_vars[idx] = (tf.clip_by_norm(grad, self.max_grad_norm), var) 46 | self.optim = optimizer.apply_gradients(grads_and_vars) 47 | else: 48 | self.optim = optimizer.minimize(self.loss) 49 | 50 | def observe(self, observation, reward, action, terminal): 51 | reward = max(self.min_r, min(self.max_r, reward)) 52 | 53 | self.history.add(observation) 54 | self.experience.add(observation, reward, action, terminal) 55 | 56 | # q, loss, is_update 57 | result = [], 0, False 58 | 59 | if self.t > self.t_learn_start: 60 | if self.t % self.t_train_freq == 0: 61 | result = self.q_learning_minibatch() 62 | 63 | if self.t % self.t_target_q_update_freq == self.t_target_q_update_freq - 1: 64 | self.update_target_q_network() 65 | 66 | return result 67 | 68 | def q_learning_minibatch(self): 69 | if self.experience.count < self.history_length: 70 | return [], 0, False 71 | else: 72 | s_t, action, reward, s_t_plus_1, terminal = self.experience.sample() 73 | 74 | terminal = np.array(terminal) + 0. 75 | 76 | if self.double_q: 77 | # Double Q-learning 78 | pred_action = self.pred_network.calc_actions(s_t_plus_1) 79 | q_t_plus_1_with_pred_action = self.target_network.calc_outputs_with_idx( 80 | s_t_plus_1, [[idx, pred_a] for idx, pred_a in enumerate(pred_action)]) 81 | target_q_t = (1. - terminal) * self.discount_r * q_t_plus_1_with_pred_action + reward 82 | else: 83 | # Deep Q-learning 84 | max_q_t_plus_1 = self.target_network.calc_max_outputs(s_t_plus_1) 85 | target_q_t = (1. - terminal) * self.discount_r * max_q_t_plus_1 + reward 86 | 87 | _, q_t, loss = self.sess.run([self.optim, self.pred_network.outputs, self.loss], { 88 | self.targets: target_q_t, 89 | self.actions: action, 90 | self.pred_network.inputs: s_t, 91 | }) 92 | 93 | return q_t, loss, True 94 | -------------------------------------------------------------------------------- /agents/experience.py: -------------------------------------------------------------------------------- 1 | """Modification of https://github.com/tambetm/simple_dqn/blob/master/src/replay_memory.py""" 2 | 3 | import random 4 | import numpy as np 5 | 6 | class Experience(object): 7 | def __init__(self, data_format, batch_size, history_length, memory_size, observation_dims): 8 | self.data_format = data_format 9 | self.batch_size = batch_size 10 | self.history_length = history_length 11 | self.memory_size = memory_size 12 | 13 | self.actions = np.empty(self.memory_size, dtype=np.uint8) 14 | self.rewards = np.empty(self.memory_size, dtype=np.int8) 15 | self.observations = np.empty([self.memory_size] + observation_dims, dtype=np.uint8) 16 | self.terminals = np.empty(self.memory_size, dtype=np.bool) 17 | 18 | # pre-allocate prestates and poststates for minibatch 19 | self.prestates = np.empty([self.batch_size, self.history_length] + observation_dims, dtype = np.float16) 20 | self.poststates = np.empty([self.batch_size, self.history_length] + observation_dims, dtype = np.float16) 21 | 22 | self.count = 0 23 | self.current = 0 24 | 25 | def add(self, observation, reward, action, terminal): 26 | self.actions[self.current] = action 27 | self.rewards[self.current] = reward 28 | self.observations[self.current, ...] = observation 29 | self.terminals[self.current] = terminal 30 | self.count = max(self.count, self.current + 1) 31 | self.current = (self.current + 1) % self.memory_size 32 | 33 | def sample(self): 34 | indexes = [] 35 | while len(indexes) < self.batch_size: 36 | while True: 37 | index = random.randint(self.history_length, self.count - 1) 38 | if index >= self.current and index - self.history_length < self.current: 39 | continue 40 | if self.terminals[(index - self.history_length):index].any(): 41 | continue 42 | break 43 | 44 | self.prestates[len(indexes), ...] = self.retreive(index - 1) 45 | self.poststates[len(indexes), ...] = self.retreive(index) 46 | indexes.append(index) 47 | 48 | actions = self.actions[indexes] 49 | rewards = self.rewards[indexes] 50 | terminals = self.terminals[indexes] 51 | 52 | if self.data_format == 'NHWC' and len(self.prestates.shape) == 4: 53 | return np.transpose(self.prestates, (0, 2, 3, 1)), actions, \ 54 | rewards, np.transpose(self.poststates, (0, 2, 3, 1)), terminals 55 | else: 56 | return self.prestates, actions, rewards, self.poststates, terminals 57 | 58 | def retreive(self, index): 59 | index = index % self.count 60 | if index >= self.history_length - 1: 61 | return self.observations[(index - (self.history_length - 1)):(index + 1), ...] 62 | else: 63 | indexes = [(index - i) % self.count for i in reversed(range(self.history_length))] 64 | return self.observations[indexes, ...] 65 | -------------------------------------------------------------------------------- /agents/history.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class History: 4 | def __init__(self, data_format, batch_size, history_length, screen_dims): 5 | self.data_format = data_format 6 | self.history = np.zeros([history_length] + screen_dims, dtype=np.float32) 7 | 8 | def add(self, screen): 9 | self.history[:-1] = self.history[1:] 10 | self.history[-1] = screen 11 | 12 | def reset(self): 13 | self.history *= 0 14 | 15 | def get(self): 16 | if self.data_format == 'NHWC' and len(self.history.shape) == 3: 17 | return np.transpose(self.history, (1, 2, 0)) 18 | else: 19 | return self.history 20 | -------------------------------------------------------------------------------- /agents/statistic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | class Statistic(object): 6 | def __init__(self, sess, t_test, t_learn_start, model_dir, variables, max_to_keep=20): 7 | self.sess = sess 8 | self.t_test = t_test 9 | self.t_learn_start = t_learn_start 10 | 11 | self.reset() 12 | self.max_avg_ep_reward = 0 13 | 14 | with tf.variable_scope('t'): 15 | self.t_op = tf.Variable(0, trainable=False, name='t') 16 | self.t_add_op = self.t_op.assign_add(1) 17 | 18 | self.model_dir = model_dir 19 | self.saver = tf.train.Saver(list(variables) + [self.t_op], max_to_keep=max_to_keep) 20 | self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph) 21 | 22 | with tf.variable_scope('summary'): 23 | scalar_summary_tags = [ 24 | 'average/reward', 'average/loss', 'average/q', 25 | 'episode/max_reward', 'episode/min_reward', 'episode/avg_reward', 26 | 'episode/num_of_game', 'training/learning_rate', 'training/epsilon', 27 | ] 28 | 29 | self.summary_placeholders = {} 30 | self.summary_ops = {} 31 | 32 | for tag in scalar_summary_tags: 33 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_')) 34 | self.summary_ops[tag] = tf.summary.scalar(tag, self.summary_placeholders[tag]) 35 | 36 | histogram_summary_tags = ['episode/rewards', 'episode/actions'] 37 | 38 | for tag in histogram_summary_tags: 39 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_')) 40 | self.summary_ops[tag] = tf.summary.histogram(tag, self.summary_placeholders[tag]) 41 | 42 | 43 | def reset(self): 44 | self.num_game = 0 45 | self.update_count = 0 46 | self.ep_reward = 0. 47 | self.total_loss = 0. 48 | self.total_reward = 0. 49 | self.actions = [] 50 | self.total_q = [] 51 | self.ep_rewards = [] 52 | 53 | def on_step(self, t, action, reward, terminal, 54 | ep, q, loss, is_update, learning_rate_op): 55 | if t >= self.t_learn_start: 56 | self.total_q.extend(q) 57 | self.actions.append(action) 58 | 59 | self.total_loss += loss 60 | self.total_reward += reward 61 | 62 | if terminal: 63 | self.num_game += 1 64 | self.ep_rewards.append(self.ep_reward) 65 | self.ep_reward = 0. 66 | else: 67 | self.ep_reward += reward 68 | 69 | if is_update: 70 | self.update_count += 1 71 | 72 | if t % self.t_test == self.t_test - 1 and self.update_count != 0: 73 | avg_q = np.mean(self.total_q) 74 | avg_loss = self.total_loss / self.update_count 75 | avg_reward = self.total_reward / self.t_test 76 | 77 | try: 78 | max_ep_reward = np.max(self.ep_rewards) 79 | min_ep_reward = np.min(self.ep_rewards) 80 | avg_ep_reward = np.mean(self.ep_rewards) 81 | except: 82 | max_ep_reward, min_ep_reward, avg_ep_reward = 0, 0, 0 83 | 84 | print ('\navg_r: %.4f, avg_l: %.6f, avg_q: %3.6f, avg_ep_r: %.4f, max_ep_r: %.4f, min_ep_r: %.4f, # game: %d' \ 85 | % (avg_reward, avg_loss, avg_q, avg_ep_reward, max_ep_reward, min_ep_reward, self.num_game)) 86 | 87 | if self.max_avg_ep_reward * 0.9 <= avg_ep_reward: 88 | assert t == self.get_t() 89 | 90 | self.save_model(t) 91 | 92 | self.max_avg_ep_reward = max(self.max_avg_ep_reward, avg_ep_reward) 93 | 94 | self.inject_summary({ 95 | 'average/q': avg_q, 96 | 'average/loss': avg_loss, 97 | 'average/reward': avg_reward, 98 | 'episode/max_reward': max_ep_reward, 99 | 'episode/min_reward': min_ep_reward, 100 | 'episode/avg_reward': avg_ep_reward, 101 | 'episode/num_of_game': self.num_game, 102 | 'episode/actions': self.actions, 103 | 'episode/rewards': self.ep_rewards, 104 | 'training/learning_rate': learning_rate_op.eval(session=self.sess), 105 | 'training/epsilon': ep, 106 | }, t) 107 | 108 | self.reset() 109 | 110 | self.t_add_op.eval(session=self.sess) 111 | 112 | def inject_summary(self, tag_dict, t): 113 | summary_str_lists = self.sess.run([self.summary_ops[tag] for tag in tag_dict.keys()], { 114 | self.summary_placeholders[tag]: value for tag, value in tag_dict.items() 115 | }) 116 | for summary_str in summary_str_lists: 117 | self.writer.add_summary(summary_str, t) 118 | 119 | def get_t(self): 120 | return self.t_op.eval(session=self.sess) 121 | 122 | def save_model(self, t): 123 | print(" [*] Saving checkpoints...") 124 | model_name = type(self).__name__ 125 | 126 | if not os.path.exists(self.model_dir): 127 | os.makedirs(self.model_dir) 128 | self.saver.save(self.sess, self.model_dir, global_step=t) 129 | 130 | def load_model(self): 131 | ckpt = tf.train.get_checkpoint_state(self.model_dir) 132 | if ckpt and ckpt.model_checkpoint_path: 133 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 134 | fname = os.path.join(self.model_dir, ckpt_name) 135 | self.saver.restore(self.sess, fname) 136 | print(" [*] Load SUCCESS: %s" % fname) 137 | return True 138 | else: 139 | print(" [!] Load FAILED: %s" % self.model_dir) 140 | return False 141 | -------------------------------------------------------------------------------- /assets/A1_A4_double_dueling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/deep-rl-tensorflow/95d3e2dde77d4a7a393ec418fe3537094d08c2ba/assets/A1_A4_double_dueling.png -------------------------------------------------------------------------------- /assets/corridor_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/deep-rl-tensorflow/95d3e2dde77d4a7a393ec418fe3537094d08c2ba/assets/corridor_result.png -------------------------------------------------------------------------------- /environments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/deep-rl-tensorflow/95d3e2dde77d4a7a393ec418fe3537094d08c2ba/environments/__init__.py -------------------------------------------------------------------------------- /environments/corridor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | try: 4 | from StringIO import StringIO 5 | except ImportError: 6 | from io import StringIO 7 | from gym import utils 8 | from gym import spaces 9 | from gym.envs.toy_text import discrete 10 | from gym.envs.registration import register 11 | 12 | MAPS = { 13 | "4x4": [ 14 | "HHHD", 15 | "FFFF", 16 | "FHHH", 17 | "SHHH", 18 | ], 19 | "9x9": [ 20 | "HHHHHHHHD", 21 | "HHHHHHHHF", 22 | "HHHHHHHHF", 23 | "HHHHHHHHF", 24 | "FFFFFFFFF", 25 | "FHHHHHHHH", 26 | "FHHHHHHHH", 27 | "FHHHHHHHH", 28 | "SHHHHHHHH", 29 | ], 30 | } 31 | 32 | class CorridorEnv(discrete.DiscreteEnv): 33 | """ 34 | The surface is described using a grid like the following 35 | 36 | HHHD 37 | FFFF 38 | SHHH 39 | AHHH 40 | 41 | S : starting point, safe 42 | F : frozen surface, safe 43 | H : hole, fall to your doom 44 | A : adjacent goal 45 | D : distant goal 46 | 47 | The episode ends when you reach the goal or fall in a hole. 48 | You receive a reward of 1 if you reach the adjacent goal, 49 | 10 if you reach the distant goal, and zero otherwise. 50 | """ 51 | metadata = {'render.modes': ['human', 'ansi']} 52 | 53 | def __init__(self, desc=None, map_name="9x9", n_actions=5): 54 | if desc is None and map_name is None: 55 | raise ValueError('Must provide either desc or map_name') 56 | elif desc is None: 57 | desc = MAPS[map_name] 58 | self.desc = desc = np.asarray(desc, dtype='c') 59 | self.nrow, self.ncol = nrow, ncol = desc.shape 60 | 61 | self.action_space = spaces.Discrete(n_actions) 62 | self.observation_space = spaces.Discrete(desc.size) 63 | 64 | n_state = nrow * ncol 65 | 66 | isd = (desc == 'S').ravel().astype('float64') 67 | isd /= isd.sum() 68 | 69 | P = {s : {a : [] for a in xrange(n_actions)} for s in xrange(n_state)} 70 | 71 | def to_s(row, col): 72 | return row*ncol + col 73 | def inc(row, col, a): 74 | if a == 0: # left 75 | col = max(col-1,0) 76 | elif a == 1: # down 77 | row = min(row+1, nrow-1) 78 | elif a == 2: # right 79 | col = min(col+1, ncol-1) 80 | elif a == 3: # up 81 | row = max(row-1, 0) 82 | 83 | return (row, col) 84 | 85 | for row in xrange(nrow): 86 | for col in xrange(ncol): 87 | s = to_s(row, col) 88 | for a in xrange(n_actions): 89 | li = P[s][a] 90 | newrow, newcol = inc(row, col, a) 91 | newstate = to_s(newrow, newcol) 92 | letter = desc[newrow, newcol] 93 | done = letter in 'DAH' 94 | rew = 1.0 if letter == 'A' \ 95 | else 10.0 if letter == 'D' \ 96 | else -1.0 if letter == 'H' \ 97 | else 1.0 if (newrow != row or newcol != col) and letter == 'F' \ 98 | else 0.0 99 | li.append((1.0/3.0, newstate, rew, done)) 100 | 101 | super(CorridorEnv, self).__init__(nrow * ncol, n_actions, P, isd) 102 | 103 | def _render(self, mode='human', close=False): 104 | if close: 105 | return 106 | 107 | outfile = StringIO.StringIO() if mode == 'ansi' else sys.stdout 108 | 109 | row, col = self.s // self.ncol, self.s % self.ncol 110 | desc = self.desc.tolist() 111 | desc[row][col] = utils.colorize(desc[row][col], "red", highlight=True) 112 | 113 | outfile.write("\n".join("".join(row) for row in desc)+"\n") 114 | if self.lastaction is not None: 115 | outfile.write(" ({})\n".format(self.get_action_meanings()[self.lastaction])) 116 | else: 117 | outfile.write("\n") 118 | 119 | return outfile 120 | 121 | def get_action_meanings(self): 122 | return [["Left", "Down", "Right", "Up"][i] if i < 4 else "NoOp" for i in xrange(self.action_space.n)] 123 | 124 | register( 125 | id='CorridorSmall-v5', 126 | entry_point='environments.corridor:CorridorEnv', 127 | kwargs={ 128 | 'map_name': '4x4', 129 | 'n_actions': 5 130 | }, 131 | timestep_limit=100, 132 | ) 133 | 134 | register( 135 | id='CorridorSmall-v10', 136 | entry_point='environments.corridor:CorridorEnv', 137 | kwargs={ 138 | 'map_name': '4x4', 139 | 'n_actions': 10 140 | }, 141 | timestep_limit=100, 142 | ) 143 | 144 | register( 145 | id='CorridorBig-v5', 146 | entry_point='environments.corridor:CorridorEnv', 147 | kwargs={ 148 | 'map_name': '9x9', 149 | 'n_actions': 5 150 | }, 151 | timestep_limit=100, 152 | ) 153 | 154 | register( 155 | id='CorridorBig-v10', 156 | entry_point='environments.corridor:CorridorEnv', 157 | kwargs={ 158 | 'map_name': '9x9', 159 | 'n_actions': 10 160 | }, 161 | timestep_limit=100, 162 | ) 163 | -------------------------------------------------------------------------------- /environments/environment.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import logging 4 | import numpy as np 5 | 6 | from .corridor import CorridorEnv 7 | 8 | try: 9 | import scipy.misc 10 | imresize = scipy.misc.imresize 11 | imwrite = scipy.misc.imsave 12 | except: 13 | import cv2 14 | imresize = cv2.resize 15 | imwrite = cv2.imwrite 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class Environment(object): 20 | def __init__(self, env_name, n_action_repeat, max_random_start, 21 | observation_dims, data_format, display, use_cumulated_reward=False): 22 | self.env = gym.make(env_name) 23 | 24 | self.n_action_repeat = n_action_repeat 25 | self.max_random_start = max_random_start 26 | self.action_size = self.env.action_space.n 27 | 28 | self.display = display 29 | self.data_format = data_format 30 | self.observation_dims = observation_dims 31 | self.use_cumulated_reward = use_cumulated_reward 32 | 33 | if hasattr(self.env, 'get_action_meanings'): 34 | logger.info("Using %d actions : %s" % (self.action_size, ", ".join(self.env.get_action_meanings()))) 35 | 36 | def new_game(self): 37 | return self.preprocess(self.env.reset()), 0, False 38 | 39 | def new_random_game(self): 40 | return self.new_game() 41 | 42 | def step(self, action, is_training=False): 43 | observation, reward, terminal, info = self.env.step(action) 44 | if self.display: self.env.render() 45 | return self.preprocess(observation), reward, terminal, info 46 | 47 | def preprocess(self): 48 | raise NotImplementedError() 49 | 50 | class ToyEnvironment(Environment): 51 | def preprocess(self, obs): 52 | new_obs = np.zeros([self.env.observation_space.n]) 53 | new_obs[obs] = 1 54 | return new_obs 55 | 56 | class AtariEnvironment(Environment): 57 | def __init__(self, env_name, n_action_repeat, max_random_start, 58 | observation_dims, data_format, display, use_cumulated_reward): 59 | super(AtariEnvironment, self).__init__(env_name, 60 | n_action_repeat, max_random_start, observation_dims, data_format, display, use_cumulated_reward) 61 | 62 | def new_game(self, from_random_game=False): 63 | screen = self.env.reset() 64 | screen, reward, terminal, _ = self.env.step(0) 65 | 66 | if self.display: 67 | self.env.render() 68 | 69 | if from_random_game: 70 | return screen, 0, False 71 | else: 72 | self.lives = self.env.unwrapped.ale.lives() 73 | terminal = False 74 | return self.preprocess(screen, terminal), 0, terminal 75 | 76 | def new_random_game(self): 77 | screen, reward, terminal = self.new_game(True) 78 | 79 | for idx in range(random.randrange(self.max_random_start)): 80 | screen, reward, terminal, _ = self.env.step(0) 81 | 82 | if terminal: logger.warning("warning: terminal signal received after %d 0-steps", idx) 83 | 84 | if self.display: 85 | self.env.render() 86 | 87 | self.lives = self.env.unwrapped.ale.lives() 88 | 89 | terminal = False 90 | return self.preprocess(screen, terminal), 0, terminal 91 | 92 | def step(self, action, is_training): 93 | if action == -1: 94 | # Step with random action 95 | action = self.env.action_space.sample() 96 | 97 | cumulated_reward = 0 98 | 99 | for _ in range(self.n_action_repeat): 100 | screen, reward, terminal, _ = self.env.step(action) 101 | cumulated_reward += reward 102 | current_lives = self.env.unwrapped.ale.lives() 103 | 104 | if is_training and self.lives > current_lives: 105 | terminal = True 106 | 107 | if terminal: break 108 | 109 | if self.display: 110 | self.env.render() 111 | 112 | if not terminal: 113 | self.lives = current_lives 114 | 115 | if self.use_cumulated_reward: 116 | return self.preprocess(screen, terminal), cumulated_reward, terminal, {} 117 | else: 118 | return self.preprocess(screen, terminal), reward, terminal, {} 119 | 120 | def preprocess(self, raw_screen, terminal): 121 | y = 0.2126 * raw_screen[:, :, 0] + 0.7152 * raw_screen[:, :, 1] + 0.0722 * raw_screen[:, :, 2] 122 | y = y.astype(np.uint8) 123 | y_screen = imresize(y, self.observation_dims) 124 | return y_screen 125 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import random 3 | import logging 4 | import tensorflow as tf 5 | 6 | from utils import get_model_dir 7 | from networks.cnn import CNN 8 | from networks.mlp import MLPSmall 9 | from agents.statistic import Statistic 10 | from environments.environment import ToyEnvironment, AtariEnvironment 11 | 12 | flags = tf.app.flags 13 | 14 | # Deep q Network 15 | flags.DEFINE_boolean('use_gpu', True, 'Whether to use gpu or not. gpu use NHWC and gpu use NCHW for data_format') 16 | flags.DEFINE_string('agent_type', 'DQN', 'The type of agent [DQN]') 17 | flags.DEFINE_boolean('double_q', False, 'Whether to use double Q-learning') 18 | flags.DEFINE_string('network_header_type', 'nips', 'The type of network header [mlp, nature, nips]') 19 | flags.DEFINE_string('network_output_type', 'normal', 'The type of network output [normal, dueling]') 20 | 21 | # Environment 22 | flags.DEFINE_string('env_name', 'Breakout-v0', 'The name of gym environment to use') 23 | flags.DEFINE_integer('n_action_repeat', 1, 'The number of actions to repeat') 24 | flags.DEFINE_integer('max_random_start', 30, 'The maximum number of NOOP actions at the beginning of an episode') 25 | flags.DEFINE_integer('history_length', 4, 'The length of history of observation to use as an input to DQN') 26 | flags.DEFINE_integer('max_r', +1, 'The maximum value of clipped reward') 27 | flags.DEFINE_integer('min_r', -1, 'The minimum value of clipped reward') 28 | flags.DEFINE_string('observation_dims', '[80, 80]', 'The dimension of gym observation') 29 | flags.DEFINE_boolean('random_start', True, 'Whether to start with random state') 30 | flags.DEFINE_boolean('use_cumulated_reward', False, 'Whether to use cumulated reward or not') 31 | 32 | # Training 33 | flags.DEFINE_boolean('is_train', True, 'Whether to do training or testing') 34 | flags.DEFINE_integer('max_delta', None, 'The maximum value of delta') 35 | flags.DEFINE_integer('min_delta', None, 'The minimum value of delta') 36 | flags.DEFINE_float('ep_start', 1., 'The value of epsilon at start in e-greedy') 37 | flags.DEFINE_float('ep_end', 0.01, 'The value of epsilnon at the end in e-greedy') 38 | flags.DEFINE_integer('batch_size', 32, 'The size of batch for minibatch training') 39 | flags.DEFINE_integer('max_grad_norm', None, 'The maximum norm of gradient while updating') 40 | flags.DEFINE_float('discount_r', 0.99, 'The discount factor for reward') 41 | 42 | # Timer 43 | flags.DEFINE_integer('t_train_freq', 4, '') 44 | 45 | # Below numbers will be multiplied by scale 46 | flags.DEFINE_integer('scale', 10000, 'The scale for big numbers') 47 | flags.DEFINE_integer('memory_size', 100, 'The size of experience memory (*= scale)') 48 | flags.DEFINE_integer('t_target_q_update_freq', 1, 'The frequency of target network to be updated (*= scale)') 49 | flags.DEFINE_integer('t_test', 1, 'The maximum number of t while training (*= scale)') 50 | flags.DEFINE_integer('t_ep_end', 100, 'The time when epsilon reach ep_end (*= scale)') 51 | flags.DEFINE_integer('t_train_max', 5000, 'The maximum number of t while training (*= scale)') 52 | flags.DEFINE_float('t_learn_start', 5, 'The time when to begin training (*= scale)') 53 | flags.DEFINE_float('learning_rate_decay_step', 5, 'The learning rate of training (*= scale)') 54 | 55 | # Optimizer 56 | flags.DEFINE_float('learning_rate', 0.00025, 'The learning rate of training') 57 | flags.DEFINE_float('learning_rate_minimum', 0.00025, 'The minimum learning rate of training') 58 | flags.DEFINE_float('learning_rate_decay', 0.96, 'The decay of learning rate of training') 59 | flags.DEFINE_float('decay', 0.99, 'Decay of RMSProp optimizer') 60 | flags.DEFINE_float('momentum', 0.0, 'Momentum of RMSProp optimizer') 61 | flags.DEFINE_float('gamma', 0.99, 'Discount factor of return') 62 | flags.DEFINE_float('beta', 0.01, 'Beta of RMSProp optimizer') 63 | 64 | # Debug 65 | flags.DEFINE_boolean('display', False, 'Whether to do display the game screen or not') 66 | flags.DEFINE_string('log_level', 'INFO', 'Log level [DEBUG, INFO, WARNING, ERROR, CRITICAL]') 67 | flags.DEFINE_integer('random_seed', 123, 'Value of random seed') 68 | flags.DEFINE_string('tag', '', 'The name of tag for a model, only for debugging') 69 | flags.DEFINE_boolean('allow_soft_placement', True, 'Whether to use part or all of a GPU') 70 | #flags.DEFINE_string('gpu_fraction', '1/1', 'idx / # of gpu fraction e.g. 1/3, 2/3, 3/3') 71 | 72 | # Internal 73 | # It is forbidden to set a flag that is not defined 74 | flags.DEFINE_string('data_format', 'NCHW', 'INTERNAL USED ONLY') 75 | 76 | def calc_gpu_fraction(fraction_string): 77 | idx, num = fraction_string.split('/') 78 | idx, num = float(idx), float(num) 79 | 80 | fraction = 1 / (num - idx + 1) 81 | print (" [*] GPU : %.4f" % fraction) 82 | return fraction 83 | 84 | conf = flags.FLAGS 85 | 86 | if conf.agent_type == 'DQN': 87 | from agents.deep_q import DeepQ 88 | TrainAgent = DeepQ 89 | else: 90 | raise ValueError('Unknown agent_type: %s' % conf.agent_type) 91 | 92 | logger = logging.getLogger() 93 | logger.propagate = False 94 | logger.setLevel(conf.log_level) 95 | 96 | # set random seed 97 | tf.set_random_seed(conf.random_seed) 98 | random.seed(conf.random_seed) 99 | 100 | def main(_): 101 | # preprocess 102 | conf.observation_dims = eval(conf.observation_dims) 103 | 104 | for flag in ['memory_size', 't_target_q_update_freq', 't_test', 105 | 't_ep_end', 't_train_max', 't_learn_start', 'learning_rate_decay_step']: 106 | setattr(conf, flag, getattr(conf, flag) * conf.scale) 107 | 108 | if conf.use_gpu: 109 | conf.data_format = 'NCHW' 110 | else: 111 | conf.data_format = 'NHWC' 112 | 113 | model_dir = get_model_dir(conf, 114 | ['use_gpu', 'max_random_start', 'n_worker', 'is_train', 'memory_size', 'gpu_fraction', 115 | 't_save', 't_train', 'display', 'log_level', 'random_seed', 'tag', 'scale']) 116 | 117 | # start 118 | #gpu_options = tf.GPUOptions( 119 | # per_process_gpu_memory_fraction=calc_gpu_fraction(conf.gpu_fraction)) 120 | 121 | sess_config = tf.ConfigProto( 122 | log_device_placement=False, allow_soft_placement=conf.allow_soft_placement) 123 | sess_config.gpu_options.allow_growth = conf.allow_soft_placement 124 | 125 | with tf.Session(config=sess_config) as sess: 126 | if any(name in conf.env_name for name in ['Corridor', 'FrozenLake']) : 127 | env = ToyEnvironment(conf.env_name, conf.n_action_repeat, 128 | conf.max_random_start, conf.observation_dims, 129 | conf.data_format, conf.display, conf.use_cumulated_reward) 130 | else: 131 | env = AtariEnvironment(conf.env_name, conf.n_action_repeat, 132 | conf.max_random_start, conf.observation_dims, 133 | conf.data_format, conf.display, conf.use_cumulated_reward) 134 | 135 | if conf.network_header_type in ['nature', 'nips']: 136 | pred_network = CNN(sess=sess, 137 | data_format=conf.data_format, 138 | history_length=conf.history_length, 139 | observation_dims=conf.observation_dims, 140 | output_size=env.env.action_space.n, 141 | network_header_type=conf.network_header_type, 142 | name='pred_network', trainable=True) 143 | target_network = CNN(sess=sess, 144 | data_format=conf.data_format, 145 | history_length=conf.history_length, 146 | observation_dims=conf.observation_dims, 147 | output_size=env.env.action_space.n, 148 | network_header_type=conf.network_header_type, 149 | name='target_network', trainable=False) 150 | elif conf.network_header_type == 'mlp': 151 | pred_network = MLPSmall(sess=sess, 152 | data_format=conf.data_format, 153 | observation_dims=conf.observation_dims, 154 | history_length=conf.history_length, 155 | output_size=env.env.action_space.n, 156 | hidden_activation_fn=tf.sigmoid, 157 | network_output_type=conf.network_output_type, 158 | name='pred_network', trainable=True) 159 | target_network = MLPSmall(sess=sess, 160 | data_format=conf.data_format, 161 | observation_dims=conf.observation_dims, 162 | history_length=conf.history_length, 163 | output_size=env.env.action_space.n, 164 | hidden_activation_fn=tf.sigmoid, 165 | network_output_type=conf.network_output_type, 166 | name='target_network', trainable=False) 167 | else: 168 | raise ValueError('Unkown network_header_type: %s' % (conf.network_header_type)) 169 | 170 | stat = Statistic(sess, conf.t_test, conf.t_learn_start, model_dir, pred_network.var.values()) 171 | agent = TrainAgent(sess, pred_network, env, stat, conf, target_network=target_network) 172 | 173 | if conf.is_train: 174 | agent.train(conf.t_train_max) 175 | else: 176 | agent.play(conf.ep_end) 177 | 178 | if __name__ == '__main__': 179 | tf.app.run() 180 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/deep-rl-tensorflow/95d3e2dde77d4a7a393ec418fe3537094d08c2ba/networks/__init__.py -------------------------------------------------------------------------------- /networks/cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from .layers import * 5 | from .network import Network 6 | 7 | class CNN(Network): 8 | def __init__(self, sess, 9 | data_format, 10 | history_length, 11 | observation_dims, 12 | output_size, 13 | trainable=True, 14 | hidden_activation_fn=tf.nn.relu, 15 | output_activation_fn=None, 16 | weights_initializer=initializers.xavier_initializer(), 17 | biases_initializer=tf.constant_initializer(0.1), 18 | value_hidden_sizes=[512], 19 | advantage_hidden_sizes=[512], 20 | network_output_type='dueling', 21 | network_header_type='nips', 22 | name='CNN'): 23 | super(CNN, self).__init__(sess, name) 24 | 25 | if data_format == 'NHWC': 26 | self.inputs = tf.placeholder('float32', 27 | [None] + observation_dims + [history_length], name='inputs') 28 | elif data_format == 'NCHW': 29 | self.inputs = tf.placeholder('float32', 30 | [None, history_length] + observation_dims, name='inputs') 31 | else: 32 | raise ValueError("unknown data_format : %s" % data_format) 33 | 34 | self.var = {} 35 | self.l0 = tf.div(self.inputs, 255.) 36 | 37 | with tf.variable_scope(name): 38 | if network_header_type.lower() == 'nature': 39 | self.l1, self.var['l1_w'], self.var['l1_b'] = conv2d(self.l0, 40 | 32, [8, 8], [4, 4], weights_initializer, biases_initializer, 41 | hidden_activation_fn, data_format, name='l1_conv') 42 | self.l2, self.var['l2_w'], self.var['l2_b'] = conv2d(self.l1, 43 | 64, [4, 4], [2, 2], weights_initializer, biases_initializer, 44 | hidden_activation_fn, data_format, name='l2_conv') 45 | self.l3, self.var['l3_w'], self.var['l3_b'] = conv2d(self.l2, 46 | 64, [3, 3], [1, 1], weights_initializer, biases_initializer, 47 | hidden_activation_fn, data_format, name='l3_conv') 48 | self.l4, self.var['l4_w'], self.var['l4_b'] = \ 49 | linear(self.l3, 512, weights_initializer, biases_initializer, 50 | hidden_activation_fn, data_format, name='l4_conv') 51 | layer = self.l4 52 | elif network_header_type.lower() == 'nips': 53 | self.l1, self.var['l1_w'], self.var['l1_b'] = conv2d(self.l0, 54 | 16, [8, 8], [4, 4], weights_initializer, biases_initializer, 55 | hidden_activation_fn, data_format, name='l1_conv') 56 | self.l2, self.var['l2_w'], self.var['l2_b'] = conv2d(self.l1, 57 | 32, [4, 4], [2, 2], weights_initializer, biases_initializer, 58 | hidden_activation_fn, data_format, name='l2_conv') 59 | self.l3, self.var['l3_w'], self.var['l3_b'] = \ 60 | linear(self.l2, 256, weights_initializer, biases_initializer, 61 | hidden_activation_fn, data_format, name='l3_conv') 62 | layer = self.l3 63 | else: 64 | raise ValueError('Wrong DQN type: %s' % network_header_type) 65 | 66 | self.build_output_ops(layer, network_output_type, 67 | value_hidden_sizes, advantage_hidden_sizes, output_size, 68 | weights_initializer, biases_initializer, hidden_activation_fn, 69 | output_activation_fn, trainable) 70 | -------------------------------------------------------------------------------- /networks/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import reduce 3 | from tensorflow.contrib.layers.python.layers import initializers 4 | 5 | def conv2d(x, 6 | output_dim, 7 | kernel_size, 8 | stride, 9 | weights_initializer=tf.contrib.layers.xavier_initializer(), 10 | biases_initializer=tf.zeros_initializer, 11 | activation_fn=tf.nn.relu, 12 | data_format='NHWC', 13 | padding='VALID', 14 | name='conv2d', 15 | trainable=True): 16 | with tf.variable_scope(name): 17 | if data_format == 'NCHW': 18 | stride = [1, 1, stride[0], stride[1]] 19 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[1], output_dim] 20 | elif data_format == 'NHWC': 21 | stride = [1, stride[0], stride[1], 1] 22 | kernel_shape = [kernel_size[0], kernel_size[1], x.get_shape()[-1], output_dim] 23 | 24 | w = tf.get_variable('w', kernel_shape, 25 | tf.float32, initializer=weights_initializer, trainable=trainable) 26 | conv = tf.nn.conv2d(x, w, stride, padding, data_format=data_format) 27 | 28 | b = tf.get_variable('b', [output_dim], 29 | tf.float32, initializer=biases_initializer, trainable=trainable) 30 | out = tf.nn.bias_add(conv, b, data_format) 31 | 32 | if activation_fn != None: 33 | out = activation_fn(out) 34 | 35 | return out, w, b 36 | 37 | def linear(input_, 38 | output_size, 39 | weights_initializer=initializers.xavier_initializer(), 40 | biases_initializer=tf.zeros_initializer, 41 | activation_fn=None, 42 | trainable=True, 43 | name='linear'): 44 | shape = input_.get_shape().as_list() 45 | 46 | if len(shape) > 2: 47 | input_ = tf.reshape(input_, [-1, reduce(lambda x, y: x * y, shape[1:])]) 48 | shape = input_.get_shape().as_list() 49 | 50 | with tf.variable_scope(name): 51 | w = tf.get_variable('w', [shape[1], output_size], tf.float32, 52 | initializer=weights_initializer, trainable=trainable) 53 | b = tf.get_variable('b', [output_size], 54 | initializer=biases_initializer, trainable=trainable) 55 | out = tf.nn.bias_add(tf.matmul(input_, w), b) 56 | 57 | if activation_fn != None: 58 | return activation_fn(out), w, b 59 | else: 60 | return out, w, b 61 | 62 | def batch_sample(probs, name='batch_sample'): 63 | with tf.variable_scope(name): 64 | uniform = tf.random_uniform(tf.shape(probs), minval=0, maxval=1) 65 | samples = tf.argmax(probs - uniform, dimension=1) 66 | return samples 67 | -------------------------------------------------------------------------------- /networks/mlp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .layers import * 4 | from .network import Network 5 | 6 | class MLPSmall(Network): 7 | def __init__(self, sess, 8 | data_format, 9 | observation_dims, 10 | history_length, 11 | output_size, 12 | trainable=True, 13 | batch_size=None, 14 | weights_initializer=initializers.xavier_initializer(), 15 | biases_initializer=tf.zeros_initializer, 16 | hidden_activation_fn=tf.nn.relu, 17 | output_activation_fn=None, 18 | hidden_sizes=[50, 50, 50], 19 | value_hidden_sizes=[25], 20 | advantage_hidden_sizes=[25], 21 | network_output_type='dueling', 22 | name='MLPSmall'): 23 | super(MLPSmall, self).__init__(sess, name) 24 | 25 | with tf.variable_scope(name): 26 | if data_format == 'NHWC': 27 | layer = self.inputs = tf.placeholder( 28 | 'float32', [batch_size] + observation_dims + [history_length], 'inputs') 29 | elif data_format == 'NCHW': 30 | layer = self.inputs = tf.placeholder( 31 | 'float32', [batch_size, history_length] + observation_dims, 'inputs') 32 | else: 33 | raise ValueError("unknown data_format : %s" % data_format) 34 | 35 | if len(layer.get_shape().as_list()) == 3: 36 | assert layer.get_shape().as_list()[1] == 1 37 | layer = tf.reshape(layer, [-1] + layer.get_shape().as_list()[2:]) 38 | 39 | for idx, hidden_size in enumerate(hidden_sizes): 40 | w_name, b_name = 'w_%d' % idx, 'b_%d' % idx 41 | 42 | layer, self.var[w_name], self.var[b_name] = \ 43 | linear(layer, hidden_size, weights_initializer, 44 | biases_initializer, hidden_activation_fn, trainable, name='lin_%d' % idx) 45 | 46 | self.build_output_ops(layer, network_output_type, 47 | value_hidden_sizes, advantage_hidden_sizes, output_size, 48 | weights_initializer, biases_initializer, hidden_activation_fn, 49 | output_activation_fn, trainable) 50 | -------------------------------------------------------------------------------- /networks/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .layers import * 4 | 5 | class Network(object): 6 | def __init__(self, sess, name): 7 | self.sess = sess 8 | self.copy_op = None 9 | self.name = name 10 | self.var = {} 11 | 12 | def build_output_ops(self, input_layer, network_output_type, 13 | value_hidden_sizes, advantage_hidden_sizes, output_size, 14 | weights_initializer, biases_initializer, hidden_activation_fn, 15 | output_activation_fn, trainable): 16 | if network_output_type == 'normal': 17 | self.outputs, self.var['w_out'], self.var['b_out'] = \ 18 | linear(input_layer, output_size, weights_initializer, 19 | biases_initializer, output_activation_fn, trainable, name='out') 20 | elif network_output_type == 'dueling': 21 | # Dueling Network 22 | assert len(value_hidden_sizes) != 0 and len(advantage_hidden_sizes) != 0 23 | 24 | layer = input_layer 25 | for idx, hidden_size in enumerate(value_hidden_sizes): 26 | w_name, b_name = 'val_w_%d' % idx, 'val_b_%d' % idx 27 | 28 | layer, self.var[w_name], self.var[b_name] = \ 29 | linear(layer, hidden_size, weights_initializer, 30 | biases_initializer, hidden_activation_fn, trainable, name='val_lin_%d' % idx) 31 | 32 | self.value, self.var['val_w_out'], self.var['val_w_b'] = \ 33 | linear(layer, 1, weights_initializer, 34 | biases_initializer, output_activation_fn, trainable, name='val_lin_out') 35 | 36 | layer = input_layer 37 | for idx, hidden_size in enumerate(advantage_hidden_sizes): 38 | w_name, b_name = 'adv_w_%d' % idx, 'adv_b_%d' % idx 39 | 40 | layer, self.var[w_name], self.var[b_name] = \ 41 | linear(layer, hidden_size, weights_initializer, 42 | biases_initializer, hidden_activation_fn, trainable, name='adv_lin_%d' % idx) 43 | 44 | self.advantage, self.var['adv_w_out'], self.var['adv_w_b'] = \ 45 | linear(layer, output_size, weights_initializer, 46 | biases_initializer, output_activation_fn, trainable, name='adv_lin_out') 47 | 48 | # Simple Dueling 49 | # self.outputs = self.value + self.advantage 50 | 51 | # Max Dueling 52 | # self.outputs = self.value + (self.advantage - 53 | # tf.reduce_max(self.advantage, reduction_indices=1, keep_dims=True)) 54 | 55 | # Average Dueling 56 | self.outputs = self.value + (self.advantage - 57 | tf.reduce_mean(self.advantage, reduction_indices=1, keepdims=True)) 58 | 59 | self.max_outputs = tf.reduce_max(self.outputs, reduction_indices=1) 60 | self.outputs_idx = tf.placeholder('int32', [None, None], 'outputs_idx') 61 | self.outputs_with_idx = tf.gather_nd(self.outputs, self.outputs_idx) 62 | self.actions = tf.argmax(self.outputs, axis=1) 63 | 64 | def run_copy(self): 65 | if self.copy_op is None: 66 | raise Exception("run `create_copy_op` first before copy") 67 | else: 68 | self.sess.run(self.copy_op) 69 | 70 | def create_copy_op(self, network): 71 | with tf.variable_scope(self.name): 72 | copy_ops = [] 73 | 74 | for name in self.var.keys(): 75 | copy_op = self.var[name].assign(network.var[name]) 76 | copy_ops.append(copy_op) 77 | 78 | self.copy_op = tf.group(*copy_ops, name='copy_op') 79 | 80 | def calc_actions(self, observation): 81 | return self.actions.eval({self.inputs: observation}, session=self.sess) 82 | 83 | def calc_outputs(self, observation): 84 | return self.outputs.eval({self.inputs: observation}, session=self.sess) 85 | 86 | def calc_max_outputs(self, observation): 87 | return self.max_outputs.eval({self.inputs: observation}, session=self.sess) 88 | 89 | def calc_outputs_with_idx(self, observation, idx): 90 | return self.outputs_with_idx.eval( 91 | {self.inputs: observation, self.outputs_idx: idx}, session=self.sess) 92 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # DQN 4 | python main.py --network_header_type=mlp --network_output_type=normal --observation_dims='[64]' --env_name=FrozenLake8x8-v0 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=50 --display=True --use_gpu=False 5 | 6 | # dueling DQN 7 | python main.py --network_header_type=mlp --network_output_type=dueling --observation_dims='[64]' --env_name=FrozenLake8x8-v0 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=50 --display=True --use_gpu=False 8 | 9 | # DDQN 10 | python main.py --network_header_type=mlp --network_output_type=normal --double_q=True --observation_dims='[64]' --env_name=FrozenLake8x8-v0 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=50 --display=True --use_gpu=False 11 | 12 | # Dueling DDQN 13 | python main.py --network_header_type=mlp --network_output_type=dueling --double_q=True --observation_dims='[64]' --env_name=FrozenLake8x8-v0 --t_learn_start=0.1 --learning_rate_decay_step=0.1 --history_length=1 --n_action_repeat=1 --t_ep_end=50 --display=True --use_gpu=False 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tensorflow as tf 4 | from six.moves import range 5 | from logging import getLogger 6 | 7 | logger = getLogger(__name__) 8 | 9 | def get_model_dir(config, exceptions=None): 10 | keys = dir(config) 11 | keys.sort() 12 | keys.remove('env_name') 13 | keys = ['env_name'] + keys 14 | 15 | names = [config.env_name] 16 | for key in keys: 17 | # Only use useful flags 18 | if key not in exceptions: 19 | value = getattr(config, key) 20 | names.append( 21 | "%s=%s" % (key, ",".join([str(i) for i in value]) 22 | if type(value) == list else value)) 23 | 24 | return os.path.join('checkpoints', *names) + '/' 25 | 26 | def timeit(f): 27 | def timed(*args, **kwargs): 28 | start_time = time.time() 29 | result = f(*args, **kwargs) 30 | end_time = time.time() 31 | 32 | logger.info("%s : %2.2f sec" % (f.__name__, end_time - start_time)) 33 | return result 34 | return timed 35 | --------------------------------------------------------------------------------