├── .gitignore ├── LICENSE ├── README.md ├── assets ├── Pendulum-v0_2016-07-15.png └── algorithm.png ├── main.py ├── run_mujoco.sh ├── src ├── __init__.py ├── exploration.py ├── naf.py ├── network.py ├── ops.py ├── statistic.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # misc 2 | logs 3 | 4 | # data 5 | samples 6 | *checkpoints/ 7 | *.npy 8 | *.pkl 9 | *.tgz 10 | *.zip 11 | *.tar.gz 12 | 13 | 14 | # Created by https://www.gitignore.io/api/python,vim 15 | 16 | ### IPythonNotebook ### 17 | ## Temporary data 18 | .ipynb_checkpoints/ 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | env/ 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *,cover 66 | .hypothesis/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | 82 | ### Vim ### 83 | [._]*.s[a-w][a-z] 84 | [._]s[a-w][a-z] 85 | *.un~ 86 | Session.vim 87 | .netrwhist 88 | *~ 89 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Taehoon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Normalized Advantage Functions (NAF) in TensorFlow 2 | 3 | TensorFlow implementation of [Continuous Deep q-Learning with Model-based Acceleration](http://arxiv.org/abs/1603.00748). 4 | 5 | ![algorithm](https://github.com/carpedm20/naf-tensorflow/blob/master/assets/algorithm.png) 6 | 7 | 8 | ## Requirements 9 | 10 | - Python 2.7 11 | - [gym](https://github.com/openai/gym) 12 | - [TensorFlow](https://www.tensorflow.org/) 0.9+ 13 | 14 | 15 | ## Usage 16 | 17 | First, install prerequisites with: 18 | 19 | $ pip install tqdm gym[all] 20 | 21 | To train a model for an environment with a continuous action space: 22 | 23 | $ python main.py --env_name=Pendulum-v0 --is_train=True 24 | $ python main.py --env_name=Pendulum-v0 --is_train=True --display=True 25 | 26 | To test and record the screens with gym: 27 | 28 | $ python main.py --env_name=Pendulum-v0 --is_train=False 29 | $ python main.py --env_name=Pendulum-v0 --is_train=False --display=True 30 | 31 | 32 | ## Results 33 | 34 | Training details of `Pendulum-v0` with different hyperparameters. 35 | 36 | $ python main.py --env_name=Pendulum-v0 # dark green 37 | $ python main.py --env_name=Pendulum-v0 --action_fn=tanh # light green 38 | $ python main.py --env_name=Pendulum-v0 --use_batch_norm=True # yellow 39 | $ python main.py --env_name=Pendulum-v0 --use_seperate_networks=True # green 40 | 41 | ![Pendulum-v0_2016-07-15](https://github.com/carpedm20/naf-tensorflow/blob/master/assets/Pendulum-v0_2016-07-15.png) 42 | 43 | 44 | ## References 45 | 46 | - [rllab](https://github.com/rllab/rllab.git) 47 | - [keras implementation](https://gym.openai.com/evaluations/eval_CzoNQdPSAm0J3ikTBSTCg) 48 | 49 | 50 | ## Author 51 | 52 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 53 | -------------------------------------------------------------------------------- /assets/Pendulum-v0_2016-07-15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NAF-tensorflow/5754bd40fe135f79272b333ba2b911b02ca293f7/assets/Pendulum-v0_2016-07-15.png -------------------------------------------------------------------------------- /assets/algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NAF-tensorflow/5754bd40fe135f79272b333ba2b911b02ca293f7/assets/algorithm.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from src.naf import NAF 7 | from src.network import Network 8 | from src.statistic import Statistic 9 | from src.exploration import OUExploration, BrownianExploration, LinearDecayExploration 10 | from utils import get_model_dir, preprocess_conf 11 | 12 | flags = tf.app.flags 13 | 14 | # environment 15 | flags.DEFINE_string('env_name', 'Pendulum-v0', 'name of environment') 16 | 17 | # network 18 | flags.DEFINE_string('hidden_dims', '[100, 100]', 'dimension of hidden layers') 19 | flags.DEFINE_boolean('use_batch_norm', False, 'use batch normalization or not') 20 | flags.DEFINE_boolean('clip_action', False, 'whether to clip an action with given bound') 21 | flags.DEFINE_boolean('use_seperate_networks', False, 'use seperate networks for mu, V and A') 22 | flags.DEFINE_string('hidden_w', 'uniform_big', 'weight initialization of hidden layers [uniform_small, uniform_big, he]') 23 | flags.DEFINE_string('hidden_fn', 'tanh', 'activation function of hidden layer [none, tanh, relu]') 24 | flags.DEFINE_string('action_w', 'uniform_big', 'weight initilization of action layer [uniform_small, uniform_big, he]') 25 | flags.DEFINE_string('action_fn', 'tanh', 'activation function of action layer [none, tanh, relu]') 26 | flags.DEFINE_string('w_reg', 'none', 'weight regularization [none, l1, l2]') 27 | flags.DEFINE_float('w_reg_scale', 0.001, 'scale of regularization') 28 | 29 | # exploration 30 | flags.DEFINE_float('noise_scale', 0.3, 'scale of noise') 31 | flags.DEFINE_string('noise', 'linear_decay', 'type of noise exploration [ou, linear_decay, brownian]') 32 | 33 | # training 34 | flags.DEFINE_float('tau', 0.001, 'tau of soft target update') 35 | flags.DEFINE_float('discount', 0.99, 'discount factor of Q-learning') 36 | flags.DEFINE_float('learning_rate', 1e-3, 'value of learning rate') 37 | flags.DEFINE_integer('batch_size', 100, 'The size of batch for minibatch training') 38 | flags.DEFINE_integer('max_steps', 200, 'maximum # of steps for each episode') 39 | flags.DEFINE_integer('update_repeat', 10, 'maximum # of q-learning updates for each step') 40 | flags.DEFINE_integer('max_episodes', 10000, 'maximum # of episodes to train') 41 | 42 | # Debug 43 | flags.DEFINE_boolean('is_train', True, 'training or testing') 44 | flags.DEFINE_integer('random_seed', 123, 'random seed') 45 | flags.DEFINE_boolean('monitor', False, 'monitor the training or not') 46 | flags.DEFINE_boolean('display', False, 'display the game screen or not') 47 | flags.DEFINE_string('log_level', 'INFO', 'log level [DEBUG, INFO, WARNING, ERROR, CRITICAL]') 48 | 49 | conf = flags.FLAGS 50 | 51 | logger = logging.getLogger() 52 | logger.propagate = False 53 | logger.setLevel(conf.log_level) 54 | 55 | # set random seed 56 | tf.set_random_seed(conf.random_seed) 57 | np.random.seed(conf.random_seed) 58 | 59 | def main(_): 60 | model_dir = get_model_dir(conf, 61 | ['is_train', 'random_seed', 'monitor', 'display', 'log_level']) 62 | 63 | preprocess_conf(conf) 64 | 65 | with tf.Session() as sess: 66 | # environment 67 | env = gym.make(conf.env_name) 68 | env.seed(conf.random_seed) 69 | 70 | assert isinstance(env.observation_space, gym.spaces.Box), \ 71 | "observation space must be continuous" 72 | assert isinstance(env.action_space, gym.spaces.Box), \ 73 | "action space must be continuous" 74 | 75 | # exploration strategy 76 | if conf.noise == 'ou': 77 | strategy = OUExploration(env, sigma=conf.noise_scale) 78 | elif conf.noise == 'brownian': 79 | strategy = BrownianExploration(env, conf.noise_scale) 80 | elif conf.noise == 'linear_decay': 81 | strategy = LinearDecayExploration(env) 82 | else: 83 | raise ValueError('Unkown exploration strategy: %s' % conf.noise) 84 | 85 | # networks 86 | shared_args = { 87 | 'sess': sess, 88 | 'input_shape': env.observation_space.shape, 89 | 'action_size': env.action_space.shape[0], 90 | 'hidden_dims': conf.hidden_dims, 91 | 'use_batch_norm': conf.use_batch_norm, 92 | 'use_seperate_networks': conf.use_seperate_networks, 93 | 'hidden_w': conf.hidden_w, 'action_w': conf.action_w, 94 | 'hidden_fn': conf.hidden_fn, 'action_fn': conf.action_fn, 95 | 'w_reg': conf.w_reg, 96 | } 97 | 98 | logger.info("Creating prediction network...") 99 | pred_network = Network( 100 | scope='pred_network', **shared_args 101 | ) 102 | 103 | logger.info("Creating target network...") 104 | target_network = Network( 105 | scope='target_network', **shared_args 106 | ) 107 | target_network.make_soft_update_from(pred_network, conf.tau) 108 | 109 | # statistic 110 | stat = Statistic(sess, conf.env_name, model_dir, pred_network.variables, conf.update_repeat) 111 | 112 | agent = NAF(sess, env, strategy, pred_network, target_network, stat, 113 | conf.discount, conf.batch_size, conf.learning_rate, 114 | conf.max_steps, conf.update_repeat, conf.max_episodes) 115 | 116 | agent.run(conf.monitor, conf.display, conf.is_train) 117 | 118 | if __name__ == '__main__': 119 | tf.app.run() 120 | -------------------------------------------------------------------------------- /run_mujoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo_and_run() { echo "$@"; $@; } 4 | 5 | for env in "InvertedPendulum-v1" "InvertedDoublePendulum-v1" "Reacher-v1," "HalfCheetah-v1" "Swimmer-v1" "Hopper-v1" "Walker2d-v1" "Ant-v1" "Ant-v1" "HumanoidStandup-v1"; do 6 | echo_and_run python main.py --env_name=$env & 7 | done 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NAF-tensorflow/5754bd40fe135f79272b333ba2b911b02ca293f7/src/__init__.py -------------------------------------------------------------------------------- /src/exploration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.random as nr 3 | 4 | class Exploration(object): 5 | def __init__(self, env): 6 | self.action_size = env.action_space.shape[0] 7 | 8 | def add_noise(self, action, info={}): 9 | pass 10 | 11 | def reset(self): 12 | pass 13 | 14 | class OUExploration(Exploration): 15 | # Reference: https://github.com/rllab/rllab/blob/master/rllab/exploration_strategies/ou_strategy.py 16 | 17 | def __init__(self, env, sigma=0.3, mu=0, theta=0.15): 18 | super(OUExploration, self).__init__(env) 19 | 20 | self.mu = mu 21 | self.theta = theta 22 | self.sigma = sigma 23 | 24 | self.state = np.ones(self.action_size) * self.mu 25 | self.reset() 26 | 27 | def add_noise(self, action, info={}): 28 | x = self.state 29 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 30 | self.state = x + dx 31 | 32 | return action + self.state 33 | 34 | def reset(self): 35 | self.state = np.ones(self.action_size) * self.mu 36 | 37 | class LinearDecayExploration(Exploration): 38 | def __init__(self, env): 39 | super(LinearDecayExploration, self).__init__(env) 40 | 41 | def add_noise(self, action, info={}): 42 | return action + np.random.randn(self.action_size) / (info['idx_episode'] + 1) 43 | 44 | class BrownianExploration(Exploration): 45 | def __init__(self, env, noise_scale): 46 | super(BrownianExploration, self).__init__(env) 47 | 48 | raise Exception('not implemented yet') 49 | -------------------------------------------------------------------------------- /src/naf.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | logger = getLogger(__name__) 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.contrib.framework import get_variables 7 | 8 | from .utils import get_timestamp 9 | 10 | class NAF(object): 11 | def __init__(self, sess, 12 | env, strategy, pred_network, target_network, stat, 13 | discount, batch_size, learning_rate, 14 | max_steps, update_repeat, max_episodes): 15 | self.sess = sess 16 | self.env = env 17 | self.strategy = strategy 18 | self.pred_network = pred_network 19 | self.target_network = target_network 20 | self.stat = stat 21 | 22 | self.discount = discount 23 | self.batch_size = batch_size 24 | self.learning_rate = learning_rate 25 | self.action_size = env.action_space.shape[0] 26 | 27 | self.max_steps = max_steps 28 | self.update_repeat = update_repeat 29 | self.max_episodes = max_episodes 30 | 31 | self.prestates = [] 32 | self.actions = [] 33 | self.rewards = [] 34 | self.poststates = [] 35 | self.terminals = [] 36 | 37 | with tf.name_scope('optimizer'): 38 | self.target_y = tf.placeholder(tf.float32, [None], name='target_y') 39 | self.loss = tf.reduce_mean(tf.squared_difference(self.target_y, tf.squeeze(self.pred_network.Q)), name='loss') 40 | 41 | self.optim = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss) 42 | 43 | def run(self, monitor=False, display=False, is_train=True): 44 | self.stat.load_model() 45 | self.target_network.hard_copy_from(self.pred_network) 46 | 47 | if monitor: 48 | self.env.monitor.start('/tmp/%s-%s' % (self.stat.env_name, get_timestamp())) 49 | 50 | for self.idx_episode in xrange(self.max_episodes): 51 | state = self.env.reset() 52 | 53 | for t in xrange(0, self.max_steps): 54 | if display: self.env.render() 55 | 56 | # 1. predict 57 | action = self.predict(state) 58 | 59 | # 2. step 60 | self.prestates.append(state) 61 | state, reward, terminal, _ = self.env.step(action) 62 | self.poststates.append(state) 63 | 64 | terminal = True if t == self.max_steps - 1 else terminal 65 | 66 | # 3. perceive 67 | if is_train: 68 | q, v, a, l = self.perceive(state, reward, action, terminal) 69 | 70 | if self.stat: 71 | self.stat.on_step(action, reward, terminal, q, v, a, l) 72 | 73 | if terminal: 74 | self.strategy.reset() 75 | break 76 | 77 | if monitor: 78 | self.env.monitor.close() 79 | 80 | def run2(self, monitor=False, display=False, is_train=True): 81 | target_y = tf.placeholder(tf.float32, [None], name='target_y') 82 | loss = tf.reduce_mean(tf.squared_difference(target_y, tf.squeeze(self.pred_network.Q)), name='loss') 83 | 84 | optim = tf.train.AdamOptimizer(self.learning_rate).minimize(loss) 85 | 86 | self.stat.load_model() 87 | self.target_network.hard_copy_from(self.pred_network) 88 | 89 | # replay memory 90 | prestates = [] 91 | actions = [] 92 | rewards = [] 93 | poststates = [] 94 | terminals = [] 95 | 96 | # the main learning loop 97 | total_reward = 0 98 | for i_episode in xrange(self.max_episodes): 99 | observation = self.env.reset() 100 | episode_reward = 0 101 | 102 | for t in xrange(self.max_steps): 103 | if display: 104 | self.env.render() 105 | 106 | # predict the mean action from current observation 107 | x_ = np.array([observation]) 108 | u_ = self.pred_network.mu.eval({self.pred_network.x: x_})[0] 109 | 110 | action = u_ + np.random.randn(1) / (i_episode + 1) 111 | 112 | prestates.append(observation) 113 | actions.append(action) 114 | 115 | observation, reward, done, info = self.env.step(action) 116 | episode_reward += reward 117 | 118 | rewards.append(reward); poststates.append(observation); terminals.append(done) 119 | 120 | if len(prestates) > 10: 121 | loss_ = 0 122 | for k in xrange(self.update_repeat): 123 | if len(prestates) > self.batch_size: 124 | indexes = np.random.choice(len(prestates), size=self.batch_size) 125 | else: 126 | indexes = range(len(prestates)) 127 | 128 | # Q-update 129 | v_ = self.target_network.V.eval({self.target_network.x: np.array(poststates)[indexes]}) 130 | y_ = np.array(rewards)[indexes] + self.discount * np.squeeze(v_) 131 | 132 | tmp1, tmp2 = np.array(prestates)[indexes], np.array(actions)[indexes] 133 | loss_ += l_ 134 | 135 | self.target_network.soft_update_from(self.pred_network) 136 | 137 | if done: 138 | break 139 | 140 | print "average loss:", loss_/k 141 | print "Episode {} finished after {} timesteps, reward {}".format(i_episode + 1, t + 1, episode_reward) 142 | total_reward += episode_reward 143 | 144 | print "Average reward per episode {}".format(total_reward / self.episodes) 145 | 146 | def predict(self, state): 147 | u = self.pred_network.predict([state])[0] 148 | 149 | return self.strategy.add_noise(u, {'idx_episode': self.idx_episode}) 150 | 151 | def perceive(self, state, reward, action, terminal): 152 | self.rewards.append(reward) 153 | self.actions.append(action) 154 | 155 | return self.q_learning_minibatch() 156 | 157 | def q_learning_minibatch(self): 158 | q_list = [] 159 | v_list = [] 160 | a_list = [] 161 | l_list = [] 162 | 163 | for iteration in xrange(self.update_repeat): 164 | if len(self.rewards) >= self.batch_size: 165 | indexes = np.random.choice(len(self.rewards), size=self.batch_size) 166 | else: 167 | indexes = np.arange(len(self.rewards)) 168 | 169 | x_t = np.array(self.prestates)[indexes] 170 | x_t_plus_1 = np.array(self.poststates)[indexes] 171 | r_t = np.array(self.rewards)[indexes] 172 | u_t = np.array(self.actions)[indexes] 173 | 174 | v = self.target_network.predict_v(x_t_plus_1, u_t) 175 | target_y = self.discount * np.squeeze(v) + r_t 176 | 177 | _, l, q, v, a = self.sess.run([ 178 | self.optim, self.loss, 179 | self.pred_network.Q, self.pred_network.V, self.pred_network.A, 180 | ], { 181 | self.target_y: target_y, 182 | self.pred_network.x: x_t, 183 | self.pred_network.u: u_t, 184 | self.pred_network.is_train: True, 185 | }) 186 | 187 | q_list.extend(q) 188 | v_list.extend(v) 189 | a_list.extend(a) 190 | l_list.append(l) 191 | 192 | self.target_network.soft_update_from(self.pred_network) 193 | 194 | logger.debug("q: %s, v: %s, a: %s, l: %s" \ 195 | % (np.mean(q), np.mean(v), np.mean(a), np.mean(l))) 196 | 197 | return np.sum(q_list), np.sum(v_list), np.sum(a_list), np.sum(l_list) 198 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | logger = getLogger(__name__) 3 | 4 | import tensorflow as tf 5 | from tensorflow.contrib.framework import get_variables 6 | 7 | from .ops import * 8 | 9 | class Network: 10 | def __init__(self, sess, input_shape, action_size, hidden_dims, 11 | use_batch_norm, use_seperate_networks, 12 | hidden_w, action_w, hidden_fn, action_fn, w_reg, 13 | scope='NAF'): 14 | self.sess = sess 15 | with tf.variable_scope(scope): 16 | x = tf.placeholder(tf.float32, (None,) + tuple(input_shape), name='observations') 17 | u = tf.placeholder(tf.float32, (None, action_size), name='actions') 18 | is_train = tf.placeholder(tf.bool, name='is_train') 19 | 20 | hid_outs = {} 21 | with tf.name_scope('hidden'): 22 | if use_seperate_networks: 23 | logger.info("Creating seperate networks for v, l, and mu") 24 | 25 | for scope in ['v', 'l', 'mu']: 26 | with tf.variable_scope(scope): 27 | if use_batch_norm: 28 | h = batch_norm(x, is_training=is_train) 29 | else: 30 | h = x 31 | 32 | for idx, hidden_dim in enumerate(hidden_dims): 33 | h = fc(h, hidden_dim, is_train, hidden_w, weight_reg=w_reg, 34 | activation_fn=hidden_fn, use_batch_norm=use_batch_norm, scope='hid%d' % idx) 35 | hid_outs[scope] = h 36 | else: 37 | logger.info("Creating shared networks for v, l, and mu") 38 | 39 | if use_batch_norm: 40 | h = batch_norm(x, is_training=is_train) 41 | else: 42 | h = x 43 | 44 | for idx, hidden_dim in enumerate(hidden_dims): 45 | h = fc(h, hidden_dim, is_train, hidden_w, weight_reg=w_reg, 46 | activation_fn=hidden_fn, use_batch_norm=use_batch_norm, scope='hid%d' % idx) 47 | hid_outs['v'], hid_outs['l'], hid_outs['mu'] = h, h, h 48 | 49 | with tf.name_scope('value'): 50 | V = fc(hid_outs['v'], 1, is_train, 51 | hidden_w, use_batch_norm=use_batch_norm, scope='V') 52 | 53 | with tf.name_scope('advantage'): 54 | l = fc(hid_outs['l'], (action_size * (action_size + 1))/2, is_train, hidden_w, 55 | use_batch_norm=use_batch_norm, scope='l') 56 | mu = fc(hid_outs['mu'], action_size, is_train, action_w, 57 | activation_fn=action_fn, use_batch_norm=use_batch_norm, scope='mu') 58 | 59 | pivot = 0 60 | rows = [] 61 | for idx in xrange(action_size): 62 | count = action_size - idx 63 | 64 | diag_elem = tf.exp(tf.slice(l, (0, pivot), (-1, 1))) 65 | non_diag_elems = tf.slice(l, (0, pivot+1), (-1, count-1)) 66 | row = tf.pad(tf.concat((diag_elem, non_diag_elems), 1), ((0, 0), (idx, 0))) 67 | rows.append(row) 68 | 69 | pivot += count 70 | 71 | L = tf.transpose(tf.stack(rows, axis=1), (0, 2, 1)) 72 | P = tf.matmul(L, tf.transpose(L, (0, 2, 1))) 73 | 74 | tmp = tf.expand_dims(u - mu, -1) 75 | A = -tf.matmul(tf.transpose(tmp, [0, 2, 1]), tf.matmul(P, tmp))/2 76 | A = tf.reshape(A, [-1, 1]) 77 | 78 | with tf.name_scope('Q'): 79 | Q = A + V 80 | 81 | with tf.name_scope('optimization'): 82 | self.target_y = tf.placeholder(tf.float32, [None], name='target_y') 83 | self.loss = tf.reduce_mean(tf.squared_difference(self.target_y, tf.squeeze(Q)), name='loss') 84 | 85 | self.is_train = is_train 86 | self.variables = get_variables(scope) 87 | 88 | self.x, self.u, self.mu, self.V, self.Q, self.P, self.A = x, u, mu, V, Q, P, A 89 | 90 | def predict_v(self, x, u): 91 | return self.sess.run(self.V, { 92 | self.x: x, self.u: u, self.is_train: False, 93 | }) 94 | 95 | def predict(self, state): 96 | return self.sess.run(self.mu, { 97 | self.x: state, self.is_train: False 98 | }) 99 | 100 | def update(self, optim, target_v, x_t, u_t): 101 | _, q, v, a, l = self.sess.run([ 102 | optim, self.Q, self.V, self.A, self.loss 103 | ], { 104 | self.target_y: target_v, 105 | self.x: x_t, 106 | self.u: u_t, 107 | self.is_train: True, 108 | }) 109 | return q, v, a, l 110 | 111 | def make_soft_update_from(self, network, tau): 112 | logger.info("Creating ops for soft target update...") 113 | assert len(network.variables) == len(self.variables), \ 114 | "target and prediction network should have same # of variables" 115 | 116 | self.assign_op = {} 117 | for from_, to_ in zip(network.variables, self.variables): 118 | if 'BatchNorm' in to_.name: 119 | self.assign_op[to_.name] = to_.assign(from_) 120 | else: 121 | self.assign_op[to_.name] = to_.assign(tau * from_ + (1-tau) * to_) 122 | 123 | def hard_copy_from(self, network): 124 | logger.info("Creating ops for hard target update...") 125 | assert len(network.variables) == len(self.variables), \ 126 | "target and prediction network should have same # of variables" 127 | 128 | for from_, to_ in zip(network.variables, self.variables): 129 | self.sess.run(to_.assign(from_)) 130 | 131 | def soft_update_from(self, network): 132 | for variable in self.variables: 133 | self.sess.run(self.assign_op[variable.name]) 134 | return True 135 | -------------------------------------------------------------------------------- /src/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.contrib.layers import fully_connected 4 | # from tensorflow.contrib.layers import initializers 5 | from tensorflow.contrib.layers import l1_regularizer 6 | from tensorflow.contrib.layers import l2_regularizer 7 | from tensorflow.contrib.layers import batch_norm 8 | 9 | random_uniform_big = tf.random_uniform_initializer(-0.05, 0.05) 10 | random_uniform_small = tf.random_uniform_initializer(-3e-4, 3e-4) 11 | # he_uniform = initializers.variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False) 12 | he_uniform = tf.contrib.layers.variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False) 13 | 14 | def fc(layer, output_size, is_training, 15 | weight_init, weight_reg=None, activation_fn=None, 16 | use_batch_norm=False, scope='fc'): 17 | if use_batch_norm: 18 | batch_norm_args = { 19 | 'normalizer_fn': batch_norm, 20 | 'normalizer_params': { 21 | 'is_training': is_training, 22 | } 23 | } 24 | else: 25 | batch_norm_args = {} 26 | 27 | with tf.variable_scope(scope): 28 | return fully_connected( 29 | layer, 30 | num_outputs=output_size, 31 | activation_fn=activation_fn, 32 | weights_initializer=weight_init, 33 | weights_regularizer=weight_reg, 34 | biases_initializer=tf.constant_initializer(0.0), 35 | scope=scope, 36 | **batch_norm_args 37 | ) 38 | 39 | from tensorflow.contrib.framework.python.ops import add_arg_scope 40 | from tensorflow.contrib.framework.python.ops import variables 41 | from tensorflow.contrib.layers.python.layers import initializers 42 | from tensorflow.contrib.layers.python.layers import utils 43 | 44 | from tensorflow.python.framework import dtypes 45 | from tensorflow.python.framework import ops 46 | from tensorflow.python.ops import array_ops 47 | from tensorflow.python.ops import control_flow_ops 48 | from tensorflow.python.ops import init_ops 49 | from tensorflow.python.ops import nn 50 | from tensorflow.python.ops import standard_ops 51 | from tensorflow.python.ops import variable_scope 52 | from tensorflow.python.training import moving_averages 53 | 54 | @add_arg_scope 55 | def batch_norm(inputs, 56 | decay=0.999, 57 | center=True, 58 | scale=False, 59 | epsilon=0.001, 60 | updates_collections=ops.GraphKeys.UPDATE_OPS, 61 | is_training=True, 62 | reuse=None, 63 | variables_collections=None, 64 | outputs_collections=None, 65 | trainable=True, 66 | scope=None): 67 | """Code modification of tensorflow/contrib/layers/python/layers/layers.py 68 | """ 69 | with variable_scope.variable_op_scope([inputs], 70 | scope, 'BatchNorm', reuse=reuse) as sc: 71 | inputs = ops.convert_to_tensor(inputs) 72 | inputs_shape = inputs.get_shape() 73 | inputs_rank = inputs_shape.ndims 74 | if inputs_rank is None: 75 | raise ValueError('Inputs %s has undefined rank.' % inputs.name) 76 | dtype = inputs.dtype.base_dtype 77 | axis = list(range(inputs_rank - 1)) 78 | params_shape = inputs_shape[-1:] 79 | if not params_shape.is_fully_defined(): 80 | raise ValueError('Inputs %s has undefined last dimension %s.' % ( 81 | inputs.name, params_shape)) 82 | # Allocate parameters for the beta and gamma of the normalization. 83 | beta, gamma = None, None 84 | if center: 85 | beta_collections = utils.get_variable_collections(variables_collections, 86 | 'beta') 87 | beta = variables.model_variable('beta', 88 | shape=params_shape, 89 | dtype=dtype, 90 | initializer=init_ops.zeros_initializer, 91 | collections=beta_collections, 92 | trainable=trainable) 93 | if scale: 94 | gamma_collections = utils.get_variable_collections(variables_collections, 95 | 'gamma') 96 | gamma = variables.model_variable('gamma', 97 | shape=params_shape, 98 | dtype=dtype, 99 | initializer=init_ops.ones_initializer, 100 | collections=gamma_collections, 101 | trainable=trainable) 102 | # Create moving_mean and moving_variance variables and add them to the 103 | # appropiate collections. 104 | moving_mean_collections = utils.get_variable_collections( 105 | variables_collections, 'moving_mean') 106 | moving_mean = variables.model_variable( 107 | 'moving_mean', 108 | shape=params_shape, 109 | dtype=dtype, 110 | initializer=init_ops.zeros_initializer, 111 | trainable=False, 112 | collections=moving_mean_collections) 113 | moving_variance_collections = utils.get_variable_collections( 114 | variables_collections, 'moving_variance') 115 | moving_variance = variables.model_variable( 116 | 'moving_variance', 117 | shape=params_shape, 118 | dtype=dtype, 119 | initializer=init_ops.ones_initializer, 120 | trainable=False, 121 | collections=moving_variance_collections) 122 | 123 | # Calculate the moments based on the individual batch. 124 | mean, variance = nn.moments(inputs, axis, shift=moving_mean) 125 | # Update the moving_mean and moving_variance moments. 126 | update_moving_mean = moving_averages.assign_moving_average( 127 | moving_mean, mean, decay) 128 | update_moving_variance = moving_averages.assign_moving_average( 129 | moving_variance, variance, decay) 130 | if updates_collections is None: 131 | # Make sure the updates are computed here. 132 | with ops.control_dependencies([update_moving_mean, 133 | update_moving_variance]): 134 | outputs = nn.batch_normalization( 135 | inputs, mean, variance, beta, gamma, epsilon) 136 | else: 137 | # Collect the updates to be computed later. 138 | ops.add_to_collections(updates_collections, update_moving_mean) 139 | ops.add_to_collections(updates_collections, update_moving_variance) 140 | outputs = nn.batch_normalization( 141 | inputs, mean, variance, beta, gamma, epsilon) 142 | 143 | test_outputs = nn.batch_normalization( 144 | inputs, moving_mean, moving_variance, beta, gamma, epsilon) 145 | 146 | outputs = tf.cond(is_training, lambda: outputs, lambda: test_outputs) 147 | outputs.set_shape(inputs_shape) 148 | 149 | return utils.collect_named_outputs(outputs_collections, sc.name, outputs) 150 | -------------------------------------------------------------------------------- /src/statistic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from logging import getLogger 5 | 6 | logger = getLogger(__name__) 7 | 8 | class Statistic(object): 9 | def __init__(self, sess, env_name, model_dir, variables, max_update_per_step, max_to_keep=20): 10 | self.sess = sess 11 | self.env_name = env_name 12 | self.max_update_per_step = max_update_per_step 13 | 14 | self.reset() 15 | self.max_avg_r = None 16 | 17 | with tf.variable_scope('t'): 18 | self.t_op = tf.Variable(0, trainable=False, name='t') 19 | self.t_add_op = self.t_op.assign_add(1) 20 | 21 | self.model_dir = model_dir 22 | self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep) 23 | self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph) 24 | 25 | with tf.variable_scope('summary'): 26 | scalar_summary_tags = ['total r', 'avg r', 'avg q', 'avg v', 'avg a', 'avg l'] 27 | 28 | self.summary_placeholders = {} 29 | self.summary_ops = {} 30 | 31 | for tag in scalar_summary_tags: 32 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_')) 33 | self.summary_ops[tag] = tf.summary.scalar('%s/%s' % (self.env_name, tag), self.summary_placeholders[tag]) 34 | 35 | def reset(self): 36 | self.total_q = 0. 37 | self.total_v = 0. 38 | self.total_a = 0. 39 | self.total_l = 0. 40 | 41 | self.ep_step = 0 42 | self.ep_rewards = [] 43 | 44 | def on_step(self, action, reward, terminal, q, v, a, l): 45 | self.t = self.t_add_op.eval(session=self.sess) 46 | 47 | self.total_q += q 48 | self.total_v += v 49 | self.total_a += a 50 | self.total_l += l 51 | 52 | self.ep_step += 1 53 | self.ep_rewards.append(reward) 54 | 55 | if terminal: 56 | avg_q = self.total_q / self.ep_step / self.max_update_per_step 57 | avg_v = self.total_v / self.ep_step / self.max_update_per_step 58 | avg_a = self.total_a / self.ep_step / self.max_update_per_step 59 | avg_l = self.total_l / self.ep_step / self.max_update_per_step 60 | 61 | avg_r = np.mean(self.ep_rewards) 62 | total_r = np.sum(self.ep_rewards) 63 | 64 | logger.info('t: %d, R: %.3f, r: %.3f, q: %.3f, v: %.3f, a: %.3f, l: %.3f' \ 65 | % (self.t, total_r, avg_r, avg_q, avg_q, avg_a, avg_l)) 66 | 67 | if self.max_avg_r == None: 68 | self.max_avg_r = avg_r 69 | 70 | if self.max_avg_r * 0.9 <= avg_r: 71 | self.save_model(self.t) 72 | self.max_avg_r = max(self.max_avg_r, avg_r) 73 | 74 | self.inject_summary({ 75 | 'total r': total_r, 'avg r': avg_r, 76 | 'avg q': avg_q, 'avg v': avg_v, 'avg a': avg_a, 'avg l': avg_l, 77 | }, self.t) 78 | 79 | self.reset() 80 | 81 | def inject_summary(self, tag_dict, t): 82 | summary_str_lists = self.sess.run([self.summary_ops[tag] for tag in tag_dict.keys()], { 83 | self.summary_placeholders[tag]: value for tag, value in tag_dict.items() 84 | }) 85 | for summary_str in summary_str_lists: 86 | self.writer.add_summary(summary_str, t) 87 | 88 | def save_model(self, t): 89 | logger.info("Saving checkpoints...") 90 | model_name = type(self).__name__ 91 | 92 | if not os.path.exists(self.model_dir): 93 | os.makedirs(self.model_dir) 94 | self.saver.save(self.sess, self.model_dir, global_step=t) 95 | 96 | def load_model(self): 97 | logger.info("Loading checkpoints...") 98 | tf.initialize_all_variables().run() 99 | 100 | ckpt = tf.train.get_checkpoint_state(self.model_dir) 101 | if ckpt and ckpt.model_checkpoint_path: 102 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 103 | fname = os.path.join(self.model_dir, ckpt_name) 104 | self.saver.restore(self.sess, fname) 105 | logger.info("Load SUCCESS: %s" % fname) 106 | else: 107 | logger.info("Load FAILED: %s" % self.model_dir) 108 | 109 | self.t = self.t_add_op.eval(session=self.sess) 110 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import dateutil.tz 3 | 4 | def get_timestamp(): 5 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 6 | return now.strftime('%Y_%m_%d_%H_%M_%S') 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import tensorflow as tf 4 | 5 | from src.network import * 6 | 7 | pp = pprint.PrettyPrinter().pprint 8 | 9 | def get_model_dir(config, exceptions=None): 10 | 11 | attrs = config.__flags 12 | pp(attrs) 13 | 14 | keys = attrs.keys() 15 | keys.sort() 16 | keys.remove('env_name') 17 | keys = ['env_name'] + keys 18 | 19 | names =[] 20 | for key in keys: 21 | # Only use useful flags 22 | if key not in exceptions: 23 | names.append("%s=%s" % (key, ",".join([str(i) for i in attrs[key]]) 24 | if type(attrs[key]) == list else attrs[key])) 25 | return os.path.join('checkpoints', *names) + '/' 26 | 27 | def preprocess_conf(conf): 28 | options = conf.__flags 29 | 30 | for option, value in options.items(): 31 | option = option.lower() 32 | value = value.value 33 | 34 | if option == 'hidden_dims': 35 | conf.hidden_dims = eval(conf.hidden_dims) 36 | elif option == 'w_reg': 37 | if value == 'l1': 38 | w_reg = l1_regularizer(conf.w_reg_scale) 39 | elif value == 'l2': 40 | w_reg = l2_regularizer(conf.w_reg_scale) 41 | elif value == 'none': 42 | w_reg = None 43 | else: 44 | raise ValueError('Wrong weight regularizer %s: %s' % (option, value)) 45 | conf.w_reg = w_reg 46 | elif option.endswith('_w'): 47 | if value == 'uniform_small': 48 | weights_initializer = random_uniform_small 49 | elif value == 'uniform_big': 50 | weights_initializer = random_uniform_big 51 | elif value == 'he': 52 | weights_initializer = he_uniform 53 | else: 54 | raise ValueError('Wrong %s: %s' % (option, value)) 55 | setattr(conf, option, weights_initializer) 56 | elif option.endswith('_fn'): 57 | if value == 'tanh': 58 | activation_fn = tf.nn.tanh 59 | elif value == 'relu': 60 | activation_fn = tf.nn.relu 61 | elif value == 'none': 62 | activation_fn = None 63 | else: 64 | raise ValueError('Wrong %s: %s' % (option, value)) 65 | setattr(conf, option, activation_fn) 66 | --------------------------------------------------------------------------------