├── .gitignore ├── LICENSE ├── README.md ├── base_network.py ├── bullet_cartpole.py ├── cartpole.gif ├── cartpole.png ├── ddpg_cartpole.py ├── deciles.py ├── dqn_cartpole.py ├── eg_render.png ├── event.proto ├── event_log.py ├── event_log_sample.push_in_one_dir.gif ├── exps ├── run_81.sh ├── run_82.sh ├── run_83.sh ├── run_84.sh ├── run_85.sh ├── run_86.sh ├── run_87.sh ├── run_89.sh ├── run_90.sh ├── run_91.sh ├── run_92.sh ├── run_93.sh ├── run_97.sh ├── run_98.sh └── runs_77_plot.sh ├── lrpg_cartpole.py ├── make_plots.sh ├── models ├── cart.urdf ├── ground.urdf └── pole.urdf ├── naf_cartpole.py ├── plots.R ├── random_action_agent.py ├── random_plots.R ├── replay_memory.py ├── replay_memory_test.py ├── run_diff.py ├── stitch_activations.py ├── u ├── parse_gradient_logging.py ├── parse_out.py ├── parse_out_ddpg.py ├── parse_out_eval.py ├── parse_out_eval_with_time.py ├── parse_out_naf.py ├── parse_runs.py ├── random_params.py └── test_render.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs 2 | event_logs 3 | event_pb2.py 4 | ckpts 5 | *h5 6 | *out 7 | *stats 8 | *tsv 9 | venv 10 | *pyc 11 | pybullet.so 12 | *hdf5 13 | *~ 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Mat Kelcey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cartpole ++ 2 | 3 | cartpole++ is a non trivial 3d version of cartpole 4 | simulated using [bullet physics](http://bulletphysics.org/) where the pole _isn't_ connected to the cart. 5 | 6 | ![cartpole](cartpole.gif) 7 | 8 | this repo contains a [gym env](https://gym.openai.com/) for this cartpole as well as implementations for training with ... 9 | 10 | * likelihood ratio policy gradients ( [lrpg_cartpole.py](lrpg_cartpole.py) ) for the discrete control version 11 | * [deep deterministic policy gradients](http://arxiv.org/abs/1509.02971) ( [ddpg_cartpole.py](ddpg_cartpole.py) ) for the continuous control version 12 | * [normalised advantage functions](https://arxiv.org/abs/1603.00748) ( [naf_cartpole.py](naf_cartpole.py) ) also for the continuous control version 13 | 14 | we also train a [deep q network](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) from [keras-rl](https://github.com/matthiasplappert/keras-rl) as an externally implemented baseline. 15 | 16 | for more info see [the blog post](http://matpalm.com/blog/cartpole_plus_plus/). 17 | for experiments more related to potential transfer between sim and real see [drivebot](https://github.com/matpalm/drivebot). 18 | for next gen experiments in continuous control from pixels in minecraft see [malmomo](https://github.com/matpalm/malmomo) 19 | 20 | ## general environment 21 | 22 | episodes are initialised with the pole standing upright and receiving a small push in a random direction 23 | 24 | episodes are terminated when either the pole further than a set angle from vertical or 200 steps have passed 25 | 26 | there are two state representations available; a low dimensional one based on the cart & pole pose and a high dimensional one based on raw pixels (see below) 27 | 28 | there are two options for controlling the cart; a discrete and continuous method (see below) 29 | 30 | reward is simply 1.0 for each step in the episode 31 | 32 | ### states 33 | 34 | in both low and high dimensional representations we use the idea of action repeats; 35 | per env.step we apply the chosen action N times, take a state snapshot and repeat this R times. 36 | the deltas between these snapshots provides enough information 37 | to infer velocity (or acceleration (or jerk)) if the learning algorithm finds that useful to do. 38 | 39 | observation state in the low dimensional case is constructed from the poses of the cart & pole 40 | * it's shaped `(R, 2, 7)` 41 | * axis 0 represents the R repeats 42 | * axis 1 represents the object; 0=cart, 1=pole 43 | * axis 2 is the 7d pose; 3d postition + 4d quaternion orientation 44 | * this representation is usually just flattened to (R*14,) when used 45 | 46 | the high dimensional state is a rendering of the scene 47 | * it's shaped `(height, width, 3, R, C)` 48 | * axis 0 & 1 are the rendering height/width in pixels 49 | * axis 2 represents the 3 colour channels; red, green and blue 50 | * axis 3 represents the R repeats 51 | * axis 4 represents which camera the image is from; we have the option of rendering with one camera or two (located at right angles to each other) 52 | * this representation is flattened to have shape `(H, W, 3*R*C)`. we do this for ease of use of conv2d operations. (TODO: try conv3d instead) 53 | 54 | ![eg_render](eg_render.png) 55 | 56 | ### actions 57 | 58 | in the discrete case the actions are push cart left, right, up, down or do nothing. 59 | 60 | in the continuous case the action is a 2d value representing the push force in x and y direction (-1 to 1) 61 | 62 | ### rewards 63 | 64 | in all cases we give a reward of 1 for each step and terminate the episode when either 200 steps have passed or 65 | the pole has fallen too far from the z-axis 66 | 67 | ## agents 68 | 69 | ### random agent 70 | 71 | we use a random action agent (click through for video) to sanity check the setup. 72 | add `--gui` to any of these to get a rendering 73 | 74 | [![link](https://img.youtube.com/vi/buSAT-3Q8Zs/0.jpg)](https://www.youtube.com/watch?v=buSAT-3Q8Zs) 75 | 76 | ``` 77 | # no initial push and taking no action (action=0) results in episode timeout of 200 steps. 78 | # this is a check of the stability of the pole under no forces 79 | $ ./random_action_agent.py --initial-force=0 --actions="0" --num-eval=100 | ./deciles.py 80 | [ 200. 200. 200. 200. 200. 200. 200. 200. 200. 200. 200.] 81 | 82 | # no initial push and random actions knocks pole over 83 | $ ./random_action_agent.py --initial-force=0 --actions="0,1,2,3,4" --num-eval=100 | ./deciles.py 84 | [ 16. 22.9 26. 28. 31.6 35. 37.4 42.3 48.4 56.1 79. ] 85 | 86 | # initial push and no action knocks pole over 87 | $ ./random_action_agent.py --initial-force=55 --actions="0" --num-eval=100 | ./deciles.py 88 | [ 6. 7. 7. 8. 8.6 9. 11. 12.3 15. 21. 39. ] 89 | 90 | # initial push and random action knocks pole over 91 | $ ./random_action_agent.py --initial-force=55 --actions="0,1,2,3,4" --num-eval=100 | ./deciles.py 92 | [ 3. 5.9 7. 7.7 8. 9. 10. 11. 13. 15. 32. ] 93 | ``` 94 | 95 | ### discrete control with a deep q network 96 | 97 | ``` 98 | $ ./dqn_cartpole.py \ 99 | --num-train=2000000 --num-eval=0 \ 100 | --save-file=ckpt.h5 101 | ``` 102 | 103 | result by numbers... 104 | 105 | ``` 106 | $ ./dqn_cartpole.py \ 107 | --load-file=ckpt.h5 \ 108 | --num-train=0 --num-eval=100 \ 109 | | grep ^Episode | sed -es/.*steps:// | ./deciles.py 110 | [ 5. 35.5 49.8 63.4 79. 104.5 122. 162.6 184. 200. 200. ] 111 | ``` 112 | 113 | result visually (click through for video) 114 | 115 | [![link](https://img.youtube.com/vi/zteyMIvhn1U/0.jpg)](https://www.youtube.com/watch?v=zteyMIvhn1U) 116 | 117 | ``` 118 | $ ./dqn_cartpole.py \ 119 | --gui --delay=0.005 \ 120 | --load-file=run11_50.weights.2.h5 \ 121 | --num-train=0 --num-eval=100 122 | ``` 123 | 124 | ### discrete control with likelihood ratio policy gradient 125 | 126 | policy gradient nails it 127 | 128 | ``` 129 | $ ./lrpg_cartpole.py --rollouts-per-batch=20 --num-train-batches=100 \ 130 | --ckpt-dir=ckpts/foo 131 | ``` 132 | 133 | result by numbers... 134 | 135 | ``` 136 | # deciles 137 | [ 13. 70.6 195.8 200. 200. 200. 200. 200. 200. 200. 200. ] 138 | ``` 139 | 140 | result visually (click through for video) 141 | 142 | [![link](https://img.youtube.com/vi/aricda9gs2I/0.jpg)](https://www.youtube.com/watch?v=aricda9gs2I) 143 | 144 | ### continuous control with deep deterministic policy gradient 145 | 146 | ``` 147 | ./ddpg_cartpole.py \ 148 | --actor-hidden-layers="100,100,50" --critic-hidden-layers="100,100,50" \ 149 | --action-force=100 --action-noise-sigma=0.1 --batch-size=256 \ 150 | --max-num-actions=1000000 --ckpt-dir=ckpts/run43 151 | ``` 152 | 153 | result by numbers 154 | 155 | ``` 156 | # episode len deciles 157 | [ 30. 48. 56.8 65. 73. 86. 116.4 153.3 200. 200. 200. ] 158 | # reward deciles 159 | [ 35.51154724 153.20243076 178.7908135 243.38630372 272.64655323 160 | 426.95298195 519.25360223 856.9702368 890.72279221 913.21068417 161 | 955.50168709] 162 | ``` 163 | 164 | result visually (click through for video) 165 | 166 | [![link](https://img.youtube.com/vi/8X05GA5ZKvQ/0.jpg)](https://www.youtube.com/watch?v=8X05GA5ZKvQ) 167 | 168 | ### low dimensional continuous control with normalised advantage functions 169 | 170 | ``` 171 | ./naf_cartpole.py --action-force=100 \ 172 | --action-repeats=3 --steps-per-repeat=4 \ 173 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 174 | ``` 175 | 176 | similiar convergence to ddpg 177 | 178 | ### high dimensional continuous control with normalised advantage functions 179 | 180 | does OK, but not perfect yet. as a human it's hard to do even... (see [the blog post](http://matpalm.com/blog/cartpole_plus_plus/)) 181 | 182 | ## general utils 183 | 184 | run a random agent, logging events to disk (outputs total rewards per episode) 185 | 186 | note: for replay logging will need to compile protobuffer `protoc event.proto --python_out=.` 187 | 188 | ``` 189 | $ ./random_action_agent.py --event-log=test.log --num-eval=10 --action-type=continuous 190 | 12 191 | 14 192 | ... 193 | ``` 194 | 195 | review event.log (either from ddpg training or from random agent) 196 | 197 | ``` 198 | $ ./event_log.py --log-file=test.log --echo 199 | event { 200 | state { 201 | cart_pose: 0.116232253611 202 | cart_pose: 0.0877446383238 203 | cart_pose: 0.0748709067702 204 | cart_pose: 1.14359036161e-05 205 | cart_pose: 5.10180834681e-05 206 | cart_pose: 0.0653914809227 207 | cart_pose: 0.997859716415 208 | pole_pose: 0.000139251351357 209 | pole_pose: -0.0611916743219 210 | pole_pose: 0.344804286957 211 | pole_pose: -0.123383037746 212 | pole_pose: 0.00611496530473 213 | pole_pose: 0.0471726879478 214 | pole_pose: 0.991218447685 215 | render { 216 | height: 120 217 | width: 160 218 | rgba: "\211PNG\r\n\032\n\000\..." 219 | } 220 | } 221 | is_terminal: false 222 | action: -0.157108291984 223 | action: 0.330988258123 224 | reward: 4.0238070488 225 | } 226 | ... 227 | ``` 228 | 229 | generate images from event.log 230 | 231 | ``` 232 | $ ./event_log.py --log-file=test.log --img-output-dir=eg_renders 233 | $ find eg_renders -type f | sort 234 | eg_renders/e_00000/s_00000.png 235 | eg_renders/e_00000/s_00001.png 236 | ... 237 | eg_renders/e_00009/s_00018.png 238 | eg_renders/e_00009/s_00019.png 239 | ``` 240 | 241 | 1000 events in an event_log is roughly 750K for the high dim case and 100K for low dim 242 | -------------------------------------------------------------------------------- /base_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import operator 3 | import sys 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | import util 7 | 8 | # TODO: move opts only used in this module to an add_opts method 9 | # requires fixing the bullet-before-slim problem though :/ 10 | 11 | IS_TRAINING = tf.placeholder(tf.bool, name="is_training") 12 | 13 | class Network(object): 14 | """Common class for handling ops for making / updating target networks.""" 15 | 16 | def __init__(self, namespace): 17 | self.namespace = namespace 18 | self.target_update_op = None 19 | 20 | def _create_variables_copy_op(self, source_namespace, affine_combo_coeff): 21 | """create an op that does updates all vars in source_namespace to target_namespace""" 22 | assert affine_combo_coeff >= 0.0 and affine_combo_coeff <= 1.0 23 | assign_ops = [] 24 | with tf.variable_scope(self.namespace, reuse=True): 25 | for src_var in tf.all_variables(): 26 | if not src_var.name.startswith(source_namespace): 27 | continue 28 | target_var_name = src_var.name.replace(source_namespace+"/", "").replace(":0", "") 29 | target_var = tf.get_variable(target_var_name) 30 | assert src_var.get_shape() == target_var.get_shape() 31 | assign_ops.append(target_var.assign_sub(affine_combo_coeff * (target_var - src_var))) 32 | single_assign_op = tf.group(*assign_ops) 33 | return single_assign_op 34 | 35 | def set_as_target_network_for(self, source_network, target_update_rate): 36 | """Create an op that will update this networks weights based on a source_network""" 37 | # first, as a one off, copy _all_ variables across. 38 | # i.e. initial target network will be a copy of source network. 39 | op = self._create_variables_copy_op(source_network.namespace, affine_combo_coeff=1.0) 40 | tf.get_default_session().run(op) 41 | # next build target update op for running later during training 42 | self.update_weights_op = self._create_variables_copy_op(source_network.namespace, 43 | target_update_rate) 44 | 45 | def update_weights(self): 46 | """called during training to update target network.""" 47 | if self.update_weights_op is None: 48 | raise Exception("not a target network? or set_source_network not yet called") 49 | return tf.get_default_session().run(self.update_weights_op) 50 | 51 | def trainable_model_vars(self): 52 | v = [] 53 | for var in tf.all_variables(): 54 | if var.name.startswith(self.namespace): 55 | v.append(var) 56 | return v 57 | 58 | def hidden_layers_starting_at(self, layer, layer_sizes, opts=None): 59 | # TODO: opts=None => will force exception on old calls.... 60 | if not isinstance(layer_sizes, list): 61 | layer_sizes = map(int, layer_sizes.split(",")) 62 | assert len(layer_sizes) > 0 63 | for i, size in enumerate(layer_sizes): 64 | layer = slim.fully_connected(scope="h%d" % i, 65 | inputs=layer, 66 | num_outputs=size, 67 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.01), 68 | activation_fn=tf.nn.relu) 69 | if opts.use_dropout: 70 | layer = slim.dropout(layer, is_training=IS_TRAINING, scope="do%d" % i) 71 | return layer 72 | 73 | def simple_conv_net_on(self, input_layer, opts): 74 | if opts.use_batch_norm: 75 | normalizer_fn = slim.batch_norm 76 | normalizer_params = { 'is_training': IS_TRAINING } 77 | else: 78 | normalizer_fn = None 79 | normalizer_params = None 80 | 81 | # optionally drop blue channel, in a simple cart pole env we only need r/g 82 | #if opts.drop_blue_channel: 83 | # input_layer = input_layer[:,:,:,0:2,:,:] 84 | 85 | # state is (batch, height, width, rgb, camera_idx, repeat) 86 | # rollup rgb, camera_idx and repeat into num_channels 87 | # i.e. (batch, height, width, rgb*camera_idx*repeat) 88 | height, width = map(int, input_layer.get_shape()[1:3]) 89 | num_channels = input_layer.get_shape()[3:].num_elements() 90 | input_layer = tf.reshape(input_layer, [-1, height, width, num_channels]) 91 | print >>sys.stderr, "input_layer", util.shape_and_product_of(input_layer) 92 | 93 | # whiten image, per channel, using batch_normalisation layer with 94 | # params calculated directly from batch. 95 | axis = list(range(input_layer.get_shape().ndims - 1)) 96 | batch_mean, batch_var = tf.nn.moments(input_layer, axis) # gives moments per channel 97 | whitened_input_layer = tf.nn.batch_normalization(input_layer, batch_mean, batch_var, 98 | scale=None, offset=None, 99 | variance_epsilon=1e-6) 100 | 101 | # TODO: num_outputs here are really dependant on the incoming channels, 102 | # which depend on the #repeats & cameras so they should be a param. 103 | model = slim.conv2d(whitened_input_layer, num_outputs=10, kernel_size=[5, 5], 104 | normalizer_fn=normalizer_fn, 105 | normalizer_params=normalizer_params, 106 | scope='conv1') 107 | model = slim.max_pool2d(model, kernel_size=[2, 2], scope='pool1') 108 | self.pool1 = model 109 | print >>sys.stderr, "pool1", util.shape_and_product_of(model) 110 | 111 | model = slim.conv2d(model, num_outputs=10, kernel_size=[5, 5], 112 | normalizer_fn=normalizer_fn, 113 | normalizer_params=normalizer_params, 114 | scope='conv2') 115 | model = slim.max_pool2d(model, kernel_size=[2, 2], scope='pool2') 116 | self.pool2 = model 117 | print >>sys.stderr, "pool2", util.shape_and_product_of(model) 118 | 119 | model = slim.conv2d(model, num_outputs=10, kernel_size=[3, 3], 120 | normalizer_fn=normalizer_fn, 121 | normalizer_params=normalizer_params, 122 | scope='conv3') 123 | model = slim.max_pool2d(model, kernel_size=[2, 2], scope='pool2') 124 | self.pool3 = model 125 | print >>sys.stderr, "pool3", util.shape_and_product_of(model) 126 | 127 | return model 128 | 129 | def input_state_network(self, input_state, opts): 130 | # TODO: use in lrpg and ddpg too 131 | if opts.use_raw_pixels: 132 | input_state = self.simple_conv_net_on(input_state, opts) 133 | flattened_input_state = slim.flatten(input_state, scope='flat') 134 | return self.hidden_layers_starting_at(flattened_input_state, opts.hidden_layers, opts) 135 | 136 | def render_convnet_activations(self, activations, filename_base): 137 | _batch, height, width, num_filters = activations.shape 138 | for f_idx in range(num_filters): 139 | single_channel = activations[0,:,:,f_idx] 140 | single_channel /= np.max(single_channel) 141 | img = np.empty((height, width, 3)) 142 | img[:,:,0] = single_channel 143 | img[:,:,1] = single_channel 144 | img[:,:,2] = single_channel 145 | util.write_img_to_png_file(img, "%s_f%02d.png" % (filename_base, f_idx)) 146 | 147 | def render_all_convnet_activations(self, step, input_state_placeholder, state): 148 | activations = tf.get_default_session().run([self.pool1, self.pool2, self.pool3], 149 | feed_dict={input_state_placeholder: [state], 150 | IS_TRAINING: False}) 151 | filename_base = "/tmp/activation_s%03d" % step 152 | self.render_convnet_activations(activations[0], filename_base + "_p0") 153 | self.render_convnet_activations(activations[1], filename_base + "_p1") 154 | self.render_convnet_activations(activations[2], filename_base + "_p2") 155 | -------------------------------------------------------------------------------- /bullet_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from collections import * 4 | import gym 5 | from gym import spaces 6 | import numpy as np 7 | import pybullet as p 8 | import sys 9 | import time 10 | 11 | np.set_printoptions(precision=3, suppress=True, linewidth=10000) 12 | 13 | def add_opts(parser): 14 | parser.add_argument('--gui', action='store_true') 15 | parser.add_argument('--delay', type=float, default=0.0) 16 | parser.add_argument('--action-force', type=float, default=50.0, 17 | help="magnitude of action force applied per step") 18 | parser.add_argument('--initial-force', type=float, default=55.0, 19 | help="magnitude of initial push, in random direction") 20 | parser.add_argument('--no-random-theta', action='store_true') 21 | parser.add_argument('--action-repeats', type=int, default=2, 22 | help="number of action repeats") 23 | parser.add_argument('--steps-per-repeat', type=int, default=5, 24 | help="number of sim steps per repeat") 25 | parser.add_argument('--num-cameras', type=int, default=1, 26 | help="how many camera points to render; 1 or 2") 27 | parser.add_argument('--event-log-out', type=str, default=None, 28 | help="path to record event log.") 29 | parser.add_argument('--max-episode-len', type=int, default=200, 30 | help="maximum episode len for cartpole") 31 | parser.add_argument('--use-raw-pixels', action='store_true', 32 | help="use raw pixels as state instead of cart/pole poses") 33 | parser.add_argument('--render-width', type=int, default=50, 34 | help="if --use-raw-pixels render with this width") 35 | parser.add_argument('--render-height', type=int, default=50, 36 | help="if --use-raw-pixels render with this height") 37 | parser.add_argument('--reward-calc', type=str, default='fixed', 38 | help="'fixed': 1 per step. 'angle': 2*max_angle - ox - oy. 'action': 1.5 - |action|. 'angle_action': both angle and action") 39 | 40 | def state_fields_of_pose_of(body_id): 41 | (x,y,z), (a,b,c,d) = p.getBasePositionAndOrientation(body_id) 42 | return np.array([x,y,z,a,b,c,d]) 43 | 44 | class BulletCartpole(gym.Env): 45 | 46 | def __init__(self, opts, discrete_actions): 47 | self.gui = opts.gui 48 | self.delay = opts.delay if self.gui else 0.0 49 | 50 | self.max_episode_len = opts.max_episode_len 51 | 52 | # threshold for pole position. 53 | # if absolute x or y moves outside this we finish episode 54 | self.pos_threshold = 2.0 # TODO: higher? 55 | 56 | # threshold for angle from z-axis. 57 | # if x or y > this value we finish episode. 58 | self.angle_threshold = 0.3 # radians; ~= 12deg 59 | 60 | # force to apply per action simulation step. 61 | # in the discrete case this is the fixed force applied 62 | # in the continuous case each x/y is in range (-F, F) 63 | self.action_force = opts.action_force 64 | 65 | # initial push force. this should be enough that taking no action will always 66 | # result in pole falling after initial_force_steps but not so much that you 67 | # can't recover. see also initial_force_steps. 68 | self.initial_force = opts.initial_force 69 | 70 | # number of sim steps initial force is applied. 71 | # (see initial_force) 72 | self.initial_force_steps = 30 73 | 74 | # whether we do initial push in a random direction 75 | # if false we always push with along x-axis (simplee problem, useful for debugging) 76 | self.random_theta = not opts.no_random_theta 77 | 78 | # true if action space is discrete; 5 values; no push, left, right, up & down 79 | # false if action space is continuous; fx, fy both (-action_force, action_force) 80 | self.discrete_actions = discrete_actions 81 | 82 | # 5 discrete actions: no push, left, right, up, down 83 | # 2 continuous action elements; fx & fy 84 | if self.discrete_actions: 85 | self.action_space = spaces.Discrete(5) 86 | else: 87 | self.action_space = spaces.Box(-1.0, 1.0, shape=(1, 2)) 88 | 89 | # open event log 90 | if opts.event_log_out: 91 | import event_log 92 | self.event_log = event_log.EventLog(opts.event_log_out, opts.use_raw_pixels) 93 | else: 94 | self.event_log = None 95 | 96 | # how many time to repeat each action per step(). 97 | # and how many sim steps to do per state capture 98 | # (total number of sim steps = action_repeats * steps_per_repeat 99 | self.repeats = opts.action_repeats 100 | self.steps_per_repeat = opts.steps_per_repeat 101 | 102 | # how many cameras to render? 103 | # if 1 just render from front 104 | # if 2 render from front and 90deg side 105 | if opts.num_cameras not in [1, 2]: 106 | raise ValueError("--num-cameras must be 1 or 2") 107 | self.num_cameras = opts.num_cameras 108 | 109 | # whether we are using raw pixels for state or just pole + cart pose 110 | self.use_raw_pixels = opts.use_raw_pixels 111 | 112 | # in the use_raw_pixels is set we will be rendering 113 | self.render_width = opts.render_width 114 | self.render_height = opts.render_height 115 | 116 | # decide observation space 117 | if self.use_raw_pixels: 118 | # in high dimensional case each observation is an RGB images (H, W, 3) 119 | # we have R repeats and C cameras resulting in (H, W, 3, R, C) 120 | # final state fed to network is concatenated in depth => (H, W, 3*R*C) 121 | state_shape = (self.render_height, self.render_width, 3, 122 | self.num_cameras, self.repeats) 123 | else: 124 | # in the low dimensional case obs space for problem is (R, 2, 7) 125 | # R = number of repeats 126 | # 2 = two items; cart & pole 127 | # 7d tuple for pos + orientation pose 128 | state_shape = (self.repeats, 2, 7) 129 | float_max = np.finfo(np.float32).max 130 | self.observation_space = gym.spaces.Box(-float_max, float_max, state_shape) 131 | 132 | # check reward type 133 | assert opts.reward_calc in ['fixed', 'angle', 'action', 'angle_action'] 134 | self.reward_calc = opts.reward_calc 135 | 136 | # no state until reset. 137 | self.state = np.empty(state_shape, dtype=np.float32) 138 | 139 | # setup bullet 140 | p.connect(p.GUI if self.gui else p.DIRECT) 141 | p.setGravity(0, 0, -9.81) 142 | p.loadURDF("models/ground.urdf", 0,0,0, 0,0,0,1) 143 | self.cart = p.loadURDF("models/cart.urdf", 0,0,0.08, 0,0,0,1) 144 | self.pole = p.loadURDF("models/pole.urdf", 0,0,0.35, 0,0,0,1) 145 | 146 | def _configure(self, display=None): 147 | pass 148 | 149 | def _seed(self, seed=None): 150 | pass 151 | 152 | def _render(self, mode, close): 153 | pass 154 | 155 | def _step(self, action): 156 | if self.done: 157 | print >>sys.stderr, "calling step after done????" 158 | return np.copy(self.state), 0, True, {} 159 | 160 | info = {} 161 | 162 | # based on action decide the x and y forces 163 | fx = fy = 0 164 | if self.discrete_actions: 165 | if action == 0: 166 | pass 167 | elif action == 1: 168 | fx = self.action_force 169 | elif action == 2: 170 | fx = -self.action_force 171 | elif action == 3: 172 | fy = self.action_force 173 | elif action == 4: 174 | fy = -self.action_force 175 | else: 176 | raise Exception("unknown discrete action [%s]" % action) 177 | else: 178 | fx, fy = action[0] * self.action_force 179 | 180 | # step simulation forward. at the end of each repeat we set part of the step's 181 | # state by capture the cart & pole state in some form. 182 | for r in xrange(self.repeats): 183 | for _ in xrange(self.steps_per_repeat): 184 | p.stepSimulation() 185 | p.applyExternalForce(self.cart, -1, (fx,fy,0), (0,0,0), p.WORLD_FRAME) 186 | if self.delay > 0: 187 | time.sleep(self.delay) 188 | self.set_state_element_for_repeat(r) 189 | self.steps += 1 190 | 191 | # Check for out of bounds by position or orientation on pole. 192 | # we (re)fetch pose explicitly rather than depending on fields in state. 193 | (x, y, _z), orient = p.getBasePositionAndOrientation(self.pole) 194 | ox, oy, _oz = p.getEulerFromQuaternion(orient) # roll / pitch / yaw 195 | if abs(x) > self.pos_threshold or abs(y) > self.pos_threshold: 196 | info['done_reason'] = 'out of position bounds' 197 | self.done = True 198 | reward = 0.0 199 | elif abs(ox) > self.angle_threshold or abs(oy) > self.angle_threshold: 200 | # TODO: probably better to do explicit angle from z? 201 | info['done_reason'] = 'out of orientation bounds' 202 | self.done = True 203 | reward = 0.0 204 | # check for end of episode (by length) 205 | if self.steps >= self.max_episode_len: 206 | info['done_reason'] = 'episode length' 207 | self.done = True 208 | 209 | # calc reward, fixed base of 1.0 210 | reward = 1.0 211 | if self.reward_calc == "angle" or self.reward_calc == "angle_action": 212 | # clip to zero since angles can be past threshold 213 | reward += max(0, 2 * self.angle_threshold - np.abs(ox) - np.abs(oy)) 214 | if self.reward_calc == "action" or self.reward_calc == "angle_action": 215 | # max norm will be sqr(2) ~= 1.4. 216 | # reward is already 1.0 to add another 0.5 as o0.1 buffer from zero 217 | reward += 0.5 - np.linalg.norm(action[0]) 218 | 219 | # log this event. 220 | # TODO in the --use-raw-pixels case would be nice to have poses in state repeats too. 221 | if self.event_log: 222 | self.event_log.add(self.state, action, reward) 223 | 224 | # return observation 225 | return np.copy(self.state), reward, self.done, info 226 | 227 | def render_rgb(self, camera_idx): 228 | cameraPos = [(0.0, 0.75, 0.75), (0.75, 0.0, 0.75)][camera_idx] 229 | targetPos = (0, 0, 0.3) 230 | cameraUp = (0, 0, 1) 231 | nearVal, farVal = 1, 20 232 | fov = 60 233 | _w, _h, rgba, _depth, _objects = p.renderImage(self.render_width, self.render_height, 234 | cameraPos, targetPos, cameraUp, 235 | nearVal, farVal, fov) 236 | # convert from 1d uint8 array to (H,W,3) hacky hardcode whitened float16 array. 237 | # TODO: for storage concerns could just store this as uint8 (which it is) 238 | # and normalise 0->1 + whiten later. 239 | rgba_img = np.reshape(np.asarray(rgba, dtype=np.float16), 240 | (self.render_height, self.render_width, 4)) 241 | rgb_img = rgba_img[:,:,:3] # slice off alpha, always 1.0 242 | rgb_img /= 255 243 | return rgb_img 244 | 245 | def set_state_element_for_repeat(self, repeat): 246 | if self.use_raw_pixels: 247 | # high dim caseis (H, W, 3, C, R) 248 | # H, W, 3 -> height x width, 3 channel RGB image 249 | # C -> camera_idx; 0 or 1 250 | # R -> repeat 251 | for camera_idx in range(self.num_cameras): 252 | self.state[:,:,:,camera_idx,repeat] = self.render_rgb(camera_idx) 253 | else: 254 | # in low dim case state is (R, 2, 7) 255 | # R -> repeat, 2 -> 2 objects (cart & pole), 7 -> 7d pose 256 | self.state[repeat][0] = state_fields_of_pose_of(self.cart) 257 | self.state[repeat][1] = state_fields_of_pose_of(self.pole) 258 | 259 | def _reset(self): 260 | # reset state 261 | self.steps = 0 262 | self.done = False 263 | 264 | # reset pole on cart in starting poses 265 | p.resetBasePositionAndOrientation(self.cart, (0,0,0.08), (0,0,0,1)) 266 | p.resetBasePositionAndOrientation(self.pole, (0,0,0.35), (0,0,0,1)) 267 | for _ in xrange(100): p.stepSimulation() 268 | 269 | # give a fixed force push in a random direction to get things going... 270 | theta = (np.random.random() * 2 * np.pi) if self.random_theta else 0.0 271 | fx, fy = self.initial_force * np.cos(theta), self.initial_force * np.sin(theta) 272 | for _ in xrange(self.initial_force_steps): 273 | p.stepSimulation() 274 | p.applyExternalForce(self.cart, -1, (fx, fy, 0), (0, 0, 0), p.WORLD_FRAME) 275 | if self.delay > 0: 276 | time.sleep(self.delay) 277 | 278 | # bootstrap state by running for all repeats 279 | for i in xrange(self.repeats): 280 | self.set_state_element_for_repeat(i) 281 | 282 | # reset event log (if applicable) and add entry with only state 283 | if self.event_log: 284 | self.event_log.reset() 285 | self.event_log.add_just_state(self.state) 286 | 287 | # return this state 288 | return np.copy(self.state) 289 | -------------------------------------------------------------------------------- /cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matpalm/cartpoleplusplus/12bdb92d2610db61df742959f17f0c42b0da62ce/cartpole.gif -------------------------------------------------------------------------------- /cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matpalm/cartpoleplusplus/12bdb92d2610db61df742959f17f0c42b0da62ce/cartpole.png -------------------------------------------------------------------------------- /ddpg_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import bullet_cartpole 4 | import collections 5 | import datetime 6 | import gym 7 | import json 8 | import numpy as np 9 | import replay_memory 10 | import signal 11 | import sys 12 | import tensorflow as tf 13 | import time 14 | import util 15 | 16 | np.set_printoptions(precision=5, threshold=10000, suppress=True, linewidth=10000) 17 | 18 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--num-eval', type=int, default=0, 20 | help="if >0 just run this many episodes with no training") 21 | parser.add_argument('--max-num-actions', type=int, default=0, 22 | help="train for (at least) this number of actions (always finish current episode)" 23 | " ignore if <=0") 24 | parser.add_argument('--max-run-time', type=int, default=0, 25 | help="train for (at least) this number of seconds (always finish current episode)" 26 | " ignore if <=0") 27 | parser.add_argument('--ckpt-dir', type=str, default=None, help="if set save ckpts to this dir") 28 | parser.add_argument('--ckpt-freq', type=int, default=3600, help="freq (sec) to save ckpts") 29 | parser.add_argument('--batch-size', type=int, default=128, help="training batch size") 30 | parser.add_argument('--batches-per-step', type=int, default=5, 31 | help="number of batches to train per step") 32 | parser.add_argument('--dont-do-rollouts', action="store_true", 33 | help="by dft we do rollouts to generate data then train after each rollout. if this flag is set we" 34 | " dont do any rollouts. this only makes sense to do if --event-log-in set.") 35 | parser.add_argument('--target-update-rate', type=float, default=0.0001, 36 | help="affine combo for updating target networks each time we run a training batch") 37 | parser.add_argument('--use-batch-norm', action='store_true', 38 | help="whether to use batch norm on conv layers") 39 | parser.add_argument('--actor-hidden-layers', type=str, default="100,100,50", help="actor hidden layer sizes") 40 | parser.add_argument('--critic-hidden-layers', type=str, default="100,100,50", help="critic hidden layer sizes") 41 | parser.add_argument('--actor-learning-rate', type=float, default=0.001, help="learning rate for actor") 42 | parser.add_argument('--critic-learning-rate', type=float, default=0.01, help="learning rate for critic") 43 | parser.add_argument('--discount', type=float, default=0.99, help="discount for RHS of critic bellman equation update") 44 | parser.add_argument('--event-log-in', type=str, default=None, 45 | help="prepopulate replay memory with entries from this event log") 46 | parser.add_argument('--replay-memory-size', type=int, default=22000, help="max size of replay memory") 47 | parser.add_argument('--replay-memory-burn-in', type=int, default=1000, help="dont train from replay memory until it reaches this size") 48 | parser.add_argument('--eval-action-noise', action='store_true', help="whether to use noise during eval") 49 | parser.add_argument('--action-noise-theta', type=float, default=0.01, 50 | help="OrnsteinUhlenbeckNoise theta (rate of change) param for action exploration") 51 | parser.add_argument('--action-noise-sigma', type=float, default=0.05, 52 | help="OrnsteinUhlenbeckNoise sigma (magnitude) param for action exploration") 53 | 54 | util.add_opts(parser) 55 | 56 | bullet_cartpole.add_opts(parser) 57 | opts = parser.parse_args() 58 | sys.stderr.write("%s\n" % opts) 59 | 60 | # TODO: if we import slim _before_ building cartpole env we can't start bullet with GL gui o_O 61 | env = bullet_cartpole.BulletCartpole(opts=opts, discrete_actions=False) 62 | import base_network 63 | import tensorflow.contrib.slim as slim 64 | 65 | VERBOSE_DEBUG = False 66 | def toggle_verbose_debug(signal, frame): 67 | global VERBOSE_DEBUG 68 | VERBOSE_DEBUG = not VERBOSE_DEBUG 69 | signal.signal(signal.SIGUSR1, toggle_verbose_debug) 70 | 71 | DUMP_WEIGHTS = False 72 | def set_dump_weights(signal, frame): 73 | global DUMP_WEIGHTS 74 | DUMP_WEIGHTS = True 75 | signal.signal(signal.SIGUSR2, set_dump_weights) 76 | 77 | 78 | class ActorNetwork(base_network.Network): 79 | """ the actor represents the learnt policy mapping states to actions""" 80 | 81 | def __init__(self, namespace, input_state, action_dim): 82 | super(ActorNetwork, self).__init__(namespace) 83 | 84 | self.input_state = input_state 85 | 86 | self.exploration_noise = util.OrnsteinUhlenbeckNoise(action_dim, 87 | opts.action_noise_theta, 88 | opts.action_noise_sigma) 89 | 90 | with tf.variable_scope(namespace): 91 | opts.hidden_layers = opts.actor_hidden_layers 92 | final_hidden = self.input_state_network(self.input_state, opts) 93 | # action dim output. note: actors out is (-1, 1) and scaled in env as required. 94 | weights_initializer = tf.random_uniform_initializer(-0.001, 0.001) 95 | self.output_action = slim.fully_connected(scope='output_action', 96 | inputs=final_hidden, 97 | num_outputs=action_dim, 98 | weights_initializer=weights_initializer, 99 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.01), 100 | activation_fn=tf.nn.tanh) 101 | 102 | def init_ops_for_training(self, critic): 103 | # actors gradients are the gradients for it's output w.r.t it's vars using initial 104 | # gradients provided by critic. this requires that critic was init'd with an 105 | # input_action = actor.output_action (which is natural anyway) 106 | # we wrap the optimiser in namespace since we don't want this as part of copy to 107 | # target networks. 108 | # note that we negate the gradients from critic since we are trying to maximise 109 | # the q values (not minimise like a loss) 110 | with tf.variable_scope("optimiser"): 111 | gradients = tf.gradients(self.output_action, 112 | self.trainable_model_vars(), 113 | tf.neg(critic.q_gradients_wrt_actions())) 114 | gradients = zip(gradients, self.trainable_model_vars()) 115 | # potentially clip and wrap with debugging 116 | gradients = util.clip_and_debug_gradients(gradients, opts) 117 | # apply 118 | optimiser = tf.train.GradientDescentOptimizer(opts.actor_learning_rate) 119 | self.train_op = optimiser.apply_gradients(gradients) 120 | 121 | def action_given(self, state, add_noise=False): 122 | # feed explicitly provided state 123 | actions = tf.get_default_session().run(self.output_action, 124 | feed_dict={self.input_state: [state], 125 | base_network.IS_TRAINING: False}) 126 | 127 | # NOTE: noise is added _outside_ tf graph. we do this simply because the noisy output 128 | # is never used for any part of computation graph required for online training. it's 129 | # only used during training after being the replay buffer. 130 | if add_noise: 131 | if VERBOSE_DEBUG: 132 | pre_noise = str(actions) 133 | actions[0] += self.exploration_noise.sample() 134 | actions = np.clip(1, -1, actions) # action output is _always_ (-1, 1) 135 | if VERBOSE_DEBUG: 136 | print "TRAIN action_given pre_noise %s post_noise %s" % (pre_noise, actions) 137 | 138 | return actions 139 | 140 | def train(self, state): 141 | # training actor only requires state since we are trying to maximise the 142 | # q_value according to the critic. 143 | tf.get_default_session().run(self.train_op, 144 | feed_dict={self.input_state: state, 145 | base_network.IS_TRAINING: True}) 146 | 147 | 148 | class CriticNetwork(base_network.Network): 149 | """ the critic represents a mapping from state & actors action to a quality score.""" 150 | 151 | def __init__(self, namespace, actor): 152 | super(CriticNetwork, self).__init__(namespace) 153 | 154 | # input state to the critic is the _same_ state given to the actor. 155 | # input action to the critic is simply the output action of the actor. 156 | # even though when training we explicitly provide a new value for the 157 | # input action (via the input_action placeholder) we need to be stop the gradient 158 | # flowing to the actor since there is a path through the actor to the input_state 159 | # too, hence we need to be explicit about cutting it (otherwise training the 160 | # critic will attempt to train the actor too. 161 | self.input_state = actor.input_state 162 | self.input_action = tf.stop_gradient(actor.output_action) 163 | 164 | with tf.variable_scope(namespace): 165 | if opts.use_raw_pixels: 166 | conv_net = self.simple_conv_net_on(self.input_state, opts) 167 | # TODO: use base_network helper 168 | hidden1 = slim.fully_connected(conv_net, 200, scope='hidden1') 169 | hidden2 = slim.fully_connected(hidden1, 50, scope='hidden2') 170 | concat_inputs = tf.concat(1, [hidden2, self.input_action]) 171 | final_hidden = slim.fully_connected(concat_inputs, 50, scope="hidden3") 172 | else: 173 | # stack of hidden layers on flattened input; (batch,2,2,7) -> (batch,28) 174 | flat_input_state = slim.flatten(self.input_state, scope='flat') 175 | concat_inputs = tf.concat(1, [flat_input_state, self.input_action]) 176 | final_hidden = self.hidden_layers_starting_at(concat_inputs, 177 | opts.critic_hidden_layers) 178 | 179 | # output from critic is a single q-value 180 | self.q_value = slim.fully_connected(scope='q_value', 181 | inputs=final_hidden, 182 | num_outputs=1, 183 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.01), 184 | activation_fn=None) 185 | 186 | def init_ops_for_training(self, target_critic): 187 | # update critic using bellman equation; Q(s1, a) = reward + discount * Q(s2, A(s2)) 188 | 189 | # left hand side of bellman is just q_value, but let's be explicit about it... 190 | bellman_lhs = self.q_value 191 | 192 | # right hand side is ... 193 | # = reward + discounted q value from target actor & critic in the non terminal case 194 | # = reward # in the terminal case 195 | self.reward = tf.placeholder(shape=[None, 1], dtype=tf.float32, name="critic_reward") 196 | self.terminal_mask = tf.placeholder(shape=[None, 1], dtype=tf.float32, 197 | name="critic_terminal_mask") 198 | self.input_state_2 = target_critic.input_state 199 | bellman_rhs = self.reward + (self.terminal_mask * opts.discount * target_critic.q_value) 200 | 201 | # note: since we are NOT training target networks we stop gradients flowing to them 202 | bellman_rhs = tf.stop_gradient(bellman_rhs) 203 | 204 | # the value we are trying to mimimise is the difference between these two; the 205 | # temporal difference we use a squared loss for optimisation and, as for actor, we 206 | # wrap optimiser in a namespace so it's not picked up by target network variable 207 | # handling. 208 | self.temporal_difference = bellman_lhs - bellman_rhs 209 | self.temporal_difference_loss = tf.reduce_mean(tf.pow(self.temporal_difference, 2)) 210 | # self.temporal_difference_loss = tf.Print(self.temporal_difference_loss, [self.temporal_difference_loss], 'temporal_difference_loss') 211 | with tf.variable_scope("optimiser"): 212 | # calc gradients 213 | optimiser = tf.train.GradientDescentOptimizer(opts.critic_learning_rate) 214 | gradients = optimiser.compute_gradients(self.temporal_difference_loss) 215 | # potentially clip and wrap with debugging tf.Print 216 | gradients = util.clip_and_debug_gradients(gradients, opts) 217 | # apply 218 | self.train_op = optimiser.apply_gradients(gradients) 219 | 220 | def q_gradients_wrt_actions(self): 221 | """ gradients for the q.value w.r.t just input_action; used for actor training""" 222 | return tf.gradients(self.q_value, self.input_action)[0] 223 | 224 | # def debug_q_value_for(self, input_state, action=None): 225 | # feed_dict = {self.input_state: input_state} 226 | # if action is not None: 227 | # feed_dict[self.input_action] = action 228 | # return np.squeeze(tf.get_default_session().run(self.q_value, feed_dict=feed_dict)) 229 | 230 | def train(self, batch): 231 | tf.get_default_session().run(self.train_op, 232 | feed_dict={self.input_state: batch.state_1, 233 | self.input_action: batch.action, 234 | self.reward: batch.reward, 235 | self.terminal_mask: batch.terminal_mask, 236 | self.input_state_2: batch.state_2, 237 | base_network.IS_TRAINING: True}) 238 | 239 | def check_loss(self, batch): 240 | return tf.get_default_session().run([self.temporal_difference_loss, 241 | self.temporal_difference, 242 | self.q_value], 243 | feed_dict={self.input_state: batch.state_1, 244 | self.input_action: batch.action, 245 | self.reward: batch.reward, 246 | self.terminal_mask: batch.terminal_mask, 247 | self.input_state_2: batch.state_2, 248 | base_network.IS_TRAINING: False}) 249 | 250 | 251 | class DeepDeterministicPolicyGradientAgent(object): 252 | def __init__(self, env): 253 | self.env = env 254 | state_shape = self.env.observation_space.shape 255 | action_dim = self.env.action_space.shape[1] 256 | 257 | # for now, with single machine synchronous training, use a replay memory for training. 258 | # this replay memory stores states in a Variable (ie potentially in gpu memory) 259 | # TODO: switch back to async training with multiple replicas (as in drivebot project) 260 | self.replay_memory = replay_memory.ReplayMemory(opts.replay_memory_size, 261 | state_shape, action_dim) 262 | 263 | # s1 and s2 placeholders 264 | batched_state_shape = [None] + list(state_shape) 265 | s1 = tf.placeholder(shape=batched_state_shape, dtype=tf.float32) 266 | s2 = tf.placeholder(shape=batched_state_shape, dtype=tf.float32) 267 | 268 | # initialise base models for actor / critic and their corresponding target networks 269 | # target_actor is never used for online sampling so doesn't need explore noise. 270 | self.actor = ActorNetwork("actor", s1, action_dim) 271 | self.critic = CriticNetwork("critic", self.actor) 272 | self.target_actor = ActorNetwork("target_actor", s2, action_dim) 273 | self.target_critic = CriticNetwork("target_critic", self.target_actor) 274 | 275 | # setup training ops; 276 | # training actor requires the critic (for getting gradients) 277 | # training critic requires target_critic (for RHS of bellman update) 278 | self.actor.init_ops_for_training(self.critic) 279 | self.critic.init_ops_for_training(self.target_critic) 280 | 281 | def post_var_init_setup(self): 282 | # prepopulate replay memory (if configured to do so) 283 | if opts.event_log_in: 284 | self.replay_memory.reset_from_event_log(opts.event_log_in) 285 | # hook networks up to their targets 286 | # ( does one off clobber to init all vars in target network ) 287 | self.target_actor.set_as_target_network_for(self.actor, opts.target_update_rate) 288 | self.target_critic.set_as_target_network_for(self.critic, opts.target_update_rate) 289 | 290 | 291 | def run_training(self, max_num_actions, max_run_time, batch_size, batches_per_step, 292 | saver_util): 293 | # log start time, in case we are limiting by time... 294 | start_time = time.time() 295 | 296 | # run for some max number of actions 297 | num_actions_taken = 0 298 | n = 0 299 | while True: 300 | rewards = [] 301 | losses = [] 302 | 303 | # run an episode 304 | if opts.dont_do_rollouts: 305 | # _not_ gathering experience online 306 | pass 307 | else: 308 | # start a new episode 309 | state_1 = self.env.reset() 310 | # prepare data for updating replay memory at end of episode 311 | initial_state = np.copy(state_1) 312 | action_reward_state_sequence = [] 313 | 314 | done = False 315 | while not done: 316 | # choose action 317 | action = self.actor.action_given(state_1, add_noise=True) 318 | # take action step in env 319 | state_2, reward, done, _ = self.env.step(action) 320 | rewards.append(reward) 321 | # cache for adding to replay memory 322 | action_reward_state_sequence.append((action, reward, np.copy(state_2))) 323 | # roll state for next step. 324 | state_1 = state_2 325 | # at end of episode update replay memory 326 | self.replay_memory.add_episode(initial_state, action_reward_state_sequence) 327 | 328 | # do a training step (after waiting for buffer to fill a bit...) 329 | if self.replay_memory.size() > opts.replay_memory_burn_in: 330 | # run a set of batches 331 | for _ in xrange(batches_per_step): 332 | batch = self.replay_memory.batch(batch_size) 333 | self.actor.train(batch.state_1) 334 | self.critic.train(batch) 335 | # update target nets 336 | self.target_actor.update_weights() 337 | self.target_critic.update_weights() 338 | # do debug (if requested) on last batch 339 | if VERBOSE_DEBUG: 340 | print "-----" 341 | #print "state_1", state_1 342 | print "action\n", batch.action.T 343 | print "reward ", batch.reward.T 344 | print "terminal_mask ", batch.terminal_mask.T 345 | #print "state_2", state_2 346 | td_loss, td, q_value = self.critic.check_loss(batch) 347 | print "temporal_difference_loss", td_loss 348 | print "temporal_difference", td.T 349 | print "q_value", q_value.T 350 | 351 | # dump some stats and progress info 352 | stats = collections.OrderedDict() 353 | stats["time"] = time.time() 354 | stats["n"] = n 355 | stats["mean_losses"] = float(np.mean(losses)) 356 | stats["total_reward"] = np.sum(rewards) 357 | stats["episode_len"] = len(rewards) 358 | stats["replay_memory_stats"] = self.replay_memory.current_stats() 359 | print "STATS %s\t%s" % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 360 | json.dumps(stats)) 361 | sys.stdout.flush() 362 | n += 1 363 | 364 | # save if required 365 | if saver_util is not None: 366 | saver_util.save_if_required() 367 | 368 | # emit occasional eval 369 | if VERBOSE_DEBUG or n % 10 == 0: 370 | self.run_eval(1) 371 | 372 | # dump weights once if requested 373 | global DUMP_WEIGHTS 374 | if DUMP_WEIGHTS: 375 | self.debug_dump_network_weights() 376 | DUMP_WEIGHTS = False 377 | 378 | # exit when finished 379 | num_actions_taken += len(rewards) 380 | if max_num_actions > 0 and num_actions_taken > max_num_actions: 381 | break 382 | if max_run_time > 0 and time.time() > start_time + max_run_time: 383 | break 384 | 385 | 386 | def run_eval(self, num_episodes, add_noise=False): 387 | """ run num_episodes of eval and output episode length and rewards """ 388 | for i in xrange(num_episodes): 389 | state = self.env.reset() 390 | total_reward = 0 391 | steps = 0 392 | done = False 393 | while not done: 394 | action = self.actor.action_given(state, add_noise) 395 | state, reward, done, _ = self.env.step(action) 396 | print "EVALSTEP r%s %s %s %s %s" % (i, steps, np.squeeze(action), np.linalg.norm(action), reward) 397 | total_reward += reward 398 | steps += 1 399 | print "EVAL", i, steps, total_reward 400 | sys.stdout.flush() 401 | 402 | def debug_dump_network_weights(self): 403 | fn = "/tmp/weights.%s" % time.time() 404 | with open(fn, "w") as f: 405 | f.write("DUMP time %s\n" % time.time()) 406 | for var in tf.all_variables(): 407 | f.write("VAR %s %s\n" % (var.name, var.get_shape())) 408 | f.write("%s\n" % var.eval()) 409 | print "weights written to", fn 410 | 411 | 412 | def main(): 413 | config = tf.ConfigProto() 414 | # config.gpu_options.allow_growth = True 415 | # config.log_device_placement = True 416 | with tf.Session(config=config) as sess: 417 | agent = DeepDeterministicPolicyGradientAgent(env=env) 418 | 419 | # setup saver util and either load latest ckpt or init variables 420 | saver_util = None 421 | if opts.ckpt_dir is not None: 422 | saver_util = util.SaverUtil(sess, opts.ckpt_dir, opts.ckpt_freq) 423 | else: 424 | sess.run(tf.initialize_all_variables()) 425 | 426 | for v in tf.all_variables(): 427 | print >>sys.stderr, v.name, util.shape_and_product_of(v) 428 | 429 | # now that we've either init'd from scratch, or loaded up a checkpoint, 430 | # we can do any required post init work. 431 | agent.post_var_init_setup() 432 | 433 | # run either eval or training 434 | if opts.num_eval > 0: 435 | agent.run_eval(opts.num_eval, opts.eval_action_noise) 436 | else: 437 | agent.run_training(opts.max_num_actions, opts.max_run_time, 438 | opts.batch_size, opts.batches_per_step, 439 | saver_util) 440 | if saver_util is not None: 441 | saver_util.force_save() 442 | 443 | env.reset() # just to flush logging, clumsy :/ 444 | 445 | if __name__ == "__main__": 446 | main() 447 | -------------------------------------------------------------------------------- /deciles.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import numpy as np 4 | print np.percentile(map(float, sys.stdin.readlines()), 5 | np.linspace(0, 100, 11)) 6 | -------------------------------------------------------------------------------- /dqn_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # copy pasta from https://github.com/matthiasplappert/keras-rl/blob/master/examples/dqn_cartpole.py 4 | # with some extra arg parsing 5 | 6 | import numpy as np 7 | import gym 8 | 9 | from keras.models import Sequential 10 | from keras.layers import Dense, Activation, Flatten 11 | from keras.optimizers import Adam 12 | 13 | from rl.agents.dqn import DQNAgent 14 | from rl.policy import BoltzmannQPolicy 15 | from rl.memory import SequentialMemory 16 | 17 | import bullet_cartpole 18 | import argparse 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--num-train', type=int, default=100) 22 | parser.add_argument('--num-eval', type=int, default=0) 23 | parser.add_argument('--load-file', type=str, default=None) 24 | parser.add_argument('--save-file', type=str, default=None) 25 | bullet_cartpole.add_opts(parser) 26 | opts = parser.parse_args() 27 | print "OPTS", opts 28 | 29 | ENV_NAME = 'BulletCartpole' 30 | 31 | # Get the environment and extract the number of actions. 32 | env = bullet_cartpole.BulletCartpole(opts=opts, discrete_actions=True) 33 | nb_actions = env.action_space.n 34 | 35 | # Next, we build a very simple model. 36 | model = Sequential() 37 | model.add(Flatten(input_shape=(1,) + env.observation_space.shape)) 38 | model.add(Dense(32)) 39 | model.add(Activation('tanh')) 40 | #model.add(Dense(16)) 41 | #model.add(Activation('relu')) 42 | #model.add(Dense(16)) 43 | #model.add(Activation('relu')) 44 | model.add(Dense(nb_actions)) 45 | model.add(Activation('linear')) 46 | print(model.summary()) 47 | 48 | memory = SequentialMemory(limit=50000) 49 | policy = BoltzmannQPolicy() 50 | dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, 51 | target_model_update=1e-2, policy=policy) 52 | dqn.compile(Adam(lr=1e-3), metrics=['mae']) 53 | 54 | if opts.load_file is not None: 55 | print "loading weights from from [%s]" % opts.load_file 56 | dqn.load_weights(opts.load_file) 57 | 58 | # Okay, now it's time to learn something! We visualize the training here for show, but this 59 | # slows down training quite a lot. You can always safely abort the training prematurely using 60 | # Ctrl + C. 61 | dqn.fit(env, nb_steps=opts.num_train, visualize=True, verbose=2) 62 | 63 | # After training is done, we save the final weights. 64 | if opts.save_file is not None: 65 | print "saving weights to [%s]" % opts.save_file 66 | dqn.save_weights(opts.save_file, overwrite=True) 67 | 68 | # Finally, evaluate our algorithm for 5 episodes. 69 | dqn.test(env, nb_episodes=opts.num_eval, visualize=True) 70 | 71 | -------------------------------------------------------------------------------- /eg_render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matpalm/cartpoleplusplus/12bdb92d2610db61df742959f17f0c42b0da62ce/eg_render.png -------------------------------------------------------------------------------- /event.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package cp; 3 | 4 | message Render { 5 | optional int32 height = 1; 6 | optional int32 width = 2; 7 | optional bytes png_bytes = 3; 8 | } 9 | 10 | message State { 11 | // 7d cart pose (px, py, pz, oa, ob, oc, od) 12 | repeated float cart_pose = 1; 13 | // 7d pole pose (px, py, pz, oa, ob, oc, od) 14 | repeated float pole_pose = 2; 15 | // 1+ renders, depending on number of cameras 16 | repeated Render render = 3; 17 | } 18 | 19 | // Event corresponds to output from one step of simulation 20 | // e.g. step(action) -> state, reward, done. 21 | // For env.reset case event has state but no action, reward 22 | // or is_terminal. 23 | // Assume last event in episode is terminal event (done=True) 24 | message Event { 25 | // N dimensional action 26 | repeated float action = 1; 27 | // 1 or more states based on action_repeats 28 | repeated State state = 2; 29 | // single reward value 30 | optional float reward = 3; 31 | } 32 | 33 | message Episode { 34 | repeated Event event = 1; 35 | } 36 | -------------------------------------------------------------------------------- /event_log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import event_pb2 3 | import gzip 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import StringIO 7 | import struct 8 | 9 | def rgb_to_png(rgb): 10 | """convert RGB data from render to png""" 11 | sio = StringIO.StringIO() 12 | plt.imsave(sio, rgb) 13 | return sio.getvalue() 14 | 15 | def png_to_rgb(png_bytes): 16 | """convert png (from rgb_to_png) to RGB""" 17 | # note PNG is always RGBA so we need to slice off A 18 | rgba = plt.imread(StringIO.StringIO(png_bytes)) 19 | return rgba[:,:,:3] 20 | 21 | def read_state_from_event(event): 22 | """unpack state from event (i.e. inverse of add_state_to_event)""" 23 | if len(event.state[0].render) > 0: 24 | num_repeats = len(event.state) 25 | num_cameras = len(event.state[0].render) 26 | eg_render = event.state[0].render[0] 27 | state = np.empty((eg_render.height, eg_render.width, 3, 28 | num_cameras, num_repeats)) 29 | for r_idx in range(num_repeats): 30 | repeat = event.state[r_idx] 31 | for c_idx in range(num_cameras): 32 | png_bytes = repeat.render[c_idx].png_bytes 33 | state[:,:,:,c_idx,r_idx] = png_to_rgb(png_bytes) 34 | else: 35 | state = np.empty((len(event.state), 2, 7)) 36 | for i, s in enumerate(event.state): 37 | state[i][0] = s.cart_pose 38 | state[i][1] = s.pole_pose 39 | return state 40 | 41 | class EventLog(object): 42 | 43 | def __init__(self, path, use_raw_pixels): 44 | self.log_file = open(path, "ab") 45 | self.episode_entry = None 46 | self.use_raw_pixels = use_raw_pixels 47 | 48 | def reset(self): 49 | if self.episode_entry is not None: 50 | # *sigh* have to frame these ourselves :/ 51 | # (a long as a header-len will do...) 52 | buff = self.episode_entry.SerializeToString() 53 | if len(buff) > 0: 54 | buff_len = struct.pack('=l', len(buff)) 55 | self.log_file.write(buff_len) 56 | self.log_file.write(buff) 57 | self.log_file.flush() 58 | self.episode_entry = event_pb2.Episode() 59 | 60 | def add_state_to_event(self, state, event): 61 | """pack state into event""" 62 | if self.use_raw_pixels: 63 | # TODO: be nice to have pose info here too in the pixel case... 64 | num_repeats = state.shape[4] 65 | for r_idx in range(num_repeats): 66 | s = event.state.add() 67 | num_cameras = state.shape[3] 68 | for c_idx in range(num_cameras): 69 | render = s.render.add() 70 | render.width = state.shape[1] 71 | render.height = state.shape[0] 72 | render.png_bytes = rgb_to_png(state[:,:,:,c_idx,r_idx]) 73 | else: 74 | num_repeats = state.shape[0] 75 | for r in range(num_repeats): 76 | s = event.state.add() 77 | s.cart_pose.extend(map(float, state[r][0])) 78 | s.pole_pose.extend(map(float, state[r][1])) 79 | 80 | def add(self, state, action, reward): 81 | event = self.episode_entry.event.add() 82 | self.add_state_to_event(state, event) 83 | if isinstance(action, int): 84 | event.action.append(action) # single action 85 | else: 86 | assert action.shape[0] == 1 # never log batch operations 87 | event.action.extend(map(float, action[0])) 88 | event.reward = reward 89 | 90 | def add_just_state(self, state): 91 | event = self.episode_entry.event.add() 92 | self.add_state_to_event(state, event) 93 | 94 | 95 | class EventLogReader(object): 96 | 97 | def __init__(self, path): 98 | if path.endswith(".gz"): 99 | self.log_file = gzip.open(path, "rb") 100 | else: 101 | self.log_file = open(path, "rb") 102 | 103 | def entries(self): 104 | episode = event_pb2.Episode() 105 | while True: 106 | buff_len_bytes = self.log_file.read(4) 107 | if len(buff_len_bytes) == 0: return 108 | buff_len = struct.unpack('=l', buff_len_bytes)[0] 109 | buff = self.log_file.read(buff_len) 110 | episode.ParseFromString(buff) 111 | yield episode 112 | 113 | def make_dir(d): 114 | if not os.path.exists(d): 115 | os.makedirs(d) 116 | 117 | if __name__ == "__main__": 118 | import argparse, os, sys, Image, ImageDraw 119 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 120 | parser.add_argument('--log-file', type=str, default=None) 121 | parser.add_argument('--echo', action='store_true', help="write event to stdout") 122 | parser.add_argument('--episodes', type=str, default=None, 123 | help="if set only process these specific episodes (comma separated list)") 124 | parser.add_argument('--img-output-dir', type=str, default=None, 125 | help="if set output all renders to this DIR/e_NUM/s_NUM.png") 126 | parser.add_argument('--img-debug-overlay', action='store_true', 127 | help="if set overlay image with debug info") 128 | # TODO args for episode range 129 | opts = parser.parse_args() 130 | 131 | episode_whitelist = None 132 | if opts.episodes is not None: 133 | episode_whitelist = set(map(int, opts.episodes.split(","))) 134 | 135 | if opts.img_output_dir is not None: 136 | make_dir(opts.img_output_dir) 137 | 138 | total_num_read_episodes = 0 139 | total_num_read_events = 0 140 | 141 | elr = EventLogReader(opts.log_file) 142 | for episode_id, episode in enumerate(elr.entries()): 143 | if episode_whitelist is not None and episode_id not in episode_whitelist: 144 | continue 145 | if opts.echo: 146 | print "-----", episode_id 147 | print episode 148 | total_num_read_episodes += 1 149 | total_num_read_events += len(episode.event) 150 | if opts.img_output_dir is not None: 151 | dir = "%s/ep_%05d" % (opts.img_output_dir, episode_id) 152 | make_dir(dir) 153 | make_dir(dir + "/c0") # HACK: assume only max two cameras 154 | make_dir(dir + "/c1") 155 | for event_id, event in enumerate(episode.event): 156 | for state_id, state in enumerate(event.state): 157 | for camera_id, render in enumerate(state.render): 158 | assert camera_id in [0, 1], "fix hack above" 159 | # open RGB png in an image canvas 160 | img = Image.open(StringIO.StringIO(render.png_bytes)) 161 | if opts.img_debug_overlay: 162 | canvas = ImageDraw.Draw(img) 163 | # draw episode and event number in top left 164 | canvas.text((0, 0), "%d %d" % (episode_id, event_id), fill="black") 165 | # draw simple fx/fy representation in bottom right... 166 | # a bounding box 167 | bx, by, bw = 40, 40, 10 168 | canvas.line((bx-bw,by-bw, bx+bw,by-bw, bx+bw,by+bw, bx-bw,by+bw, bx-bw,by-bw), fill="black") 169 | # then a simple fx/fy line 170 | fx, fy = event.action[0], event.action[1] 171 | canvas.line((bx,by, bx+(fx*bw), by+(fy*bw)), fill="black") 172 | # write it out 173 | img = img.resize((200, 200)) 174 | filename = "%s/c%d/e%05d_r%d.png" % (dir, camera_id, event_id, state_id) 175 | img.save(filename) 176 | print >>sys.stderr, "read", total_num_read_episodes, "episodes for a total of", total_num_read_events, "events" 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /event_log_sample.push_in_one_dir.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matpalm/cartpoleplusplus/12bdb92d2610db61df742959f17f0c42b0da62ce/event_log_sample.push_in_one_dir.gif -------------------------------------------------------------------------------- /exps/run_81.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | mkdir -p runs/81/{sgd,mom,adam}{,_bn} 5 | 6 | export ARGS="--use-raw-pixels --max-run-time=3600 --dont-do-rollouts --event-log-in=runs/80/events" 7 | 8 | export R=81/sgd 9 | ./naf_cartpole.py $ARGS \ 10 | --optimiser=GradientDescent --optimiser-args='{"learning_rate": 0.001}' \ 11 | --ckpt-dir=runs/$R/ckpts --event-log-out=runs/$R/events >runs/$R/out 2>runs/$R/err 12 | 13 | export R=81/sgd_bn 14 | ./naf_cartpole.py $ARGS \ 15 | --optimiser=GradientDescent --optimiser-args='{"learning_rate": 0.001}' \ 16 | --use-batch-norm \ 17 | --ckpt-dir=runs/$R/ckpts --event-log-out=runs/$R/events >runs/$R/out 2>runs/$R/err 18 | 19 | export R=81/mom 20 | ./naf_cartpole.py $ARGS \ 21 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.001, "momentum": 0.9}' \ 22 | --ckpt-dir=runs/$R/ckpts --event-log-out=runs/$R/events >runs/$R/out 2>runs/$R/err 23 | 24 | export R=81/mom_bn 25 | ./naf_cartpole.py $ARGS \ 26 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.001, "momentum": 0.9}' \ 27 | --use-batch-norm \ 28 | --ckpt-dir=runs/$R/ckpts --event-log-out=runs/$R/events >runs/$R/out 2>runs/$R/err 29 | 30 | export R=81/adam 31 | ./naf_cartpole.py $ARGS \ 32 | --optimiser=Adam --optimiser-args='{"learning_rate": 0.001}' \ 33 | --ckpt-dir=runs/$R/ckpts --event-log-out=runs/$R/events >runs/$R/out 2>runs/$R/err 34 | 35 | export R=81/adam_bn 36 | ./naf_cartpole.py $ARGS \ 37 | --optimiser=Adam --optimiser-args='{"learning_rate": 0.001}' \ 38 | --use-batch-norm \ 39 | --ckpt-dir=runs/$R/ckpts --event-log-out=runs/$R/events >runs/$R/out 2>runs/$R/err 40 | -------------------------------------------------------------------------------- /exps/run_82.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | R=runs/82 5 | 6 | mkdir -p $R/{sgd,mom}{,_bn} 7 | 8 | export ARGS="--use-raw-pixels --max-run-time=3600 --dont-do-rollouts --event-log-in=runs/80/events" 9 | 10 | export RR=$R/sgd 11 | ./naf_cartpole.py $ARGS \ 12 | --optimiser=GradientDescent --optimiser-args='{"learning_rate": 0.01}' \ 13 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 14 | 15 | export RR=$R/sgd_bn 16 | ./naf_cartpole.py $ARGS \ 17 | --optimiser=GradientDescent --optimiser-args='{"learning_rate": 0.01}' \ 18 | --use-batch-norm \ 19 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 20 | 21 | export RR=$R/mom 22 | ./naf_cartpole.py $ARGS \ 23 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 24 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 25 | 26 | export RR=$R/mom_bn 27 | ./naf_cartpole.py $ARGS \ 28 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 29 | --use-batch-norm \ 30 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 31 | -------------------------------------------------------------------------------- /exps/run_83.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | R=runs/83 5 | 6 | mkdir -p $R/mom{,_bn} 7 | 8 | export ARGS="--use-raw-pixels --max-run-time=14400 --dont-do-rollouts --event-log-in=runs/80/events" 9 | 10 | export RR=$R/mom 11 | ./naf_cartpole.py $ARGS \ 12 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 13 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 14 | 15 | export RR=$R/mom_bn 16 | ./naf_cartpole.py $ARGS \ 17 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 18 | --use-batch-norm \ 19 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 20 | -------------------------------------------------------------------------------- /exps/run_84.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | unset CUDA_VISIBLE_DEVICES # ensure running on gpu 4 | 5 | R=runs/84 6 | 7 | mkdir -p $R/{naf,ddpg}_mom_bn 8 | 9 | export ARGS="--use-raw-pixels --max-run-time=14400 --dont-do-rollouts --event-log-in=runs/80/events" 10 | 11 | export RR=$R/naf_mom_bn 12 | ./naf_cartpole.py $ARGS \ 13 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 14 | --use-batch-norm \ 15 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 16 | 17 | export RR=$R/ddpg_mom_bn 18 | ./ddpg_cartpole.py $ARGS \ 19 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 20 | --use-batch-norm \ 21 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 22 | -------------------------------------------------------------------------------- /exps/run_85.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | unset CUDA_VISIBLE_DEVICES # ensure running on gpu 4 | 5 | R=runs/85 6 | 7 | mkdir -p $R/{naf,ddpg} 8 | 9 | export ARGS="--use-raw-pixels --max-run-time=14400 --use-batch-norm --action-force=100" 10 | 11 | export RR=$R/naf 12 | ./naf_cartpole.py $ARGS \ 13 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 14 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 15 | 16 | export RR=$R/ddpg 17 | ./ddpg_cartpole.py $ARGS \ 18 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 19 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 20 | -------------------------------------------------------------------------------- /exps/run_86.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | unset CUDA_VISIBLE_DEVICES # ensure running on gpu 4 | 5 | R=runs/86 6 | 7 | mkdir -p $R/{naf,ddpg} 8 | 9 | export ARGS="--max-run-time=14400 --action-force=100" 10 | 11 | export RR=$R/naf 12 | ./naf_cartpole.py $ARGS \ 13 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 14 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 15 | 16 | export RR=$R/ddpg 17 | ./ddpg_cartpole.py $ARGS \ 18 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 19 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err 20 | -------------------------------------------------------------------------------- /exps/run_87.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | export CUDA_VISIBLE_DEVICES="" # ensure running on cpu 4 | 5 | R=runs/87 6 | 7 | mkdir -p $R/{repro1,repro2,bn,3repeats} 8 | 9 | # exactly repros run_86 10 | export RR=$R/repro1 11 | nice ./naf_cartpole.py --action-force=100 \ 12 | --action-repeats=2 --steps-per-repeat=5 \ 13 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 14 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 15 | 16 | # switchs to 6 repeats, instead of 5, so that's 12 total (like the 3x4 one below) 17 | export RR=$R/repro2 18 | nice ./naf_cartpole.py --action-force=100 \ 19 | --action-repeats=2 --steps-per-repeat=6 \ 20 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 21 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 22 | 23 | # with batch norm 24 | export RR=$R/bn 25 | nice ./naf_cartpole.py --action-force=100 \ 26 | --action-repeats=2 --steps-per-repeat=6 \ 27 | --use-batch-norm \ 28 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 29 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 30 | 31 | # with 3 action repeats (but still 12 total steps) 32 | export RR=$R/3repeats 33 | nice ./naf_cartpole.py --action-force=100 \ 34 | --action-repeats=3 --steps-per-repeat=4 \ 35 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 36 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 37 | -------------------------------------------------------------------------------- /exps/run_89.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | export CUDA_VISIBLE_DEVICES="" # ensure running on cpu 4 | 5 | R=runs/89 6 | 7 | export RR=$R/action 8 | mkdir -p $RR 9 | nice ./naf_cartpole.py --action-force=100 \ 10 | --reward-calc="action_norm" \ 11 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 12 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 13 | sleep 1 14 | 15 | export RR=$R/angle 16 | mkdir -p $RR 17 | nice ./naf_cartpole.py --action-force=100 \ 18 | --reward-calc="angles" \ 19 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 20 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 21 | sleep 1 22 | 23 | export RR=$R/both 24 | mkdir -p $RR 25 | nice ./naf_cartpole.py --action-force=100 \ 26 | --reward-calc="both" \ 27 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 28 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 29 | 30 | -------------------------------------------------------------------------------- /exps/run_90.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | export CUDA_VISIBLE_DEVICES="" # ensure running on cpu 4 | 5 | R=runs/90 6 | 7 | ARGS="--action-force=100 --action-repeats=3 --steps-per-repeat=4" 8 | 9 | export RR=$R/fixed 10 | mkdir -p $RR 11 | nice ./naf_cartpole.py $ARGS \ 12 | --reward-calc="fixed" \ 13 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 14 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 15 | sleep 1 16 | 17 | export RR=$R/angle 18 | mkdir -p $RR 19 | nice ./naf_cartpole.py $ARGS \ 20 | --reward-calc="angle" \ 21 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 22 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 23 | sleep 1 24 | 25 | export RR=$R/action 26 | mkdir -p $RR 27 | nice ./naf_cartpole.py $ARGS \ 28 | --reward-calc="action" \ 29 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 30 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 31 | 32 | export RR=$R/angle_action 33 | mkdir -p $RR 34 | nice ./naf_cartpole.py $ARGS \ 35 | --reward-calc="angle_action" \ 36 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 37 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 38 | 39 | -------------------------------------------------------------------------------- /exps/run_91.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | export CUDA_VISIBLE_DEVICES="" # ensure running on cpu 4 | 5 | R=runs/91 6 | 7 | export RR=$R/1_12 8 | mkdir -p $RR 9 | nice ./naf_cartpole.py --action-force=100 \ 10 | --action-repeats=1 --steps-per-repeat=12 \ 11 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 12 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 13 | 14 | export RR=$R/2_6 15 | mkdir -p $RR 16 | nice ./naf_cartpole.py --action-force=100 \ 17 | --action-repeats=2 --steps-per-repeat=6 \ 18 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 19 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 20 | 21 | export RR=$R/3_4 22 | mkdir -p $RR 23 | nice ./naf_cartpole.py --action-force=100 \ 24 | --action-repeats=3 --steps-per-repeat=4 \ 25 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 26 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 27 | 28 | export RR=$R/4_3 29 | mkdir -p $RR 30 | nice ./naf_cartpole.py --action-force=100 \ 31 | --action-repeats=4 --steps-per-repeat=3 \ 32 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 33 | --ckpt-dir=$RR/ckpts --event-log-out=$RR/events >$RR/out 2>$RR/err & 34 | -------------------------------------------------------------------------------- /exps/run_92.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | export CUDA_VISIBLE_DEVICES="" # ensure running on cpu 4 | 5 | R=runs/92 6 | 7 | export RR=$R/a 8 | mkdir -p $RR 9 | nice ./lrpg_cartpole.py --ckpt-dir=$RR/ckpts >$RR/out 2>$RR/err & 10 | 11 | export RR=$R/b 12 | mkdir -p $RR 13 | nice ./lrpg_cartpole.py --ckpt-dir=$RR/ckpts >$RR/out 2>$RR/err & 14 | 15 | 16 | -------------------------------------------------------------------------------- /exps/run_93.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | unset CUDA_VISIBLE_DEVICES # ensure running on gpu 4 | 5 | export R=runs/93 6 | mkdir $R 7 | 8 | ./naf_cartpole.py \ 9 | --use-raw-pixels \ 10 | --num-cameras=2 --action-repeats=3 --steps-per-repeat=4 \ 11 | --replay-memory-size=12000 \ 12 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 13 | --use-batch-norm \ 14 | --ckpt-dir=$R/ckpts --event-log-out=$R/events >$R/out 2>$R/err & 15 | -------------------------------------------------------------------------------- /exps/run_97.sh: -------------------------------------------------------------------------------- 1 | use_cpu 2 | export R=runs/97 3 | mkdir -p $R 4 | nice ./naf_cartpole.py \ 5 | --action-force=100 --action-repeats=3 --steps-per-repeat=4 \ 6 | --reward-calc="angle" \ 7 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.0001, "momentum": 0.9}' \ 8 | --ckpt-dir=$R/ckpts --event-log-out=$R/events >$R/out 2>$R/err 9 | -------------------------------------------------------------------------------- /exps/run_98.sh: -------------------------------------------------------------------------------- 1 | use_gpu 2 | R=runs/98/ 3 | mkdir $R 4 | nice ./naf_cartpole.py \ 5 | --use-raw-pixels --use-batch-norm \ 6 | --action-force=100 --action-repeats=3 --steps-per-repeat=4 --num-cameras=2 \ 7 | --replay-memory-size=200000 --replay-memory-burn-in=10000 \ 8 | --optimiser=Momentum --optimiser-args='{"learning_rate": 0.01, "momentum": 0.9}' \ 9 | --ckpt-dir=$R/ckpts --event-log-out=$R/events >$R/out 2>$R/err 10 | -------------------------------------------------------------------------------- /exps/runs_77_plot.sh: -------------------------------------------------------------------------------- 1 | cat runs/77a/out | grep EVAL | grep -v STEP | cut -f3 -d' ' | nl | sed -es/^\s*/a\\t/ > /tmp/p 2 | cat runs/77b/out | grep EVAL | grep -v STEP | cut -f3 -d' ' | nl | sed -es/^\s*/b\\t/ >> /tmp/p 3 | -------------------------------------------------------------------------------- /lrpg_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import bullet_cartpole 4 | import collections 5 | import datetime 6 | import gym 7 | import json 8 | import numpy as np 9 | import signal 10 | import sys 11 | import tensorflow as tf 12 | from tensorflow.python.ops import init_ops 13 | import time 14 | import util 15 | 16 | np.set_printoptions(precision=5, threshold=10000, suppress=True, linewidth=10000) 17 | 18 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--num-eval', type=int, default=0, 20 | help="if >0 just run this many episodes with no training") 21 | parser.add_argument('--max-num-actions', type=int, default=0, 22 | help="train for (at least) this number of actions (always finish current episode)" 23 | " ignore if <=0") 24 | parser.add_argument('--max-run-time', type=int, default=0, 25 | help="train for (at least) this number of seconds (always finish current episode)" 26 | " ignore if <=0") 27 | parser.add_argument('--ckpt-dir', type=str, default=None, help="if set save ckpts to this dir") 28 | parser.add_argument('--ckpt-freq', type=int, default=3600, help="freq (sec) to save ckpts") 29 | parser.add_argument('--hidden-layers', type=str, default="100,50", help="hidden layer sizes") 30 | parser.add_argument('--learning-rate', type=float, default=0.0001, help="learning rate") 31 | parser.add_argument('--num-train-batches', type=int, default=10, 32 | help="number of training batches to run") 33 | parser.add_argument('--rollouts-per-batch', type=int, default=10, 34 | help="number of rollouts to run for each training batch") 35 | parser.add_argument('--eval-action-noise', action='store_true', help="whether to use noise during eval") 36 | 37 | util.add_opts(parser) 38 | 39 | bullet_cartpole.add_opts(parser) 40 | opts = parser.parse_args() 41 | sys.stderr.write("%s\n" % opts) 42 | assert not opts.use_raw_pixels, "TODO: add convnet from ddpg here" 43 | 44 | # TODO: if we import slim _before_ building cartpole env we can't start bullet with GL gui o_O 45 | env = bullet_cartpole.BulletCartpole(opts=opts, discrete_actions=True) 46 | import base_network 47 | import tensorflow.contrib.slim as slim 48 | 49 | VERBOSE_DEBUG = False 50 | def toggle_verbose_debug(signal, frame): 51 | global VERBOSE_DEBUG 52 | VERBOSE_DEBUG = not VERBOSE_DEBUG 53 | signal.signal(signal.SIGUSR1, toggle_verbose_debug) 54 | 55 | DUMP_WEIGHTS = False 56 | def set_dump_weights(signal, frame): 57 | global DUMP_WEIGHTS 58 | DUMP_WEIGHTS = True 59 | signal.signal(signal.SIGUSR2, set_dump_weights) 60 | 61 | 62 | class LikelihoodRatioPolicyGradientAgent(base_network.Network): 63 | 64 | def __init__(self, env): 65 | self.env = env 66 | 67 | num_actions = self.env.action_space.n 68 | 69 | # we have three place holders we'll use... 70 | # observations; used either during rollout to sample some actions, or 71 | # during training when combined with actions_taken and advantages. 72 | shape_with_batch = [None] + list(self.env.observation_space.shape) 73 | self.observations = tf.placeholder(shape=shape_with_batch, 74 | dtype=tf.float32) 75 | # the actions we took during rollout 76 | self.actions = tf.placeholder(tf.int32, name='actions') 77 | # the advantages we got from taken 'action_taken' in 'observation' 78 | self.advantages = tf.placeholder(tf.float32, name='advantages') 79 | 80 | # our model is a very simple MLP 81 | with tf.variable_scope("model"): 82 | # stack of hidden layers on flattened input; (batch,2,2,7) -> (batch,28) 83 | flat_input_state = slim.flatten(self.observations, scope='flat') 84 | final_hidden = self.hidden_layers_starting_at(flat_input_state, 85 | opts.hidden_layers) 86 | logits = slim.fully_connected(inputs=final_hidden, 87 | num_outputs=num_actions, 88 | activation_fn=None) 89 | 90 | # in the eval case just pick arg max 91 | self.action_argmax = tf.argmax(logits, 1) 92 | 93 | # for rollouts we need an op that samples actions from this 94 | # model to give a stochastic action. 95 | sample_action = tf.multinomial(logits, num_samples=1) 96 | self.sampled_action_op = tf.reshape(sample_action, shape=[]) 97 | 98 | # we are trying to maximise the product of two components... 99 | # 1) the log_p of "good" actions. 100 | # 2) the advantage term based on the rewards from actions. 101 | 102 | # first we need the log_p values for each observation for the actions we specifically 103 | # took by sampling... we first run a log_softmax over the action logits to get 104 | # probabilities. 105 | log_softmax = tf.nn.log_softmax(logits) 106 | self.debug_softmax = tf.exp(log_softmax) 107 | 108 | # we then use a mask to only select the elements of the softmaxs that correspond 109 | # to the actions we actually took. we could also do this by complex indexing and a 110 | # gather but i always think this is more natural. the "cost" of dealing with the 111 | # mostly zero one hot, as opposed to doing a gather on sparse indexes, isn't a big 112 | # deal when the number of observations is >> number of actions. 113 | action_mask = tf.one_hot(indices=self.actions, depth=num_actions) 114 | action_log_prob = tf.reduce_sum(log_softmax * action_mask, reduction_indices=1) 115 | 116 | # the (element wise) product of these action log_p's with the total reward of the 117 | # episode represents the quantity we want to maximise. we standardise the advantage 118 | # values so roughly 1/2 +ve / -ve as a variance control. 119 | action_mul_advantages = tf.mul(action_log_prob, 120 | util.standardise(self.advantages)) 121 | self.loss = -tf.reduce_sum(action_mul_advantages) # recall: we are maximising. 122 | with tf.variable_scope("optimiser"): 123 | # dynamically create optimiser based on opts 124 | optimiser = util.construct_optimiser(opts) 125 | # calc gradients 126 | gradients = optimiser.compute_gradients(self.loss) 127 | # potentially clip and wrap with debugging tf.Print 128 | gradients = util.clip_and_debug_gradients(gradients, opts) 129 | # apply 130 | self.train_op = optimiser.apply_gradients(gradients) 131 | 132 | def sample_action_given(self, observation, doing_eval=False): 133 | """ sample one action given observation""" 134 | if doing_eval: 135 | sao, sm = tf.get_default_session().run([self.sampled_action_op, self.debug_softmax], 136 | feed_dict={self.observations: [observation]}) 137 | print "EVAL sm ", sm, "action", sao 138 | return sao 139 | 140 | # epilson greedy "noise" will do for this simple case.. 141 | if np.random.random() < 0.1: 142 | return self.env.action_space.sample() 143 | 144 | # sample from logits 145 | return tf.get_default_session().run(self.sampled_action_op, 146 | feed_dict={self.observations: [observation]}) 147 | 148 | 149 | def rollout(self, doing_eval=False): 150 | """ run one episode collecting observations, actions and advantages""" 151 | observations, actions, rewards = [], [], [] 152 | observation = self.env.reset() 153 | done = False 154 | while not done: 155 | observations.append(observation) 156 | action = self.sample_action_given(observation, doing_eval) 157 | assert action != 5, "FAIL! (multinomial logits sampling bug?" 158 | observation, reward, done, _ = self.env.step(action) 159 | actions.append(action) 160 | rewards.append(reward) 161 | if VERBOSE_DEBUG: 162 | print "rollout: actions=%s" % (actions) 163 | return observations, actions, rewards 164 | 165 | def train(self, observations, actions, advantages): 166 | """ take one training step given observations, actions and subsequent advantages""" 167 | if VERBOSE_DEBUG: 168 | print "TRAIN" 169 | print "observations", np.stack(observations) 170 | print "actions", actions 171 | print "advantages", advantages 172 | _, loss = tf.get_default_session().run([self.train_op, self.loss], 173 | feed_dict={self.observations: observations, 174 | self.actions: actions, 175 | self.advantages: advantages}) 176 | 177 | else: 178 | _, loss = tf.get_default_session().run([self.train_op, self.loss], 179 | feed_dict={self.observations: observations, 180 | self.actions: actions, 181 | self.advantages: advantages}) 182 | return float(loss) 183 | 184 | def post_var_init_setup(self): 185 | pass 186 | 187 | def run_training(self, max_num_actions, max_run_time, rollouts_per_batch, 188 | saver_util): 189 | # log start time, in case we are limiting by time... 190 | start_time = time.time() 191 | 192 | # run for some max number of actions 193 | num_actions_taken = 0 194 | n = 0 195 | while True: 196 | total_rewards = [] 197 | losses = [] 198 | 199 | # perform a number of rollouts 200 | batch_observations, batch_actions, batch_advantages = [], [], [] 201 | 202 | for _ in xrange(rollouts_per_batch): 203 | observations, actions, rewards = self.rollout() 204 | batch_observations += observations 205 | batch_actions += actions 206 | # train with advantages, not per observation/action rewards. 207 | # _every_ observation/action in this rollout gets assigned 208 | # the _total_ reward of the episode. (crazy that this works!) 209 | batch_advantages += [sum(rewards)] * len(rewards) 210 | # keep total rewards just for debugging / stats 211 | total_rewards.append(sum(rewards)) 212 | 213 | if min(total_rewards) == max(total_rewards): 214 | # converged ?? 215 | sys.stderr.write("converged? standardisation of advantaged will barf here....\n") 216 | loss = 0 217 | else: 218 | loss = self.train(batch_observations, batch_actions, batch_advantages) 219 | losses.append(loss) 220 | 221 | # dump some stats and progress info 222 | stats = collections.OrderedDict() 223 | stats["time"] = time.time() 224 | stats["n"] = n 225 | stats["mean_losses"] = float(np.mean(losses)) 226 | stats["total_reward"] = np.sum(total_rewards) 227 | stats["episode_len"] = len(rewards) 228 | 229 | print "STATS %s\t%s" % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 230 | json.dumps(stats)) 231 | sys.stdout.flush() 232 | n += 1 233 | 234 | # save if required 235 | if saver_util is not None: 236 | saver_util.save_if_required() 237 | 238 | # emit occasional eval 239 | if VERBOSE_DEBUG or n % 10 == 0: 240 | self.run_eval(1) 241 | 242 | # dump weights once if requested 243 | global DUMP_WEIGHTS 244 | if DUMP_WEIGHTS: 245 | self.debug_dump_network_weights() 246 | DUMP_WEIGHTS = False 247 | 248 | # exit when finished 249 | num_actions_taken += len(rewards) 250 | if max_num_actions > 0 and num_actions_taken > max_num_actions: 251 | break 252 | if max_run_time > 0 and time.time() > start_time + max_run_time: 253 | break 254 | 255 | 256 | def run_eval(self, num_episodes, add_noise=False): 257 | for _ in xrange(num_episodes): 258 | _, _, rewards = self.rollout(doing_eval=True) 259 | print sum(rewards) 260 | 261 | def debug_dump_network_weights(self): 262 | fn = "/tmp/weights.%s" % time.time() 263 | with open(fn, "w") as f: 264 | f.write("DUMP time %s\n" % time.time()) 265 | for var in tf.all_variables(): 266 | f.write("VAR %s %s\n" % (var.name, var.get_shape())) 267 | f.write("%s\n" % var.eval()) 268 | print "weights written to", fn 269 | 270 | 271 | def main(): 272 | config = tf.ConfigProto() 273 | # config.gpu_options.allow_growth = True 274 | # config.log_device_placement = True 275 | with tf.Session(config=config) as sess: 276 | agent = LikelihoodRatioPolicyGradientAgent(env) 277 | 278 | # setup saver util and either load latest ckpt or init variables 279 | saver_util = None 280 | if opts.ckpt_dir is not None: 281 | saver_util = util.SaverUtil(sess, opts.ckpt_dir, opts.ckpt_freq) 282 | else: 283 | sess.run(tf.initialize_all_variables()) 284 | 285 | for v in tf.all_variables(): 286 | print >>sys.stderr, v.name, util.shape_and_product_of(v) 287 | 288 | # now that we've either init'd from scratch, or loaded up a checkpoint, 289 | # we can do any required post init work. 290 | agent.post_var_init_setup() 291 | 292 | # run either eval or training 293 | if opts.num_eval > 0: 294 | agent.run_eval(opts.num_eval, opts.eval_action_noise) 295 | else: 296 | agent.run_training(opts.max_num_actions, opts.max_run_time, 297 | opts.rollouts_per_batch, 298 | saver_util) 299 | if saver_util is not None: 300 | saver_util.force_save() 301 | 302 | env.reset() # just to flush logging, clumsy :/ 303 | 304 | if __name__ == "__main__": 305 | main() 306 | -------------------------------------------------------------------------------- /make_plots.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | mkdir /tmp/plots 3 | R --vanilla < plots.R 4 | 5 | -------------------------------------------------------------------------------- /models/cart.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 |   4 |     5 |       6 |       7 |       8 |       9 |       10 |     11 |     12 |       13 |       14 |       15 |     16 |     17 |       18 |       19 |         20 |       21 |       22 |         23 |       24 |     25 |     26 |       27 |       28 |         29 |       30 |     31 |   32 | 33 | -------------------------------------------------------------------------------- /models/ground.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 |   4 |     5 |       6 |       7 |       8 |     9 |     10 |       11 |       12 |         13 |       14 |       15 |         16 |       17 |     18 |     19 |       20 |       21 |         22 |       23 |     24 |   25 | 26 | -------------------------------------------------------------------------------- /models/pole.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 |   4 |     5 |       6 |       7 |       8 |       9 |       10 |     11 |     12 |       13 |       14 |       15 |     16 |     17 |       18 |       19 |         20 |       21 |       22 |         23 |       24 |     25 |     26 |       27 |       28 |         29 |       30 |     31 |   32 | 33 | -------------------------------------------------------------------------------- /naf_cartpole.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import bullet_cartpole 4 | import collections 5 | import datetime 6 | import gym 7 | import json 8 | import numpy as np 9 | import replay_memory 10 | import signal 11 | import sys 12 | import tensorflow as tf 13 | import time 14 | import util 15 | 16 | np.set_printoptions(precision=5, threshold=10000, suppress=True, linewidth=10000) 17 | 18 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 19 | parser.add_argument('--num-eval', type=int, default=0, 20 | help="if >0 just run this many episodes with no training") 21 | parser.add_argument('--max-num-actions', type=int, default=0, 22 | help="train for (at least) this number of actions (always finish" 23 | " current episode) ignore if <=0") 24 | parser.add_argument('--max-run-time', type=int, default=0, 25 | help="train for (at least) this number of seconds (always finish" 26 | " current episode) ignore if <=0") 27 | parser.add_argument('--ckpt-dir', type=str, default=None, 28 | help="if set save ckpts to this dir") 29 | parser.add_argument('--ckpt-freq', type=int, default=3600, help="freq (sec) to save ckpts") 30 | parser.add_argument('--batch-size', type=int, default=128, help="training batch size") 31 | parser.add_argument('--batches-per-step', type=int, default=5, 32 | help="number of batches to train per step") 33 | parser.add_argument('--dont-do-rollouts', action="store_true", 34 | help="by dft we do rollouts to generate data then train after each rollout. if this flag is set we" 35 | " dont do any rollouts. this only makes sense to do if --event-log-in set.") 36 | parser.add_argument('--target-update-rate', type=float, default=0.0001, 37 | help="affine combo for updating target networks each time we run a" 38 | " training batch") 39 | # TODO params per value, P, output_action networks? 40 | parser.add_argument('--share-input-state-representation', action='store_true', 41 | help="if set we have one network for processing input state that is" 42 | " shared between value, l_value and output_action networks. if" 43 | " not set each net has it's own network.") 44 | parser.add_argument('--hidden-layers', type=str, default="100,50", 45 | help="hidden layer sizes") 46 | parser.add_argument('--use-batch-norm', action='store_true', 47 | help="whether to use batch norm on conv layers") 48 | parser.add_argument('--discount', type=float, default=0.99, 49 | help="discount for RHS of bellman equation update") 50 | parser.add_argument('--event-log-in', type=str, default=None, 51 | help="prepopulate replay memory with entries from this event log") 52 | parser.add_argument('--replay-memory-size', type=int, default=22000, 53 | help="max size of replay memory") 54 | parser.add_argument('--replay-memory-burn-in', type=int, default=1000, 55 | help="dont train from replay memory until it reaches this size") 56 | parser.add_argument('--eval-action-noise', action='store_true', 57 | help="whether to use noise during eval") 58 | parser.add_argument('--action-noise-theta', type=float, default=0.01, 59 | help="OrnsteinUhlenbeckNoise theta (rate of change) param for action" 60 | " exploration") 61 | parser.add_argument('--action-noise-sigma', type=float, default=0.05, 62 | help="OrnsteinUhlenbeckNoise sigma (magnitude) param for action" 63 | " exploration") 64 | parser.add_argument('--gpu-mem-fraction', type=float, default=None, 65 | help="if not none use only this fraction of gpu memory") 66 | 67 | util.add_opts(parser) 68 | 69 | bullet_cartpole.add_opts(parser) 70 | opts = parser.parse_args() 71 | sys.stderr.write("%s\n" % opts) 72 | 73 | # TODO: check that if --dont-do-rollouts set then --event-log-in also set 74 | 75 | # TODO: if we import slim before cartpole env we can't start bullet withGL gui o_O 76 | env = bullet_cartpole.BulletCartpole(opts=opts, discrete_actions=False) 77 | import base_network 78 | import tensorflow.contrib.slim as slim 79 | 80 | VERBOSE_DEBUG = False 81 | def toggle_verbose_debug(signal, frame): 82 | global VERBOSE_DEBUG 83 | VERBOSE_DEBUG = not VERBOSE_DEBUG 84 | signal.signal(signal.SIGUSR1, toggle_verbose_debug) 85 | 86 | DUMP_WEIGHTS = False 87 | def set_dump_weights(signal, frame): 88 | global DUMP_WEIGHTS 89 | DUMP_WEIGHTS = True 90 | signal.signal(signal.SIGUSR2, set_dump_weights) 91 | 92 | 93 | class ValueNetwork(base_network.Network): 94 | """ Value network component of a NAF network. Created as seperate net because it has a target network.""" 95 | 96 | def __init__(self, namespace, input_state, hidden_layer_config): 97 | super(ValueNetwork, self).__init__(namespace) 98 | 99 | self.input_state = input_state 100 | 101 | with tf.variable_scope(namespace): 102 | # expose self.input_state_representation since it will be the network "shared" 103 | # by l_value & output_action network when running --share-input-state-representation 104 | self.input_state_representation = self.input_state_network(input_state, opts) 105 | self.value = slim.fully_connected(scope='fc', 106 | inputs=self.input_state_representation, 107 | num_outputs=1, 108 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.01), 109 | activation_fn=None) # (batch, 1) 110 | 111 | def value_given(self, state): 112 | return tf.get_default_session().run(self.value, 113 | feed_dict={self.input_state: state, 114 | base_network.IS_TRAINING: False}) 115 | 116 | 117 | class NafNetwork(base_network.Network): 118 | 119 | def __init__(self, namespace, 120 | input_state, input_state_2, 121 | value_net, target_value_net, 122 | action_dim): 123 | super(NafNetwork, self).__init__(namespace) 124 | 125 | # noise to apply to actions during rollouts 126 | self.exploration_noise = util.OrnsteinUhlenbeckNoise(action_dim, 127 | opts.action_noise_theta, 128 | opts.action_noise_sigma) 129 | 130 | # we already have the V networks, created independently because it also 131 | # has a target network. 132 | self.value_net = value_net 133 | self.target_value_net = target_value_net 134 | 135 | # keep placeholders provided and build any others required 136 | self.input_state = input_state 137 | self.input_state_2 = input_state_2 138 | self.input_action = tf.placeholder(shape=[None, action_dim], 139 | dtype=tf.float32, name="input_action") 140 | self.reward = tf.placeholder(shape=[None, 1], 141 | dtype=tf.float32, name="reward") 142 | self.terminal_mask = tf.placeholder(shape=[None, 1], 143 | dtype=tf.float32, name="terminal_mask") 144 | 145 | # TODO: dont actually use terminal mask? 146 | 147 | with tf.variable_scope(namespace): 148 | # mu (output_action) is also a simple NN mapping input state -> action 149 | # this is our target op for inference (i.e. value that maximises Q given input_state) 150 | with tf.variable_scope("output_action"): 151 | if opts.share_input_state_representation: 152 | input_representation = value_net.input_state_representation 153 | else: 154 | input_representation = self.input_state_network(self.input_state, opts) 155 | weights_initializer = tf.random_uniform_initializer(-0.001, 0.001) 156 | self.output_action = slim.fully_connected(scope='fc', 157 | inputs=input_representation, 158 | num_outputs=action_dim, 159 | weights_initializer=weights_initializer, 160 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.01), 161 | activation_fn=tf.nn.tanh) # (batch, action_dim) 162 | 163 | # A (advantage) is a bit more work and has three components... 164 | # first the u / mu difference. note: to use in a matmul we need 165 | # to convert this vector into a matrix by adding an "unused" 166 | # trailing dimension 167 | u_mu_diff = self.input_action - self.output_action # (batch, action_dim) 168 | u_mu_diff = tf.expand_dims(u_mu_diff, -1) # (batch, action_dim, 1) 169 | 170 | # next we have P = L(x).L(x)_T where L is the values of lower triangular 171 | # matrix with diagonals exp'd. yikes! 172 | 173 | # first the L lower triangular values; a network on top of the input state 174 | num_l_values = (action_dim*(action_dim+1))/2 175 | with tf.variable_scope("l_values"): 176 | if opts.share_input_state_representation: 177 | input_representation = value_net.input_state_representation 178 | else: 179 | input_representation = self.input_state_network(self.input_state, opts) 180 | l_values = slim.fully_connected(scope='fc', 181 | inputs=input_representation, 182 | num_outputs=num_l_values, 183 | weights_regularizer=tf.contrib.layers.l2_regularizer(0.01), 184 | activation_fn=None) 185 | 186 | # we will convert these l_values into a matrix one row at a time. 187 | rows = [] 188 | 189 | self._l_values = l_values 190 | 191 | # each row is made of three components; 192 | # 1) the lower part of the matrix, i.e. elements to the left of diagonal 193 | # 2) the single diagonal element (that we exponentiate) 194 | # 3) the upper part of the matrix; all zeros 195 | batch_size = tf.shape(l_values)[0] 196 | row_idx = 0 197 | for row_idx in xrange(action_dim): 198 | row_offset_in_l = (row_idx*(row_idx+1))/2 199 | lower = tf.slice(l_values, begin=(0, row_offset_in_l), size=(-1, row_idx)) 200 | diag = tf.exp(tf.slice(l_values, begin=(0, row_offset_in_l+row_idx), size=(-1, 1))) 201 | upper = tf.zeros((batch_size, action_dim - tf.shape(lower)[1] - 1)) # -1 for diag 202 | rows.append(tf.concat(1, [lower, diag, upper])) 203 | # full L matrix is these rows packed. 204 | L = tf.pack(rows, 0) 205 | # and since leading axis in l was always the batch 206 | # we need to transpose it back to axis0 again 207 | L = tf.transpose(L, (1, 0, 2)) # (batch_size, action_dim, action_dim) 208 | self.check_L = tf.check_numerics(L, "L") 209 | 210 | # P is L.L_T 211 | L_T = tf.transpose(L, (0, 2, 1)) # TODO: update tf & use batch_matrix_transpose 212 | P = tf.batch_matmul(L, L_T) # (batch_size, action_dim, action_dim) 213 | 214 | # can now calculate advantage 215 | u_mu_diff_T = tf.transpose(u_mu_diff, (0, 2, 1)) 216 | advantage = -0.5 * tf.batch_matmul(u_mu_diff_T, tf.batch_matmul(P, u_mu_diff)) # (batch, 1, 1) 217 | # and finally we need to reshape off the axis we added to be able to matmul 218 | self.advantage = tf.reshape(advantage, [-1, 1]) # (batch, 1) 219 | 220 | # Q is value + advantage 221 | self.q_value = value_net.value + self.advantage 222 | 223 | # target y is reward + discounted target value 224 | # TODO: pull discount out 225 | self.target_y = self.reward + (self.terminal_mask * opts.discount * \ 226 | target_value_net.value) 227 | self.target_y = tf.stop_gradient(self.target_y) 228 | 229 | # loss is squared difference that we want to minimise. 230 | self.loss = tf.reduce_mean(tf.pow(self.q_value - self.target_y, 2)) 231 | with tf.variable_scope("optimiser"): 232 | # dynamically create optimiser based on opts 233 | optimiser = util.construct_optimiser(opts) 234 | # calc gradients 235 | gradients = optimiser.compute_gradients(self.loss) 236 | # potentially clip and wrap with debugging tf.Print 237 | gradients = util.clip_and_debug_gradients(gradients, opts) 238 | # apply 239 | self.train_op = optimiser.apply_gradients(gradients) 240 | 241 | # sanity checks (in the dependent order) 242 | checks = [] 243 | for op, name in [(l_values, 'l_values'), (L,'L'), (self.loss, 'loss')]: 244 | checks.append(tf.check_numerics(op, name)) 245 | self.check_numerics = tf.group(*checks) 246 | 247 | def action_given(self, state, add_noise): 248 | # NOTE: noise is added _outside_ tf graph. we do this simply because the noisy output 249 | # is never used for any part of computation graph required for online training. it's 250 | # only used during training after being the replay buffer. 251 | actions = tf.get_default_session().run(self.output_action, 252 | feed_dict={self.input_state: [state], 253 | base_network.IS_TRAINING: False}) 254 | 255 | if add_noise: 256 | if VERBOSE_DEBUG: 257 | pre_noise = str(actions) 258 | actions[0] += self.exploration_noise.sample() 259 | actions = np.clip(1, -1, actions) # action output is _always_ (-1, 1) 260 | if VERBOSE_DEBUG: 261 | print "TRAIN action_given pre_noise %s post_noise %s" % (pre_noise, actions) 262 | return actions 263 | 264 | def train(self, batch): 265 | _, _, l = tf.get_default_session().run([self.check_numerics, self.train_op, self.loss], 266 | feed_dict={self.input_state: batch.state_1, 267 | self.input_action: batch.action, 268 | self.reward: batch.reward, 269 | self.terminal_mask: batch.terminal_mask, 270 | self.input_state_2: batch.state_2, 271 | base_network.IS_TRAINING: True}) 272 | return l 273 | 274 | def debug_values(self, batch): 275 | values = tf.get_default_session().run([self._l_values, self.loss, self.value_net.value, 276 | self.advantage, self.target_value_net.value], 277 | feed_dict={self.input_state: batch.state_1, 278 | self.input_action: batch.action, 279 | self.reward: batch.reward, 280 | self.terminal_mask: batch.terminal_mask, 281 | self.input_state_2: batch.state_2, 282 | base_network.IS_TRAINING: False}) 283 | values = [np.squeeze(v) for v in values] 284 | return values 285 | 286 | 287 | class NormalizedAdvantageFunctionAgent(object): 288 | def __init__(self, env): 289 | self.env = env 290 | state_shape = self.env.observation_space.shape 291 | action_dim = self.env.action_space.shape[1] 292 | 293 | # for now, with single machine synchronous training, use a replay memory for training. 294 | # TODO: switch back to async training with multiple replicas (as in drivebot project) 295 | self.replay_memory = replay_memory.ReplayMemory(opts.replay_memory_size, 296 | state_shape, action_dim) 297 | 298 | # s1 and s2 placeholders 299 | batched_state_shape = [None] + list(state_shape) 300 | s1 = tf.placeholder(shape=batched_state_shape, dtype=tf.float32) 301 | s2 = tf.placeholder(shape=batched_state_shape, dtype=tf.float32) 302 | 303 | # initialise base models for value & naf networks. value subportion of net is 304 | # explicitly created seperate because it has a target network note: in the case of 305 | # --share-input-state-representation the input state network of the value_net will 306 | # be reused by the naf.l_value and naf.output_actions net 307 | self.value_net = ValueNetwork("value", s1, opts.hidden_layers) 308 | self.target_value_net = ValueNetwork("target_value", s2, opts.hidden_layers) 309 | self.naf = NafNetwork("naf", s1, s2, 310 | self.value_net, self.target_value_net, 311 | action_dim) 312 | 313 | def post_var_init_setup(self): 314 | # prepopulate replay memory (if configured to do so) 315 | # TODO: rewrite!!! 316 | if opts.event_log_in: 317 | self.replay_memory.reset_from_event_log(opts.event_log_in) 318 | # hook networks up to their targets 319 | # ( does one off clobber to init all vars in target network ) 320 | self.target_value_net.set_as_target_network_for(self.value_net, 321 | opts.target_update_rate) 322 | 323 | def run_training(self, max_num_actions, max_run_time, batch_size, batches_per_step, 324 | saver_util): 325 | # log start time, in case we are limiting by time... 326 | start_time = time.time() 327 | 328 | # run for some max number of actions 329 | num_actions_taken = 0 330 | n = 0 331 | while True: 332 | rewards = [] 333 | losses = [] 334 | 335 | # run an episode 336 | if opts.dont_do_rollouts: 337 | # _not_ gathering experience online 338 | pass 339 | else: 340 | # start a new episode 341 | state_1 = self.env.reset() 342 | # prepare data for updating replay memory at end of episode 343 | initial_state = np.copy(state_1) 344 | action_reward_state_sequence = [] 345 | episode_start = time.time() 346 | done = False 347 | while not done: 348 | # choose action 349 | action = self.naf.action_given(state_1, add_noise=True) 350 | # take action step in env 351 | state_2, reward, done, _ = self.env.step(action) 352 | rewards.append(reward) 353 | # cache for adding to replay memory 354 | action_reward_state_sequence.append((action, reward, np.copy(state_2))) 355 | # roll state for next step. 356 | state_1 = state_2 357 | # at end of episode update replay memory 358 | print "episode_took", time.time() - episode_start, len(rewards) 359 | 360 | replay_add_start = time.time() 361 | self.replay_memory.add_episode(initial_state, action_reward_state_sequence) 362 | print "replay_took", time.time() - replay_add_start 363 | 364 | # do a training step (after waiting for buffer to fill a bit...) 365 | if self.replay_memory.size() > opts.replay_memory_burn_in: 366 | # run a set of batches 367 | for _ in xrange(batches_per_step): 368 | batch_start = time.time() 369 | batch = self.replay_memory.batch(batch_size) 370 | losses.append(self.naf.train(batch)) 371 | print "batch_took", time.time() - batch_start 372 | # update target nets 373 | self.target_value_net.update_weights() 374 | # do debug (if requested) on last batch 375 | if VERBOSE_DEBUG: 376 | print "-----" 377 | print "> BATCH" 378 | print "state_1", batch.state_1.T 379 | print "action\n", batch.action.T 380 | print "reward ", batch.reward.T 381 | print "terminal_mask ", batch.terminal_mask.T 382 | print "state_2", batch.state_2.T 383 | print "< BATCH" 384 | l_values, l, v, a, vp = self.naf.debug_values(batch) 385 | print "> BATCH DEBUG VALUES" 386 | print "l_values\n", l_values.T 387 | print "loss\t", l 388 | print "val\t" , np.mean(v), "\t", v.T 389 | print "adv\t", np.mean(a), "\t", a.T 390 | print "val'\t", np.mean(vp), "\t", vp.T 391 | print "< BATCH DEBUG VALUES" 392 | 393 | # dump some stats and progress info 394 | stats = collections.OrderedDict() 395 | stats["time"] = time.time() 396 | stats["n"] = n 397 | stats["mean_losses"] = float(np.mean(losses)) 398 | stats["total_reward"] = np.sum(rewards) 399 | stats["episode_len"] = len(rewards) 400 | stats["replay_memory_stats"] = self.replay_memory.current_stats() 401 | print "STATS %s\t%s" % (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 402 | json.dumps(stats)) 403 | sys.stdout.flush() 404 | n += 1 405 | 406 | # save if required 407 | if saver_util is not None: 408 | saver_util.save_if_required() 409 | 410 | # emit occasional eval 411 | if VERBOSE_DEBUG or n % 10 == 0: 412 | self.run_eval(1) 413 | 414 | # dump weights once if requested 415 | global DUMP_WEIGHTS 416 | if DUMP_WEIGHTS: 417 | self.debug_dump_network_weights() 418 | DUMP_WEIGHTS = False 419 | 420 | # exit when finished 421 | num_actions_taken += len(rewards) 422 | if max_num_actions > 0 and num_actions_taken > max_num_actions: 423 | break 424 | if max_run_time > 0 and time.time() > start_time + max_run_time: 425 | break 426 | 427 | def run_eval(self, num_episodes, add_noise=False): 428 | """ run num_episodes of eval and output episode length and rewards """ 429 | for i in xrange(num_episodes): 430 | state = self.env.reset() 431 | total_reward = 0 432 | steps = 0 433 | done = False 434 | while not done: 435 | action = self.naf.action_given(state, add_noise) 436 | state, reward, done, _ = self.env.step(action) 437 | print "EVALSTEP e%d s%d action=%s (l2=%s) => reward %s" % (i, steps, action, 438 | np.linalg.norm(action), reward) 439 | total_reward += reward 440 | steps += 1 441 | if False: # RENDER ALL STATES / ACTIVATIONS to /tmp 442 | self.naf.render_all_convnet_activations(steps, self.naf.input_state, state) 443 | util.render_state_to_png(steps, state) 444 | util.render_action_to_png(steps, action) 445 | print "EVAL", i, steps, total_reward 446 | sys.stdout.flush() 447 | 448 | def debug_dump_network_weights(self): 449 | fn = "/tmp/weights.%s" % time.time() 450 | with open(fn, "w") as f: 451 | f.write("DUMP time %s\n" % time.time()) 452 | for var in tf.all_variables(): 453 | f.write("VAR %s %s\n" % (var.name, var.get_shape())) 454 | f.write("%s\n" % var.eval()) 455 | print "weights written to", fn 456 | 457 | 458 | def main(): 459 | config = tf.ConfigProto() 460 | # config.gpu_options.allow_growth = True 461 | # config.log_device_placement = True 462 | if opts.gpu_mem_fraction is not None: 463 | config.gpu_options.per_process_gpu_memory_fraction = opts.gpu_mem_fraction 464 | with tf.Session(config=config) as sess: 465 | agent = NormalizedAdvantageFunctionAgent(env=env) 466 | 467 | # setup saver util and either load latest ckpt or init variables 468 | saver_util = None 469 | if opts.ckpt_dir is not None: 470 | saver_util = util.SaverUtil(sess, opts.ckpt_dir, opts.ckpt_freq) 471 | else: 472 | sess.run(tf.initialize_all_variables()) 473 | 474 | for v in tf.all_variables(): 475 | print >>sys.stderr, v.name, util.shape_and_product_of(v) 476 | 477 | # now that we've either init'd from scratch, or loaded up a checkpoint, 478 | # we can do any required post init work. 479 | agent.post_var_init_setup() 480 | 481 | # run either eval or training 482 | if opts.num_eval > 0: 483 | agent.run_eval(opts.num_eval, opts.eval_action_noise) 484 | else: 485 | agent.run_training(opts.max_num_actions, opts.max_run_time, 486 | opts.batch_size, opts.batches_per_step, 487 | saver_util) 488 | if saver_util is not None: 489 | saver_util.force_save() 490 | 491 | env.reset() # just to flush logging, clumsy :/ 492 | 493 | if __name__ == "__main__": 494 | main() 495 | -------------------------------------------------------------------------------- /plots.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | 3 | df = read.delim("/tmp/f", s=" ", h=F, col.names=c("run", "length", "reward")) 4 | df$n = 1:nrow(df) 5 | head(df) 6 | ggplot(df, aes()) 7 | 8 | # df = read.delim("/tmp/actions", h=T, sep=" ") 9 | # png("/tmp/plots/00a_pre_noise_x_y_scatter.png", width=300, height=300) 10 | # ggplot(df[df$type=='pre',], aes(x, y)) + geom_bin2d() + labs(title="x pre noise") 11 | # dev.off() 12 | # png("/tmp/plots/00b_post_noise_x_y_scatter.png", width=300, height=300) 13 | # ggplot(df[df$type=='post',], aes(x, y)) + geom_bin2d() + labs(title="x post noise") 14 | # dev.off() 15 | # png("/tmp/plots/00c_x_over_time.png", width=640, height=400) 16 | # ggplot(df, aes(episode, x)) + geom_point(alpha=0.1) + geom_smooth() + facet_grid(type~.) + labs(title="x over time") 17 | # dev.off() 18 | # png("/tmp/plots/00d_y_over_time.png", width=640, height=400) 19 | # ggplot(df, aes(episode, y)) + geom_point(alpha=0.1) + geom_smooth() + facet_grid(type~.) + labs(title="yx over time") 20 | # dev.off() 21 | 22 | df = read.delim("/tmp/q_values", h=T, sep=" ") 23 | png("/tmp/plots/05a_action_q_values.png", width=640, height=320) 24 | ggplot(df, aes(episode, q_value)) + geom_point(alpha=0.2, aes(color=net_type)) + geom_smooth(aes(color=net_type)) + labs(title="q values over time") 25 | dev.off() 26 | 27 | df = read.delim("/tmp/episode_stats", h=T, sep=" ") 28 | png("/tmp/plots/06a_episode_len.png", width=640, height=320) 29 | ggplot(df, aes(episode, len)) + geom_point(alpha=0.2) + geom_smooth() + labs(title="episode len") 30 | dev.off() 31 | png("/tmp/plots/06b_episode_rewards.png", width=640, height=320) 32 | ggplot(df, aes(episode, total_reward)) + geom_point(alpha=0.2) + geom_smooth() + labs(title="episode total reward") 33 | dev.off() 34 | png("/tmp/plots/06c_episode_stats.png", width=320, height=320) 35 | ggplot(df, aes(len, total_reward)) + geom_point(alpha=0.2) + labs(title="episode step vs reward") 36 | dev.off() 37 | 38 | df = read.delim("/tmp/eval", h=T, sep=" ") 39 | png("/tmp/plots/07a_eval_episode_len.png", width=640, height=320) 40 | ggplot(df, aes(episode, steps)) + geom_point(alpha=0.2) + geom_smooth() + labs(title="eval episode len") 41 | dev.off() 42 | png("/tmp/plots/07b_eval_total_reward.png", width=640, height=320) 43 | ggplot(df, aes(episode, total_reward)) + geom_point(alpha=0.2) + geom_smooth() + labs(title="eval total reward") 44 | dev.off() 45 | 46 | # df = read.delim("/tmp/batch_num_terminal", h=T, sep=" ") 47 | # png("/tmp/plots/08_batch_num_terminal.png", width=640, height=320) 48 | # ggplot(df, aes(episode, batch_num_terminals)) + geom_point(alpha=0.2) + geom_smooth() + labs(# title="batch num terminal") 49 | # dev.off() 50 | 51 | df = read.delim("/tmp/gradient_l2_norms", sep=" ") 52 | png("/tmp/plots/09_gradient_l2_norms.png", width=640, height=320) 53 | ggplot(df, aes(time, l2_norm)) + 54 | geom_point(alpha=0.1, aes(color=source)) + 55 | geom_smooth(aes(color=source)) 56 | dev.off() 57 | 58 | df = read.delim("/tmp/q_loss", h=T, sep=" ") 59 | png("/tmp/plots/10_q_loss.png", width=640, height=320) 60 | ggplot(df, aes(episode, q_loss)) + geom_point(alpha=0.1) + geom_smooth() + labs(title="critic training q loss") 61 | dev.off() 62 | 63 | # df = read.delim("/tmp/replay_memory_size", h=F) 64 | # df$n = 1:nrow(df) 65 | # png("/tmp/plots/09_replay_memory_size.png", width=640, height=320) 66 | # ggplot(df, aes(n, V1)) + geom_point() + labs(title="replay memory size") 67 | # dev.off() 68 | -------------------------------------------------------------------------------- /random_action_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import bullet_cartpole 4 | import random 5 | import time 6 | 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | parser.add_argument('--actions', type=str, default='0,1,2,3,4', 9 | help='comma seperated list of actions to pick from, if env is discrete') 10 | parser.add_argument('--num-eval', type=int, default=1000) 11 | parser.add_argument('--action-type', type=str, default='discrete', 12 | help="either 'discrete' or 'continuous'") 13 | bullet_cartpole.add_opts(parser) 14 | opts = parser.parse_args() 15 | 16 | actions = map(int, opts.actions.split(",")) 17 | 18 | if opts.action_type == 'discrete': 19 | discrete_actions = True 20 | elif opts.action_type == 'continuous': 21 | discrete_actions = False 22 | else: 23 | raise Exception("Unknown action type [%s]" % opts.action_type) 24 | 25 | env = bullet_cartpole.BulletCartpole(opts=opts, discrete_actions=discrete_actions) 26 | 27 | for _ in xrange(opts.num_eval): 28 | env.reset() 29 | done = False 30 | total_reward = 0 31 | steps = 0 32 | while not done: 33 | if discrete_actions: 34 | action = random.choice(actions) 35 | else: 36 | action = env.action_space.sample() 37 | _state, reward, done, info = env.step(action) 38 | steps += 1 39 | total_reward += reward 40 | if opts.max_episode_len is not None and steps > opts.max_episode_len: 41 | break 42 | print total_reward 43 | 44 | env.reset() # hack to flush last event log if required 45 | 46 | -------------------------------------------------------------------------------- /random_plots.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | library(plyr) 3 | library(lubridate) 4 | library(reshape) 5 | 6 | # coloured on single plot 7 | df = read.delim("/tmp/f", h=F, sep=" ", col.names=c("run", "episode_len", "reward")) 8 | df = ddply(df, .(run), mutate, n=seq_along(episode_len)) # adds n seq number distinst per run 9 | ggplot(df, aes(n, reward)) + 10 | geom_point(alpha=0.1, aes(color=run)) + 11 | geom_smooth(aes(color=run)) 12 | 13 | # grid 14 | df = read.delim("/tmp/f", h=F, sep=" ", col.names=c("run", "episode_len", "reward")) 15 | df = ddply(df, .(run), mutate, n=seq_along(episode_len)) # adds n seq number distinst per run 16 | ggplot(df, aes(n, reward)) + 17 | geom_point(alpha=0.5) + 18 | geom_smooth() + facet_grid(~run) 19 | 20 | # density 21 | df = read.delim("/tmp/g", h=F, sep=" ") 22 | df$time_per_step = df$V2 / df$V3 23 | head(df) 24 | ggplot(df, aes(time_per_step*200)) + geom_histogram() 25 | 26 | x = seq(0, 1.41, 0.005) 27 | y = (1.5 - x) ** 5 28 | plot(x, y) 29 | 30 | x = seq(0, 0.6, 0.001) 31 | y = 7 * (0.4 + 0.6 - x) ** 10 32 | plot(x, y) 33 | 34 | df = data.frame() 35 | df$ 36 | df$ 37 | ggplot(df, aes(x, y)) + geom_point() 38 | 39 | df = data.frame() 40 | df$ 41 | df$ 42 | head(df) 43 | ggplot(df, aes(x, y)) + geom_point() 44 | 45 | df = read.delim("/home/mat/dev/cartpole++/rewards.action", h=F) 46 | ggplot(df, aes(V1)) + geom_density() 47 | 48 | df = read.delim("/tmp/f", h=F, sep="\t", col.names=c("dts", "eval")) 49 | df$dts = ymd_hms(df$dts) 50 | ggplot(df, aes(dts, eval)) + geom_point(alpha=0.2) + geom_smooth() 51 | 52 | df = read.delim("/tmp/q", h=F, col.names=c("R", "angles", "actions")) 53 | df = df[c("angles", "actions")] 54 | df$both = df$angles + df$actions 55 | df$n = seq(1:nrow(df)) 56 | df <- melt(df, id=c("n")) 57 | ggplot(df, aes(n, value)) + geom_point(aes(color=variable)) 58 | 59 | df = read.delim("/tmp/outs", h=F, sep=" ", col.names=c("run", "n", "r")) 60 | summary(df) 61 | ggplot(df, aes(n, r)) + 62 | geom_point(alpha=0.1, aes(color=run)) + 63 | geom_smooth(aes(color=run)) + 64 | facet_grid(~run) -------------------------------------------------------------------------------- /replay_memory.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import event_log 3 | import numpy as np 4 | import sys 5 | import tensorflow as tf 6 | import time 7 | import util 8 | 9 | Batch = collections.namedtuple("Batch", "state_1 action reward terminal_mask state_2") 10 | 11 | class ReplayMemory(object): 12 | def __init__(self, buffer_size, state_shape, action_dim, load_factor=1.5): 13 | assert load_factor >= 1.5, "load_factor has to be at least 1.5" 14 | self.buffer_size = buffer_size 15 | self.state_shape = state_shape 16 | self.insert = 0 17 | self.full = False 18 | 19 | # the elements of the replay memory. each event represents a row in the following 20 | # five matrices. 21 | self.state_1_idx = np.empty(buffer_size, dtype=np.int32) 22 | self.action = np.empty((buffer_size, action_dim), dtype=np.float32) 23 | self.reward = np.empty((buffer_size, 1), dtype=np.float32) 24 | self.terminal_mask = np.empty((buffer_size, 1), dtype=np.float32) 25 | self.state_2_idx = np.empty(buffer_size, dtype=np.int32) 26 | 27 | # states themselves, since they can either be state_1 or state_2 in an event 28 | # are stored in a separate matrix. it is sized fractionally larger than the replay 29 | # memory since a rollout of length n contains n+1 states. 30 | self.state_buffer_size = int(buffer_size*load_factor) 31 | shape = [self.state_buffer_size] + list(state_shape) 32 | self.state = np.empty(shape, dtype=np.float16) 33 | 34 | # keep track of free slots in state buffer 35 | self.state_free_slots = list(range(self.state_buffer_size)) 36 | 37 | # some stats 38 | self.stats = collections.Counter() 39 | 40 | def reset_from_event_log(self, log_file): 41 | elr = event_log.EventLogReader(log_file) 42 | num_episodes = 0 43 | num_events = 0 44 | start = time.time() 45 | for episode in elr.entries(): 46 | initial_state = None 47 | action_reward_state_sequence = [] 48 | for event_id, event in enumerate(episode.event): 49 | if event_id == 0: 50 | assert len(event.action) == 0 51 | assert not event.HasField("reward") 52 | initial_state = event_log.read_state_from_event(event) 53 | else: 54 | action_reward_state_sequence.append((event.action, event.reward, 55 | event_log.read_state_from_event(event))) 56 | num_events += 1 57 | num_episodes += 1 58 | self.add_episode(initial_state, action_reward_state_sequence) 59 | if self.full: 60 | break 61 | print >>sys.stderr, "reset_from_event_log \"%s\" num_episodes=%d num_events=%d took %s sec" % (log_file, num_episodes, num_events, time.time()-start) 62 | 63 | def add_episode(self, initial_state, action_reward_state_sequence): 64 | self.stats['>add_episode'] += 1 65 | assert len(action_reward_state_sequence) > 0 66 | state_1_idx = self.state_free_slots.pop(0) 67 | self.state[state_1_idx] = initial_state 68 | for n, (action, reward, state_2) in enumerate(action_reward_state_sequence): 69 | terminal = n == len(action_reward_state_sequence)-1 70 | state_2_idx = self._add(state_1_idx, action, reward, terminal, state_2) 71 | state_1_idx = state_2_idx 72 | 73 | def _add(self, s1_idx, a, r, t, s2): 74 | # print ">add s1_idx=%s, a=%s, r=%s, t=%s" % (s1_idx, a, r, t) 75 | 76 | self.stats['>add'] += 1 77 | assert s1_idx >= 0, s1_idx 78 | assert s1_idx < self.state_buffer_size, s1_idx 79 | assert s1_idx not in self.state_free_slots, s1_idx 80 | 81 | if self.full: 82 | # are are about to overwrite an existing entry. 83 | # we always free the state_1 slot we are about to clobber... 84 | self.state_free_slots.append(self.state_1_idx[self.insert]) 85 | # print "full; so free slot", self.state_1_idx[self.insert] 86 | # and we free the state_2 slot also if the slot is a terminal event 87 | # (since that implies no other event uses this state_2 as a state_1) 88 | # self.stats['cache_evicted_s1'] += 1 89 | if self.terminal_mask[self.insert] == 0: 90 | self.state_free_slots.append(self.state_2_idx[self.insert]) 91 | # print "also, since terminal, free", self.state_2_idx[self.insert] 92 | self.stats['cache_evicted_s2'] += 1 93 | 94 | # add s1, a, r 95 | self.state_1_idx[self.insert] = s1_idx 96 | self.action[self.insert] = a 97 | self.reward[self.insert] = r 98 | 99 | # if terminal we set terminal mask to 0.0 representing the masking of the righthand 100 | # side of the bellman equation 101 | self.terminal_mask[self.insert] = 0.0 if t else 1.0 102 | 103 | # state_2 is fully provided so we need to prepare a new slot for it 104 | s2_idx = self.state_free_slots.pop(0) 105 | self.state_2_idx[self.insert] = s2_idx 106 | self.state[s2_idx] = s2 107 | 108 | # move insert ptr forward 109 | self.insert += 1 110 | if self.insert >= self.buffer_size: 111 | self.insert = 0 112 | self.full = True 113 | 114 | # print "batch'] += 1 133 | idxs = self.random_indexes(batch_size) 134 | return Batch(np.copy(self.state[self.state_1_idx[idxs]]), 135 | np.copy(self.action[idxs]), 136 | np.copy(self.reward[idxs]), 137 | np.copy(self.terminal_mask[idxs]), 138 | np.copy(self.state[self.state_2_idx[idxs]])) 139 | 140 | def dump(self): 141 | print ">>>> dump" 142 | print "insert", self.insert 143 | print "full?", self.full 144 | print "state free slots", util.collapsed_successive_ranges(self.state_free_slots) 145 | if self.insert==0 and not self.full: 146 | print "EMPTY!" 147 | else: 148 | idxs = range(self.buffer_size if self.full else self.insert) 149 | for idx in idxs: 150 | print "idx", idx, 151 | print "state_1_idx", self.state_1_idx[idx], 152 | print "state_1", self.state[self.state_1_idx[idx]] 153 | print "action", self.action[idx], 154 | print "reward", self.reward[idx], 155 | print "terminal_mask", self.terminal_mask[idx], 156 | print "state_2_idx", self.state_2_idx[idx] 157 | print "state_2", self.state[self.state_2_idx[idx]] 158 | print "<<<< dump" 159 | 160 | def current_stats(self): 161 | current_stats = dict(self.stats) 162 | current_stats["free_slots"] = len(self.state_free_slots) 163 | return current_stats 164 | 165 | 166 | if __name__ == "__main__": 167 | # LATE NIGHT SUPER HACK SOAK TEST. I WILL PAY FOR THIS HACK LATER !!!! 168 | rm = ReplayMemory(buffer_size=43, state_shape=(2,3), action_dim=2) 169 | def s(i): # state for insert i 170 | i = (i * 10) % 199 171 | return [[i+1,0,0],[0,0,0]] 172 | def ars(i): # action, reward, state_2 for insert i 173 | return ((i,0), i, s(i)) 174 | def FAILDOG(b, i, d): # dump batch and rm in case of assertion 175 | print "FAILDOG", i, d 176 | print b 177 | rm.dump() 178 | assert False 179 | def check_batch_valid(b): # check batch is valid by consistency of how we build elements 180 | for i in range(3): 181 | r = int(b.reward[i][0]) 182 | if b.state_1[i][0][0] != (((r-1)*10)%199)+1: FAILDOG(b, i, "s1") 183 | if b.action[i][0] != r: FAILDOG(b, i, "r") 184 | if b.terminal_mask[i] != (0 if r in terminals else 1): FAILDOG(b, i, "r") 185 | if b.state_2[i][0][0] != ((r*10)%199)+1: FAILDOG(b, i, "s2") 186 | terminals = set() 187 | i = 0 188 | import random 189 | while True: 190 | initial_state = s(i) 191 | action_reward_state_sequence = [] 192 | episode_len = int(3 + (random.random() * 5)) 193 | for _ in range(episode_len): 194 | i += 1 195 | action_reward_state_sequence.append(ars(i)) 196 | rm.add_episode(initial_state, action_reward_state_sequence) 197 | terminals.add(i) 198 | print rm.stats 199 | for _ in range(7): check_batch_valid(rm.batch(13)) 200 | i += 1 201 | -------------------------------------------------------------------------------- /replay_memory_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import random 4 | import tensorflow as tf 5 | import time 6 | import unittest 7 | from util import StopWatch 8 | from replay_memory import ReplayMemory 9 | 10 | class TestReplayMemory(unittest.TestCase): 11 | def setUp(self): 12 | self.sess = tf.Session() 13 | self.rm = ReplayMemory(self.sess, buffer_size=3, state_shape=(2, 3), action_dim=2, load_factor=2) 14 | self.sess.run(tf.initialize_all_variables()) 15 | 16 | def assert_np_eq(self, a, b): 17 | self.assertTrue(np.all(np.equal(a, b))) 18 | 19 | def test_empty_memory(self): 20 | # api 21 | self.assertEqual(self.rm.size(), 0) 22 | self.assertEqual(self.rm.random_indexes(), []) 23 | b = self.rm.batch(4) 24 | self.assertEqual(len(b), 5) 25 | for i in range(5): 26 | self.assertEqual(len(b[i]), 0) 27 | 28 | # internals 29 | self.assertEqual(self.rm.insert, 0) 30 | self.assertEqual(self.rm.full, False) 31 | 32 | def test_adds_to_full(self): 33 | # add entries to full 34 | initial_state = [[11,12,13], [14,15,16]] 35 | action_reward_state = [(17, 18, [[21,22,23],[24,25,26]]), 36 | (27, 28, [[31,32,33],[34,35,36]]), 37 | (37, 38, [[41,42,43],[44,45,46]])] 38 | self.rm.add_episode(initial_state, action_reward_state) 39 | 40 | # api 41 | self.assertEqual(self.rm.size(), 3) 42 | # random_idxs are valid 43 | idxs = self.rm.random_indexes(n=100) 44 | self.assertEqual(len(idxs), 100) 45 | self.assertEquals(sorted(set(idxs)), [0,1,2]) 46 | # batch returns values 47 | 48 | # internals 49 | self.assertEqual(self.rm.insert, 0) 50 | self.assertEqual(self.rm.full, True) 51 | # check state contains these entries 52 | state = self.rm.sess.run(self.rm.state) 53 | self.assertEqual(state[0][0][0], 11) 54 | self.assertEqual(state[1][0][0], 21) 55 | self.assertEqual(state[2][0][0], 31) 56 | self.assertEqual(state[3][0][0], 41) 57 | 58 | def test_adds_over_full(self): 59 | def s_for(i): 60 | return (np.array(range(1,7))+(10*i)).reshape(2, 3) 61 | 62 | # add one episode of 5 states; 0X -> 4X 63 | initial_state = s_for(0) 64 | action_reward_state = [] 65 | for i in range(1, 5): 66 | a, r, s2 = (i*10)+7, (i*10)+8, s_for(i) 67 | action_reward_state.append((a, r, s2)) 68 | self.rm.add_episode(initial_state, action_reward_state) 69 | # add another episode of 4 states; 5X -> 8X 70 | initial_state = s_for(5) 71 | action_reward_state = [] 72 | for i in range(6, 9): 73 | a, r, s2 = (i*10)+7, (i*10)+8, s_for(i) 74 | action_reward_state.append((a, r, s2)) 75 | self.rm.add_episode(initial_state, action_reward_state) 76 | 77 | # api 78 | self.assertEqual(self.rm.size(), 3) 79 | # random_idxs are valid 80 | idxs = self.rm.random_indexes(n=100) 81 | self.assertEqual(len(idxs), 100) 82 | self.assertEquals(sorted(set(idxs)), [0,1,2]) 83 | # fetch a batch, of all items 84 | batch = self.rm.batch(idxs=[0,1,2]) 85 | self.assert_np_eq(batch.reward, [[88], [68], [78]]) 86 | self.assert_np_eq(batch.terminal_mask, [[0], [1], [1]]) 87 | 88 | def test_large_var(self): 89 | ### python replay_memory_test.py TestReplayMemory.test_large_var 90 | 91 | s = StopWatch() 92 | 93 | state_shape = (50, 50, 6) 94 | s.reset() 95 | rm = ReplayMemory(self.sess, buffer_size=10000, state_shape=state_shape, action_dim=2, load_factor=1.5) 96 | self.sess.run(tf.initialize_all_variables()) 97 | print "cstr_and_init", s.time() 98 | 99 | bs1, bs1i, bs2, bs2i = rm.batch_ops() 100 | 101 | # build a simple, useless, net that uses state_1 & state_2 idxs 102 | # we want this to reduce to a single value to minimise data coming 103 | # back from GPU 104 | added_states = bs1 + bs2 105 | total_value = tf.reduce_sum(added_states) 106 | 107 | def random_s(): 108 | return np.random.random(state_shape) 109 | 110 | for i in xrange(10): 111 | # add an episode to rm 112 | episode_len = random.choice([5,7,9,10,15]) 113 | initial_state = random_s() 114 | action_reward_state = [] 115 | for i in range(i+1, i+episode_len+1): 116 | a, r, s2 = (i*10)+7, (i*10)+8, random_s() 117 | action_reward_state.append((a, r, s2)) 118 | start = time.time() 119 | s.reset() 120 | rm.add_episode(initial_state, action_reward_state) 121 | t = s.time() 122 | num_states = len(action_reward_state)+1 123 | print "add_episode_time", t, "#states=", num_states, "=> s/state", t/num_states 124 | i += episode_len + 1 125 | 126 | # get a random batch state 127 | b = rm.batch(batch_size=128) 128 | s.reset() 129 | x = self.sess.run(total_value, feed_dict={bs1i: b.state_1_idx, 130 | bs2i: b.state_2_idx}) 131 | print "fetch_and_run", x, s.time() 132 | 133 | 134 | def test_soak(self): 135 | state_shape = (50,50,6) 136 | rm = ReplayMemory(self.sess, buffer_size=10000, 137 | state_shape=state_shape, action_dim=2, load_factor=1.5) 138 | self.sess.run(tf.initialize_all_variables()) 139 | def s_for(i): 140 | return np.random.random(state_shape) 141 | import random 142 | i = 0 143 | for e in xrange(10000): 144 | # add an episode to rm 145 | episode_len = random.choice([5,7,9,10,15]) 146 | initial_state = s_for(i) 147 | action_reward_state = [] 148 | for i in range(i+1, i+episode_len+1): 149 | a, r, s2 = (i*10)+7, (i*10)+8, s_for(i) 150 | action_reward_state.append((a, r, s2)) 151 | rm.add_episode(initial_state, action_reward_state) 152 | i += episode_len + 1 153 | # dump 154 | print rm.current_stats() 155 | # fetch a batch, of all items, but do nothing with it. 156 | _ = rm.batch(idxs=range(10)) 157 | 158 | 159 | 160 | if __name__ == '__main__': 161 | unittest.main() 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /run_diff.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # slurp the argparse debug lines from two 'run/NNN/err' files and show a side by side diff 3 | 4 | import sys, re 5 | 6 | def config(f): 7 | c = {} 8 | for line in open("runs/%s/err" % f, "r"): 9 | m = re.match("^Namespace\((.*)\)$", line) 10 | if m: 11 | config = m.group(1) 12 | while config: 13 | if "," in config: 14 | m = re.match("^((.*?)=(.*?),\s+)", config) 15 | pair, key, value = m.groups() 16 | if value.startswith("'"): 17 | # value is a string, need special handling since it might contain a comma! 18 | m = re.match("^((.*?)='(.*?)',\s+)", config) 19 | pair, key, value = m.groups() 20 | else: 21 | # final entry 22 | pair = config 23 | key, value = config.split("=") 24 | # ignore run related keys 25 | if key not in ['ckpt_dir', 'event_log_out']: 26 | c[key] = value 27 | config = config.replace(pair, "") 28 | return c 29 | assert False, "no namespace in file?" 30 | 31 | c1 = config(sys.argv[1]) 32 | c2 = config(sys.argv[2]) 33 | 34 | #unset_defaults = {"use_dropout": "False"} 35 | 36 | data = [] 37 | max_widths = [0,0,0] 38 | for key in sorted(set(c1.keys()).union(c2.keys())): 39 | c1v = c1[key] if key in c1 else " "#unset_defaults[key] 40 | c2v = c2[key] if key in c2 else " "#unset_defaults[key] 41 | data.append((key, c1v, c2v)) 42 | max_widths[0] = max(max_widths[0], len(key)) 43 | max_widths[1] = max(max_widths[1], len(c1v)) 44 | max_widths[2] = max(max_widths[2], len(c2v)) 45 | for k, c1, c2 in data: 46 | format_str = "%%%ds %%%ds %%%ds %%s" % tuple(max_widths) 47 | star = " ***" if c1 != c2 else " " 48 | print format_str % (k, c1, c2, star) 49 | -------------------------------------------------------------------------------- /stitch_activations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import glob 3 | import Image, ImageDraw, ImageChops 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | # hacktastic stitching of activation renderings from an eval run 8 | 9 | step = 1 10 | while True: 11 | if not os.path.isfile("/tmp/state_s%03d_c0_r0.png" % step): 12 | print "rendered to step", step 13 | break 14 | background = Image.new('RGB', (600, 250), (0, 0, 0)) 15 | canvas = ImageDraw.Draw(background) 16 | canvas.text((0, 0), str(step)) 17 | canvas.text((30, 0), "c0") 18 | canvas.text((85, 0), "c1") 19 | canvas.text((0, 30), "r0") 20 | canvas.text((0, 85), "r1") 21 | canvas.text((55, 130), "diffs") 22 | # draw cameras and repeats up in top corner, with a difference below them 23 | for c in [0, 1]: 24 | r0 = Image.open("/tmp/state_s%03d_c%d_r0.png" % (step, c)) 25 | r1 = Image.open("/tmp/state_s%03d_c%d_r1.png" % (step, c)) 26 | background.paste(r0, (15+c*55, 15)) 27 | background.paste(r1, (15+c*55, 70)) 28 | diff = ImageChops.invert(ImageChops.difference(r0, r1)) 29 | background.paste(diff, (15+c*55, 145)) 30 | # 3 conv layers, 10 filters each 31 | for p in range(3): 32 | for f in range(10): 33 | act = Image.open("/tmp/activation_s%03d_p%d_f%02d.png" % (step, p, f)) 34 | act = act.resize((40, 40)) 35 | background.paste(act, (130+(f*45), 20+(p*45))) 36 | # down the bottom draw in the force magnitude 37 | background.paste(Image.open("/tmp/action_%03d.png" % step), (250, 170)) 38 | # write image 39 | background.save("/tmp/viz_%03d.png" % step) 40 | step += 1 41 | 42 | -------------------------------------------------------------------------------- /u/parse_gradient_logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys, re 3 | from collections import Counter 4 | freq = Counter() 5 | nth = 5 6 | for line in sys.stdin: 7 | m = re.match(".* gradient (.*?) l2_norm \[(.*?)\]", line) 8 | if not m: continue 9 | name, val = m.groups() 10 | freq[name] += 1 11 | if freq[name] % nth == 0: 12 | print "%d\t%s\t%s" % (freq[name], name, val) 13 | 14 | -------------------------------------------------------------------------------- /u/parse_out.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse, sys, re, json 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('files', metavar='F', type=str, nargs='+', 7 | help='files to process') 8 | parser.add_argument('--pg-emit-all', action='store_true', 9 | help="if set emit all rewards in pg case. if false just emit stats") 10 | parser.add_argument('--nth', type=int, default=1, help="emit every nth") 11 | opts = parser.parse_args() 12 | 13 | # fields we output 14 | KEYS = ["run_id", "exp", "replica", "episode", "loss", "r_min", "r_mean", "r_max"] 15 | 16 | # tsv header 17 | print "\t".join(KEYS) 18 | 19 | # emit a record. 20 | n = 0 21 | def emit(data): 22 | global n 23 | if n % opts.nth == 0: 24 | print "\t".join(map(str, [data[key] for key in KEYS])) 25 | n += 1 26 | 27 | for filename in opts.files: 28 | run_id = filename.replace(".out", "").replace(".stats", "") 29 | run_info = {"run_id": run_id, "exp": run_id[:-1], "replica": run_id[-1]} 30 | 31 | for line in open(filename, "r"): 32 | if line.startswith("STATS"): 33 | # looks like entry from the policy gradient code... 34 | try: 35 | d = json.loads(re.sub("STATS.*?\t", "", line)) 36 | d.update(run_info) 37 | d['episode'] = d['batch'] 38 | d['r_min'] = np.min(d['rewards']) 39 | d['r_mean'] = np.mean(d['rewards']) 40 | d['r_max'] = np.max(d['rewards']) 41 | emit(d) 42 | except ValueError: 43 | pass # partial line? 44 | else: 45 | # old file format? or just noise? ignore.... 46 | pass 47 | -------------------------------------------------------------------------------- /u/parse_out_ddpg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse, sys, re, json 3 | import numpy as np 4 | from collections import Counter 5 | 6 | #f_actions= open("/tmp/actions", "w") 7 | #f_actions.write("time type x y\n") 8 | f_q_values = open("/tmp/q_values", "w") 9 | f_q_values.write("episode net_type q_value\n") 10 | f_episode_len = open("/tmp/episode_stats", "w") 11 | f_episode_len.write("episode len total_reward\n") 12 | f_eval = open("/tmp/eval", "w") 13 | f_eval.write("episode steps total_reward\n") 14 | #f_batch_num_terminal = open("/tmp/batch_num_terminal", "w") 15 | #f_batch_num_terminal.write("time batch_num_terminals\n") 16 | f_gradient_l2_norms = open("/tmp/gradient_l2_norms", "w") 17 | f_gradient_l2_norms.write("time source l2_norm\n") 18 | f_q_loss = open("/tmp/q_loss", "w") 19 | f_q_loss.write("episode q_loss\n") 20 | 21 | freq = Counter() 22 | emit_freq = {"EVAL": 1, "ACTOR_L2_NORM": 1, "CRITIC_L2_NORM": 20, 23 | "Q LOSS": 1, "EXPECTED_Q_VALUES": 1} 24 | def should_emit(tag): 25 | freq[tag] += 1 26 | return freq[tag] % (emit_freq[tag] if tag in emit_freq else 100) == 0 27 | 28 | n_parse_errors = 0 29 | 30 | episode = None 31 | 32 | for line in sys.stdin: 33 | line = line.strip() 34 | 35 | if line.startswith("STATS"): 36 | cols = line.split("\t") 37 | assert len(cols) == 2, line 38 | try: 39 | d = json.loads(cols[1]) 40 | if should_emit("EPISODE_LEN"): 41 | episode = d["episode"] 42 | total_reward = d["total_reward"] 43 | episode_len = d["episode_len"] if "episode_len" in d else total_reward 44 | time = d["time"] 45 | f_episode_len.write("%s %s %s\n" % (episode, episode_len, total_reward)) 46 | except ValueError: 47 | # interleaving output :/ 48 | n_parse_errors += 1 49 | 50 | if "actor gradient" in line and should_emit("ACTOR_L2_NORM"): 51 | m = re.match(".*actor gradient (.*) l2_norm pre \[(.*?)\]", line) 52 | var_id, norm = m.groups() 53 | f_gradient_l2_norms.write("%s actor_%s %s\n" % (freq["ACTOR_L2_NORM"], var_id, norm)) 54 | continue 55 | 56 | if "critic gradient l2_norm" in line and should_emit("CRITIC_L2_NORM"): 57 | norm = re.sub(".*\[", "", line).replace("]", "") 58 | f_gradient_l2_norms.write("%s critic %s\n" % (freq["CRITIC_L2_NORM"], norm)) 59 | continue 60 | 61 | # elif line.startswith("ACTIONS") and should_emit("ACTIONS"): 62 | # m = re.match("ACTIONS\t\[(.*), (.*)\]\t\[(.*), (.*)\]", line) 63 | # if m: 64 | # pre_x, pre_y, post_x, post_y = m.groups() 65 | # f_actions.write("%s pre %s %s\n" % (time, pre_x, pre_y)) 66 | # f_actions.write("%s post %s %s\n" % (time, post_x, post_y)) 67 | 68 | if "temporal_difference_loss" in line and should_emit("Q LOSS"): 69 | m = re.match(".*temporal_difference_loss\[(.*?)\]", line) 70 | tdl, = m.groups() 71 | f_q_loss.write("%s %s\n" % (freq["Q LOSS"], tdl)) 72 | continue 73 | 74 | # everything else requires episode for keying 75 | if episode is None: 76 | continue 77 | 78 | elif line.startswith("EXPECTED_Q_VALUES") and should_emit("EXPECTED_Q_VALUES"): 79 | cols = line.split(" ") 80 | assert len(cols) == 3 81 | assert cols[0] == "EXPECTED_Q_VALUES" 82 | f_q_values.write("%s main %f\n" % (episode, float(cols[1]))) 83 | f_q_values.write("%s target %f\n" % (episode, float(cols[2]))) 84 | 85 | elif line.startswith("EVAL") and \ 86 | not line.startswith("EVALSTEP") and \ 87 | should_emit("EVAL"): 88 | cols = line.split(" ") 89 | if len(cols) == 2: # OLD FORMAT 90 | tag, steps = cols 91 | total_reward = steps 92 | elif len(cols) == 3: 93 | tag, episode, steps, total_reward = cols 94 | elif len(cols) == 4: 95 | tag, _, steps, total_reward = cols 96 | else: 97 | assert False, line 98 | assert tag == "EVAL" 99 | assert steps >= 0 100 | assert total_reward >= 0 101 | f_eval.write("%s %s %s\n" % (episode, steps, total_reward)) 102 | 103 | # elif line.startswith("NUM_TERMINALS_IN_BATCH") and should_emit("NUM_TERMINALS_IN_BATCH"): 104 | # cols = line.split(" ") 105 | # assert len(cols) == 2 106 | # assert cols[0] == "NUM_TERMINALS_IN_BATCH" 107 | # f_batch_num_terminal.write("%s %f\n" % (episode, float(cols[1]))) 108 | 109 | 110 | print "n_parse_errors", n_parse_errors 111 | print freq 112 | 113 | -------------------------------------------------------------------------------- /u/parse_out_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # grep EVAL $* | grep -v EVALSTEP | perl -plne's/runs\///;s/:EVAL / /;s/\/out//;' | cut -f1,3,4 -d' ' 3 | from multiprocessing import Pool 4 | import sys, re 5 | 6 | def process(filename): 7 | col0 = filename.replace("runs/", "").replace("/out", "") 8 | time = None 9 | first_time = None 10 | for line in open(filename).readlines(): 11 | if line.startswith("STATS"): 12 | m = re.match(".*\"time\": (.*?),", line) 13 | time = float(m.group(1)) 14 | if first_time is None: first_time = time 15 | continue 16 | if not line.startswith("EVAL"): continue 17 | if line.startswith("EVALSTEP"): continue 18 | if time is None: continue 19 | sys.stdout.write("%s %s %s" % (col0, (time-first_time), line.replace("EVAL 0 ", ""))) 20 | sys.stdout.flush() 21 | 22 | #if __name__ == '__main__': 23 | # p = Pool(5) 24 | # p.map(process, sys.argv[1:]) 25 | for filename in sys.argv[1:]: 26 | process(filename) 27 | 28 | -------------------------------------------------------------------------------- /u/parse_out_eval_with_time.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys, re, json 3 | import numpy as np 4 | 5 | timestamp = None 6 | for line in sys.stdin: 7 | if line.startswith("STATS"): 8 | m = re.match("STATS (.*)\t", line) 9 | timestamp = m.group(1) 10 | elif line.startswith("EVAL"): 11 | if "STEP" not in line: 12 | if timestamp is None: 13 | continue 14 | cols = line.split(" ") 15 | assert len(cols) == 4 16 | total_reward = cols[2] 17 | print "\t".join([timestamp, total_reward]) 18 | 19 | -------------------------------------------------------------------------------- /u/parse_out_naf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | def second_col(l): 5 | cols = l.split("\t") 6 | assert len(cols) == 2 or len(cols) == 3 7 | return float(cols[1]) 8 | 9 | print "\t".join(["n", "loss", "val", "adv", "t_val"]) 10 | 11 | loss = val = adv = t_val = None 12 | n = 0 13 | for line in sys.stdin: 14 | if line.startswith("--"): 15 | if loss is not None: 16 | print "\t".join(map(str, [n, loss, val, adv, t_val])) 17 | loss = val = adv = t_val = None 18 | n += 1 19 | elif line.startswith("loss"): 20 | loss = second_col(line) 21 | elif line.startswith("val'"): 22 | t_val = second_col(line) 23 | elif line.startswith("val"): 24 | val = second_col(line) 25 | elif line.startswith("adv"): 26 | adv = second_col(line) 27 | 28 | -------------------------------------------------------------------------------- /u/parse_runs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys, re 3 | import numpy as np 4 | 5 | params = [ 6 | "--actor-hidden-layers", 7 | "--critic-hidden-layers", 8 | "--action-force", 9 | "--batch-size", 10 | "--target-update-rate", 11 | "--actor-learning-rate", 12 | "--critic-learning-rate", 13 | "--actor-gradient-clip", 14 | "--critic-gradient-clip", 15 | "--action-noise-theta", 16 | "--action-noise-sigma", 17 | ] 18 | params = [p.replace("--","").replace("-","_") for p in params] 19 | 20 | header = ["run_id"] + params + ["eval_type", "eval_val"] 21 | print "\t".join(header) 22 | 23 | for run_id in sys.stdin: 24 | run_id = run_id.strip() 25 | row = [run_id] 26 | 27 | for line in open("ckpts/%s/err" % run_id, "r").readlines(): 28 | if line.startswith("Namespace"): 29 | line = re.sub(r"^Namespace\(", "", line).strip() 30 | line = re.sub(r"\)$", "", line) 31 | kvs = {} 32 | for kv in line.split(", "): 33 | k, v = kv.split("=") 34 | kvs[k] = v.replace("'", "") 35 | for p in params: 36 | row.append("_"+kvs[p]) 37 | assert len(row) == 12, len(row) 38 | 39 | evals = [] 40 | for line in open("ckpts/%s/out" % run_id, "r"): 41 | if line.startswith("EVAL"): 42 | cols = line.strip().split(" ") 43 | assert len(cols) == 4 44 | assert cols[0] == "EVAL" 45 | evals.append(float(cols[3])) 46 | assert len(evals) > 10 47 | evals = evals[-10:] 48 | 49 | print "\t".join(map(str, row + ["min", np.min(evals)])) 50 | print "\t".join(map(str, row + ["mean", np.mean(evals)])) 51 | print "\t".join(map(str, row + ["max", np.max(evals)])) 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /u/random_params.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import random 4 | 5 | flags = { 6 | "--actor-hidden-layers": ["50", "100,50", "100,100,50"], 7 | "--critic-hidden-layers": ["50", "100,50", "100,100,50"], 8 | "--action-force": [50, 100], 9 | "--batch-size": [64, 128, 256], 10 | "--target-update-rate": [0.01, 0.001, 0.0001], 11 | "--actor-learning-rate": [0.01, 0.001], 12 | "--critic-learning-rate": [0.01, 0.001], 13 | "--actor-gradient-clip": [None, 50, 100, 200], 14 | "--critic-gradient-clip": [None, 50, 100, 200], 15 | "--action-noise-theta": [ 0.1, 0.01, 0.001 ], 16 | "--action-noise-sigma": [ 0.1, 0.2, 0.5] 17 | } 18 | 19 | while True: 20 | run_id = "run_%s" % str(random.random())[2:] 21 | cmd = "mkdir ckpts/%s;" % run_id 22 | cmd += " ./ddpg_cartpole.py" 23 | cmd += " --max-run-time=3600" 24 | cmd += " --ckpt-dir=ckpts/%s/ckpts" % run_id 25 | for flag, values in flags.iteritems(): 26 | value = random.choice(values) 27 | if value is not None: 28 | cmd += " %s=%s" % (flag, value) 29 | cmd += " >ckpts/%s/out 2>ckpts/%s/err" % (run_id, run_id) 30 | print cmd 31 | 32 | -------------------------------------------------------------------------------- /u/test_render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import pybullet as p 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | p.connect(p.DIRECT) 7 | #p.loadURDF("models/ground.urdf", 0,0,0, 0,0,0,1) 8 | #p.loadURDF("models/cart.urdf", 0,0,0.08, 0,0,0,1) 9 | p.loadURDF("models/pole.urdf", 0,0,0.35, 0,0,0,1) 10 | cameraPos = (0.75, 0.75, 0.75) 11 | targetPos = (0, 0, 0.2) 12 | cameraUp = (0, 0, 1) 13 | nearVal, farVal = 1, 20 14 | fov = 60 15 | _w, _h, rgba, _depth, _objects = p.renderImage(50, 50, cameraPos, targetPos, cameraUp, nearVal, farVal, fov) 16 | rgba_img = np.reshape(np.asarray(rgba, dtype=np.float32), (50, 50, 4)) 17 | rgb_img = rgba_img[:,:,:3] 18 | rgb_img /= 255 19 | plt.imsave("/tmp/ppp.png", rgb_img) 20 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import datetime, os, time, yaml, sys 3 | import json 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import StringIO 7 | import tensorflow as tf 8 | import time 9 | 10 | def add_opts(parser): 11 | parser.add_argument('--gradient-clip', type=float, default=5, 12 | help="do global clipping to this norm") 13 | parser.add_argument('--print-gradients', action='store_true', 14 | help="whether to verbose print all gradients and l2 norms") 15 | parser.add_argument('--optimiser', type=str, default="GradientDescent", 16 | help="tf.train.XXXOptimizer to use") 17 | parser.add_argument('--optimiser-args', type=str, default="{\"learning_rate\": 0.001}", 18 | help="json serialised args for optimiser constructor") 19 | parser.add_argument('--use-dropout', action='store_true', 20 | help="include a dropout layers after each fully connected layer") 21 | 22 | class StopWatch: 23 | def reset(self): 24 | self.start = time.time() 25 | def time(self): 26 | return time.time() - self.start 27 | # def __enter__(self): 28 | # self.start = time.time() 29 | # return self 30 | # def __exit__(self, *args): 31 | # self.time = time.time() - self.start 32 | 33 | def l2_norm(tensor): 34 | """(row wise) l2 norm of a tensor""" 35 | return tf.sqrt(tf.reduce_sum(tf.pow(tensor, 2))) 36 | 37 | def standardise(tensor): 38 | """standardise a tensor""" 39 | # is std_dev not an op in tensorflow?!? i must be taking crazy pills... 40 | mean = tf.reduce_mean(tensor) 41 | variance = tf.reduce_mean(tf.square(tensor - mean)) 42 | std_dev = tf.sqrt(variance) 43 | return (tensor - mean) / std_dev 44 | 45 | def clip_and_debug_gradients(gradients, opts): 46 | # extract just the gradients temporarily for global clipping and then rezip 47 | if opts.gradient_clip is not None: 48 | just_gradients, variables = zip(*gradients) 49 | just_gradients, _ = tf.clip_by_global_norm(just_gradients, opts.gradient_clip) 50 | gradients = zip(just_gradients, variables) 51 | # verbose debugging 52 | if opts.print_gradients: 53 | for i, (gradient, variable) in enumerate(gradients): 54 | if gradient is not None: 55 | gradients[i] = (tf.Print(gradient, [l2_norm(gradient)], 56 | "gradient %s l2_norm " % variable.name), variable) 57 | # done 58 | return gradients 59 | 60 | def collapsed_successive_ranges(values): 61 | """reduce an array, e.g. [2,3,4,5,13,14,15], to its successive ranges [2-5, 13-15]""" 62 | last, start, out = None, None, [] 63 | for value in values: 64 | if start is None: 65 | start = value 66 | elif value != last + 1: 67 | out.append("%d-%d" % (start, last)) 68 | start = value 69 | last = value 70 | out.append("%d-%d" % (start, last)) 71 | return ", ".join(out) 72 | 73 | def construct_optimiser(opts): 74 | optimiser_cstr = eval("tf.train.%sOptimizer" % opts.optimiser) 75 | args = json.loads(opts.optimiser_args) 76 | return optimiser_cstr(**args) 77 | 78 | def shape_and_product_of(t): 79 | shape_product = 1 80 | for dim in t.get_shape(): 81 | try: 82 | shape_product *= int(dim) 83 | except TypeError: 84 | # Dimension(None) 85 | pass 86 | return "%s #%s" % (t.get_shape(), shape_product) 87 | 88 | class SaverUtil(object): 89 | def __init__(self, sess, ckpt_dir="/tmp", save_freq=60): 90 | self.sess = sess 91 | var_list = [v for v in tf.all_variables() if not "replay_memory" in v.name] 92 | self.saver = tf.train.Saver(var_list=var_list, max_to_keep=1000) 93 | self.ckpt_dir = ckpt_dir 94 | if not os.path.exists(self.ckpt_dir): 95 | os.makedirs(self.ckpt_dir) 96 | assert save_freq > 0 97 | self.save_freq = save_freq 98 | self.load_latest_ckpt_or_init_if_none() 99 | 100 | def load_latest_ckpt_or_init_if_none(self): 101 | """loads latests ckpt from dir. if there are non run init variables.""" 102 | # if no latest checkpoint init vars and return 103 | ckpt_info_file = "%s/checkpoint" % self.ckpt_dir 104 | if os.path.isfile(ckpt_info_file): 105 | # load latest ckpt 106 | info = yaml.load(open(ckpt_info_file, "r")) 107 | assert 'model_checkpoint_path' in info 108 | most_recent_ckpt = "%s/%s" % (self.ckpt_dir, info['model_checkpoint_path']) 109 | sys.stderr.write("loading ckpt %s\n" % most_recent_ckpt) 110 | self.saver.restore(self.sess, most_recent_ckpt) 111 | self.next_scheduled_save_time = time.time() + self.save_freq 112 | else: 113 | # no latest ckpts, init and force a save now 114 | sys.stderr.write("no latest ckpt in %s, just initing vars...\n" % self.ckpt_dir) 115 | self.sess.run(tf.initialize_all_variables()) 116 | self.force_save() 117 | 118 | def force_save(self): 119 | """force a save now.""" 120 | dts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') 121 | new_ckpt = "%s/ckpt.%s" % (self.ckpt_dir, dts) 122 | sys.stderr.write("saving ckpt %s\n" % new_ckpt) 123 | start_time = time.time() 124 | self.saver.save(self.sess, new_ckpt) 125 | print "save_took", time.time() - start_time 126 | self.next_scheduled_save_time = time.time() + self.save_freq 127 | 128 | def save_if_required(self): 129 | """check if save is required based on time and if so, save.""" 130 | if time.time() >= self.next_scheduled_save_time: 131 | self.force_save() 132 | 133 | 134 | class OrnsteinUhlenbeckNoise(object): 135 | """generate time correlated noise for action exploration""" 136 | 137 | def __init__(self, dim, theta=0.01, sigma=0.2, max_magnitude=1.5): 138 | # dim: dimensionality of returned noise 139 | # theta: how quickly the value moves; near zero => slow, near one => fast 140 | # 0.01 gives very roughly 2/3 peaks troughs over ~1000 samples 141 | # sigma: maximum range of values; 0.2 gives approximately the range (-1.5, 1.5) 142 | # which is useful for shifting the output of a tanh which is (-1, 1) 143 | # max_magnitude: max +ve / -ve value to clip at. dft clip at 1.5 (again for 144 | # adding to output from tanh. we do this since sigma gives no guarantees 145 | # regarding min/max values. 146 | self.dim = dim 147 | self.theta = theta 148 | self.sigma = sigma 149 | self.max_magnitude = max_magnitude 150 | self.state = np.zeros(self.dim) 151 | 152 | def sample(self): 153 | self.state += self.theta * -self.state 154 | self.state += self.sigma * np.random.randn(self.dim) 155 | self.state = np.clip(self.max_magnitude, -self.max_magnitude, self.state) 156 | return np.copy(self.state) 157 | 158 | 159 | def write_img_to_png_file(img, filename): 160 | png_bytes = StringIO.StringIO() 161 | plt.imsave(png_bytes, img) 162 | print "writing", filename 163 | with open(filename, "w") as f: 164 | f.write(png_bytes.getvalue()) 165 | 166 | # some hacks for writing state to disk for visualisation 167 | def render_state_to_png(step, state, split_channels=False): 168 | height, width, num_channels, num_cameras, num_repeats = state.shape 169 | for c_idx in range(num_cameras): 170 | for r_idx in range(num_repeats): 171 | if split_channels: 172 | for channel in range(num_channels): 173 | single_channel = state[:,:,channel,c_idx,r_idx] 174 | img = np.zeros((height, width, 3)) 175 | img[:,:,channel] = single_channel 176 | write_img_to_png_file(img, "/tmp/state_s%03d_ch%s_c%s_r%s.png" % (step, channel, c_idx, r_idx)) 177 | else: 178 | img = np.empty((height, width, 3)) 179 | img[:,:,0] = state[:,:,0,c_idx,r_idx] 180 | img[:,:,1] = state[:,:,1,c_idx,r_idx] 181 | img[:,:,2] = state[:,:,2,c_idx,r_idx] 182 | write_img_to_png_file(img, "/tmp/state_s%03d_c%s_r%s.png" % (step, c_idx, r_idx)) 183 | 184 | def render_action_to_png(step, action): 185 | import Image, ImageDraw 186 | img = Image.new('RGB', (50, 50), (50, 50, 50)) 187 | canvas = ImageDraw.Draw(img) 188 | lx, ly = int(25+(action[0][0]*25)), int(25+(action[0][1]*25)) 189 | canvas.line((25,25, lx,ly), fill="black") 190 | img.save("/tmp/action_%03d.png" % step) 191 | --------------------------------------------------------------------------------