├── .gitignore
├── LICENSE
├── README.md
├── a3c.py
├── assets
├── feature-control-bptt-100.png
├── feature-control-video-10M.gif
├── feature-control-video-160M.gif
├── feature-control-video-27M.gif
├── feature-control-video-50M.gif
├── feature-control-video-90M.gif
├── intrinsic_feature.png
├── intrinsic_pixel.png
├── model.png
└── pixel-control-bptt-100.png
├── envs.py
├── model.py
├── ops.py
├── test.py
├── train.py
└── worker.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Tensorflow logs / datasets / results
2 | logs/
3 | datasets/
4 | results/
5 |
6 | # Temporary files
7 | *.zip
8 | *.swp
9 | *~
10 | .DS_Store
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | env/
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # dotenv
94 | .env
95 |
96 | # virtualenv
97 | .venv
98 | venv/
99 | ENV/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 | .spyproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
108 | # mkdocs documentation
109 | /site
110 |
111 | # mypy
112 | .mypy_cache/
113 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Youngwoon Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Feature Control as Intrinsic Motivation for Hierarchical RL in Tensorflow
2 |
3 | As part of the implementation series of [Joseph Lim's group at USC](http://www-bcf.usc.edu/~limjj/), our motivation is to accelerate (or sometimes delay) research in the AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our [group github site](https://github.com/gitlimlab) for other projects.
4 |
5 | This project is implemented by [Youngwoon Lee](https://github.com/youngwoon) and the codes have been reviewed by [Shao-Hua Sun](https://github.com/shaohua0116) before being published.
6 |
7 | ## Description
8 | This repo is a [Tensorflow](https://www.tensorflow.org/) implementation of Feature-control and Pixel-control agents on Montezuma's Revenge: [Feature Control as Intrinsic Motivation for Hierarchical Reinforcement Learning](https://arxiv.org/abs/1705.06769).
9 |
10 | This paper focuses on solving hierarchical reinforcement learning (HRL), where a task can be decomposed into a hierarchy of subproblems or subtasks such that higher-level parent-tasks invoke lower-level child tasks as if they were primitive actions. To tackle this problem, this paper presents a HRL framework which is able to overcome the issue of sparse rewards by extracting intrinsic rewards from changes in consecutive observations. The motivation of this paper is that given an intention of an agent a model learns a set of features, each of which can judge whether the corresponding intention is achieved or not. In other words, the agent learns a set of skills that change future observations in a certain direction. For example, one skill we want to learn in Montezuma's Revenge is catching a key. The success of this skill can be judged by the presence of the key. If we succeed to remove the key from the observation, we can get reward since the presence of the key is changed. These skills are called subgoals of the meta-controller.
11 |
12 |
13 |
14 |
15 |
16 | The proposed model consists of two controllers, which are meta-controller and sub-controller. As a high-level planner, the meta-controller sets a subgoal that wants to achieve for the next 100 timesteps; on the other hand, the sub-controller aims to figure out the optimal action sequences to accomplish this subgoal. The meta-controller decides the next subgoal based on the previous subgoal, the current observation, and the reward it gathered during the previous subgoal. Then the actions are computed using the previous reward, the previous action, and the current observation. To capture temporal information, policy networks of the meta-controller and the sub-controller use LSTMs.
17 |
18 | The intrinsic reward of a feature-control agent is the relative changes in two consecutive frames' k-th feature map over all feature maps on the second convolutional layer. The paper also proposes a pixel-control agent which computes intrinsic reward based on the pixel value changes in a certain region, which is also included in this repo.
19 |
20 |
21 |
22 |
23 |
24 | This method outperforms the state-of-the-art method (Feudal Network) and reaches to 2500 score on Montezuma's Revenge-v0.
25 |
26 | ## Dependencies
27 |
28 | - Ubuntu 16.04
29 | - Python 3.6
30 | - [tmux](https://tmux.github.io)
31 | - [htop](https://hisham.hm/htop)
32 | - [Tensorflow 1.3.0](https://www.tensorflow.org/)
33 | - [Universe](https://github.com/openai/universe)
34 | - [gym](https://github.com/openai/gym)
35 | - [tqdm](https://github.com/tqdm/tqdm)
36 |
37 | ## Usage
38 |
39 | - Execute the following command to train a model:
40 |
41 | ```
42 | $ python train.py --log-dir='/tmp/feature-control' --intrinsic-type='feature' --bptt=100
43 | ```
44 |
45 | - `intrinsic-type` can be either `'feature'` or `'pixel'`
46 |
47 | - With `--bptt` option you can choose 20 or 100 time steps as a bptt.
48 |
49 | - Once training is ended, you can test the agent will play the game 10 times and show the average reward.
50 |
51 | ```
52 | $ python test.py --log-dir='/tmp/feature-control' --intrinsic-type='feature' --bptt=100 --visualise
53 | ```
54 |
55 | - Check the training status on Tensorboard. The default port number is 12345 (i.e. http://localhost:12345).
56 |
57 |
58 | ## Results
59 |
60 | ### Montezuma's Revenge-v0
61 |
62 | - Feature-control agent with bptt 100
63 | 
64 |
65 | - Pixel-control agent with bptt 100
66 | 
67 |
68 | - The training speed shows slower convergence speed compared to the result reported in the paper. Be patient and keep training an agent until 20M iterations.
69 |
70 | ### Videos (Feature-control agent)
71 |
72 | | Iterations | 10M | 27M | 50M | 90M | 160M |
73 | | :--------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: | :--------------------------------------: |
74 | | Rewards | 0 | 100 | 400 | 0 | 2500 |
75 | | Videos |  |  |  |  |  |
76 |
77 |
78 | ## References
79 |
80 | - [Feature Control as Intrinsic Motivation for Hierarchical Reinforcement Learning](https://arxiv.org/abs/1705.06769)
81 | - [The A3C implementation by OpenAI's code](https://github.com/openai/universe-starter-agent)
82 | - [FeUdal Networks for Hierarchical Reinforcement Learning](https://arxiv.org/abs/1703.01161)
83 |
--------------------------------------------------------------------------------
/a3c.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import scipy.signal
4 |
5 | from model import SubPolicy, MetaPolicy
6 |
7 |
8 | def discount(x, gamma):
9 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
10 |
11 | class Batch:
12 | def __init__(self, si, a, adv, r, terminal, features):
13 | self.si = si
14 | self.a = a
15 | self.adv = adv
16 | self.r = r
17 | self.terminal = terminal
18 | self.features = features
19 |
20 | def process_rollout(rollout, gamma, lambda_=1.0):
21 | """
22 | given a rollout, compute its returns and the advantage
23 | """
24 | batch_si = {}
25 | for key in rollout.states[0].keys():
26 | batch_si[key] = np.stack([s[key] for s in rollout.states])
27 | #batch_si = np.asarray(rollout.states)
28 | batch_a = np.asarray(rollout.actions)
29 | rewards = np.asarray(rollout.rewards)
30 | vpred_t = np.asarray(rollout.values + [rollout.r])
31 |
32 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
33 | batch_r = discount(rewards_plus_v, gamma)[:-1]
34 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
35 | # this formula for the advantage comes "Generalized Advantage Estimation":
36 | # https://arxiv.org/abs/1506.02438
37 | batch_adv = discount(delta_t, gamma * lambda_)
38 |
39 | features = rollout.features[0]
40 | return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, features)
41 |
42 |
43 | class PartialRollout(object):
44 | """
45 | a piece of a complete rollout. We run our agent, and process its experience
46 | once it has processed enough steps.
47 | """
48 | def __init__(self):
49 | self.states = []
50 | self.actions = []
51 | self.rewards = []
52 | self.values = []
53 | self.r = 0.0
54 | self.terminal = False
55 | self.features = []
56 |
57 | def add(self, state, action, reward, value, terminal, features):
58 | self.states += [state]
59 | self.actions += [action]
60 | self.rewards += [reward]
61 | self.values += [value]
62 | self.terminal = terminal
63 | self.features += [features]
64 |
65 | def extend(self, other):
66 | assert not self.terminal
67 | self.states.extend(other.states)
68 | self.actions.extend(other.actions)
69 | self.rewards.extend(other.rewards)
70 | self.values.extend(other.values)
71 | self.r = other.r
72 | self.terminal = other.terminal
73 | self.features.extend(other.features)
74 |
75 |
76 | class A3C(object):
77 | def __init__(self, env, task, visualise, intrinsic_type, bptt):
78 | """
79 | An implementation of the A3C algorithm that is reasonably well-tuned for the VNC environments.
80 | Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism.
81 | But overall, we'll define the model, specify its inputs, and describe how the policy gradients step
82 | should be computed.
83 | """
84 |
85 | self.env = env
86 | self.task = task
87 | self.visualise = visualise
88 | self.intrinsic_type = intrinsic_type
89 | self.bptt = bptt
90 | self.subgoal_space = 32 if intrinsic_type == 'feature' else 37
91 | self.action_space = env.action_space.n
92 |
93 | self.summary_writer = None
94 | self.local_steps = 0
95 |
96 | self.beta = 0.75
97 | self.eta = 0.05
98 | self.num_local_steps = self.bptt
99 | self.num_local_meta_steps = 20
100 |
101 | # Testing
102 | if task is None:
103 | with tf.variable_scope("global"):
104 | self.local_sub_network = SubPolicy(env.observation_space.shape,
105 | env.action_space.n,
106 | self.subgoal_space,
107 | self.intrinsic_type)
108 | self.local_meta_network = MetaPolicy(env.observation_space.shape,
109 | self.subgoal_space,
110 | self.intrinsic_type)
111 | return
112 |
113 | # Training
114 | worker_device = "/job:worker/task:{}/cpu:0".format(task)
115 | with tf.device(tf.train.replica_device_setter(1, worker_device=worker_device)):
116 | with tf.variable_scope("global"):
117 | print(env.observation_space.shape)
118 | self.sub_network = SubPolicy(env.observation_space.shape,
119 | env.action_space.n,
120 | self.subgoal_space,
121 | self.intrinsic_type)
122 | self.meta_network = MetaPolicy(env.observation_space.shape,
123 | self.subgoal_space,
124 | self.intrinsic_type)
125 | self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32),
126 | trainable=False)
127 |
128 | with tf.device(worker_device):
129 | with tf.variable_scope("local"):
130 | self.local_sub_network = pi = SubPolicy(env.observation_space.shape,
131 | env.action_space.n,
132 | self.subgoal_space,
133 | self.intrinsic_type)
134 | self.local_meta_network = meta_pi = MetaPolicy(env.observation_space.shape,
135 | self.subgoal_space,
136 | self.intrinsic_type)
137 | pi.global_step = self.global_step
138 |
139 | self.ac = tf.placeholder(tf.float32, [None, env.action_space.n], name="ac")
140 | self.adv = tf.placeholder(tf.float32, [None], name="adv")
141 | self.r = tf.placeholder(tf.float32, [None], name="r")
142 |
143 | log_prob_tf = tf.nn.log_softmax(pi.logits)
144 | prob_tf = tf.nn.softmax(pi.logits)
145 |
146 | # the "policy gradients" loss: its derivative is precisely the policy gradient
147 | # notice that self.ac is a placeholder that is provided externally.
148 | # adv will contain the advantages, as calculated in process_rollout
149 | pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, [1]) * self.adv)
150 |
151 | # loss of value function
152 | vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.vf - self.r))
153 | entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
154 |
155 | bs = tf.to_float(tf.shape(pi.x)[0])
156 | self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
157 |
158 |
159 | self.meta_ac = tf.placeholder(tf.float32, [None, self.subgoal_space], name="meta_ac")
160 | self.meta_adv = tf.placeholder(tf.float32, [None], name="meta_adv")
161 | self.meta_r = tf.placeholder(tf.float32, [None], name="meta_r")
162 |
163 | meta_log_prob_tf = tf.nn.log_softmax(meta_pi.logits)
164 | meta_prob_tf = tf.nn.softmax(meta_pi.logits)
165 |
166 | # the "policy gradients" loss: its derivative is precisely the policy gradient
167 | # notice that self.ac is a placeholder that is provided externally.
168 | # adv will contain the advantages, as calculated in process_rollout
169 | meta_pi_loss = - tf.reduce_sum(tf.reduce_sum(meta_log_prob_tf * self.meta_ac, [1]) * self.meta_adv)
170 |
171 | # loss of value function
172 | meta_vf_loss = 0.5 * tf.reduce_sum(tf.square(meta_pi.vf - self.meta_r))
173 | meta_entropy = - tf.reduce_sum(meta_prob_tf * meta_log_prob_tf)
174 |
175 | meta_bs = tf.to_float(tf.shape(meta_pi.x)[0])
176 | self.meta_loss = meta_pi_loss + 0.5 * meta_vf_loss - meta_entropy * 0.01
177 |
178 | # 20 represents the number of "local steps": the number of timesteps
179 | # we run the policy before we update the parameters.
180 | # The larger local steps is, the lower is the variance in our policy gradients estimate
181 | # on the one hand; but on the other hand, we get less frequent parameter updates, which
182 | # slows down learning. In this code, we found that making local steps be much
183 | # smaller than 20 makes the algorithm more difficult to tune and to get to work.
184 | # self.runner = RunnerThread(env, pi, meta_pi, 20, visualise)
185 |
186 | grads = tf.gradients(self.loss, pi.var_list)
187 | meta_grads = tf.gradients(self.meta_loss, meta_pi.var_list)
188 |
189 | summary = [
190 | tf.summary.scalar("model/policy_loss", pi_loss / bs),
191 | tf.summary.scalar("model/value_loss", vf_loss / bs),
192 | tf.summary.scalar("model/entropy", entropy / bs),
193 | tf.summary.image("model/state", pi.x),
194 | tf.summary.scalar("model/grad_global_norm", tf.global_norm(grads)),
195 | tf.summary.scalar("model/var_global_norm", tf.global_norm(pi.var_list))
196 | ]
197 |
198 | meta_summary = [
199 | tf.summary.scalar("meta_model/policy_loss", meta_pi_loss / meta_bs),
200 | tf.summary.scalar("meta_model/value_loss", meta_vf_loss / meta_bs),
201 | tf.summary.scalar("meta_model/entropy", meta_entropy / meta_bs),
202 | tf.summary.scalar("meta_model/grad_global_norm", tf.global_norm(meta_grads)),
203 | tf.summary.scalar("meta_model/var_global_norm", tf.global_norm(meta_pi.var_list))
204 | ]
205 | self.summary_op = tf.summary.merge(summary)
206 | self.meta_summary_op = tf.summary.merge(meta_summary)
207 |
208 | grads, _ = tf.clip_by_global_norm(grads, 40.0)
209 | meta_grads, _ = tf.clip_by_global_norm(meta_grads, 40.0)
210 | self.grads = grads
211 |
212 | # copy weights from the parameter server to the local model
213 | self.sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(pi.var_list, self.sub_network.var_list)])
214 | self.meta_sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(meta_pi.var_list, self.meta_network.var_list)])
215 |
216 | grads_and_vars = list(zip(grads, self.sub_network.var_list))
217 | meta_grads_and_vars = list(zip(meta_grads, self.meta_network.var_list))
218 |
219 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0])
220 |
221 | # each worker has a different set of adam optimizer parameters
222 | opt = tf.train.AdamOptimizer(1e-4)
223 | self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step)
224 | meta_opt = tf.train.AdamOptimizer(1e-4)
225 | self.meta_train_op = meta_opt.apply_gradients(meta_grads_and_vars)
226 |
227 | def start(self, sess, summary_writer):
228 | self.summary_writer = summary_writer
229 |
230 | def process(self, sess):
231 | """
232 | run one episode and process experience to train both meta and sub networks
233 | """
234 | env = self.env
235 |
236 | sub_policy = self.local_sub_network
237 | meta_policy = self.local_meta_network
238 |
239 | self.last_state = env.reset()
240 | self.last_meta_state = env.reset()
241 | self.last_features = sub_policy.get_initial_features()
242 | self.last_meta_features = meta_policy.get_initial_features()
243 |
244 | self.last_action = np.zeros(self.action_space)
245 | self.last_subgoal = np.zeros(self.subgoal_space)
246 | self.last_reward = np.zeros(1)
247 | self.last_meta_reward = np.zeros(1)
248 |
249 | self.length = 0
250 | self.rewards = 0
251 | self.extrinsic_rewards = 0
252 | self.intrinsic_rewards = 0
253 |
254 | terminal_end = False
255 | while not terminal_end:
256 | terminal_end = self._meta_process(sess)
257 |
258 | def _meta_process(self, sess):
259 | sess.run(self.meta_sync) # copy weights from shared to local
260 | meta_rollout = PartialRollout()
261 | meta_policy = self.local_meta_network
262 |
263 | terminal_end = False
264 | for _ in range(self.num_local_meta_steps):
265 | fetched = meta_policy.act(self.last_meta_state, self.last_subgoal,
266 | self.last_meta_reward, *self.last_meta_features)
267 | subgoal, meta_value, meta_features = fetched[0], fetched[1], fetched[2:]
268 |
269 | assert self.bptt in [20, 100], 'bptt (%d) should be 20 or 100' % self.bptt
270 |
271 | if self.bptt == 20:
272 | meta_reward = 0
273 | for _ in range(5):
274 | state, reward, terminal_end = self._sub_process(sess, subgoal)
275 | meta_reward += reward
276 | if terminal_end:
277 | break
278 | elif self.bptt == 100:
279 | state, meta_reward, terminal_end = self._sub_process(sess, subgoal)
280 |
281 |
282 | si = {
283 | 'x': self.last_meta_state,
284 | 'subgoal_prev': self.last_subgoal,
285 | 'reward_prev': self.last_meta_reward
286 | }
287 | meta_rollout.add(si, subgoal, meta_reward, meta_value,
288 | terminal_end, self.last_meta_features)
289 |
290 | self.last_meta_state = state
291 | self.last_meta_features = meta_features
292 | self.last_meta_reward = [meta_reward]
293 | self.last_subgoal = subgoal
294 |
295 | if terminal_end:
296 | break
297 |
298 | if not terminal_end:
299 | meta_rollout.r = meta_policy.value(self.last_state,
300 | self.last_subgoal,
301 | self.last_meta_reward,
302 | *self.last_meta_features)
303 |
304 | # meta rollout
305 | batch = process_rollout(meta_rollout, gamma=0.99, lambda_=1.0)
306 | fetches = [self.meta_summary_op, self.meta_train_op, self.global_step]
307 |
308 | feed_dict = {
309 | meta_policy.x: batch.si['x'],
310 | meta_policy.subgoal_prev: batch.si['subgoal_prev'],
311 | meta_policy.reward_prev: batch.si['reward_prev'],
312 | self.meta_ac: batch.a,
313 | self.meta_adv: batch.adv,
314 | self.meta_r: batch.r,
315 | meta_policy.state_in[0]: batch.features[0],
316 | meta_policy.state_in[1]: batch.features[1],
317 | }
318 | fetched = sess.run(fetches, feed_dict=feed_dict)
319 | self.summary_writer.add_summary(fetched[0], fetched[-1])
320 |
321 | return terminal_end
322 |
323 | def _sub_process(self, sess, subgoal):
324 | sess.run(self.sync) # copy weights from shared to local
325 | sub_rollout = PartialRollout()
326 | sub_policy = self.local_sub_network
327 | meta_reward = 0
328 |
329 | terminal_end = False
330 | for _ in range(self.num_local_steps):
331 | fetched = sub_policy.act(self.last_state, self.last_action,
332 | self.last_reward, subgoal, *self.last_features)
333 | action, value, features = fetched[0], fetched[1], fetched[2:]
334 |
335 | # argmax to convert from one-hot
336 | state, episode_reward, terminal, info = self.env.step(action.argmax())
337 | # reward clipping to the range of [-1, 1]
338 | extrinsic_reward = max(min(episode_reward, 1), -1)
339 |
340 | def get_mask(shape, subgoal):
341 | mask = np.zeros(shape)
342 | if subgoal < 36:
343 | y = subgoal // 6
344 | x = subgoal % 6
345 | mask[y*14:(y+1)*14, x*14:(x+1)*14] = 1
346 | mask = np.stack([mask] * 3)
347 | return mask
348 |
349 | def compute_intrinsic(state, last_state, subgoal):
350 | f = sub_policy.feature(state)
351 | last_f = sub_policy.feature(last_state)
352 | if self.intrinsic_type == 'feature':
353 | diff = np.abs(f - last_f)
354 | return self.eta * diff[subgoal] / (np.sum(diff) + 1e-10)
355 | else:
356 | diff = f - last_f
357 | diff = diff * diff
358 | mask = get_mask(diff.shape, subgoal)
359 | return self.eta * np.sum(mask * diff) / (np.sum(diff) + 1e-10)
360 |
361 | intrinsic_reward = compute_intrinsic(state, self.last_state, subgoal.argmax())
362 | reward = self.beta * extrinsic_reward + (1 - self.beta) * intrinsic_reward
363 |
364 | meta_reward += extrinsic_reward
365 | # meta_reward += reward
366 | self.intrinsic_rewards += intrinsic_reward
367 | self.extrinsic_rewards += extrinsic_reward
368 |
369 | # collect the experience
370 | si = {
371 | 'x': self.last_state,
372 | 'action_prev': self.last_action,
373 | 'reward_prev': self.last_reward,
374 | 'subgoal': subgoal
375 | }
376 | sub_rollout.add(si, action, reward, value, terminal, self.last_features)
377 |
378 | self.length += 1
379 | self.rewards += episode_reward
380 |
381 | self.last_state = state
382 | self.last_action = action
383 | self.last_features = features
384 | self.last_reward = [reward]
385 |
386 | timestep_limit = self.env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
387 | if terminal or (timestep_limit and self.length >= timestep_limit):
388 | terminal_end = True
389 |
390 | summary = tf.Summary()
391 | summary.value.add(tag='global/episode_reward',
392 | simple_value=self.rewards)
393 | summary.value.add(tag='global/extrinsic_reward',
394 | simple_value=self.extrinsic_rewards)
395 | summary.value.add(tag='global/intrinsic_reward',
396 | simple_value=self.intrinsic_rewards)
397 | summary.value.add(tag='global/episode_length',
398 | simple_value=self.length)
399 | self.summary_writer.add_summary(summary, sub_policy.global_step.eval())
400 | self.summary_writer.flush()
401 |
402 | print("Episode finished. Ep rewards: %.5f (In: %.5f, Ex: %.5f). Length: %d" %
403 | (self.rewards, self.intrinsic_rewards, self.extrinsic_rewards, self.length))
404 | break
405 |
406 | if not terminal_end:
407 | sub_rollout.r = sub_policy.value(self.last_state,
408 | self.last_action,
409 | self.last_reward,
410 | subgoal, *self.last_features)
411 |
412 | self.local_steps += 1
413 | should_compute_summary = self.task == 0 and self.local_steps % 11 == 0
414 |
415 | # sub rollout
416 | batch = process_rollout(sub_rollout, gamma=0.99, lambda_=1.0)
417 | fetches = [self.train_op, self.global_step]
418 | if should_compute_summary:
419 | fetches = [self.summary_op] + fetches
420 | feed_dict = {
421 | sub_policy.x: batch.si['x'],
422 | sub_policy.action_prev: batch.si['action_prev'],
423 | sub_policy.reward_prev: batch.si['reward_prev'],
424 | sub_policy.subgoal: batch.si['subgoal'],
425 | self.ac: batch.a,
426 | self.adv: batch.adv,
427 | self.r: batch.r,
428 | sub_policy.state_in[0]: batch.features[0],
429 | sub_policy.state_in[1]: batch.features[1],
430 | }
431 |
432 | fetched = sess.run(fetches, feed_dict=feed_dict)
433 | if should_compute_summary:
434 | self.summary_writer.add_summary(fetched[0], fetched[-1])
435 |
436 | return self.last_state, meta_reward, terminal_end
437 |
438 |
439 | def evaluate(self, sess):
440 | """
441 | run one episode and process experience to train both meta and sub networks
442 | """
443 | env = self.env
444 |
445 | sub_policy = self.local_sub_network
446 | meta_policy = self.local_meta_network
447 |
448 | self.last_state = env.reset()
449 | self.last_meta_state = env.reset()
450 | self.last_features = sub_policy.get_initial_features()
451 | self.last_meta_features = meta_policy.get_initial_features()
452 |
453 | self.last_action = np.zeros(self.action_space)
454 | self.last_subgoal = np.zeros(self.subgoal_space)
455 | self.last_reward = np.zeros(1)
456 | self.last_meta_reward = np.zeros(1)
457 |
458 | self.length = 0
459 | self.rewards = 0
460 | self.extrinsic_rewards = 0
461 | self.intrinsic_rewards = 0
462 |
463 | terminal_end = False
464 | frames = [self.last_state]
465 | while not terminal_end:
466 | frames_, terminal_end = self._meta_evaluate(sess)
467 | frames.extend(frames_)
468 | frames = np.stack(frames)
469 | return frames, self.rewards, self.length
470 |
471 | def _meta_evaluate(self, sess):
472 | meta_policy = self.local_meta_network
473 |
474 | terminal_end = False
475 | frames = []
476 | for _ in range(self.num_local_meta_steps):
477 | fetched = meta_policy.act(self.last_meta_state, self.last_subgoal,
478 | self.last_meta_reward, *self.last_meta_features)
479 | subgoal, meta_features = fetched[0], fetched[2:]
480 |
481 | if self.bptt == 20:
482 | meta_reward = 0
483 | for _ in range(5):
484 | frames_, state, reward, terminal_end = self._sub_evaluate(sess, subgoal)
485 | frames.extend(frames_)
486 | meta_reward += reward
487 | if terminal_end:
488 | break
489 | elif self.bptt == 100:
490 | frames_, state, meta_reward, terminal_end = self._sub_evaluate(sess, subgoal)
491 | frames.extend(frames_)
492 |
493 | self.last_meta_state = state
494 | self.last_subgoal = subgoal
495 | self.last_meta_reward = [meta_reward]
496 | self.last_meta_features = meta_features
497 | if terminal_end:
498 | break
499 | return frames, terminal_end
500 |
501 | def _sub_evaluate(self, sess, subgoal):
502 | sub_policy = self.local_sub_network
503 | meta_reward = 0
504 | frames = []
505 |
506 | for _ in range(self.num_local_steps):
507 | fetched = sub_policy.act(self.last_state, self.last_action,
508 | self.last_reward, subgoal, *self.last_features)
509 | action, features = fetched[0], fetched[2:]
510 |
511 | # argmax to convert from one-hot
512 | state, episode_reward, terminal, info = self.env.step(action.argmax())
513 | # reward clipping to the range of [-1, 1]
514 | extrinsic_reward = max(min(episode_reward, 1), -1)
515 | frames.append(state)
516 |
517 | if self.visualise:
518 | self.env.render()
519 |
520 | def get_mask(shape, subgoal):
521 | mask = np.zeros(shape)
522 | if subgoal < 36:
523 | y = subgoal // 6
524 | x = subgoal % 6
525 | mask[y*14:(y+1)*14, x*14:(x+1)*14] = 1
526 | mask = np.stack([mask] * 3)
527 | return mask
528 |
529 | def compute_intrinsic(state, last_state, subgoal):
530 | if self.intrinsic_type == 'feature':
531 | f = sub_policy.feature(state)
532 | last_f = sub_policy.feature(last_state)
533 | diff = np.abs(f - last_f)
534 | return self.eta * diff[subgoal] / (np.sum(diff) + 1e-10)
535 | else:
536 | diff = state - last_state
537 | diff = diff * diff
538 | mask = get_mask(diff.shape, subgoal)
539 | return self.eta * np.sum(mask * diff) / (np.sum(diff) + 1e-10)
540 |
541 | intrinsic_reward = compute_intrinsic(state, self.last_state, subgoal.argmax())
542 | reward = self.beta * extrinsic_reward + (1 - self.beta) * intrinsic_reward
543 |
544 | meta_reward += extrinsic_reward
545 | # meta_reward += reward
546 | self.intrinsic_rewards += intrinsic_reward
547 | self.extrinsic_rewards += extrinsic_reward
548 |
549 | self.length += 1
550 | self.rewards += episode_reward
551 |
552 | self.last_state = state
553 | self.last_action = action
554 | self.last_features = features
555 | self.last_reward = [reward]
556 |
557 | timestep_limit = self.env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
558 | if terminal or (timestep_limit and self.length >= timestep_limit):
559 | print("Episode finished. Ep rewards: %.5f (In: %.5f, Ex: %.5f). Length: %d" %
560 | (self.rewards, self.intrinsic_rewards, self.extrinsic_rewards, self.length))
561 | return frames, state, meta_reward, True
562 | return frames, state, meta_reward, False
563 |
--------------------------------------------------------------------------------
/assets/feature-control-bptt-100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-bptt-100.png
--------------------------------------------------------------------------------
/assets/feature-control-video-10M.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-10M.gif
--------------------------------------------------------------------------------
/assets/feature-control-video-160M.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-160M.gif
--------------------------------------------------------------------------------
/assets/feature-control-video-27M.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-27M.gif
--------------------------------------------------------------------------------
/assets/feature-control-video-50M.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-50M.gif
--------------------------------------------------------------------------------
/assets/feature-control-video-90M.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/feature-control-video-90M.gif
--------------------------------------------------------------------------------
/assets/intrinsic_feature.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/intrinsic_feature.png
--------------------------------------------------------------------------------
/assets/intrinsic_pixel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/intrinsic_pixel.png
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/model.png
--------------------------------------------------------------------------------
/assets/pixel-control-bptt-100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clvrai/FeatureControlHRL-Tensorflow/7e611febd296bada68f44710992f9bcd284941d2/assets/pixel-control-bptt-100.png
--------------------------------------------------------------------------------
/envs.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import logging
3 |
4 | logger = logging.getLogger(__name__)
5 | logger.setLevel(logging.INFO)
6 |
7 |
8 | def create_env(env_id, **kwargs):
9 | return create_atari_env(env_id)
10 |
11 |
12 | def create_atari_env(env_id):
13 | env = gym.make(env_id)
14 | return env
15 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow.contrib.rnn as rnn
4 | from ops import flatten, conv2d, linear
5 |
6 |
7 | def normalized_columns_initializer(std=1.0):
8 | def _initializer(shape, dtype=None, partition_info=None):
9 | out = np.random.randn(*shape).astype(np.float32)
10 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
11 | return tf.constant(out)
12 | return _initializer
13 |
14 |
15 | def categorical_sample(logits, d):
16 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1])
17 | return tf.one_hot(value, d)
18 |
19 |
20 | class SubPolicy(object):
21 | def __init__(self, ob_space, ac_space, subgoal_space, intrinsic_type):
22 | self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space), name='x')
23 | self.action_prev = action_prev = tf.placeholder(tf.float32, [None, ac_space], name='action_prev')
24 | self.reward_prev = reward_prev = tf.placeholder(tf.float32, [None, 1], name='reward_prev')
25 | self.subgoal = subgoal = tf.placeholder(tf.float32, [None, subgoal_space], name='subgoal')
26 | self.intrinsic_type = intrinsic_type
27 |
28 | with tf.variable_scope('encoder'):
29 | x = tf.image.resize_images(x, [84, 84])
30 | x = x / 255.0
31 | self.p = x
32 | x = tf.nn.relu(conv2d(x, 16, "l1", [8, 8], [4, 4]))
33 | x = tf.nn.relu(conv2d(x, 32, "l2", [4, 4], [2, 2]))
34 | self.f = tf.reduce_mean(x, axis=[1, 2])
35 | x = flatten(x)
36 |
37 | with tf.variable_scope('sub_policy'):
38 | x = tf.nn.relu(linear(x, 256, "fc",
39 | normalized_columns_initializer(0.01)))
40 | x = tf.concat([x, action_prev], axis=1)
41 | x = tf.concat([x, reward_prev], axis=1)
42 | x = tf.concat([x, subgoal], axis=1)
43 |
44 | # introduce a "fake" batch dimension of 1 after flatten
45 | # so that we can do LSTM over time dim
46 | x = tf.expand_dims(x, [0])
47 |
48 | size = 256
49 | lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
50 | self.state_size = lstm.state_size
51 | step_size = tf.shape(self.x)[:1]
52 |
53 | c_init = np.zeros((1, lstm.state_size.c), np.float32)
54 | h_init = np.zeros((1, lstm.state_size.h), np.float32)
55 | self.state_init = [c_init, h_init]
56 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
57 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
58 | self.state_in = [c_in, h_in]
59 |
60 | state_in = rnn.LSTMStateTuple(c_in, h_in)
61 |
62 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
63 | lstm, x, initial_state=state_in, sequence_length=step_size,
64 | time_major=False
65 | )
66 | lstm_c, lstm_h = lstm_state
67 | lstm_outputs = tf.reshape(lstm_outputs, [-1, size])
68 | self.logits = linear(lstm_outputs, ac_space, "action",
69 | normalized_columns_initializer(0.01))
70 | self.vf = tf.reshape(linear(lstm_outputs, 1, "value",
71 | normalized_columns_initializer(1.0)), [-1])
72 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
73 | self.sample = categorical_sample(self.logits, ac_space)[0, :]
74 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
75 | tf.get_variable_scope().name)
76 |
77 | def get_initial_features(self):
78 | return self.state_init
79 |
80 | def act(self, ob, action_prev, reward_prev, subgoal, c, h):
81 | sess = tf.get_default_session()
82 | return sess.run([self.sample, self.vf] + self.state_out,
83 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h,
84 | self.action_prev: [action_prev],
85 | self.reward_prev: [reward_prev],
86 | self.subgoal: [subgoal]})
87 |
88 | def value(self, ob, action_prev, reward_prev, subgoal, c, h):
89 | sess = tf.get_default_session()
90 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c,
91 | self.state_in[1]: h,
92 | self.action_prev: [action_prev],
93 | self.reward_prev: [reward_prev],
94 | self.subgoal: [subgoal]})[0]
95 |
96 | def feature(self, state):
97 | sess = tf.get_default_session()
98 | if self.intrinsic_type == 'feature':
99 | return sess.run(self.f, {self.x: [state]})[0, :]
100 | else:
101 | return sess.run(self.p, {self.x: [state]})[0, :]
102 |
103 |
104 | class MetaPolicy(object):
105 | def __init__(self, ob_space, subgoal_space, intrinsic_type):
106 | self.x = x = \
107 | tf.placeholder(tf.float32, [None] + list(ob_space), name='x_meta')
108 | self.subgoal_prev = subgoal_prev = \
109 | tf.placeholder(tf.float32, [None, subgoal_space], name='subgoal_prev')
110 | self.reward_prev = reward_prev = \
111 | tf.placeholder(tf.float32, [None, 1], name='reward_prev_meta')
112 | self.intrinsic_type = intrinsic_type
113 |
114 | with tf.variable_scope('encoder', reuse=True):
115 | x = tf.image.resize_images(x, [84, 84])
116 | x = x / 255.0
117 | x = tf.nn.relu(conv2d(x, 16, "l1", [8, 8], [4, 4]))
118 | x = tf.nn.relu(conv2d(x, 32, "l2", [4, 4], [2, 2]))
119 | x = flatten(x)
120 |
121 | with tf.variable_scope('meta_policy'):
122 | x = tf.nn.relu(linear(x, 256, "fc",
123 | normalized_columns_initializer(0.01)))
124 | x = tf.concat([x, subgoal_prev], axis=1)
125 | x = tf.concat([x, reward_prev], axis=1)
126 |
127 | # introduce a "fake" batch dimension of 1 after flatten
128 | # so that we can do LSTM over time dim
129 | x = tf.expand_dims(x, [0])
130 |
131 | size = 256
132 | lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
133 | self.state_size = lstm.state_size
134 | step_size = tf.shape(self.x)[:1]
135 |
136 | c_init = np.zeros((1, lstm.state_size.c), np.float32)
137 | h_init = np.zeros((1, lstm.state_size.h), np.float32)
138 | self.state_init = [c_init, h_init]
139 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
140 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
141 | self.state_in = [c_in, h_in]
142 |
143 | state_in = rnn.LSTMStateTuple(c_in, h_in)
144 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
145 | lstm, x, initial_state=state_in, sequence_length=step_size,
146 | time_major=False)
147 | lstm_c, lstm_h = lstm_state
148 | lstm_outputs = tf.reshape(lstm_outputs, [-1, size])
149 | self.logits = linear(lstm_outputs, subgoal_space, "action",
150 | normalized_columns_initializer(0.01))
151 | self.vf = tf.reshape(linear(lstm_outputs, 1, "value",
152 | normalized_columns_initializer(1.0)), [-1])
153 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
154 | self.sample = categorical_sample(self.logits, subgoal_space)[0, :]
155 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
156 |
157 | def get_initial_features(self):
158 | return self.state_init
159 |
160 | def act(self, ob, subgoal_prev, reward_prev, c, h):
161 | sess = tf.get_default_session()
162 | return sess.run([self.sample, self.vf] + self.state_out,
163 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h,
164 | self.subgoal_prev: [subgoal_prev],
165 | self.reward_prev: [reward_prev]})
166 |
167 | def value(self, ob, subgoal_prev, reward_prev, c, h):
168 | sess = tf.get_default_session()
169 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c,
170 | self.state_in[1]: h,
171 | self.subgoal_prev: [subgoal_prev],
172 | self.reward_prev: [reward_prev]})[0]
173 |
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | def flatten(x):
6 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
7 |
8 |
9 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1),
10 | pad="SAME", dtype=tf.float32, collections=None):
11 | with tf.variable_scope(name):
12 | stride_shape = [1, stride[0], stride[1], 1]
13 | filter_shape = [filter_size[0], filter_size[1],
14 | int(x.get_shape()[3]), num_filters]
15 |
16 | # there are "num input feature maps * filter height * filter width"
17 | # inputs to each hidden unit
18 | fan_in = np.prod(filter_shape[:3])
19 | # each unit in the lower layer receives a gradient from:
20 | # "num output feature maps * filter height * filter width" /
21 | # pooling size
22 | fan_out = np.prod(filter_shape[:2]) * num_filters
23 | # initialize weights with random weights
24 | w_bound = np.sqrt(6. / (fan_in + fan_out))
25 |
26 | w = tf.get_variable("W", filter_shape, dtype,
27 | tf.random_uniform_initializer(-w_bound, w_bound),
28 | collections=collections)
29 | b = tf.get_variable("b", [1, 1, 1, num_filters],
30 | initializer=tf.constant_initializer(0.0),
31 | collections=collections)
32 | return tf.nn.conv2d(x, w, stride_shape, pad) + b
33 |
34 |
35 | def linear(x, size, name, initializer=None, bias_init=0):
36 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer)
37 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init))
38 | return tf.matmul(x, w) + b
39 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import logging
4 | import sys
5 | import signal
6 | import os
7 |
8 | import tensorflow as tf
9 | import imageio
10 |
11 | from a3c import A3C
12 | from envs import create_env
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 | logger.setLevel(logging.INFO)
17 |
18 | # parsing cmd arguments
19 | parser = argparse.ArgumentParser(description="Test commands")
20 | parser.add_argument('-e', '--env-id', type=str, default="MontezumaRevenge-v0",
21 | help="Environment id")
22 | parser.add_argument('-l', '--log-dir', type=str, default="/tmp/montezuma",
23 | help="Log directory path")
24 |
25 | # Add visualisation argument
26 | parser.add_argument('--visualise', action='store_true',
27 | help="Visualise the gym environment by running env.render() between each timestep")
28 |
29 | # Add model parameters
30 | parser.add_argument('--intrinsic-type', type=str, default='feature',
31 | choices=['feature', 'pixel'], help="Feature-control or Pixel-control")
32 | parser.add_argument('--bptt', type=int, default=100,
33 | help="BPTT")
34 |
35 |
36 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless.
37 | class FastSaver(tf.train.Saver):
38 | def save(self, sess, save_path, global_step=None, latest_filename=None,
39 | meta_graph_suffix="meta", write_meta_graph=True):
40 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
41 | meta_graph_suffix, False)
42 |
43 |
44 | def run(args):
45 | env = create_env(args.env_id)
46 | trainer = A3C(env, None, args.visualise, args.intrinsic_type, args.bptt)
47 |
48 | # Variable names that start with "local" are not saved in checkpoints.
49 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
50 | init_op = tf.variables_initializer(variables_to_save)
51 | init_all_op = tf.global_variables_initializer()
52 | saver = FastSaver(variables_to_save)
53 |
54 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
55 | logger.info('Trainable vars:')
56 | for v in var_list:
57 | logger.info(' %s %s', v.name, v.get_shape())
58 |
59 | def init_fn(ses):
60 | logger.info("Initializing all parameters.")
61 | ses.run(init_all_op)
62 |
63 | logdir = os.path.join(args.log_dir, 'train')
64 | summary_writer = tf.summary.FileWriter(logdir)
65 | logger.info("Events directory: %s", logdir)
66 |
67 | sv = tf.train.Supervisor(is_chief=True,
68 | logdir=logdir,
69 | saver=saver,
70 | summary_op=None,
71 | init_op=init_op,
72 | init_fn=init_fn,
73 | summary_writer=summary_writer,
74 | ready_op=tf.report_uninitialized_variables(variables_to_save),
75 | global_step=None,
76 | save_model_secs=0,
77 | save_summaries_secs=0)
78 |
79 | video_dir = os.path.join(args.log_dir, 'test_videos_' + args.intrinsic_type)
80 | if not os.path.exists(video_dir):
81 | os.makedirs(video_dir)
82 | video_filename = video_dir + "/%s_%02d_%d.gif"
83 | print("Video saved at %s" % video_dir)
84 |
85 | with sv.managed_session() as sess, sess.as_default():
86 | trainer.start(sess, summary_writer)
87 | rewards = []
88 | lengths = []
89 | for i in range(10):
90 | frames, reward, length = trainer.evaluate(sess)
91 | rewards.append(reward)
92 | lengths.append(length)
93 | imageio.mimsave(video_filename % (args.env_id, i, reward), frames, fps=30)
94 |
95 | print('Evaluation: avg. reward %.2f avg.length %.2f' %
96 | (sum(rewards) / 10.0, sum(lengths) / 10.0))
97 |
98 | # Ask for all the services to stop.
99 | sv.stop()
100 |
101 |
102 | def main(_):
103 | args, unparsed = parser.parse_known_args()
104 |
105 | def shutdown(signal, frame):
106 | logger.warn('Received signal %s: exiting', signal)
107 | sys.exit(128+signal)
108 | signal.signal(signal.SIGHUP, shutdown)
109 | signal.signal(signal.SIGINT, shutdown)
110 | signal.signal(signal.SIGTERM, shutdown)
111 |
112 | run(args)
113 |
114 | if __name__ == "__main__":
115 | tf.app.run()
116 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | from six.moves import shlex_quote
6 |
7 | parser = argparse.ArgumentParser(description="Run commands")
8 | parser.add_argument('-w', '--num-workers', default=8, type=int,
9 | help="Number of workers")
10 | parser.add_argument('-e', '--env-id', type=str, default="MontezumaRevenge-v0",
11 | help="Environment id")
12 | parser.add_argument('-l', '--log-dir', type=str, default="/tmp/montezuma",
13 | help="Log directory path")
14 | parser.add_argument('-n', '--dry-run', action='store_true',
15 | help="Print out commands rather than executing them")
16 | parser.add_argument('-m', '--mode', type=str, default='tmux',
17 | help="tmux: run workers in a tmux session. nohup: "
18 | "run workers with nohup. child: run workers as child processes")
19 |
20 | # Add visualise tag
21 | parser.add_argument('--visualise', action='store_true',
22 | help="Visualise the gym environment by running env.render() between each timestep")
23 |
24 | # Add model parameters
25 | parser.add_argument('--intrinsic-type', type=str, default='feature',
26 | choices=['feature', 'pixel'], help="feature or pixel")
27 | parser.add_argument('--bptt', type=int, default=100,
28 | help="BPTT")
29 |
30 |
31 | def new_cmd(session, name, cmd, mode, logdir, shell):
32 | if isinstance(cmd, (list, tuple)):
33 | cmd = " ".join(shlex_quote(str(v)) for v in cmd)
34 | if mode == 'tmux':
35 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd))
36 | elif mode == 'child':
37 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir)
38 | elif mode == 'nohup':
39 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir)
40 |
41 |
42 | def create_commands(session, num_workers, env_id, logdir,
43 | intrinsic_type, bptt, shell='bash', mode='tmux', visualise=False):
44 | # Launch the TF workers and for launching tensorboard
45 | base_cmd = [
46 | 'CUDA_VISIBLE_DEVICES=',
47 | sys.executable, 'worker.py',
48 | '--log-dir', logdir,
49 | '--env-id', env_id,
50 | '--num-workers', str(num_workers),
51 | '--intrinsic-type', intrinsic_type,
52 | '--bptt', str(bptt)
53 | ]
54 |
55 | if visualise:
56 | base_cmd += ['--visualise']
57 |
58 | cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)]
59 | for i in range(num_workers):
60 | cmds_map += [new_cmd(session, "w-%d" % i, base_cmd +
61 | ["--job-name", "worker", "--task", str(i)], mode, logdir, shell)]
62 |
63 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir,
64 | "--port", "12345"], mode, logdir, shell)]
65 | if mode == 'tmux':
66 | cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)]
67 |
68 | windows = [v[0] for v in cmds_map]
69 |
70 | notes = []
71 | cmds = [
72 | "mkdir -p {}".format(logdir),
73 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir),
74 | ]
75 | if mode == 'nohup' or mode == 'child':
76 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)]
77 | notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)]
78 | if mode == 'tmux':
79 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)]
80 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)]
81 | else:
82 | notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)]
83 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"]
84 |
85 | if mode == 'tmux':
86 | cmds += [
87 | # kill any process using tensorboard's port
88 | "kill $( lsof -i:12345 -t ) > /dev/null 2>&1",
89 | # kill any processes using ps / worker ports
90 | "kill $( lsof -i:12222-{} -t ) > /dev/null 2>&1".format(num_workers+12222),
91 | "tmux kill-session -t {}".format(session),
92 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell)
93 | ]
94 | for w in windows[1:]:
95 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)]
96 | cmds += ["sleep 1"]
97 | for window, cmd in cmds_map:
98 | cmds += [cmd]
99 |
100 | return cmds, notes
101 |
102 |
103 | def run():
104 | args = parser.parse_args()
105 | cmds, notes = create_commands("a3c", args.num_workers,
106 | args.env_id, args.log_dir, args.intrinsic_type,
107 | args.bptt, mode=args.mode, visualise=args.visualise)
108 | if args.dry_run:
109 | print("Dry-run mode due to -n flag, otherwise the following commands would be executed:")
110 | else:
111 | print("Executing the following commands:")
112 | print("\n".join(cmds))
113 | print("")
114 | if not args.dry_run:
115 | if args.mode == "tmux":
116 | os.environ["TMUX"] = ""
117 | os.system("\n".join(cmds))
118 | print('\n'.join(notes))
119 |
120 | if __name__ == "__main__":
121 | run()
122 |
--------------------------------------------------------------------------------
/worker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import logging
4 | import sys
5 | import signal
6 | import time
7 | import os
8 |
9 | import imageio
10 | from tqdm import tqdm
11 | import tensorflow as tf
12 |
13 | from a3c import A3C
14 | from envs import create_env
15 |
16 | logger = logging.getLogger(__name__)
17 | logger.setLevel(logging.INFO)
18 |
19 |
20 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless.
21 | class FastSaver(tf.train.Saver):
22 | def save(self, sess, save_path, global_step=None, latest_filename=None,
23 | meta_graph_suffix="meta", write_meta_graph=True):
24 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
25 | meta_graph_suffix, False)
26 |
27 |
28 | def run(args, server):
29 | env = create_env(args.env_id)
30 | trainer = A3C(env, args.task, args.visualise, args.intrinsic_type, args.bptt)
31 |
32 | # Variable names that start with "local" are not saved in checkpoints.
33 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
34 | init_op = tf.variables_initializer(variables_to_save)
35 | init_all_op = tf.global_variables_initializer()
36 | saver = FastSaver(variables_to_save)
37 |
38 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
39 | logger.info('Trainable vars:')
40 | for v in var_list:
41 | logger.info(' %s %s', v.name, v.get_shape())
42 |
43 | def init_fn(ses):
44 | logger.info("Initializing all parameters.")
45 | ses.run(init_all_op)
46 |
47 | config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
48 | logdir = os.path.join(args.log_dir, 'train')
49 |
50 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
51 |
52 | logger.info("Events directory: %s_%s", logdir, args.task)
53 | sv = tf.train.Supervisor(is_chief=(args.task == 0),
54 | logdir=logdir,
55 | saver=saver,
56 | summary_op=None,
57 | init_op=init_op,
58 | init_fn=init_fn,
59 | summary_writer=summary_writer,
60 | ready_op=tf.report_uninitialized_variables(variables_to_save),
61 | global_step=trainer.global_step,
62 | save_model_secs=30,
63 | save_summaries_secs=30)
64 |
65 | video_dir = os.path.join(args.log_dir, 'train_videos_' + args.intrinsic_type)
66 | if not os.path.exists(video_dir):
67 | os.makedirs(video_dir)
68 | video_filename = video_dir + "/%s_%010d_%d.gif"
69 | print("Video saved at %s" % video_dir)
70 |
71 | num_global_steps = 300000000
72 | num_record_steps = 1000000
73 | last_record_step = 0
74 |
75 | logger.info(
76 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
77 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
78 | with sv.managed_session(server.target, config=config) as sess, sess.as_default():
79 | sess.run(trainer.meta_sync)
80 | sess.run(trainer.sync)
81 | trainer.start(sess, summary_writer)
82 | global_step = sess.run(trainer.global_step)
83 | logger.info("Starting training at step=%d", global_step)
84 |
85 | pbar = tqdm(total=num_global_steps)
86 | pbar.update(global_step)
87 |
88 | while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
89 | trainer.process(sess)
90 |
91 | new_global_step = sess.run(trainer.global_step)
92 | pbar.set_description('')
93 | pbar.update(max(1, new_global_step - global_step))
94 | global_step = new_global_step
95 |
96 | if args.task == 0 and global_step - last_record_step > num_record_steps:
97 | sess.run(trainer.meta_sync)
98 | sess.run(trainer.sync)
99 | last_record_step = global_step
100 | frames, reward, length = trainer.evaluate(sess)
101 | imageio.mimsave(video_filename % (args.env_id, global_step, reward), frames, fps=30)
102 |
103 | # Ask for all the services to stop.
104 | sv.stop()
105 | logger.info('reached %s steps. worker stopped.', global_step)
106 |
107 |
108 | def cluster_spec(num_workers, num_ps):
109 | """
110 | More tensorflow setup for data parallelism
111 | """
112 | cluster = {}
113 | port = 12222
114 |
115 | all_ps = []
116 | host = '127.0.0.1'
117 | for _ in range(num_ps):
118 | all_ps.append('{}:{}'.format(host, port))
119 | port += 1
120 | cluster['ps'] = all_ps
121 |
122 | all_workers = []
123 | for _ in range(num_workers):
124 | all_workers.append('{}:{}'.format(host, port))
125 | port += 1
126 | cluster['worker'] = all_workers
127 | return cluster
128 |
129 |
130 | def main(_):
131 | """
132 | Setting up Tensorflow for data parallel work
133 | """
134 |
135 | parser = argparse.ArgumentParser(description=None)
136 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.')
137 | parser.add_argument('--task', default=0, type=int, help='Task index')
138 | parser.add_argument('--job-name', default="worker", help='worker or ps')
139 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers')
140 | parser.add_argument('--log-dir', default="/tmp/pong", help='Log directory path')
141 | parser.add_argument('--env-id', default="PongDeterministic-v3", help='Environment id')
142 |
143 | # Add visualisation argument
144 | parser.add_argument('--visualise', action='store_true',
145 | help="Visualise the gym environment by running env.render() between each timestep")
146 |
147 | # Add model parameters
148 | parser.add_argument('--intrinsic-type', type=str, default='feature',
149 | choices=['feature', 'pixel'], help="feature or pixel")
150 | parser.add_argument('--bptt', type=int, default=100,
151 | help="BPTT")
152 |
153 | args = parser.parse_args()
154 | spec = cluster_spec(args.num_workers, 1)
155 | cluster = tf.train.ClusterSpec(spec).as_cluster_def()
156 |
157 | def shutdown(signal, frame):
158 | logger.warn('Received signal %s: exiting', signal)
159 | sys.exit(128+signal)
160 | signal.signal(signal.SIGHUP, shutdown)
161 | signal.signal(signal.SIGINT, shutdown)
162 | signal.signal(signal.SIGTERM, shutdown)
163 |
164 | if args.job_name == "worker":
165 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task,
166 | config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
167 | run(args, server)
168 | else:
169 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task,
170 | config=tf.ConfigProto(device_filters=["/job:ps"]))
171 | while True:
172 | time.sleep(1000)
173 |
174 | if __name__ == "__main__":
175 | tf.app.run()
176 |
--------------------------------------------------------------------------------