├── .gitignore ├── LICENSE ├── README.md ├── a3c.py ├── assets ├── feature-control-bptt-100.png ├── feature-control-video-10M.gif ├── feature-control-video-160M.gif ├── feature-control-video-27M.gif ├── feature-control-video-50M.gif ├── feature-control-video-90M.gif ├── intrinsic_feature.png ├── intrinsic_pixel.png ├── model.png └── pixel-control-bptt-100.png ├── envs.py ├── model.py ├── ops.py ├── test.py ├── train.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Tensorflow logs / datasets / results 2 | logs/ 3 | datasets/ 4 | results/ 5 | 6 | # Temporary files 7 | *.zip 8 | *.swp 9 | *~ 10 | .DS_Store 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | env/ 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # dotenv 94 | .env 95 | 96 | # virtualenv 97 | .venv 98 | venv/ 99 | ENV/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Youngwoon Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature Control as Intrinsic Motivation for Hierarchical RL in Tensorflow 2 | 3 | As part of the implementation series of [Joseph Lim's group at USC](http://www-bcf.usc.edu/~limjj/), our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our [group github site](https://github.com/gitlimlab) for other projects. 4 | 5 | This project is implemented by [Youngwoon Lee](https://github.com/youngwoon) and the codes have been reviewed by [Shao-Hua Sun](https://github.com/shaohua0116) before being published. 6 | 7 | ## Description 8 | This repo is a [Tensorflow](https://www.tensorflow.org/) implementation of Feature-control and Pixel-control agents on Montezuma's Revenge: [Feature Control as Intrinsic Motivation for Hierarchical Reinforcement Learning](https://arxiv.org/abs/1705.06769). 9 | 10 | This paper focuses on solving hierarchical reinforcement learning (HRL), where a task can be decomposed into a hierarchy of subproblems or subtasks such that higher-level parent-tasks invoke lower-level child tasks as if they were primitive actions. To tackle this problem, this paper presents a HRL framework which is able to overcome the issue of sparse rewards by extracting intrinsic rewards from changes in consecutive observations. The motivation of this paper is that given an intention of an agent a model learns a set of features, each of which can judge whether the corresponding intention is achieved or not. In other words, the agent learns a set of skills that change future observations in a certain direction. For example, one skill we want to learn in Montezuma's Revenge is catching a key. The success of this skill can be judged by the presence of the key. If we succeed to remove the key from the observation, we can get reward since the presence of the key is changed. These skills are called subgoals of the meta-controller. 11 | 12 |

13 | 14 |

15 | 16 | The proposed model consists of two controllers, which are meta-controller and sub-controller. As a high-level planner, the meta-controller sets a subgoal that wants to achieve for the next 100 timesteps; on the other hand, the sub-controller aims to figure out the optimal action sequences to accomplish this subgoal. The meta-controller decides the next subgoal based on the previous subgoal, the current observation, and the reward it gathered during the previous subgoal. Then the actions are computed using the previous reward, the previous action, and the current observation. To capture temporal information, policy networks of the meta-controller and the sub-controller use LSTMs. 17 | 18 | The intrinsic reward of a feature-control agent is the relative changes in two consecutive frames' k-th feature map over all feature maps on the second convolutional layer. The paper also proposes a pixel-control agent which computes intrinsic reward based on the pixel value changes in a certain region, which is also included in this repo. 19 | 20 |

21 | 22 |

23 | 24 | This method outperforms the state-of-the-art method (Feudal Network) and reaches to 2500 score on Montezuma's Revenge-v0. 25 | 26 | ## Dependencies 27 | 28 | - Ubuntu 16.04 29 | - Python 3.6 30 | - [tmux](https://tmux.github.io) 31 | - [htop](https://hisham.hm/htop) 32 | - [Tensorflow 1.3.0](https://www.tensorflow.org/) 33 | - [Universe](https://github.com/openai/universe) 34 | - [gym](https://github.com/openai/gym) 35 | - [tqdm](https://github.com/tqdm/tqdm) 36 | 37 | ## Usage 38 | 39 | - Execute the following command to train a model: 40 | 41 | ``` 42 | $ python train.py --log-dir='/tmp/feature-control' --intrinsic-type='feature' --bptt=100 43 | ``` 44 | 45 | - `intrinsic-type` can be either `'feature'` or `'pixel'` 46 | 47 | - With `--bptt` option you can choose 20 or 100 time steps as a bptt. 48 | 49 | - Once training is ended, you can test the agent will play the game 10 times and show the average reward. 50 | 51 | ``` 52 | $ python test.py --log-dir='/tmp/feature-control' --intrinsic-type='feature' --bptt=100 --visualise 53 | ``` 54 | 55 | - Check the training status on Tensorboard. The default port number is 12345 (i.e. http://localhost:12345). 56 | 57 | 58 | ## Results 59 | 60 | ### Montezuma's Revenge-v0 61 | 62 | - Feature-control agent with bptt 100 63 | ![training_curve_feature_control](assets/feature-control-bptt-100.png) 64 | 65 | - Pixel-control agent with bptt 100 66 | ![training_curve_pixel_control](assets/pixel-control-bptt-100.png) 67 | 68 | - The training speed shows slower convergence speed compared to the result reported in the paper. Be patient and keep training an agent until 20M iterations. 69 | 70 | ### Videos (Feature-control agent) 71 | 72 | | Iterations | 10M | 27M | 50M | 90M | 160M | 73 | | :--------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | 74 | | Rewards | 0 | 100 | 400 | 0 | 2500 | 75 | | Videos | ![training_curve_feature_control](assets/feature-control-video-10M.gif) | ![training_curve_feature_control](assets/feature-control-video-27M.gif) | ![training_curve_feature_control](assets/feature-control-video-50M.gif) | ![training_curve_feature_control](assets/feature-control-video-90M.gif) | ![training_curve_feature_control](assets/feature-control-video-160M.gif) | 76 | 77 | 78 | ## References 79 | 80 | - [Feature Control as Intrinsic Motivation for Hierarchical Reinforcement Learning](https://arxiv.org/abs/1705.06769) 81 | - [The A3C implementation by OpenAI's code](https://github.com/openai/universe-starter-agent) 82 | - [FeUdal Networks for Hierarchical Reinforcement Learning](https://arxiv.org/abs/1703.01161) 83 | -------------------------------------------------------------------------------- /a3c.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import scipy.signal 4 | 5 | from model import SubPolicy, MetaPolicy 6 | 7 | 8 | def discount(x, gamma): 9 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] 10 | 11 | class Batch: 12 | def __init__(self, si, a, adv, r, terminal, features): 13 | self.si = si 14 | self.a = a 15 | self.adv = adv 16 | self.r = r 17 | self.terminal = terminal 18 | self.features = features 19 | 20 | def process_rollout(rollout, gamma, lambda_=1.0): 21 | """ 22 | given a rollout, compute its returns and the advantage 23 | """ 24 | batch_si = {} 25 | for key in rollout.states[0].keys(): 26 | batch_si[key] = np.stack([s[key] for s in rollout.states]) 27 | #batch_si = np.asarray(rollout.states) 28 | batch_a = np.asarray(rollout.actions) 29 | rewards = np.asarray(rollout.rewards) 30 | vpred_t = np.asarray(rollout.values + [rollout.r]) 31 | 32 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) 33 | batch_r = discount(rewards_plus_v, gamma)[:-1] 34 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] 35 | # this formula for the advantage comes "Generalized Advantage Estimation": 36 | # https://arxiv.org/abs/1506.02438 37 | batch_adv = discount(delta_t, gamma * lambda_) 38 | 39 | features = rollout.features[0] 40 | return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, features) 41 | 42 | 43 | class PartialRollout(object): 44 | """ 45 | a piece of a complete rollout. We run our agent, and process its experience 46 | once it has processed enough steps. 47 | """ 48 | def __init__(self): 49 | self.states = [] 50 | self.actions = [] 51 | self.rewards = [] 52 | self.values = [] 53 | self.r = 0.0 54 | self.terminal = False 55 | self.features = [] 56 | 57 | def add(self, state, action, reward, value, terminal, features): 58 | self.states += [state] 59 | self.actions += [action] 60 | self.rewards += [reward] 61 | self.values += [value] 62 | self.terminal = terminal 63 | self.features += [features] 64 | 65 | def extend(self, other): 66 | assert not self.terminal 67 | self.states.extend(other.states) 68 | self.actions.extend(other.actions) 69 | self.rewards.extend(other.rewards) 70 | self.values.extend(other.values) 71 | self.r = other.r 72 | self.terminal = other.terminal 73 | self.features.extend(other.features) 74 | 75 | 76 | class A3C(object): 77 | def __init__(self, env, task, visualise, intrinsic_type, bptt): 78 | """ 79 | An implementation of the A3C algorithm that is reasonably well-tuned for the VNC environments. 80 | Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism. 81 | But overall, we'll define the model, specify its inputs, and describe how the policy gradients step 82 | should be computed. 83 | """ 84 | 85 | self.env = env 86 | self.task = task 87 | self.visualise = visualise 88 | self.intrinsic_type = intrinsic_type 89 | self.bptt = bptt 90 | self.subgoal_space = 32 if intrinsic_type == 'feature' else 37 91 | self.action_space = env.action_space.n 92 | 93 | self.summary_writer = None 94 | self.local_steps = 0 95 | 96 | self.beta = 0.75 97 | self.eta = 0.05 98 | self.num_local_steps = self.bptt 99 | self.num_local_meta_steps = 20 100 | 101 | # Testing 102 | if task is None: 103 | with tf.variable_scope("global"): 104 | self.local_sub_network = SubPolicy(env.observation_space.shape, 105 | env.action_space.n, 106 | self.subgoal_space, 107 | self.intrinsic_type) 108 | self.local_meta_network = MetaPolicy(env.observation_space.shape, 109 | self.subgoal_space, 110 | self.intrinsic_type) 111 | return 112 | 113 | # Training 114 | worker_device = "/job:worker/task:{}/cpu:0".format(task) 115 | with tf.device(tf.train.replica_device_setter(1, worker_device=worker_device)): 116 | with tf.variable_scope("global"): 117 | print(env.observation_space.shape) 118 | self.sub_network = SubPolicy(env.observation_space.shape, 119 | env.action_space.n, 120 | self.subgoal_space, 121 | self.intrinsic_type) 122 | self.meta_network = MetaPolicy(env.observation_space.shape, 123 | self.subgoal_space, 124 | self.intrinsic_type) 125 | self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), 126 | trainable=False) 127 | 128 | with tf.device(worker_device): 129 | with tf.variable_scope("local"): 130 | self.local_sub_network = pi = SubPolicy(env.observation_space.shape, 131 | env.action_space.n, 132 | self.subgoal_space, 133 | self.intrinsic_type) 134 | self.local_meta_network = meta_pi = MetaPolicy(env.observation_space.shape, 135 | self.subgoal_space, 136 | self.intrinsic_type) 137 | pi.global_step = self.global_step 138 | 139 | self.ac = tf.placeholder(tf.float32, [None, env.action_space.n], name="ac") 140 | self.adv = tf.placeholder(tf.float32, [None], name="adv") 141 | self.r = tf.placeholder(tf.float32, [None], name="r") 142 | 143 | log_prob_tf = tf.nn.log_softmax(pi.logits) 144 | prob_tf = tf.nn.softmax(pi.logits) 145 | 146 | # the "policy gradients" loss: its derivative is precisely the policy gradient 147 | # notice that self.ac is a placeholder that is provided externally. 148 | # adv will contain the advantages, as calculated in process_rollout 149 | pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, [1]) * self.adv) 150 | 151 | # loss of value function 152 | vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.vf - self.r)) 153 | entropy = - tf.reduce_sum(prob_tf * log_prob_tf) 154 | 155 | bs = tf.to_float(tf.shape(pi.x)[0]) 156 | self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01 157 | 158 | 159 | self.meta_ac = tf.placeholder(tf.float32, [None, self.subgoal_space], name="meta_ac") 160 | self.meta_adv = tf.placeholder(tf.float32, [None], name="meta_adv") 161 | self.meta_r = tf.placeholder(tf.float32, [None], name="meta_r") 162 | 163 | meta_log_prob_tf = tf.nn.log_softmax(meta_pi.logits) 164 | meta_prob_tf = tf.nn.softmax(meta_pi.logits) 165 | 166 | # the "policy gradients" loss: its derivative is precisely the policy gradient 167 | # notice that self.ac is a placeholder that is provided externally. 168 | # adv will contain the advantages, as calculated in process_rollout 169 | meta_pi_loss = - tf.reduce_sum(tf.reduce_sum(meta_log_prob_tf * self.meta_ac, [1]) * self.meta_adv) 170 | 171 | # loss of value function 172 | meta_vf_loss = 0.5 * tf.reduce_sum(tf.square(meta_pi.vf - self.meta_r)) 173 | meta_entropy = - tf.reduce_sum(meta_prob_tf * meta_log_prob_tf) 174 | 175 | meta_bs = tf.to_float(tf.shape(meta_pi.x)[0]) 176 | self.meta_loss = meta_pi_loss + 0.5 * meta_vf_loss - meta_entropy * 0.01 177 | 178 | # 20 represents the number of "local steps": the number of timesteps 179 | # we run the policy before we update the parameters. 180 | # The larger local steps is, the lower is the variance in our policy gradients estimate 181 | # on the one hand; but on the other hand, we get less frequent parameter updates, which 182 | # slows down learning. In this code, we found that making local steps be much 183 | # smaller than 20 makes the algorithm more difficult to tune and to get to work. 184 | # self.runner = RunnerThread(env, pi, meta_pi, 20, visualise) 185 | 186 | grads = tf.gradients(self.loss, pi.var_list) 187 | meta_grads = tf.gradients(self.meta_loss, meta_pi.var_list) 188 | 189 | summary = [ 190 | tf.summary.scalar("model/policy_loss", pi_loss / bs), 191 | tf.summary.scalar("model/value_loss", vf_loss / bs), 192 | tf.summary.scalar("model/entropy", entropy / bs), 193 | tf.summary.image("model/state", pi.x), 194 | tf.summary.scalar("model/grad_global_norm", tf.global_norm(grads)), 195 | tf.summary.scalar("model/var_global_norm", tf.global_norm(pi.var_list)) 196 | ] 197 | 198 | meta_summary = [ 199 | tf.summary.scalar("meta_model/policy_loss", meta_pi_loss / meta_bs), 200 | tf.summary.scalar("meta_model/value_loss", meta_vf_loss / meta_bs), 201 | tf.summary.scalar("meta_model/entropy", meta_entropy / meta_bs), 202 | tf.summary.scalar("meta_model/grad_global_norm", tf.global_norm(meta_grads)), 203 | tf.summary.scalar("meta_model/var_global_norm", tf.global_norm(meta_pi.var_list)) 204 | ] 205 | self.summary_op = tf.summary.merge(summary) 206 | self.meta_summary_op = tf.summary.merge(meta_summary) 207 | 208 | grads, _ = tf.clip_by_global_norm(grads, 40.0) 209 | meta_grads, _ = tf.clip_by_global_norm(meta_grads, 40.0) 210 | self.grads = grads 211 | 212 | # copy weights from the parameter server to the local model 213 | self.sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(pi.var_list, self.sub_network.var_list)]) 214 | self.meta_sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(meta_pi.var_list, self.meta_network.var_list)]) 215 | 216 | grads_and_vars = list(zip(grads, self.sub_network.var_list)) 217 | meta_grads_and_vars = list(zip(meta_grads, self.meta_network.var_list)) 218 | 219 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0]) 220 | 221 | # each worker has a different set of adam optimizer parameters 222 | opt = tf.train.AdamOptimizer(1e-4) 223 | self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step) 224 | meta_opt = tf.train.AdamOptimizer(1e-4) 225 | self.meta_train_op = meta_opt.apply_gradients(meta_grads_and_vars) 226 | 227 | def start(self, sess, summary_writer): 228 | self.summary_writer = summary_writer 229 | 230 | def process(self, sess): 231 | """ 232 | run one episode and process experience to train both meta and sub networks 233 | """ 234 | env = self.env 235 | 236 | sub_policy = self.local_sub_network 237 | meta_policy = self.local_meta_network 238 | 239 | self.last_state = env.reset() 240 | self.last_meta_state = env.reset() 241 | self.last_features = sub_policy.get_initial_features() 242 | self.last_meta_features = meta_policy.get_initial_features() 243 | 244 | self.last_action = np.zeros(self.action_space) 245 | self.last_subgoal = np.zeros(self.subgoal_space) 246 | self.last_reward = np.zeros(1) 247 | self.last_meta_reward = np.zeros(1) 248 | 249 | self.length = 0 250 | self.rewards = 0 251 | self.extrinsic_rewards = 0 252 | self.intrinsic_rewards = 0 253 | 254 | terminal_end = False 255 | while not terminal_end: 256 | terminal_end = self._meta_process(sess) 257 | 258 | def _meta_process(self, sess): 259 | sess.run(self.meta_sync) # copy weights from shared to local 260 | meta_rollout = PartialRollout() 261 | meta_policy = self.local_meta_network 262 | 263 | terminal_end = False 264 | for _ in range(self.num_local_meta_steps): 265 | fetched = meta_policy.act(self.last_meta_state, self.last_subgoal, 266 | self.last_meta_reward, *self.last_meta_features) 267 | subgoal, meta_value, meta_features = fetched[0], fetched[1], fetched[2:] 268 | 269 | assert self.bptt in [20, 100], 'bptt (%d) should be 20 or 100' % self.bptt 270 | 271 | if self.bptt == 20: 272 | meta_reward = 0 273 | for _ in range(5): 274 | state, reward, terminal_end = self._sub_process(sess, subgoal) 275 | meta_reward += reward 276 | if terminal_end: 277 | break 278 | elif self.bptt == 100: 279 | state, meta_reward, terminal_end = self._sub_process(sess, subgoal) 280 | 281 | 282 | si = { 283 | 'x': self.last_meta_state, 284 | 'subgoal_prev': self.last_subgoal, 285 | 'reward_prev': self.last_meta_reward 286 | } 287 | meta_rollout.add(si, subgoal, meta_reward, meta_value, 288 | terminal_end, self.last_meta_features) 289 | 290 | self.last_meta_state = state 291 | self.last_meta_features = meta_features 292 | self.last_meta_reward = [meta_reward] 293 | self.last_subgoal = subgoal 294 | 295 | if terminal_end: 296 | break 297 | 298 | if not terminal_end: 299 | meta_rollout.r = meta_policy.value(self.last_state, 300 | self.last_subgoal, 301 | self.last_meta_reward, 302 | *self.last_meta_features) 303 | 304 | # meta rollout 305 | batch = process_rollout(meta_rollout, gamma=0.99, lambda_=1.0) 306 | fetches = [self.meta_summary_op, self.meta_train_op, self.global_step] 307 | 308 | feed_dict = { 309 | meta_policy.x: batch.si['x'], 310 | meta_policy.subgoal_prev: batch.si['subgoal_prev'], 311 | meta_policy.reward_prev: batch.si['reward_prev'], 312 | self.meta_ac: batch.a, 313 | self.meta_adv: batch.adv, 314 | self.meta_r: batch.r, 315 | meta_policy.state_in[0]: batch.features[0], 316 | meta_policy.state_in[1]: batch.features[1], 317 | } 318 | fetched = sess.run(fetches, feed_dict=feed_dict) 319 | self.summary_writer.add_summary(fetched[0], fetched[-1]) 320 | 321 | return terminal_end 322 | 323 | def _sub_process(self, sess, subgoal): 324 | sess.run(self.sync) # copy weights from shared to local 325 | sub_rollout = PartialRollout() 326 | sub_policy = self.local_sub_network 327 | meta_reward = 0 328 | 329 | terminal_end = False 330 | for _ in range(self.num_local_steps): 331 | fetched = sub_policy.act(self.last_state, self.last_action, 332 | self.last_reward, subgoal, *self.last_features) 333 | action, value, features = fetched[0], fetched[1], fetched[2:] 334 | 335 | # argmax to convert from one-hot 336 | state, episode_reward, terminal, info = self.env.step(action.argmax()) 337 | # reward clipping to the range of [-1, 1] 338 | extrinsic_reward = max(min(episode_reward, 1), -1) 339 | 340 | def get_mask(shape, subgoal): 341 | mask = np.zeros(shape) 342 | if subgoal < 36: 343 | y = subgoal // 6 344 | x = subgoal % 6 345 | mask[y*14:(y+1)*14, x*14:(x+1)*14] = 1 346 | mask = np.stack([mask] * 3) 347 | return mask 348 | 349 | def compute_intrinsic(state, last_state, subgoal): 350 | f = sub_policy.feature(state) 351 | last_f = sub_policy.feature(last_state) 352 | if self.intrinsic_type == 'feature': 353 | diff = np.abs(f - last_f) 354 | return self.eta * diff[subgoal] / (np.sum(diff) + 1e-10) 355 | else: 356 | diff = f - last_f 357 | diff = diff * diff 358 | mask = get_mask(diff.shape, subgoal) 359 | return self.eta * np.sum(mask * diff) / (np.sum(diff) + 1e-10) 360 | 361 | intrinsic_reward = compute_intrinsic(state, self.last_state, subgoal.argmax()) 362 | reward = self.beta * extrinsic_reward + (1 - self.beta) * intrinsic_reward 363 | 364 | meta_reward += extrinsic_reward 365 | # meta_reward += reward 366 | self.intrinsic_rewards += intrinsic_reward 367 | self.extrinsic_rewards += extrinsic_reward 368 | 369 | # collect the experience 370 | si = { 371 | 'x': self.last_state, 372 | 'action_prev': self.last_action, 373 | 'reward_prev': self.last_reward, 374 | 'subgoal': subgoal 375 | } 376 | sub_rollout.add(si, action, reward, value, terminal, self.last_features) 377 | 378 | self.length += 1 379 | self.rewards += episode_reward 380 | 381 | self.last_state = state 382 | self.last_action = action 383 | self.last_features = features 384 | self.last_reward = [reward] 385 | 386 | timestep_limit = self.env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 387 | if terminal or (timestep_limit and self.length >= timestep_limit): 388 | terminal_end = True 389 | 390 | summary = tf.Summary() 391 | summary.value.add(tag='global/episode_reward', 392 | simple_value=self.rewards) 393 | summary.value.add(tag='global/extrinsic_reward', 394 | simple_value=self.extrinsic_rewards) 395 | summary.value.add(tag='global/intrinsic_reward', 396 | simple_value=self.intrinsic_rewards) 397 | summary.value.add(tag='global/episode_length', 398 | simple_value=self.length) 399 | self.summary_writer.add_summary(summary, sub_policy.global_step.eval()) 400 | self.summary_writer.flush() 401 | 402 | print("Episode finished. Ep rewards: %.5f (In: %.5f, Ex: %.5f). Length: %d" % 403 | (self.rewards, self.intrinsic_rewards, self.extrinsic_rewards, self.length)) 404 | break 405 | 406 | if not terminal_end: 407 | sub_rollout.r = sub_policy.value(self.last_state, 408 | self.last_action, 409 | self.last_reward, 410 | subgoal, *self.last_features) 411 | 412 | self.local_steps += 1 413 | should_compute_summary = self.task == 0 and self.local_steps % 11 == 0 414 | 415 | # sub rollout 416 | batch = process_rollout(sub_rollout, gamma=0.99, lambda_=1.0) 417 | fetches = [self.train_op, self.global_step] 418 | if should_compute_summary: 419 | fetches = [self.summary_op] + fetches 420 | feed_dict = { 421 | sub_policy.x: batch.si['x'], 422 | sub_policy.action_prev: batch.si['action_prev'], 423 | sub_policy.reward_prev: batch.si['reward_prev'], 424 | sub_policy.subgoal: batch.si['subgoal'], 425 | self.ac: batch.a, 426 | self.adv: batch.adv, 427 | self.r: batch.r, 428 | sub_policy.state_in[0]: batch.features[0], 429 | sub_policy.state_in[1]: batch.features[1], 430 | } 431 | 432 | fetched = sess.run(fetches, feed_dict=feed_dict) 433 | if should_compute_summary: 434 | self.summary_writer.add_summary(fetched[0], fetched[-1]) 435 | 436 | return self.last_state, meta_reward, terminal_end 437 | 438 | 439 | def evaluate(self, sess): 440 | """ 441 | run one episode and process experience to train both meta and sub networks 442 | """ 443 | env = self.env 444 | 445 | sub_policy = self.local_sub_network 446 | meta_policy = self.local_meta_network 447 | 448 | self.last_state = env.reset() 449 | self.last_meta_state = env.reset() 450 | self.last_features = sub_policy.get_initial_features() 451 | self.last_meta_features = meta_policy.get_initial_features() 452 | 453 | self.last_action = np.zeros(self.action_space) 454 | self.last_subgoal = np.zeros(self.subgoal_space) 455 | self.last_reward = np.zeros(1) 456 | self.last_meta_reward = np.zeros(1) 457 | 458 | self.length = 0 459 | self.rewards = 0 460 | self.extrinsic_rewards = 0 461 | self.intrinsic_rewards = 0 462 | 463 | terminal_end = False 464 | frames = [self.last_state] 465 | while not terminal_end: 466 | frames_, terminal_end = self._meta_evaluate(sess) 467 | frames.extend(frames_) 468 | frames = np.stack(frames) 469 | return frames, self.rewards, self.length 470 | 471 | def _meta_evaluate(self, sess): 472 | meta_policy = self.local_meta_network 473 | 474 | terminal_end = False 475 | frames = [] 476 | for _ in range(self.num_local_meta_steps): 477 | fetched = meta_policy.act(self.last_meta_state, self.last_subgoal, 478 | self.last_meta_reward, *self.last_meta_features) 479 | subgoal, meta_features = fetched[0], fetched[2:] 480 | 481 | if self.bptt == 20: 482 | meta_reward = 0 483 | for _ in range(5): 484 | frames_, state, reward, terminal_end = self._sub_evaluate(sess, subgoal) 485 | frames.extend(frames_) 486 | meta_reward += reward 487 | if terminal_end: 488 | break 489 | elif self.bptt == 100: 490 | frames_, state, meta_reward, terminal_end = self._sub_evaluate(sess, subgoal) 491 | frames.extend(frames_) 492 | 493 | self.last_meta_state = state 494 | self.last_subgoal = subgoal 495 | self.last_meta_reward = [meta_reward] 496 | self.last_meta_features = meta_features 497 | if terminal_end: 498 | break 499 | return frames, terminal_end 500 | 501 | def _sub_evaluate(self, sess, subgoal): 502 | sub_policy = self.local_sub_network 503 | meta_reward = 0 504 | frames = [] 505 | 506 | for _ in range(self.num_local_steps): 507 | fetched = sub_policy.act(self.last_state, self.last_action, 508 | self.last_reward, subgoal, *self.last_features) 509 | action, features = fetched[0], fetched[2:] 510 | 511 | # argmax to convert from one-hot 512 | state, episode_reward, terminal, info = self.env.step(action.argmax()) 513 | # reward clipping to the range of [-1, 1] 514 | extrinsic_reward = max(min(episode_reward, 1), -1) 515 | frames.append(state) 516 | 517 | if self.visualise: 518 | self.env.render() 519 | 520 | def get_mask(shape, subgoal): 521 | mask = np.zeros(shape) 522 | if subgoal < 36: 523 | y = subgoal // 6 524 | x = subgoal % 6 525 | mask[y*14:(y+1)*14, x*14:(x+1)*14] = 1 526 | mask = np.stack([mask] * 3) 527 | return mask 528 | 529 | def compute_intrinsic(state, last_state, subgoal): 530 | if self.intrinsic_type == 'feature': 531 | f = sub_policy.feature(state) 532 | last_f = sub_policy.feature(last_state) 533 | diff = np.abs(f - last_f) 534 | return self.eta * diff[subgoal] / (np.sum(diff) + 1e-10) 535 | else: 536 | diff = state - last_state 537 | diff = diff * diff 538 | mask = get_mask(diff.shape, subgoal) 539 | return self.eta * np.sum(mask * diff) / (np.sum(diff) + 1e-10) 540 | 541 | intrinsic_reward = compute_intrinsic(state, self.last_state, subgoal.argmax()) 542 | reward = self.beta * extrinsic_reward + (1 - self.beta) * intrinsic_reward 543 | 544 | meta_reward += extrinsic_reward 545 | # meta_reward += reward 546 | self.intrinsic_rewards += intrinsic_reward 547 | self.extrinsic_rewards += extrinsic_reward 548 | 549 | self.length += 1 550 | self.rewards += episode_reward 551 | 552 | self.last_state = state 553 | self.last_action = action 554 | self.last_features = features 555 | self.last_reward = [reward] 556 | 557 | timestep_limit = self.env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 558 | if terminal or (timestep_limit and self.length >= timestep_limit): 559 | print("Episode finished. Ep rewards: %.5f (In: %.5f, Ex: %.5f). Length: %d" % 560 | (self.rewards, self.intrinsic_rewards, self.extrinsic_rewards, self.length)) 561 | return frames, state, meta_reward, True 562 | return frames, state, meta_reward, False 563 | -------------------------------------------------------------------------------- /assets/feature-control-bptt-100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-bptt-100.png -------------------------------------------------------------------------------- /assets/feature-control-video-10M.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-10M.gif -------------------------------------------------------------------------------- /assets/feature-control-video-160M.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-160M.gif -------------------------------------------------------------------------------- /assets/feature-control-video-27M.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-27M.gif -------------------------------------------------------------------------------- /assets/feature-control-video-50M.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-50M.gif -------------------------------------------------------------------------------- /assets/feature-control-video-90M.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-90M.gif -------------------------------------------------------------------------------- /assets/intrinsic_feature.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/intrinsic_feature.png -------------------------------------------------------------------------------- /assets/intrinsic_pixel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/intrinsic_pixel.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/model.png -------------------------------------------------------------------------------- /assets/pixel-control-bptt-100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/pixel-control-bptt-100.png -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | 7 | 8 | def create_env(env_id, **kwargs): 9 | return create_atari_env(env_id) 10 | 11 | 12 | def create_atari_env(env_id): 13 | env = gym.make(env_id) 14 | return env 15 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.rnn as rnn 4 | from ops import flatten, conv2d, linear 5 | 6 | 7 | def normalized_columns_initializer(std=1.0): 8 | def _initializer(shape, dtype=None, partition_info=None): 9 | out = np.random.randn(*shape).astype(np.float32) 10 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) 11 | return tf.constant(out) 12 | return _initializer 13 | 14 | 15 | def categorical_sample(logits, d): 16 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1]) 17 | return tf.one_hot(value, d) 18 | 19 | 20 | class SubPolicy(object): 21 | def __init__(self, ob_space, ac_space, subgoal_space, intrinsic_type): 22 | self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space), name='x') 23 | self.action_prev = action_prev = tf.placeholder(tf.float32, [None, ac_space], name='action_prev') 24 | self.reward_prev = reward_prev = tf.placeholder(tf.float32, [None, 1], name='reward_prev') 25 | self.subgoal = subgoal = tf.placeholder(tf.float32, [None, subgoal_space], name='subgoal') 26 | self.intrinsic_type = intrinsic_type 27 | 28 | with tf.variable_scope('encoder'): 29 | x = tf.image.resize_images(x, [84, 84]) 30 | x = x / 255.0 31 | self.p = x 32 | x = tf.nn.relu(conv2d(x, 16, "l1", [8, 8], [4, 4])) 33 | x = tf.nn.relu(conv2d(x, 32, "l2", [4, 4], [2, 2])) 34 | self.f = tf.reduce_mean(x, axis=[1, 2]) 35 | x = flatten(x) 36 | 37 | with tf.variable_scope('sub_policy'): 38 | x = tf.nn.relu(linear(x, 256, "fc", 39 | normalized_columns_initializer(0.01))) 40 | x = tf.concat([x, action_prev], axis=1) 41 | x = tf.concat([x, reward_prev], axis=1) 42 | x = tf.concat([x, subgoal], axis=1) 43 | 44 | # introduce a "fake" batch dimension of 1 after flatten 45 | # so that we can do LSTM over time dim 46 | x = tf.expand_dims(x, [0]) 47 | 48 | size = 256 49 | lstm = rnn.BasicLSTMCell(size, state_is_tuple=True) 50 | self.state_size = lstm.state_size 51 | step_size = tf.shape(self.x)[:1] 52 | 53 | c_init = np.zeros((1, lstm.state_size.c), np.float32) 54 | h_init = np.zeros((1, lstm.state_size.h), np.float32) 55 | self.state_init = [c_init, h_init] 56 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) 57 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) 58 | self.state_in = [c_in, h_in] 59 | 60 | state_in = rnn.LSTMStateTuple(c_in, h_in) 61 | 62 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn( 63 | lstm, x, initial_state=state_in, sequence_length=step_size, 64 | time_major=False 65 | ) 66 | lstm_c, lstm_h = lstm_state 67 | lstm_outputs = tf.reshape(lstm_outputs, [-1, size]) 68 | self.logits = linear(lstm_outputs, ac_space, "action", 69 | normalized_columns_initializer(0.01)) 70 | self.vf = tf.reshape(linear(lstm_outputs, 1, "value", 71 | normalized_columns_initializer(1.0)), [-1]) 72 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] 73 | self.sample = categorical_sample(self.logits, ac_space)[0, :] 74 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 75 | tf.get_variable_scope().name) 76 | 77 | def get_initial_features(self): 78 | return self.state_init 79 | 80 | def act(self, ob, action_prev, reward_prev, subgoal, c, h): 81 | sess = tf.get_default_session() 82 | return sess.run([self.sample, self.vf] + self.state_out, 83 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h, 84 | self.action_prev: [action_prev], 85 | self.reward_prev: [reward_prev], 86 | self.subgoal: [subgoal]}) 87 | 88 | def value(self, ob, action_prev, reward_prev, subgoal, c, h): 89 | sess = tf.get_default_session() 90 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, 91 | self.state_in[1]: h, 92 | self.action_prev: [action_prev], 93 | self.reward_prev: [reward_prev], 94 | self.subgoal: [subgoal]})[0] 95 | 96 | def feature(self, state): 97 | sess = tf.get_default_session() 98 | if self.intrinsic_type == 'feature': 99 | return sess.run(self.f, {self.x: [state]})[0, :] 100 | else: 101 | return sess.run(self.p, {self.x: [state]})[0, :] 102 | 103 | 104 | class MetaPolicy(object): 105 | def __init__(self, ob_space, subgoal_space, intrinsic_type): 106 | self.x = x = \ 107 | tf.placeholder(tf.float32, [None] + list(ob_space), name='x_meta') 108 | self.subgoal_prev = subgoal_prev = \ 109 | tf.placeholder(tf.float32, [None, subgoal_space], name='subgoal_prev') 110 | self.reward_prev = reward_prev = \ 111 | tf.placeholder(tf.float32, [None, 1], name='reward_prev_meta') 112 | self.intrinsic_type = intrinsic_type 113 | 114 | with tf.variable_scope('encoder', reuse=True): 115 | x = tf.image.resize_images(x, [84, 84]) 116 | x = x / 255.0 117 | x = tf.nn.relu(conv2d(x, 16, "l1", [8, 8], [4, 4])) 118 | x = tf.nn.relu(conv2d(x, 32, "l2", [4, 4], [2, 2])) 119 | x = flatten(x) 120 | 121 | with tf.variable_scope('meta_policy'): 122 | x = tf.nn.relu(linear(x, 256, "fc", 123 | normalized_columns_initializer(0.01))) 124 | x = tf.concat([x, subgoal_prev], axis=1) 125 | x = tf.concat([x, reward_prev], axis=1) 126 | 127 | # introduce a "fake" batch dimension of 1 after flatten 128 | # so that we can do LSTM over time dim 129 | x = tf.expand_dims(x, [0]) 130 | 131 | size = 256 132 | lstm = rnn.BasicLSTMCell(size, state_is_tuple=True) 133 | self.state_size = lstm.state_size 134 | step_size = tf.shape(self.x)[:1] 135 | 136 | c_init = np.zeros((1, lstm.state_size.c), np.float32) 137 | h_init = np.zeros((1, lstm.state_size.h), np.float32) 138 | self.state_init = [c_init, h_init] 139 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) 140 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) 141 | self.state_in = [c_in, h_in] 142 | 143 | state_in = rnn.LSTMStateTuple(c_in, h_in) 144 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn( 145 | lstm, x, initial_state=state_in, sequence_length=step_size, 146 | time_major=False) 147 | lstm_c, lstm_h = lstm_state 148 | lstm_outputs = tf.reshape(lstm_outputs, [-1, size]) 149 | self.logits = linear(lstm_outputs, subgoal_space, "action", 150 | normalized_columns_initializer(0.01)) 151 | self.vf = tf.reshape(linear(lstm_outputs, 1, "value", 152 | normalized_columns_initializer(1.0)), [-1]) 153 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] 154 | self.sample = categorical_sample(self.logits, subgoal_space)[0, :] 155 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 156 | 157 | def get_initial_features(self): 158 | return self.state_init 159 | 160 | def act(self, ob, subgoal_prev, reward_prev, c, h): 161 | sess = tf.get_default_session() 162 | return sess.run([self.sample, self.vf] + self.state_out, 163 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h, 164 | self.subgoal_prev: [subgoal_prev], 165 | self.reward_prev: [reward_prev]}) 166 | 167 | def value(self, ob, subgoal_prev, reward_prev, c, h): 168 | sess = tf.get_default_session() 169 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, 170 | self.state_in[1]: h, 171 | self.subgoal_prev: [subgoal_prev], 172 | self.reward_prev: [reward_prev]})[0] 173 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def flatten(x): 6 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) 7 | 8 | 9 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), 10 | pad="SAME", dtype=tf.float32, collections=None): 11 | with tf.variable_scope(name): 12 | stride_shape = [1, stride[0], stride[1], 1] 13 | filter_shape = [filter_size[0], filter_size[1], 14 | int(x.get_shape()[3]), num_filters] 15 | 16 | # there are "num input feature maps * filter height * filter width" 17 | # inputs to each hidden unit 18 | fan_in = np.prod(filter_shape[:3]) 19 | # each unit in the lower layer receives a gradient from: 20 | # "num output feature maps * filter height * filter width" / 21 | # pooling size 22 | fan_out = np.prod(filter_shape[:2]) * num_filters 23 | # initialize weights with random weights 24 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 25 | 26 | w = tf.get_variable("W", filter_shape, dtype, 27 | tf.random_uniform_initializer(-w_bound, w_bound), 28 | collections=collections) 29 | b = tf.get_variable("b", [1, 1, 1, num_filters], 30 | initializer=tf.constant_initializer(0.0), 31 | collections=collections) 32 | return tf.nn.conv2d(x, w, stride_shape, pad) + b 33 | 34 | 35 | def linear(x, size, name, initializer=None, bias_init=0): 36 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer) 37 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init)) 38 | return tf.matmul(x, w) + b 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import logging 4 | import sys 5 | import signal 6 | import os 7 | 8 | import tensorflow as tf 9 | import imageio 10 | 11 | from a3c import A3C 12 | from envs import create_env 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | 18 | # parsing cmd arguments 19 | parser = argparse.ArgumentParser(description="Test commands") 20 | parser.add_argument('-e', '--env-id', type=str, default="MontezumaRevenge-v0", 21 | help="Environment id") 22 | parser.add_argument('-l', '--log-dir', type=str, default="/tmp/montezuma", 23 | help="Log directory path") 24 | 25 | # Add visualisation argument 26 | parser.add_argument('--visualise', action='store_true', 27 | help="Visualise the gym environment by running env.render() between each timestep") 28 | 29 | # Add model parameters 30 | parser.add_argument('--intrinsic-type', type=str, default='feature', 31 | choices=['feature', 'pixel'], help="Feature-control or Pixel-control") 32 | parser.add_argument('--bptt', type=int, default=100, 33 | help="BPTT") 34 | 35 | 36 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless. 37 | class FastSaver(tf.train.Saver): 38 | def save(self, sess, save_path, global_step=None, latest_filename=None, 39 | meta_graph_suffix="meta", write_meta_graph=True): 40 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 41 | meta_graph_suffix, False) 42 | 43 | 44 | def run(args): 45 | env = create_env(args.env_id) 46 | trainer = A3C(env, None, args.visualise, args.intrinsic_type, args.bptt) 47 | 48 | # Variable names that start with "local" are not saved in checkpoints. 49 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")] 50 | init_op = tf.variables_initializer(variables_to_save) 51 | init_all_op = tf.global_variables_initializer() 52 | saver = FastSaver(variables_to_save) 53 | 54 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 55 | logger.info('Trainable vars:') 56 | for v in var_list: 57 | logger.info(' %s %s', v.name, v.get_shape()) 58 | 59 | def init_fn(ses): 60 | logger.info("Initializing all parameters.") 61 | ses.run(init_all_op) 62 | 63 | logdir = os.path.join(args.log_dir, 'train') 64 | summary_writer = tf.summary.FileWriter(logdir) 65 | logger.info("Events directory: %s", logdir) 66 | 67 | sv = tf.train.Supervisor(is_chief=True, 68 | logdir=logdir, 69 | saver=saver, 70 | summary_op=None, 71 | init_op=init_op, 72 | init_fn=init_fn, 73 | summary_writer=summary_writer, 74 | ready_op=tf.report_uninitialized_variables(variables_to_save), 75 | global_step=None, 76 | save_model_secs=0, 77 | save_summaries_secs=0) 78 | 79 | video_dir = os.path.join(args.log_dir, 'test_videos_' + args.intrinsic_type) 80 | if not os.path.exists(video_dir): 81 | os.makedirs(video_dir) 82 | video_filename = video_dir + "/%s_%02d_%d.gif" 83 | print("Video saved at %s" % video_dir) 84 | 85 | with sv.managed_session() as sess, sess.as_default(): 86 | trainer.start(sess, summary_writer) 87 | rewards = [] 88 | lengths = [] 89 | for i in range(10): 90 | frames, reward, length = trainer.evaluate(sess) 91 | rewards.append(reward) 92 | lengths.append(length) 93 | imageio.mimsave(video_filename % (args.env_id, i, reward), frames, fps=30) 94 | 95 | print('Evaluation: avg. reward %.2f avg.length %.2f' % 96 | (sum(rewards) / 10.0, sum(lengths) / 10.0)) 97 | 98 | # Ask for all the services to stop. 99 | sv.stop() 100 | 101 | 102 | def main(_): 103 | args, unparsed = parser.parse_known_args() 104 | 105 | def shutdown(signal, frame): 106 | logger.warn('Received signal %s: exiting', signal) 107 | sys.exit(128+signal) 108 | signal.signal(signal.SIGHUP, shutdown) 109 | signal.signal(signal.SIGINT, shutdown) 110 | signal.signal(signal.SIGTERM, shutdown) 111 | 112 | run(args) 113 | 114 | if __name__ == "__main__": 115 | tf.app.run() 116 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from six.moves import shlex_quote 6 | 7 | parser = argparse.ArgumentParser(description="Run commands") 8 | parser.add_argument('-w', '--num-workers', default=8, type=int, 9 | help="Number of workers") 10 | parser.add_argument('-e', '--env-id', type=str, default="MontezumaRevenge-v0", 11 | help="Environment id") 12 | parser.add_argument('-l', '--log-dir', type=str, default="/tmp/montezuma", 13 | help="Log directory path") 14 | parser.add_argument('-n', '--dry-run', action='store_true', 15 | help="Print out commands rather than executing them") 16 | parser.add_argument('-m', '--mode', type=str, default='tmux', 17 | help="tmux: run workers in a tmux session. nohup: " 18 | "run workers with nohup. child: run workers as child processes") 19 | 20 | # Add visualise tag 21 | parser.add_argument('--visualise', action='store_true', 22 | help="Visualise the gym environment by running env.render() between each timestep") 23 | 24 | # Add model parameters 25 | parser.add_argument('--intrinsic-type', type=str, default='feature', 26 | choices=['feature', 'pixel'], help="feature or pixel") 27 | parser.add_argument('--bptt', type=int, default=100, 28 | help="BPTT") 29 | 30 | 31 | def new_cmd(session, name, cmd, mode, logdir, shell): 32 | if isinstance(cmd, (list, tuple)): 33 | cmd = " ".join(shlex_quote(str(v)) for v in cmd) 34 | if mode == 'tmux': 35 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd)) 36 | elif mode == 'child': 37 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir) 38 | elif mode == 'nohup': 39 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir) 40 | 41 | 42 | def create_commands(session, num_workers, env_id, logdir, 43 | intrinsic_type, bptt, shell='bash', mode='tmux', visualise=False): 44 | # Launch the TF workers and for launching tensorboard 45 | base_cmd = [ 46 | 'CUDA_VISIBLE_DEVICES=', 47 | sys.executable, 'worker.py', 48 | '--log-dir', logdir, 49 | '--env-id', env_id, 50 | '--num-workers', str(num_workers), 51 | '--intrinsic-type', intrinsic_type, 52 | '--bptt', str(bptt) 53 | ] 54 | 55 | if visualise: 56 | base_cmd += ['--visualise'] 57 | 58 | cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)] 59 | for i in range(num_workers): 60 | cmds_map += [new_cmd(session, "w-%d" % i, base_cmd + 61 | ["--job-name", "worker", "--task", str(i)], mode, logdir, shell)] 62 | 63 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir, 64 | "--port", "12345"], mode, logdir, shell)] 65 | if mode == 'tmux': 66 | cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)] 67 | 68 | windows = [v[0] for v in cmds_map] 69 | 70 | notes = [] 71 | cmds = [ 72 | "mkdir -p {}".format(logdir), 73 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir), 74 | ] 75 | if mode == 'nohup' or mode == 'child': 76 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)] 77 | notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)] 78 | if mode == 'tmux': 79 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)] 80 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)] 81 | else: 82 | notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)] 83 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"] 84 | 85 | if mode == 'tmux': 86 | cmds += [ 87 | # kill any process using tensorboard's port 88 | "kill $( lsof -i:12345 -t ) > /dev/null 2>&1", 89 | # kill any processes using ps / worker ports 90 | "kill $( lsof -i:12222-{} -t ) > /dev/null 2>&1".format(num_workers+12222), 91 | "tmux kill-session -t {}".format(session), 92 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell) 93 | ] 94 | for w in windows[1:]: 95 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)] 96 | cmds += ["sleep 1"] 97 | for window, cmd in cmds_map: 98 | cmds += [cmd] 99 | 100 | return cmds, notes 101 | 102 | 103 | def run(): 104 | args = parser.parse_args() 105 | cmds, notes = create_commands("a3c", args.num_workers, 106 | args.env_id, args.log_dir, args.intrinsic_type, 107 | args.bptt, mode=args.mode, visualise=args.visualise) 108 | if args.dry_run: 109 | print("Dry-run mode due to -n flag, otherwise the following commands would be executed:") 110 | else: 111 | print("Executing the following commands:") 112 | print("\n".join(cmds)) 113 | print("") 114 | if not args.dry_run: 115 | if args.mode == "tmux": 116 | os.environ["TMUX"] = "" 117 | os.system("\n".join(cmds)) 118 | print('\n'.join(notes)) 119 | 120 | if __name__ == "__main__": 121 | run() 122 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import logging 4 | import sys 5 | import signal 6 | import time 7 | import os 8 | 9 | import imageio 10 | from tqdm import tqdm 11 | import tensorflow as tf 12 | 13 | from a3c import A3C 14 | from envs import create_env 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless. 21 | class FastSaver(tf.train.Saver): 22 | def save(self, sess, save_path, global_step=None, latest_filename=None, 23 | meta_graph_suffix="meta", write_meta_graph=True): 24 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 25 | meta_graph_suffix, False) 26 | 27 | 28 | def run(args, server): 29 | env = create_env(args.env_id) 30 | trainer = A3C(env, args.task, args.visualise, args.intrinsic_type, args.bptt) 31 | 32 | # Variable names that start with "local" are not saved in checkpoints. 33 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")] 34 | init_op = tf.variables_initializer(variables_to_save) 35 | init_all_op = tf.global_variables_initializer() 36 | saver = FastSaver(variables_to_save) 37 | 38 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 39 | logger.info('Trainable vars:') 40 | for v in var_list: 41 | logger.info(' %s %s', v.name, v.get_shape()) 42 | 43 | def init_fn(ses): 44 | logger.info("Initializing all parameters.") 45 | ses.run(init_all_op) 46 | 47 | config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)]) 48 | logdir = os.path.join(args.log_dir, 'train') 49 | 50 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task) 51 | 52 | logger.info("Events directory: %s_%s", logdir, args.task) 53 | sv = tf.train.Supervisor(is_chief=(args.task == 0), 54 | logdir=logdir, 55 | saver=saver, 56 | summary_op=None, 57 | init_op=init_op, 58 | init_fn=init_fn, 59 | summary_writer=summary_writer, 60 | ready_op=tf.report_uninitialized_variables(variables_to_save), 61 | global_step=trainer.global_step, 62 | save_model_secs=30, 63 | save_summaries_secs=30) 64 | 65 | video_dir = os.path.join(args.log_dir, 'train_videos_' + args.intrinsic_type) 66 | if not os.path.exists(video_dir): 67 | os.makedirs(video_dir) 68 | video_filename = video_dir + "/%s_%010d_%d.gif" 69 | print("Video saved at %s" % video_dir) 70 | 71 | num_global_steps = 300000000 72 | num_record_steps = 1000000 73 | last_record_step = 0 74 | 75 | logger.info( 76 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " + 77 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.") 78 | with sv.managed_session(server.target, config=config) as sess, sess.as_default(): 79 | sess.run(trainer.meta_sync) 80 | sess.run(trainer.sync) 81 | trainer.start(sess, summary_writer) 82 | global_step = sess.run(trainer.global_step) 83 | logger.info("Starting training at step=%d", global_step) 84 | 85 | pbar = tqdm(total=num_global_steps) 86 | pbar.update(global_step) 87 | 88 | while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps): 89 | trainer.process(sess) 90 | 91 | new_global_step = sess.run(trainer.global_step) 92 | pbar.set_description('') 93 | pbar.update(max(1, new_global_step - global_step)) 94 | global_step = new_global_step 95 | 96 | if args.task == 0 and global_step - last_record_step > num_record_steps: 97 | sess.run(trainer.meta_sync) 98 | sess.run(trainer.sync) 99 | last_record_step = global_step 100 | frames, reward, length = trainer.evaluate(sess) 101 | imageio.mimsave(video_filename % (args.env_id, global_step, reward), frames, fps=30) 102 | 103 | # Ask for all the services to stop. 104 | sv.stop() 105 | logger.info('reached %s steps. worker stopped.', global_step) 106 | 107 | 108 | def cluster_spec(num_workers, num_ps): 109 | """ 110 | More tensorflow setup for data parallelism 111 | """ 112 | cluster = {} 113 | port = 12222 114 | 115 | all_ps = [] 116 | host = '127.0.0.1' 117 | for _ in range(num_ps): 118 | all_ps.append('{}:{}'.format(host, port)) 119 | port += 1 120 | cluster['ps'] = all_ps 121 | 122 | all_workers = [] 123 | for _ in range(num_workers): 124 | all_workers.append('{}:{}'.format(host, port)) 125 | port += 1 126 | cluster['worker'] = all_workers 127 | return cluster 128 | 129 | 130 | def main(_): 131 | """ 132 | Setting up Tensorflow for data parallel work 133 | """ 134 | 135 | parser = argparse.ArgumentParser(description=None) 136 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.') 137 | parser.add_argument('--task', default=0, type=int, help='Task index') 138 | parser.add_argument('--job-name', default="worker", help='worker or ps') 139 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers') 140 | parser.add_argument('--log-dir', default="/tmp/pong", help='Log directory path') 141 | parser.add_argument('--env-id', default="PongDeterministic-v3", help='Environment id') 142 | 143 | # Add visualisation argument 144 | parser.add_argument('--visualise', action='store_true', 145 | help="Visualise the gym environment by running env.render() between each timestep") 146 | 147 | # Add model parameters 148 | parser.add_argument('--intrinsic-type', type=str, default='feature', 149 | choices=['feature', 'pixel'], help="feature or pixel") 150 | parser.add_argument('--bptt', type=int, default=100, 151 | help="BPTT") 152 | 153 | args = parser.parse_args() 154 | spec = cluster_spec(args.num_workers, 1) 155 | cluster = tf.train.ClusterSpec(spec).as_cluster_def() 156 | 157 | def shutdown(signal, frame): 158 | logger.warn('Received signal %s: exiting', signal) 159 | sys.exit(128+signal) 160 | signal.signal(signal.SIGHUP, shutdown) 161 | signal.signal(signal.SIGINT, shutdown) 162 | signal.signal(signal.SIGTERM, shutdown) 163 | 164 | if args.job_name == "worker": 165 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task, 166 | config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=2)) 167 | run(args, server) 168 | else: 169 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task, 170 | config=tf.ConfigProto(device_filters=["/job:ps"])) 171 | while True: 172 | time.sleep(1000) 173 | 174 | if __name__ == "__main__": 175 | tf.app.run() 176 | --------------------------------------------------------------------------------