├── ddpg.pyc ├── filter_env.pyc ├── ou_noise.pyc ├── critic_network.pyc ├── replay_buffer.pyc ├── actor_network_bn.pyc ├── figures └── addpg_res.PNG ├── auto_run.sh ├── replay_buffer.py ├── LICENSE ├── ou_noise.py ├── README.md ├── filter_env.py ├── ddpg.py ├── critic_network.py ├── actor_network_bn.py └── gym_addpg.py /ddpg.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/ddpg.pyc -------------------------------------------------------------------------------- /filter_env.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/filter_env.pyc -------------------------------------------------------------------------------- /ou_noise.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/ou_noise.pyc -------------------------------------------------------------------------------- /critic_network.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/critic_network.pyc -------------------------------------------------------------------------------- /replay_buffer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/replay_buffer.pyc -------------------------------------------------------------------------------- /actor_network_bn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/actor_network_bn.pyc -------------------------------------------------------------------------------- /figures/addpg_res.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsikyoon/Asynchronous-DDPG_distributed_tensorflow/HEAD/figures/addpg_res.PNG -------------------------------------------------------------------------------- /auto_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ps_num=0 4 | worker_num=3 5 | 6 | for i in `eval echo {0..$ps_num}` 7 | do 8 | python gym_addpg.py --ps_hosts_num=$ps_num --worker_hosts_num=$worker_num --job_name=ps --task_index=$i & 9 | done 10 | 11 | for i in `eval echo {0..$worker_num}` 12 | do 13 | python gym_addpg.py --ps_hosts_num=$ps_num --worker_hosts_num=$worker_num --job_name=worker --task_index=$i & 14 | done 15 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import random 3 | 4 | class ReplayBuffer(object): 5 | 6 | def __init__(self, buffer_size): 7 | self.buffer_size = buffer_size 8 | self.num_experiences = 0 9 | self.buffer = deque() 10 | 11 | def get_batch(self, batch_size): 12 | # Randomly sample batch_size examples 13 | return random.sample(self.buffer, batch_size) 14 | 15 | def size(self): 16 | return self.buffer_size 17 | 18 | def add(self, state, action, reward, new_state, done): 19 | experience = (state, action, reward, new_state, done) 20 | if self.num_experiences < self.buffer_size: 21 | self.buffer.append(experience) 22 | self.num_experiences += 1 23 | else: 24 | self.buffer.popleft() 25 | self.buffer.append(experience) 26 | 27 | def count(self): 28 | # if buffer is full, return buffer size 29 | # otherwise, return experience counter 30 | return self.num_experiences 31 | 32 | def erase(self): 33 | self.buffer = deque() 34 | self.num_experiences = 0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Jaesik Yoon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ou_noise.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------- 2 | # Ornstein-Uhlenbeck Noise 3 | # Author: Flood Sung 4 | # Date: 2016.5.4 5 | # Reference: https://github.com/rllab/rllab/blob/master/rllab/exploration_strategies/ou_strategy.py 6 | # -------------------------------------- 7 | 8 | import numpy as np 9 | import numpy.random as nr 10 | 11 | class OUNoise: 12 | """docstring for OUNoise""" 13 | def __init__(self,action_dimension,mu=0, theta=0.15, sigma=0.2): 14 | self.action_dimension = action_dimension 15 | self.mu = mu 16 | self.theta = theta 17 | self.sigma = sigma 18 | self.state = np.ones(self.action_dimension) * self.mu 19 | self.reset() 20 | 21 | def reset(self): 22 | self.state = np.ones(self.action_dimension) * self.mu 23 | 24 | def noise(self): 25 | x = self.state 26 | dx = self.theta * (self.mu - x) + self.sigma * nr.randn(len(x)) 27 | self.state = x + dx 28 | return self.state 29 | 30 | if __name__ == '__main__': 31 | ou = OUNoise(3) 32 | states = [] 33 | for i in range(1000): 34 | states.append(ou.noise()) 35 | import matplotlib.pyplot as plt 36 | 37 | plt.plot(states) 38 | plt.show() 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Asynchronous-DDPG_distributed_tensorflow 2 | =========== 3 | 4 | Distributed Tensorflow Implementation of asynchronous ddpg. 5 | 6 | Implementation is on Tensorflow 1.2.1. 7 | 8 | DDPG script is based on songrotek's repo. https://github.com/songrotek/DDPG.git 9 | 10 | One of popular pain-points of reinforcement learning is too long learning time. Thus, A3C was proposed for parallel learning to efficiently learn the agent. However, for DDPG, one of strong alogrithm for continuous action episode, there are a few research for parallel learning. One of them is intentional unintentional agent, which is to learn several tasks simultaneously (https://arxiv.org/abs/1707.03300). In here, I validate parallel learning of ddpg for simpler experiment than IU agent's one. Each workers learn just one task. After learning several episodes, their training information is merged with parameter server. 11 | 12 | GYM Reacher-v1 game 13 | ------------------- 14 | 15 | ` 16 | ./auto_run.sh 17 | ` 18 | 19 | You need to set your hostname and port number in gym_addpg.py code. The number of parameter servers and workers can be set in auto_run.sh script file. 20 | 21 | ### Settings 22 | 23 | Almost Settings are same to songrotek's ones, except learning rate of critic networks. 24 | The number of parameter server and workers is 1 and 4, repectively. 25 | 26 | ### Results 27 | 28 | ![alt tag](https://github.com/jaesik817/Asynchronous-DDPG_distributed_tensorflow/blob/master/figures/addpg_res.PNG) 29 | 30 | 31 | -------------------------------------------------------------------------------- /filter_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | 4 | def makeFilteredEnv(env): 5 | """ crate a new environment class with actions and states normalized to [-1,1] """ 6 | acsp = env.action_space 7 | obsp = env.observation_space 8 | if not type(acsp)==gym.spaces.box.Box: 9 | raise RuntimeError('Environment with continous action space (i.e. Box) required.') 10 | if not type(obsp)==gym.spaces.box.Box: 11 | raise RuntimeError('Environment with continous observation space (i.e. Box) required.') 12 | 13 | env_type = type(env) 14 | 15 | class FilteredEnv(env_type): 16 | def __init__(self): 17 | self.__dict__.update(env.__dict__) # transfer properties 18 | 19 | # Observation space 20 | if np.any(obsp.high < 1e10): 21 | h = obsp.high 22 | l = obsp.low 23 | sc = h-l 24 | self.o_c = (h+l)/2. 25 | self.o_sc = sc / 2. 26 | else: 27 | self.o_c = np.zeros_like(obsp.high) 28 | self.o_sc = np.ones_like(obsp.high) 29 | 30 | # Action space 31 | h = acsp.high 32 | l = acsp.low 33 | sc = (h-l) 34 | self.a_c = (h+l)/2. 35 | self.a_sc = sc / 2. 36 | 37 | # Rewards 38 | self.r_sc = 0.1 39 | self.r_c = 0. 40 | 41 | # Special cases 42 | if self.spec.id == "Reacher-v1": 43 | self.o_sc[6] = 40. 44 | self.o_sc[7] = 20. 45 | self.r_sc = 200. 46 | self.r_c = 0. 47 | # Check and assign transformed spaces 48 | self.observation_space = gym.spaces.Box(self.filter_observation(obsp.low), 49 | self.filter_observation(obsp.high)) 50 | self.action_space = gym.spaces.Box(-np.ones_like(acsp.high),np.ones_like(acsp.high)) 51 | def assertEqual(a,b): assert np.all(a == b), "{} != {}".format(a,b) 52 | assertEqual(self.filter_action(self.action_space.low), acsp.low) 53 | assertEqual(self.filter_action(self.action_space.high), acsp.high) 54 | 55 | def filter_observation(self,obs): 56 | return (obs-self.o_c) / self.o_sc 57 | 58 | def filter_action(self,action): 59 | return self.a_sc*action+self.a_c 60 | 61 | def filter_reward(self,reward): 62 | ''' has to be applied manually otherwise it makes the reward_threshold invalid ''' 63 | return self.r_sc*reward+self.r_c 64 | 65 | def step(self,action): 66 | 67 | ac_f = np.clip(self.filter_action(action),self.action_space.low,self.action_space.high) 68 | 69 | obs, reward, term, info = env_type.step(self,ac_f) # super function 70 | 71 | reward=self.filter_reward(reward); 72 | 73 | obs_f = self.filter_observation(obs) 74 | 75 | return obs_f, reward, term, info 76 | 77 | fenv = FilteredEnv() 78 | 79 | print('True action space: ' + str(acsp.low) + ', ' + str(acsp.high)) 80 | print('True state space: ' + str(obsp.low) + ', ' + str(obsp.high)) 81 | print('Filtered action space: ' + str(fenv.action_space.low) + ', ' + str(fenv.action_space.high)) 82 | print('Filtered state space: ' + str(fenv.observation_space.low) + ', ' + str(fenv.observation_space.high)) 83 | 84 | return fenv 85 | -------------------------------------------------------------------------------- /ddpg.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------- 2 | # Deep Deterministic Policy Gradient 3 | # Author: Flood Sung 4 | # Date: 2016.5.4 5 | # ----------------------------------- 6 | import gym 7 | import tensorflow as tf 8 | import numpy as np 9 | from ou_noise import OUNoise 10 | from critic_network import CriticNetwork 11 | from actor_network_bn import ActorNetwork 12 | from replay_buffer import ReplayBuffer 13 | 14 | # Hyper Parameters: 15 | 16 | REPLAY_BUFFER_SIZE = 1000000 17 | REPLAY_START_SIZE = 10000 18 | BATCH_SIZE = 64 19 | GAMMA = 0.99 20 | 21 | 22 | class DDPG: 23 | """docstring for DDPG""" 24 | def __init__(self, env,device): 25 | self.name = 'DDPG' # name for uploading results 26 | self.environment = env 27 | self.device=device 28 | # Randomly initialize actor network and critic network 29 | # with both their target networks 30 | self.state_dim = env.observation_space.shape[0] 31 | self.action_dim = env.action_space.shape[0] 32 | 33 | with tf.device(self.device): 34 | self.actor_network = ActorNetwork(self.state_dim,self.action_dim) 35 | self.critic_network = CriticNetwork(self.state_dim,self.action_dim) 36 | 37 | # initialize replay buffer 38 | self.replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE) 39 | 40 | # Initialize a random process the Ornstein-Uhlenbeck process for action exploration 41 | self.exploration_noise = OUNoise(self.action_dim) 42 | 43 | def set_sess(self,sess): 44 | self.actor_network.set_sess(sess); 45 | self.critic_network.set_sess(sess); 46 | 47 | def train(self): 48 | #print "train step",self.time_step 49 | # Sample a random minibatch of N transitions from replay buffer 50 | minibatch = self.replay_buffer.get_batch(BATCH_SIZE) 51 | state_batch = np.asarray([data[0] for data in minibatch]) 52 | action_batch = np.asarray([data[1] for data in minibatch]) 53 | reward_batch = np.asarray([data[2] for data in minibatch]) 54 | next_state_batch = np.asarray([data[3] for data in minibatch]) 55 | done_batch = np.asarray([data[4] for data in minibatch]) 56 | 57 | # for action_dim = 1 58 | action_batch = np.resize(action_batch,[BATCH_SIZE,self.action_dim]) 59 | 60 | # Calculate y_batch 61 | 62 | next_action_batch = self.actor_network.target_actions(next_state_batch) 63 | q_value_batch = self.critic_network.target_q(next_state_batch,next_action_batch) 64 | y_batch = [] 65 | for i in range(len(minibatch)): 66 | if done_batch[i]: 67 | y_batch.append(reward_batch[i]) 68 | else : 69 | y_batch.append(reward_batch[i] + GAMMA * q_value_batch[i]) 70 | y_batch = np.resize(y_batch,[BATCH_SIZE,1]) 71 | # Update critic by minimizing the loss L 72 | self.critic_network.train(y_batch,state_batch,action_batch) 73 | 74 | # Update the actor policy using the sampled gradient: 75 | action_batch_for_gradients = self.actor_network.actions(state_batch) 76 | q_gradient_batch = self.critic_network.gradients(state_batch,action_batch_for_gradients) 77 | 78 | self.actor_network.train(q_gradient_batch,state_batch) 79 | 80 | # Update the target networks 81 | self.actor_network.update_target() 82 | self.critic_network.update_target() 83 | 84 | def noise_action(self,state): 85 | # Select action a_t according to the current policy and exploration noise 86 | action = self.actor_network.action(state) 87 | return action+self.exploration_noise.noise() 88 | 89 | def action(self,state): 90 | action = self.actor_network.action(state) 91 | return action 92 | 93 | def perceive(self,state,action,reward,next_state,done): 94 | # Store transition (s_t,a_t,r_t,s_{t+1}) in replay buffer 95 | self.replay_buffer.add(state,action,reward,next_state,done) 96 | 97 | # Store transitions to replay start size then start training 98 | if self.replay_buffer.count() > REPLAY_START_SIZE: 99 | self.train() 100 | 101 | #if self.time_step % 10000 == 0: 102 | #self.actor_network.save_network(self.time_step) 103 | #self.critic_network.save_network(self.time_step) 104 | 105 | # Re-iniitialize the random process when an episode ends 106 | if done: 107 | self.exploration_noise.reset() 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /critic_network.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | 6 | 7 | LAYER1_SIZE = 400 8 | LAYER2_SIZE = 300 9 | LEARNING_RATE = 1e-3 10 | TAU = 0.001 11 | L2 = 0.01 12 | 13 | class CriticNetwork: 14 | """docstring for CriticNetwork""" 15 | def __init__(self,state_dim,action_dim): 16 | self.time_step = 0 17 | # create q network 18 | self.state_input,\ 19 | self.action_input,\ 20 | self.q_value_output,\ 21 | self.net = self.create_q_network(state_dim,action_dim) 22 | 23 | # create target q network (the same structure with q network) 24 | self.target_state_input,\ 25 | self.target_action_input,\ 26 | self.target_q_value_output,\ 27 | self.target_update = self.create_target_q_network(state_dim,action_dim,self.net) 28 | 29 | self.create_training_method() 30 | 31 | #self.update_target() 32 | 33 | def set_sess(self,sess): 34 | self.sess=sess; 35 | #self.sess.run(tf.initialize_all_variables()) 36 | #self.update_target() 37 | 38 | def create_training_method(self): 39 | # Define training optimizer 40 | self.y_input = tf.placeholder("float",[None,1]) 41 | weight_decay = tf.add_n([L2 * tf.nn.l2_loss(var) for var in self.net]) 42 | self.cost = tf.reduce_mean(tf.square(self.y_input - self.q_value_output)) + weight_decay 43 | self.optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.cost) 44 | self.action_gradients = tf.gradients(self.q_value_output,self.action_input) 45 | 46 | def create_q_network(self,state_dim,action_dim): 47 | # the layer size could be changed 48 | layer1_size = LAYER1_SIZE 49 | layer2_size = LAYER2_SIZE 50 | 51 | state_input = tf.placeholder("float",[None,state_dim]) 52 | action_input = tf.placeholder("float",[None,action_dim]) 53 | 54 | W1 = self.variable([state_dim,layer1_size],state_dim) 55 | b1 = self.variable([layer1_size],state_dim) 56 | W2 = self.variable([layer1_size,layer2_size],layer1_size+action_dim) 57 | W2_action = self.variable([action_dim,layer2_size],layer1_size+action_dim) 58 | b2 = self.variable([layer2_size],layer1_size+action_dim) 59 | W3 = tf.Variable(tf.random_uniform([layer2_size,1],-3e-3,3e-3)) 60 | b3 = tf.Variable(tf.random_uniform([1],-3e-3,3e-3)) 61 | 62 | layer1 = tf.nn.relu(tf.matmul(state_input,W1) + b1) 63 | layer2 = tf.nn.relu(tf.matmul(layer1,W2) + tf.matmul(action_input,W2_action) + b2) 64 | q_value_output = tf.identity(tf.matmul(layer2,W3) + b3) 65 | 66 | return state_input,action_input,q_value_output,[W1,b1,W2,W2_action,b2,W3,b3] 67 | 68 | def create_target_q_network(self,state_dim,action_dim,net): 69 | state_input = tf.placeholder("float",[None,state_dim]) 70 | action_input = tf.placeholder("float",[None,action_dim]) 71 | 72 | ema = tf.train.ExponentialMovingAverage(decay=1-TAU) 73 | target_update = ema.apply(net) 74 | target_net = [ema.average(x) for x in net] 75 | 76 | layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1]) 77 | layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4]) 78 | q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6]) 79 | 80 | return state_input,action_input,q_value_output,target_update 81 | 82 | def update_target(self): 83 | self.sess.run(self.target_update) 84 | 85 | def train(self,y_batch,state_batch,action_batch): 86 | self.time_step += 1 87 | self.sess.run(self.optimizer,feed_dict={ 88 | self.y_input:y_batch, 89 | self.state_input:state_batch, 90 | self.action_input:action_batch 91 | }) 92 | 93 | def gradients(self,state_batch,action_batch): 94 | return self.sess.run(self.action_gradients,feed_dict={ 95 | self.state_input:state_batch, 96 | self.action_input:action_batch 97 | })[0] 98 | 99 | def target_q(self,state_batch,action_batch): 100 | return self.sess.run(self.target_q_value_output,feed_dict={ 101 | self.target_state_input:state_batch, 102 | self.target_action_input:action_batch 103 | }) 104 | 105 | def q_value(self,state_batch,action_batch): 106 | return self.sess.run(self.q_value_output,feed_dict={ 107 | self.state_input:state_batch, 108 | self.action_input:action_batch}) 109 | 110 | # f fan-in size 111 | def variable(self,shape,f): 112 | return tf.Variable(tf.random_uniform(shape,-1/math.sqrt(f),1/math.sqrt(f))) 113 | ''' 114 | def load_network(self): 115 | self.saver = tf.train.Saver() 116 | checkpoint = tf.train.get_checkpoint_state("saved_critic_networks") 117 | if checkpoint and checkpoint.model_checkpoint_path: 118 | self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 119 | print "Successfully loaded:", checkpoint.model_checkpoint_path 120 | else: 121 | print "Could not find old network weights" 122 | 123 | def save_network(self,time_step): 124 | print 'save critic-network...',time_step 125 | self.saver.save(self.sess, 'saved_critic_networks/' + 'critic-network', global_step = time_step) 126 | ''' 127 | 128 | -------------------------------------------------------------------------------- /actor_network_bn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm 3 | import numpy as np 4 | import math 5 | 6 | 7 | # Hyper Parameters 8 | LAYER1_SIZE = 400 9 | LAYER2_SIZE = 300 10 | LEARNING_RATE = 1e-4 11 | TAU = 0.001 12 | BATCH_SIZE = 64 13 | 14 | class ActorNetwork: 15 | """docstring for ActorNetwork""" 16 | def __init__(self,state_dim,action_dim): 17 | 18 | self.state_dim = state_dim 19 | self.action_dim = action_dim 20 | # create actor network 21 | self.state_input,self.action_output,self.net,self.is_training = self.create_network(state_dim,action_dim) 22 | 23 | # create target actor network 24 | self.target_state_input,self.target_action_output,self.target_update,self.target_is_training = self.create_target_network(state_dim,action_dim,self.net) 25 | 26 | # define training rules 27 | self.create_training_method() 28 | 29 | #self.update_target() 30 | #self.load_network() 31 | 32 | def set_sess(self,sess): 33 | self.sess=sess; 34 | #self.sess.run(tf.initialize_all_variables()) 35 | #self.update_target() 36 | 37 | def create_training_method(self): 38 | self.q_gradient_input = tf.placeholder("float",[None,self.action_dim]) 39 | self.parameters_gradients = tf.gradients(self.action_output,self.net,-self.q_gradient_input) 40 | self.optimizer = tf.train.AdamOptimizer(LEARNING_RATE).apply_gradients(zip(self.parameters_gradients,self.net)) 41 | 42 | def create_network(self,state_dim,action_dim): 43 | layer1_size = LAYER1_SIZE 44 | layer2_size = LAYER2_SIZE 45 | 46 | state_input = tf.placeholder("float",[None,state_dim]) 47 | is_training = tf.placeholder(tf.bool) 48 | 49 | W1 = self.variable([state_dim,layer1_size],state_dim) 50 | b1 = self.variable([layer1_size],state_dim) 51 | W2 = self.variable([layer1_size,layer2_size],layer1_size) 52 | b2 = self.variable([layer2_size],layer1_size) 53 | W3 = tf.Variable(tf.random_uniform([layer2_size,action_dim],-3e-3,3e-3)) 54 | b3 = tf.Variable(tf.random_uniform([action_dim],-3e-3,3e-3)) 55 | 56 | layer0_bn = self.batch_norm_layer(state_input,training_phase=is_training,scope_bn='batch_norm_0',activation=tf.identity) 57 | layer1 = tf.matmul(layer0_bn,W1) + b1 58 | layer1_bn = self.batch_norm_layer(layer1,training_phase=is_training,scope_bn='batch_norm_1',activation=tf.nn.relu) 59 | layer2 = tf.matmul(layer1_bn,W2) + b2 60 | layer2_bn = self.batch_norm_layer(layer2,training_phase=is_training,scope_bn='batch_norm_2',activation=tf.nn.relu) 61 | 62 | action_output = tf.tanh(tf.matmul(layer2_bn,W3) + b3) 63 | 64 | return state_input,action_output,[W1,b1,W2,b2,W3,b3],is_training 65 | 66 | def create_target_network(self,state_dim,action_dim,net): 67 | state_input = tf.placeholder("float",[None,state_dim]) 68 | is_training = tf.placeholder(tf.bool) 69 | ema = tf.train.ExponentialMovingAverage(decay=1-TAU) 70 | target_update = ema.apply(net) 71 | target_net = [ema.average(x) for x in net] 72 | 73 | layer0_bn = self.batch_norm_layer(state_input,training_phase=is_training,scope_bn='target_batch_norm_0',activation=tf.identity) 74 | 75 | layer1 = tf.matmul(layer0_bn,target_net[0]) + target_net[1] 76 | layer1_bn = self.batch_norm_layer(layer1,training_phase=is_training,scope_bn='target_batch_norm_1',activation=tf.nn.relu) 77 | layer2 = tf.matmul(layer1_bn,target_net[2]) + target_net[3] 78 | layer2_bn = self.batch_norm_layer(layer2,training_phase=is_training,scope_bn='target_batch_norm_2',activation=tf.nn.relu) 79 | 80 | action_output = tf.tanh(tf.matmul(layer2_bn,target_net[4]) + target_net[5]) 81 | 82 | return state_input,action_output,target_update,is_training 83 | 84 | def update_target(self): 85 | self.sess.run(self.target_update) 86 | 87 | def train(self,q_gradient_batch,state_batch): 88 | self.sess.run(self.optimizer,feed_dict={ 89 | self.q_gradient_input:q_gradient_batch, 90 | self.state_input:state_batch, 91 | self.is_training: True 92 | }) 93 | 94 | def actions(self,state_batch): 95 | return self.sess.run(self.action_output,feed_dict={ 96 | self.state_input:state_batch, 97 | self.is_training: True 98 | }) 99 | 100 | def action(self,state): 101 | return self.sess.run(self.action_output,feed_dict={ 102 | self.state_input:[state], 103 | self.is_training: False 104 | })[0] 105 | 106 | 107 | def target_actions(self,state_batch): 108 | return self.sess.run(self.target_action_output,feed_dict={ 109 | self.target_state_input: state_batch, 110 | self.target_is_training: True 111 | }) 112 | 113 | # f fan-in size 114 | def variable(self,shape,f): 115 | return tf.Variable(tf.random_uniform(shape,-1/math.sqrt(f),1/math.sqrt(f))) 116 | 117 | 118 | def batch_norm_layer(self,x,training_phase,scope_bn,activation=None): 119 | return tf.cond(training_phase, 120 | lambda: tf.contrib.layers.batch_norm(x, activation_fn=activation, center=True, scale=True, 121 | updates_collections=None,is_training=True, reuse=None,scope=scope_bn,decay=0.9, epsilon=1e-5), 122 | lambda: tf.contrib.layers.batch_norm(x, activation_fn =activation, center=True, scale=True, 123 | updates_collections=None,is_training=False, reuse=True,scope=scope_bn,decay=0.9, epsilon=1e-5)) 124 | ''' 125 | def load_network(self): 126 | self.saver = tf.train.Saver() 127 | checkpoint = tf.train.get_checkpoint_state("saved_actor_networks") 128 | if checkpoint and checkpoint.model_checkpoint_path: 129 | self.saver.restore(self.sess, checkpoint.model_checkpoint_path) 130 | print "Successfully loaded:", checkpoint.model_checkpoint_path 131 | else: 132 | print "Could not find old network weights" 133 | def save_network(self,time_step): 134 | print 'save actor-network...',time_step 135 | self.saver.save(self.sess, 'saved_actor_networks/' + 'actor-network', global_step = time_step) 136 | 137 | ''' 138 | 139 | 140 | -------------------------------------------------------------------------------- /gym_addpg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import random 5 | import os 6 | import time 7 | import sys 8 | import argparse 9 | import filter_env 10 | from gym import wrappers 11 | from ddpg import * 12 | import gc 13 | gc.enable() 14 | 15 | FLAGS=None; 16 | ENV_NAME = 'Reacher-v1' 17 | EPISODES = 100000 18 | local_step=1 19 | TEST=10 20 | 21 | def train(): 22 | # parameter server and worker information 23 | ps_hosts = np.zeros(FLAGS.ps_hosts_num,dtype=object); 24 | worker_hosts = np.zeros(FLAGS.worker_hosts_num,dtype=object); 25 | port_num=FLAGS.st_port_num; 26 | for i in range(FLAGS.ps_hosts_num): 27 | ps_hosts[i]=str(FLAGS.hostname)+":"+str(port_num); 28 | port_num+=1; 29 | for i in range(FLAGS.worker_hosts_num): 30 | worker_hosts[i]=str(FLAGS.hostname)+":"+str(port_num); 31 | port_num+=1; 32 | ps_hosts=list(ps_hosts); 33 | worker_hosts=list(worker_hosts); 34 | # Create a cluster from the parameter server and worker hosts. 35 | cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) 36 | # Create and start a server for the local task. 37 | server = tf.train.Server(cluster, 38 | job_name=FLAGS.job_name, 39 | task_index=FLAGS.task_index) 40 | if FLAGS.job_name == "ps": 41 | server.join(); 42 | elif FLAGS.job_name == "worker": 43 | device=tf.train.replica_device_setter( 44 | worker_device="/job:worker/task:%d" % FLAGS.task_index, 45 | cluster=cluster); 46 | 47 | #tf.set_random_seed(1); 48 | # env and model call 49 | env = filter_env.makeFilteredEnv(gym.make(ENV_NAME)) 50 | agent = DDPG(env,device) 51 | 52 | # prepare session 53 | with tf.device(tf.train.replica_device_setter( 54 | worker_device="/job:worker/task:%d" % FLAGS.task_index, 55 | cluster=cluster)): 56 | global_step = tf.get_variable('global_step',[],initializer=tf.constant_initializer(0),trainable=False); 57 | global_step_ph=tf.placeholder(global_step.dtype,shape=global_step.get_shape()); 58 | global_step_ops=global_step.assign(global_step_ph); 59 | score = tf.get_variable('score',[],initializer=tf.constant_initializer(-21),trainable=False); 60 | score_ph=tf.placeholder(score.dtype,shape=score.get_shape()); 61 | score_ops=score.assign(score_ph); 62 | init_op=tf.global_variables_initializer(); 63 | # summary for tensorboard 64 | tf.summary.scalar("score", score); 65 | summary_op = tf.summary.merge_all() 66 | saver = tf.train.Saver(); 67 | 68 | sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), 69 | global_step=global_step, 70 | logdir=FLAGS.log_dir, 71 | summary_op=summary_op, 72 | saver=saver, 73 | init_op=init_op) 74 | 75 | with sv.managed_session(server.target) as sess: 76 | agent.set_sess(sess); 77 | while True: 78 | if sess.run([global_step])[0] > EPISODES: 79 | break 80 | score=0; 81 | for ls in range(local_step): 82 | state = env.reset(); 83 | for step in xrange(env.spec.timestep_limit): 84 | action = agent.noise_action(state) 85 | next_state,reward,done,_ = env.step(action) 86 | agent.perceive(state,action,reward,next_state,done) 87 | state = next_state 88 | if done: 89 | break; 90 | for i in xrange(TEST): 91 | state = env.reset() 92 | for j in xrange(env.spec.timestep_limit): 93 | #env.render() 94 | action = agent.action(state) # direct action for test 95 | state,reward,done,_ = env.step(action) 96 | score += reward 97 | if done: 98 | break 99 | sess.run(global_step_ops,{global_step_ph:sess.run([global_step])[0]+local_step}); 100 | sess.run(score_ops,{score_ph:score/TEST/200}); 101 | print(str(FLAGS.task_index)+","+str(sess.run([global_step])[0])+","+str(score/TEST/200)); 102 | sv.stop(); 103 | print("Done"); 104 | 105 | def main(_): 106 | #os.system("rm -rf "+FLAGS.log_dir); 107 | FLAGS.ps_hosts_num+=1; 108 | FLAGS.worker_hosts_num+=1; 109 | train() 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser() 113 | parser.register("type", "bool", lambda v: v.lower() == "true") 114 | # Flags for defining the tf.train.ClusterSpec 115 | parser.add_argument( 116 | "--ps_hosts_num", 117 | type=int, 118 | default=5, 119 | help="The Number of Parameter Servers" 120 | ) 121 | parser.add_argument( 122 | "--worker_hosts_num", 123 | type=int, 124 | default=10, 125 | help="The Number of Workers" 126 | ) 127 | parser.add_argument( 128 | "--hostname", 129 | type=str, 130 | default="jaesik-System-Product-Name", 131 | help="The Hostname of the machine" 132 | ) 133 | parser.add_argument( 134 | "--st_port_num", 135 | type=int, 136 | default=2222, 137 | help="The start port number of ps and worker servers" 138 | ) 139 | parser.add_argument( 140 | "--job_name", 141 | type=str, 142 | default="", 143 | help="One of 'ps', 'worker'" 144 | ) 145 | # Flags for defining the tf.train.Server 146 | parser.add_argument( 147 | "--task_index", 148 | type=int, 149 | default=0, 150 | help="Index of task within the job" 151 | ) 152 | # Log folder 153 | parser.add_argument( 154 | "--log_dir", 155 | type=str, 156 | default="/tmp/addpg_log/", 157 | help="log folder name" 158 | ) 159 | FLAGS, unparsed = parser.parse_known_args() 160 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 161 | --------------------------------------------------------------------------------