├── .gitignore
├── README.md
├── a3c.py
├── async.py
├── config
├── collect_deterministic.xml
└── collect_stochastic.xml
├── envs.py
├── maze.py
├── model.py
├── q.py
├── test.py
├── test_vpn
├── train.py
├── train_vpn
├── util.py
├── vpn.py
└── worker.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 | *.swp
91 |
92 | result/
93 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This repository implements **[NIPS 2017 Value Prediction Network (Oh et al.)](https://arxiv.org/abs/1707.03497)** in Tensorflow.
3 | ```
4 | @inproceedings{Oh2017VPN,
5 | title={Value Prediction Network},
6 | author={Junhyuk Oh and Satinder Singh and Honglak Lee},
7 | booktitle={NIPS},
8 | year={2017}
9 | }
10 | ```
11 | Our code is based on [OpenAI's A3C implemenation](https://github.com/openai/universe-starter-agent).
12 |
13 | # Dependencies
14 | * [Tensorflow](https://www.tensorflow.org/install/)
15 | * [Beutiful Soup](https://www.crummy.com/software/BeautifulSoup/bs4/doc/)
16 | * [Golang](https://golang.org/doc/install)
17 | * [six](https://pypi.python.org/pypi/six) (for py2/3 compatibility)
18 | * [tmux](https://tmux.github.io/) (the start script opens up a tmux session with multiple windows)
19 | * [htop](https://hisham.hm/htop/) (shown in one of the tmux windows)
20 | * [gym](https://pypi.python.org/pypi/gym)
21 | * gym[atari]
22 | * [universe](https://pypi.python.org/pypi/universe)
23 | * [opencv-python](https://pypi.python.org/pypi/opencv-python)
24 | * [numpy](https://pypi.python.org/pypi/numpy)
25 | * [scipy](https://pypi.python.org/pypi/scipy)
26 |
27 | # Training
28 | The following command trains a value prediction network (VPN) with plan depth of 3 on stochastic Collect domain:
29 | ```
30 | python train.py --config config/collect_deterministic.xml --branch 4,4,4 --alg VPN
31 | ```
32 | `train_vpn` script contains commands for reproducing the main result of the paper.
33 |
34 | # Notes
35 | * Tensorboard shows the performance of the epsilon-greedy policy. This is NOT the learning curve in the paper, because epsilon decreases from 1.0 to 0.05 for the first 1e6 steps. Instead, `[logdir]/eval.csv` shows the performance of the agent using greedy-policy.
36 | * Our code supports multi-gpu training. You can specify GPU IDs in `--gpu` option (e.g., `--gpu 0,1,2,3`).
37 |
--------------------------------------------------------------------------------
/a3c.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import tensorflow as tf
4 | import util
5 | from async import AsyncSolver
6 | import model
7 |
8 | class A3C(AsyncSolver):
9 | def define_network(self, name):
10 | self.args.meta_dim = 0 if self.env.meta() is None else len(self.env.meta())
11 | return eval("model." + name)(self.env.observation_space.shape,
12 | self.env.action_space.n, type='policy',
13 | gamma=self.args.gamma,
14 | dim=self.args.dim,
15 | f_num=self.args.f_num,
16 | f_pad=self.args.f_pad,
17 | f_stride=self.args.f_stride,
18 | f_size=self.args.f_size,
19 | meta_dim=self.args.meta_dim,
20 | )
21 |
22 | def process_rollout(self, rollout, gamma, lambda_=1.0):
23 | """
24 | given a rollout, compute its returns and the advantage
25 | """
26 | batch_si = np.asarray(rollout.states)
27 | batch_a = np.asarray(rollout.actions)
28 | rewards = np.asarray(rollout.rewards)
29 | time = np.asarray(rollout.time)
30 | meta = np.asarray(rollout.meta)
31 | vpred_t = np.asarray(rollout.values + [rollout.r])
32 |
33 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
34 | batch_r = util.discount(rewards_plus_v, gamma)[:-1]
35 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
36 | # this formula for the advantage comes "Generalized Advantage Estimation":
37 | # https://arxiv.org/abs/1506.02438
38 | batch_adv = util.discount(delta_t, gamma * lambda_)
39 |
40 | features = rollout.features[0]
41 | return util.Batch(si=batch_si,
42 | a=batch_a,
43 | adv=batch_adv,
44 | r=batch_r,
45 | terminal=rollout.terminal,
46 | features=features,
47 | reward=rewards,
48 | step=time,
49 | meta=meta)
50 |
51 | def init_variables(self):
52 | pi = self.local_network
53 | self.ac = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="ac")
54 | self.adv = tf.placeholder(tf.float32, [None], name="adv")
55 | self.r = tf.placeholder(tf.float32, [None], name="r")
56 |
57 | log_prob_tf = tf.nn.log_softmax(pi.logits)
58 | prob_tf = tf.nn.softmax(pi.logits)
59 |
60 | # the "policy gradients" loss: its derivative is precisely the policy gradient
61 | # notice that self.ac is a placeholder that is provided externally.
62 | # adv will contain the advantages, as calculated in process_rollout
63 | self.pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, [1]) * self.adv)
64 |
65 | # loss of value function
66 | self.vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.vf - self.r))
67 | self.entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
68 |
69 | self.bs = tf.to_float(tf.shape(pi.x)[0])
70 | self.loss = self.pi_loss + 0.5 * self.vf_loss - self.entropy * 0.01
71 |
72 | def define_summary(self):
73 | super(A3C, self).define_summary()
74 | tf.summary.scalar("model/policy_loss", self.pi_loss / self.bs)
75 | tf.summary.scalar("model/value_loss", self.vf_loss / self.bs)
76 | tf.summary.scalar("model/entropy", self.entropy / self.bs)
77 | self.summary_op = tf.summary.merge_all()
78 |
79 | def prepare_input(self, batch):
80 | feed_dict = {self.local_network.x: batch.si,
81 | self.ac: batch.a,
82 | self.adv: batch.adv,
83 | self.r: batch.r}
84 | if self.args.meta_dim > 0:
85 | feed_dict[self.local_network.meta] = batch.meta
86 | for i in range(len(self.local_network.state_in)):
87 | feed_dict[self.local_network.state_in[i]] = batch.features[i]
88 | return feed_dict
89 |
--------------------------------------------------------------------------------
/async.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import logging
3 | import numpy as np
4 | import tensorflow as tf
5 | import six.moves.queue as queue
6 | import threading
7 | import distutils.version
8 |
9 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0')
10 | logger = logging.getLogger(__name__)
11 | logger.setLevel(logging.INFO)
12 |
13 | class PartialRollout(object):
14 | """
15 | a piece of a complete rollout. We run our agent, and process its experience
16 | once it has processed enough steps.
17 | """
18 | def __init__(self):
19 | self.states = []
20 | self.actions = []
21 | self.rewards = []
22 | self.values = []
23 | self.r = 0.0
24 | self.terminal = False
25 | self.features = []
26 | self.time = []
27 | self.meta = []
28 |
29 | def add(self, state, action, reward, terminal, features,
30 | value = None, time = None, meta=None):
31 | self.states += [state]
32 | self.actions += [action]
33 | self.rewards += [reward]
34 | self.terminal = terminal
35 | self.features += [features]
36 | if value is not None:
37 | self.values += [value]
38 | if time is not None:
39 | self.time += [time]
40 | if meta is not None:
41 | self.meta += [meta]
42 |
43 | def extend(self, other):
44 | assert not self.terminal
45 | self.states.extend(other.states)
46 | self.actions.extend(other.actions)
47 | self.rewards.extend(other.rewards)
48 | self.r = other.r
49 | self.terminal = other.terminal
50 | self.features.extend(other.features)
51 | if other.values is not None:
52 | self.values.extend(other.values)
53 | if other.time is not None:
54 | self.time.extend(other.time)
55 | if other.meta is not None:
56 | self.meta.extend(other.meta)
57 |
58 | class RunnerThread(threading.Thread):
59 | """
60 | One of the key distinctions between a normal environment and a universe environment
61 | is that a universe environment is _real time_. This means that there should be a thread
62 | that would constantly interact with the environment and tell it what to do. This thread is here.
63 | """
64 | def __init__(self, solver):
65 | threading.Thread.__init__(self)
66 | self.queue = queue.Queue(5)
67 | self.solver = solver
68 | self.num_local_steps = solver.t_max
69 | self.env = solver.env
70 | self.last_features = None
71 | self.network = solver.local_network
72 | self.daemon = True
73 | self.sess = None
74 | self.summary_writer = None
75 |
76 | def start_runner(self, sess, summary_writer):
77 | self.sess = sess
78 | self.summary_writer = summary_writer
79 | self.start()
80 |
81 | def run(self):
82 | with self.sess.as_default():
83 | self._run()
84 |
85 | def _run(self):
86 | rollout_provider = env_runner(self.env, self.network, self.num_local_steps,
87 | self.summary_writer, solver=self.solver)
88 | while True:
89 | # the timeout variable exists because apparently, if one worker dies, the other workers
90 | # won't die with it, unless the timeout is set to some large number. This is an empirical
91 | # observation.
92 |
93 | self.queue.put(next(rollout_provider), timeout=600.0)
94 |
95 |
96 | def env_runner(env, network, num_local_steps, summary_writer, solver=None):
97 | """
98 | The logic of the thread runner. In brief, it constantly keeps on running
99 | the policy, and as long as the rollout exceeds a certain length, the thread
100 | runner appends the policy to the queue.
101 | """
102 | last_state = env.reset()
103 | last_features = network.get_initial_features()
104 | last_meta = env.meta()
105 | if solver.use_target_network():
106 | last_target_features = solver.target_network.get_initial_features()
107 |
108 | while True:
109 | terminal_end = False
110 | rollout = PartialRollout()
111 |
112 | for _ in range(num_local_steps):
113 | value = None
114 |
115 | # choose an action from the policy
116 | if not hasattr(solver, 'epsilon') or solver.epsilon() < np.random.uniform():
117 | fetched = network.act(last_state, last_features,
118 | meta=last_meta)
119 | if network.type == 'policy':
120 | action, value, features = fetched[0], fetched[1], fetched[2:]
121 | else:
122 | action, features = fetched[0], fetched[1:]
123 | else:
124 | # choose a random action
125 | assert network.type != 'policy'
126 | act_idx = np.random.randint(0, env.action_space.n)
127 | action = np.zeros(env.action_space.n)
128 | action[act_idx] = 1
129 | if network.is_recurrent():
130 | features = network.update_state(last_state, last_features,
131 | meta=last_meta)
132 | else:
133 | features = []
134 |
135 | # argmax to convert from one-hot
136 | state, reward, terminal, info, time = env.step(action.argmax())
137 | if hasattr(env, 'atari'):
138 | reward = np.clip(reward, -1, 1)
139 |
140 | # collect the experience
141 | rollout.add(last_state, action, reward, terminal, last_features,
142 | value = value, time = time, meta=last_meta)
143 |
144 | last_state = state
145 | last_features = features
146 | last_meta = env.meta()
147 |
148 | if info:
149 | summary = tf.Summary()
150 | for k, v in info.items():
151 | summary.value.add(tag=k, simple_value=float(v))
152 | summary_writer.add_summary(summary, network.global_step.eval())
153 | summary_writer.flush()
154 |
155 | if terminal:
156 | terminal_end = True
157 | last_state = env.reset()
158 | last_features = network.get_initial_features()
159 | last_meta = env.meta()
160 | break
161 |
162 | if not terminal_end:
163 | if solver.use_target_network():
164 | rollout.r = solver.target_network.value(last_state,
165 | last_features,
166 | meta=last_meta)
167 | else:
168 | rollout.r = network.value(last_state, last_features,
169 | meta=last_meta)
170 |
171 | # once we have enough experience, yield it, and have the ThreadRunner place it on a queue
172 | yield rollout
173 |
174 | class AsyncSolver(object):
175 | def __init__(self, env, args, env_off=None):
176 | self.env = env
177 | self.args = args
178 | self.task = args.task
179 | self.t_max = args.t_max
180 | self.ld = args.ld
181 | self.lr = args.lr
182 | self.model = args.model
183 | self.env_off = env_off
184 | self.last_global_step = 0
185 |
186 | device = 'gpu' if self.args.gpu > 0 else 'cpu'
187 | worker_device = "/job:worker/task:{}/{}:0".format(self.task, device)
188 | def _load_fn(unused_op):
189 | return 1
190 | with tf.device(tf.train.replica_device_setter(self.args.num_ps,
191 | worker_device=worker_device,
192 | ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
193 | self.args.num_ps, _load_fn))):
194 | with tf.variable_scope("global"):
195 | with tf.variable_scope("learner"):
196 | self.network = self.define_network(self.model)
197 | if self.use_target_network():
198 | with tf.variable_scope("target"):
199 | self.global_target_network = self.define_network(self.model)
200 | self.global_target_sync_step = tf.get_variable("target_sync_step", [],
201 | tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32),
202 | trainable=False)
203 | self.global_step = tf.get_variable("global_step", [], tf.int32,
204 | initializer=tf.constant_initializer(0, dtype=tf.int32),
205 | trainable=False)
206 |
207 |
208 | with tf.device(worker_device):
209 | with tf.variable_scope("local"):
210 | with tf.variable_scope("learner"):
211 | self.local_network = pi = self.define_network(self.model)
212 | pi.global_step = self.global_step
213 | if self.use_target_network():
214 | with tf.variable_scope("target"):
215 | self.target_network = self.define_network(self.model)
216 |
217 | self.init_variables()
218 |
219 | # 20 represents the number of "local steps": the number of timesteps
220 | # we run the policy before we update the parameters.
221 | # The larger local steps is, the lower is the variance in our policy gradients estimate
222 | # on the one hand; but on the other hand, we get less frequent parameter updates, which
223 | # slows down learning. In this code, we found that making local steps be much
224 | # smaller than 20 makes the algorithm more difficult to tune and to get to work.
225 | self.runner = RunnerThread(self)
226 |
227 | self.grads = tf.gradients(self.loss, pi.var_list)
228 | self.grads, _ = tf.clip_by_global_norm(self.grads, 40.0)
229 |
230 | # copy weights from the parameter server to the local model
231 | self.sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(pi.var_list,
232 | self.network.var_list)])
233 |
234 | self.grads_and_vars = list(zip(self.grads, self.network.var_list))
235 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0])
236 |
237 | self.learning_rate = tf.placeholder(tf.float32, shape=[])
238 | # each worker has a different set of adam optimizer parameters
239 | opt = tf.train.AdamOptimizer(self.learning_rate)
240 | self.train_op = tf.group(opt.apply_gradients(self.grads_and_vars), inc_step)
241 | if self.use_target_network():
242 | self.update_target_step = self.global_target_sync_step.assign(self.global_step)
243 |
244 | with tf.device(None):
245 | self.define_summary()
246 | self.summary_writer = None
247 | self.local_steps = 0
248 |
249 | def define_summary(self):
250 | tf.summary.scalar("model/lr", self.learning_rate)
251 | tf.summary.image("model/state", self.env.tf_visualize(self.local_network.x), max_outputs=10)
252 | tf.summary.scalar("gradient/grad_norm", tf.global_norm(self.grads))
253 | tf.summary.scalar("param/param_norm", tf.global_norm(self.local_network.var_list))
254 | for grad_var in self.grads_and_vars:
255 | grad = grad_var[0]
256 | var = grad_var[1]
257 | if var.name.find('/W:') >= 0 or var.name.find('/w:') >= 0:
258 | if grad is None:
259 | raise ValueError(var.name + " grads are missing")
260 | tf.summary.scalar("gradient/%s" % var.name, tf.norm(grad))
261 | tf.summary.scalar("param/%s" % var.name, tf.norm(var))
262 |
263 | self.summary_op = tf.summary.merge_all()
264 |
265 | def use_target_network(self):
266 | return False
267 |
268 | def start(self, sess, summary_writer):
269 | self.runner.start_runner(sess, summary_writer)
270 | self.summary_writer = summary_writer
271 |
272 | def pull_batch_from_queue(self):
273 | """
274 | self explanatory: take a rollout from the queue of the thread runner.
275 | """
276 | rollout = self.runner.queue.get(timeout=600.0)
277 | '''
278 | while not rollout.terminal:
279 | try:
280 | rollout.extend(self.runner.queue.get_nowait())
281 | except queue.Empty:
282 | break
283 | '''
284 | return rollout
285 |
286 | def process(self, sess):
287 | """
288 | process grabs a rollout that's been produced by the thread runner,
289 | and updates the parameters. The update is then sent to the parameter
290 | server.
291 | """
292 | sess.run(self.sync) # copy weights from shared to local
293 | rollout = self.pull_batch_from_queue()
294 | should_compute_summary = self.task == 0 and self.local_steps % 101 == 0
295 |
296 | if self.local_steps % self.args.update_freq == 0:
297 | batch = self.process_rollout(rollout, gamma=self.args.gamma, lambda_=self.ld)
298 | extra_fetches = self.extra_fetches()
299 | if should_compute_summary:
300 | fetches = [self.train_op, self.summary_op, self.global_step]
301 | else:
302 | fetches = [self.train_op, self.global_step]
303 |
304 | feed_dict = self.prepare_input(batch)
305 | feed_dict[self.learning_rate] = \
306 | self.args.lr * self.args.decay ** (self.last_global_step/float(10**6))
307 | fetched = sess.run(extra_fetches + fetches, feed_dict=feed_dict)
308 | if should_compute_summary:
309 | self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]), fetched[-1])
310 | self.write_extra_summary(rollout=rollout)
311 | self.summary_writer.flush()
312 | self.last_global_step = fetched[-1]
313 | self.handle_extra_fetches(fetched[:len(extra_fetches)])
314 |
315 | self.local_steps += 1
316 | self.post_process(sess)
317 |
318 | def extra_fetches(self):
319 | return []
320 |
321 | def handle_extra_fetches(self, fetches):
322 | return None
323 |
324 | def post_process(self, sess):
325 | return None
326 |
327 | def write_extra_summary(self, rollout=None):
328 | return None
329 |
--------------------------------------------------------------------------------
/config/collect_deterministic.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/config/collect_stochastic.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/envs.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from gym.spaces.box import Box
3 | import numpy as np
4 | import gym
5 | from gym import spaces
6 | import logging
7 | import universe
8 | from universe import vectorized
9 | from universe.wrappers import BlockingReset, GymCoreAction, EpisodeID, Unvectorize, Vectorize, Vision, Logger
10 | from universe import spaces as vnc_spaces
11 | from universe.spaces.vnc_event import keycode
12 | import time
13 | logger = logging.getLogger(__name__)
14 | logger.setLevel(logging.INFO)
15 | universe.configure_logging()
16 | import maze
17 |
18 | def create_env(env_id, client_id, remotes, **kwargs):
19 | if env_id == 'maze':
20 | return maze.MazeSMDP(**kwargs)
21 |
22 | spec = gym.spec(env_id)
23 |
24 | if spec.tags.get('flashgames', False):
25 | return create_flash_env(env_id, client_id, remotes, **kwargs)
26 | elif spec.tags.get('atari', False) and spec.tags.get('vnc', False):
27 | return create_vncatari_env(env_id, client_id, remotes, **kwargs)
28 | else:
29 | # Assume atari.
30 | assert "." not in env_id # universe environments have dots in names.
31 | return create_atari_env(env_id, **kwargs)
32 |
33 | def create_flash_env(env_id, client_id, remotes, **_):
34 | env = gym.make(env_id)
35 | env = Vision(env)
36 | env = Logger(env)
37 | env = BlockingReset(env)
38 |
39 | reg = universe.runtime_spec('flashgames').server_registry
40 | height = reg[env_id]["height"]
41 | width = reg[env_id]["width"]
42 | env = CropScreen(env, height, width, 84, 18)
43 | env = FlashRescale(env)
44 |
45 | keys = ['left', 'right', 'up', 'down', 'x']
46 | if env_id == 'flashgames.NeonRace-v0':
47 | # Better key space for this game.
48 | keys = ['left', 'right', 'up', 'left up', 'right up', 'down', 'up x']
49 | logger.info('create_flash_env(%s): keys=%s', env_id, keys)
50 |
51 | env = DiscreteToFixedKeysVNCActions(env, keys)
52 | env = EpisodeID(env)
53 | env = DiagnosticsInfo(env)
54 | env = Unvectorize(env)
55 | env.configure(fps=5.0, remotes=remotes, start_timeout=15 * 60, client_id=client_id,
56 | vnc_driver='go', vnc_kwargs={
57 | 'encoding': 'tight', 'compress_level': 0,
58 | 'fine_quality_level': 50, 'subsample_level': 3})
59 | return env
60 |
61 | def create_vncatari_env(env_id, client_id, remotes, **_):
62 | env = gym.make(env_id)
63 | env = Vision(env)
64 | env = Logger(env)
65 | env = BlockingReset(env)
66 | env = GymCoreAction(env)
67 | env = AtariRescale84x84(env)
68 | env = EpisodeID(env)
69 | env = DiagnosticsInfo(env)
70 | env = Unvectorize(env)
71 |
72 | logger.info('Connecting to remotes: %s', remotes)
73 | fps = env.metadata['video.frames_per_second']
74 | env.configure(remotes=remotes, start_timeout=15 * 60, fps=fps, client_id=client_id)
75 | env.atari = True
76 | return env
77 |
78 | def create_atari_env(env_id, **kwargs):
79 | env = gym.make(env_id)
80 | env = Vectorize(env)
81 | env = AtariRescale84x84(env)
82 | env = DiagnosticsInfo(env)
83 | env = Unvectorize(env)
84 | env.atari = True
85 | return env
86 |
87 | def DiagnosticsInfo(env, *args, **kwargs):
88 | return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs)
89 |
90 | class DiagnosticsInfoI(vectorized.Filter):
91 | def __init__(self, log_interval=503):
92 | super(DiagnosticsInfoI, self).__init__()
93 |
94 | self._episode_time = time.time()
95 | self._last_time = time.time()
96 | self._local_t = 0
97 | self._log_interval = log_interval
98 | self._episode_reward = 0
99 | self._episode_length = 0
100 | self._all_rewards = []
101 | self._num_vnc_updates = 0
102 | self._last_episode_id = -1
103 |
104 | def _after_reset(self, observation):
105 | # logger.info('Resetting environment')
106 | self._episode_reward = 0
107 | self._episode_length = 0
108 | self._all_rewards = []
109 | return observation
110 |
111 | def _after_step(self, observation, reward, done, info):
112 | to_log = {}
113 | if self._episode_length == 0:
114 | self._episode_time = time.time()
115 |
116 | self._local_t += 1
117 | if info.get("stats.vnc.updates.n") is not None:
118 | self._num_vnc_updates += info.get("stats.vnc.updates.n")
119 |
120 | if self._local_t % self._log_interval == 0:
121 | cur_time = time.time()
122 | elapsed = cur_time - self._last_time
123 | fps = self._log_interval / elapsed
124 | self._last_time = cur_time
125 | cur_episode_id = info.get('vectorized.episode_id', 0)
126 | to_log["diagnostics/fps"] = fps
127 | if self._last_episode_id == cur_episode_id:
128 | to_log["diagnostics/fps_within_episode"] = fps
129 | self._last_episode_id = cur_episode_id
130 | if info.get("stats.gauges.diagnostics.lag.action") is not None:
131 | to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0]
132 | to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1]
133 | if info.get("reward.count") is not None:
134 | to_log["diagnostics/reward_count"] = info["reward.count"]
135 | if info.get("stats.gauges.diagnostics.clock_skew") is not None:
136 | to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0]
137 | to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1]
138 | if info.get("stats.gauges.diagnostics.lag.observation") is not None:
139 | to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0]
140 | to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1]
141 |
142 | if info.get("stats.vnc.updates.n") is not None:
143 | to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"]
144 | to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed
145 | self._num_vnc_updates = 0
146 | if info.get("stats.vnc.updates.bytes") is not None:
147 | to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"]
148 | if info.get("stats.vnc.updates.pixels") is not None:
149 | to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"]
150 | if info.get("stats.vnc.updates.rectangles") is not None:
151 | to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"]
152 | if info.get("env_status.state_id") is not None:
153 | to_log["diagnostics/env_state_id"] = info["env_status.state_id"]
154 |
155 | if reward is not None:
156 | self._episode_reward += reward
157 | if observation is not None:
158 | self._episode_length += 1
159 | self._all_rewards.append(reward)
160 |
161 | if done:
162 | logger.info('Episode terminating: episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length)
163 | total_time = time.time() - self._episode_time
164 | to_log["global/episode_reward"] = self._episode_reward
165 | to_log["global/episode_length"] = self._episode_length
166 | to_log["global/episode_time"] = total_time
167 | to_log["global/reward_per_time"] = self._episode_reward / total_time
168 | self._episode_reward = 0
169 | self._episode_length = 0
170 | self._all_rewards = []
171 |
172 | return observation, reward, done, to_log
173 |
174 | def _process_frame84gray(frame):
175 | frame = cv2.resize(frame, (84, 84))
176 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
177 | frame = frame.astype(np.float32)
178 | frame *= (1.0 / 255.0)
179 | frame = np.reshape(frame, [84, 84, 1])
180 | return frame
181 |
182 | def _process_frame42(frame):
183 | frame = frame[34:34+160, :160]
184 | # Resize by half, then down to 42x42 (essentially mipmapping). If
185 | # we resize directly we lose pixels that, when mapped to 42x42,
186 | # aren't close enough to the pixel boundary.
187 | frame = cv2.resize(frame, (84, 84))
188 | frame = cv2.resize(frame, (42, 42))
189 | frame = frame.mean(2)
190 | frame = frame.astype(np.float32)
191 | frame *= (1.0 / 255.0)
192 | frame = np.reshape(frame, [42, 42, 1])
193 | return frame
194 |
195 | class AtariRescale84x84(vectorized.ObservationWrapper):
196 | def __init__(self, env=None):
197 | super(AtariRescale84x84, self).__init__(env)
198 | self.observation_space = Box(0.0, 1.0, [84, 84, 1])
199 |
200 | def _observation(self, observation_n):
201 | return [_process_frame84gray(observation) for observation in observation_n]
202 |
203 | class AtariRescale42x42(vectorized.ObservationWrapper):
204 | def __init__(self, env=None):
205 | super(AtariRescale42x42, self).__init__(env)
206 | self.observation_space = Box(0.0, 1.0, [42, 42, 1])
207 |
208 | def _observation(self, observation_n):
209 | return [_process_frame42(observation) for observation in observation_n]
210 |
211 | class FixedKeyState(object):
212 | def __init__(self, keys):
213 | self._keys = [keycode(key) for key in keys]
214 | self._down_keysyms = set()
215 |
216 | def apply_vnc_actions(self, vnc_actions):
217 | for event in vnc_actions:
218 | if isinstance(event, vnc_spaces.KeyEvent):
219 | if event.down:
220 | self._down_keysyms.add(event.key)
221 | else:
222 | self._down_keysyms.discard(event.key)
223 |
224 | def to_index(self):
225 | action_n = 0
226 | for key in self._down_keysyms:
227 | if key in self._keys:
228 | # If multiple keys are pressed, just use the first one
229 | action_n = self._keys.index(key) + 1
230 | break
231 | return action_n
232 |
233 | class DiscreteToFixedKeysVNCActions(vectorized.ActionWrapper):
234 | """
235 | Define a fixed action space. Action 0 is all keys up. Each element of keys can be a single key or a space-separated list of keys
236 |
237 | For example,
238 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right'])
239 | will have 3 actions: [none, left, right]
240 |
241 | You can define a state with more than one key down by separating with spaces. For example,
242 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right', 'space', 'left space', 'right space'])
243 | will have 6 actions: [none, left, right, space, left space, right space]
244 | """
245 | def __init__(self, env, keys):
246 | super(DiscreteToFixedKeysVNCActions, self).__init__(env)
247 |
248 | self._keys = keys
249 | self._generate_actions()
250 | self.action_space = spaces.Discrete(len(self._actions))
251 |
252 | def _generate_actions(self):
253 | self._actions = []
254 | uniq_keys = set()
255 | for key in self._keys:
256 | for cur_key in key.split(' '):
257 | uniq_keys.add(cur_key)
258 |
259 | for key in [''] + self._keys:
260 | split_keys = key.split(' ')
261 | cur_action = []
262 | for cur_key in uniq_keys:
263 | cur_action.append(vnc_spaces.KeyEvent.by_name(cur_key, down=(cur_key in split_keys)))
264 | self._actions.append(cur_action)
265 | self.key_state = FixedKeyState(uniq_keys)
266 |
267 | def _action(self, action_n):
268 | # Each action might be a length-1 np.array. Cast to int to
269 | # avoid warnings.
270 | return [self._actions[int(action)] for action in action_n]
271 |
272 | class CropScreen(vectorized.ObservationWrapper):
273 | """Crops out a [height]x[width] area starting from (top,left) """
274 | def __init__(self, env, height, width, top=0, left=0):
275 | super(CropScreen, self).__init__(env)
276 | self.height = height
277 | self.width = width
278 | self.top = top
279 | self.left = left
280 | self.observation_space = Box(0, 255, shape=(height, width, 3))
281 |
282 | def _observation(self, observation_n):
283 | return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None
284 | for ob in observation_n]
285 |
286 | def _process_frame_flash(frame):
287 | frame = cv2.resize(frame, (200, 128))
288 | frame = frame.mean(2).astype(np.float32)
289 | frame *= (1.0 / 255.0)
290 | frame = np.reshape(frame, [128, 200, 1])
291 | return frame
292 |
293 | class FlashRescale(vectorized.ObservationWrapper):
294 | def __init__(self, env=None):
295 | super(FlashRescale, self).__init__(env)
296 | self.observation_space = Box(0.0, 1.0, [128, 200, 1])
297 |
298 | def _observation(self, observation_n):
299 | return [_process_frame_flash(observation) for observation in observation_n]
300 |
--------------------------------------------------------------------------------
/maze.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import universe
4 | import gym
5 | import logging
6 | import copy
7 | from bs4 import BeautifulSoup
8 | import tensorflow as tf
9 | logger = logging.getLogger(__name__)
10 | logger.setLevel(logging.INFO)
11 | universe.configure_logging()
12 |
13 | BLOCK = 0
14 | AGENT = 1
15 | GOAL = 2
16 | DX = [0, 1, 0, -1]
17 | DY = [-1, 0, 1, 0]
18 |
19 | COLOR = [[44, 42, 60], # block
20 | [91, 255, 123], # agent
21 | [52, 152, 219], # goal
22 | ]
23 |
24 | def str2bool(v):
25 | return v.lower() in ("yes", "true", "t", "1")
26 |
27 | def generate_maze(size, holes=0):
28 | # Source: http://code.activestate.com/recipes/578356-random-maze-generator/
29 | # Random Maze Generator using Depth-first Search
30 | # http://en.wikipedia.org/wiki/Maze_generation_algorithm
31 | mx = size-2; my = size-2 # width and height of the maze
32 | maze = np.ones((my, mx))
33 | dx = [0, 1, 0, -1]; dy = [-1, 0, 1, 0] # 4 directions to move in the maze
34 | # start the maze from a random cell
35 | start_x = np.random.randint(0, mx); start_y = np.random.randint(0, my)
36 | cx, cy = 0, 0
37 | # stack element: (x, y, direction)
38 | maze[start_y][start_x] = 0; stack = [(start_x, start_y, 0)]
39 | while len(stack) > 0:
40 | (cx, cy, cd) = stack[-1]
41 | # to prevent zigzags:
42 | # if changed direction in the last move then cannot change again
43 | if len(stack) > 2:
44 | if cd != stack[-2][2]: dirRange = [cd]
45 | else: dirRange = range(4)
46 | else: dirRange = range(4)
47 |
48 | # find a new cell to add
49 | nlst = [] # list of available neighbors
50 | for i in dirRange:
51 | nx = cx + dx[i]; ny = cy + dy[i]
52 | if nx >= 0 and nx < mx and ny >= 0 and ny < my:
53 | if maze[ny][nx] == 1:
54 | ctr = 0 # of occupied neighbors must be 1
55 | for j in range(4):
56 | ex = nx + dx[j]; ey = ny + dy[j]
57 | if ex >= 0 and ex < mx and ey >= 0 and ey < my:
58 | if maze[ey][ex] == 0: ctr += 1
59 | if ctr == 1: nlst.append(i)
60 |
61 | # if 1 or more neighbors available then randomly select one and move
62 | if len(nlst) > 0:
63 | ir = nlst[np.random.randint(0, len(nlst))]
64 | cx += dx[ir]; cy += dy[ir]; maze[cy][cx] = 0
65 | stack.append((cx, cy, ir))
66 | else: stack.pop()
67 |
68 | maze_tensor = np.zeros((size, size, 3))
69 | maze_tensor[:,:,BLOCK] = 1
70 | maze_tensor[1:-1, 1:-1, BLOCK] = maze
71 | maze_tensor[start_y+1][start_x+1][AGENT] = 1
72 |
73 | while holes > 0:
74 | removable = []
75 | for y in range(0, my):
76 | for x in range(0, mx):
77 | if maze_tensor[y+1][x+1][BLOCK] == 1:
78 | if maze_tensor[y][x+1][BLOCK] == 1 and maze_tensor[y+2][x+1][BLOCK] == 1 and \
79 | maze_tensor[y+1][x][BLOCK] == 0 and maze_tensor[y+1][x+2][BLOCK] == 0:
80 | removable.append((y+1, x+1))
81 | elif maze_tensor[y][x+1][BLOCK] == 0 and maze_tensor[y+2][x+1][BLOCK] == 0 and \
82 | maze_tensor[y+1][x][BLOCK] == 1 and maze_tensor[y+1][x+2][BLOCK] == 1:
83 | removable.append((y+1, x+1))
84 |
85 | if len(removable) == 0:
86 | break
87 |
88 | idx = np.random.randint(0, len(removable))
89 | maze_tensor[removable[idx][0]][removable[idx][1]][BLOCK] = 0
90 | holes -= 1
91 |
92 | return maze_tensor, start_y+1, start_x+1
93 |
94 | def find_empty_loc(maze):
95 | size = maze.shape[0]
96 | # Randomly determine a goal position
97 | for i in range(300):
98 | y = np.random.randint(0, size-2) + 1
99 | x = np.random.randint(0, size-2) + 1
100 | if np.sum(maze[y][x]) == 0:
101 | return [y, x]
102 |
103 | raise AttributeError("Cannot find an empty location in 300 trials")
104 |
105 | def generate_maze_with_multiple_goal(size, num_goal=1, holes=3):
106 | maze, start_y, start_x = generate_maze(size, holes=holes)
107 |
108 | # Randomly determine agent position
109 | maze[start_y][start_x][AGENT] = 0
110 | agent_pos = find_empty_loc(maze)
111 | maze[agent_pos[0]][agent_pos[1]][AGENT] = 1
112 |
113 | object_pos = [[],[],[],[]]
114 | for i in range(num_goal):
115 | pos = find_empty_loc(maze)
116 | maze[pos[0]][pos[1]][GOAL] = 1
117 | object_pos[GOAL].append(pos)
118 |
119 | return maze, agent_pos, object_pos
120 |
121 | def visualize_maze(maze, img_size=320):
122 | my = maze.shape[0]
123 | mx = maze.shape[1]
124 | colors = np.array(COLOR, np.uint8)
125 | num_channel = maze.shape[2]
126 | vis_maze = np.matmul(maze, colors[:num_channel])
127 | vis_maze = vis_maze.astype(np.uint8)
128 | for i in range(vis_maze.shape[0]):
129 | for j in range(vis_maze.shape[1]):
130 | if maze[i][j].sum() == 0.0:
131 | vis_maze[i][j][:] = int(255)
132 | image = Image.fromarray(vis_maze)
133 | return image.resize((int(float(img_size) * mx / my), img_size), Image.NEAREST)
134 |
135 | def visualize_mazes(maze, img_size=320):
136 | if maze.ndim == 3:
137 | return visualize_maze(maze, img_size=img_size)
138 | elif maze.ndim == 4:
139 | n = maze.shape[0]
140 | size = maze.shape[1]
141 | dim = maze.shape[-1]
142 | concat_m = maze.transpose((1,0,2,3)).reshape((size, n * size, dim))
143 | return visualize_maze(concat_m, img_size=img_size)
144 | else:
145 | raise ValueError("maze should be 3d or 4d tensor")
146 |
147 | def to_string(maze):
148 | my = maze.shape[0]
149 | mx = maze.shape[1]
150 | str = ''
151 | for y in range(my):
152 | for x in range(mx):
153 | if maze[y][x][BLOCK] == 1:
154 | str += '#'
155 | elif maze[y][x][AGENT] == 1:
156 | str += 'o'
157 | elif maze[y][x][GOAL] == 1:
158 | str += 'x'
159 | else:
160 | str += ' '
161 | str += '\n'
162 | return str
163 |
164 | class Maze(object):
165 | def __init__(self, size=10, num_goal=1, holes=0):
166 | self.size = size
167 | self.dx = [0, 1, 0, -1]
168 | self.dy = [-1, 0, 1, 0]
169 | self.num_goal = num_goal
170 | self.holes = holes
171 | self.reset()
172 |
173 | def reset(self):
174 | self.maze, self.agent_pos, self.obj_pos = \
175 | generate_maze_with_multiple_goal(self.size, num_goal=self.num_goal,
176 | holes=self.holes)
177 |
178 | def is_reachable(self, y, x):
179 | return self.maze[y][x][BLOCK] == 0
180 |
181 | def is_branch(self, y, x):
182 | if self.maze[y][x][BLOCK] == 1:
183 | return False
184 | neighbor_count = 0
185 | for i in range(4):
186 | new_y = y + self.dy[i]
187 | new_x = x + self.dx[i]
188 | if self.maze[new_y][new_x][BLOCK] == 0:
189 | neighbor_count += 1
190 | return neighbor_count > 2
191 |
192 | def is_agent_on_branch(self):
193 | return self.is_branch(self.agent_pos[0], self.agent_pos[1])
194 |
195 | def is_end_of_corridor(self, y, x, direction):
196 | return self.maze[y + self.dy[direction]][x + self.dx[direction]][BLOCK] == 1
197 |
198 | def is_agent_on_end_of_corridor(self, direction):
199 | return self.is_end_of_corridor(self.agent_pos[0], self.agent_pos[1], direction)
200 |
201 | def move_agent(self, direction):
202 | y = self.agent_pos[0] + self.dy[direction]
203 | x = self.agent_pos[1] + self.dx[direction]
204 | if not self.is_reachable(y, x):
205 | return False
206 | self.maze[self.agent_pos[0]][self.agent_pos[1]][AGENT] = 0
207 | self.maze[y][x][AGENT] = 1
208 | self.agent_pos = [y, x]
209 | return True
210 |
211 | def is_object_reached(self, obj_idx):
212 | if self.maze.shape[2] <= obj_idx:
213 | return False
214 | return self.maze[self.agent_pos[0]][self.agent_pos[1]][obj_idx] == 1
215 |
216 | def is_empty(self, y, x):
217 | return np.sum(self.maze[y][x]) == 0
218 |
219 | def remove_object(self, y, x, obj_idx):
220 | removed = self.maze[y][x][obj_idx] == 1
221 | self.maze[y][x][obj_idx] = 0
222 | self.obj_pos[obj_idx].remove([y, x])
223 | return removed
224 |
225 | def remaining_goal(self):
226 | return self.remaining_object(GOAL)
227 |
228 | def remaining_object(self, obj_idx):
229 | return len(self.obj_pos[obj_idx])
230 |
231 | def add_object(self, y, x, obj_idx):
232 | if self.is_empty(y, x):
233 | self.maze[y][x][obj_idx] = 1
234 | self.obj_pos[obj_idx].append([y, x])
235 | else:
236 | ValueError("%d, %d is not empty" % (y, x))
237 |
238 | def move_object_random(self, prob, obj_idx):
239 | pos_copy = copy.deepcopy(self.obj_pos[obj_idx])
240 | for pos in pos_copy:
241 | if not hasattr(self, "goal_move_prob"):
242 | self.goal_move_prob = np.random.rand(1000)
243 | self.goal_move_idx = 0
244 | else:
245 | self.goal_move_idx = (self.goal_move_idx + 1) \
246 | % self.goal_move_prob.size
247 | if self.goal_move_prob[self.goal_move_idx] < prob:
248 | possible_moves = []
249 | for i in range(4):
250 | y = pos[0] + DY[i]
251 | x = pos[1] + DX[i]
252 | if self.is_empty(y, x):
253 | possible_moves.append(i)
254 | if len(possible_moves) > 0:
255 | self.move_object(pos, obj_idx,
256 | possible_moves[np.random.randint(len(possible_moves))])
257 |
258 | def move_object(self, pos, obj_idx, direction):
259 | y = pos[0] + self.dy[direction]
260 | x = pos[1] + self.dx[direction]
261 | if not self.is_reachable(y, x):
262 | return False
263 | self.remove_object(pos[0], pos[1], obj_idx)
264 | self.add_object(y, x, obj_idx)
265 | return True
266 |
267 | def observation(self, clone=True):
268 | return np.array(self.maze, copy=clone)
269 |
270 | def visualize(self):
271 | return visualize_maze(self.maze)
272 |
273 | def to_string(self):
274 | return to_string(self.maze)
275 |
276 | class MazeEnv(object):
277 | def __init__(self, config="", verbose=1):
278 | self.config = BeautifulSoup(config, "lxml")
279 | # map
280 | self.size = int(self.config.maze["size"])
281 | self.max_step = int(self.config.maze["time"])
282 | self.holes = int(self.config.maze["holes"])
283 | self.num_goal = int(self.config.object["num_goal"])
284 | # reward
285 | self.default_reward = float(self.config.reward["default"])
286 | self.goal_reward = float(self.config.reward["goal"])
287 | self.lazy_reward = float(self.config.reward["lazy"])
288 | # randomness
289 | self.prob_stop = float(self.config.random["p_stop"])
290 | self.prob_goal_move = float(self.config.random["p_goal"])
291 | # meta
292 | self.meta_remaining_time = str2bool(self.config.meta["remaining_time"]) if \
293 | self.config.meta.has_attr("remaining_time") else False
294 |
295 | # log
296 | self.log_freq = 100
297 | self.log_t = 0
298 | self.max_history = 1000
299 | self.reward_history = []
300 | self.length_history = []
301 | self.verbose = verbose
302 |
303 | self.reset()
304 | self.action_space = gym.spaces.discrete.Discrete(4)
305 | self.observation_space = gym.spaces.box.Box(0, 1, self.observation().shape)
306 |
307 |
308 | def observation(self, clone=True):
309 | return self.maze.observation(clone=clone)
310 |
311 | def reset(self, reset_episode=True, holes=None):
312 | if reset_episode:
313 | self.t = 0
314 | self.episode_reward = 0
315 | self.last_step_reward = 0.0
316 | self.terminated = False
317 |
318 | holes = self.holes if holes is None else holes
319 | self.maze = Maze(self.size, num_goal=self.num_goal,
320 | holes=holes)
321 |
322 | return self.observation()
323 |
324 | def remaining_time(self, normalized=True):
325 | return float(self.max_step - self.t) / float(self.max_step)
326 |
327 | def last_reward(self):
328 | return self.last_step_reward
329 |
330 | def meta(self):
331 | meta = []
332 | if self.meta_remaining_time:
333 | meta.append(self.remaining_time())
334 | if len(meta) == 0:
335 | return None
336 | return meta
337 |
338 | def visualize(self):
339 | return self.maze.visualize()
340 |
341 | def to_string(self):
342 | return self.maze.to_string()
343 |
344 | def step(self, act):
345 | assert self.action_space.contains(act), "invalid action: %d" % act
346 | assert not self.terminated, "episode is terminated"
347 | self.t += 1
348 |
349 | self.object_reached = False
350 | self.rand_stopped = False
351 | if self.prob_stop > 0 and np.random.rand() < self.prob_stop:
352 | reward = self.default_reward
353 | self.rand_stopped = True
354 | else:
355 | moved = self.maze.move_agent(act)
356 | reward = self.default_reward if moved else self.lazy_reward
357 |
358 | if self.maze.is_object_reached(GOAL):
359 | self.object_reached = True
360 | reward = self.goal_reward
361 | self.maze.remove_object(self.maze.agent_pos[0], self.maze.agent_pos[1], GOAL)
362 | if self.maze.remaining_goal() == 0:
363 | self.terminated = True
364 |
365 | if self.t >= self.max_step:
366 | self.terminated = True
367 |
368 | self.episode_reward += reward
369 | self.last_step_reward = reward
370 |
371 | to_log = None
372 | if self.terminated:
373 | if self.verbose > 0:
374 | logger.info('Episode terminating: episode_reward=%s episode_length=%s',
375 | self.episode_reward, self.t)
376 | self.log_episode(self.episode_reward, self.t)
377 | if self.log_t < self.log_freq:
378 | self.log_t += 1
379 | else:
380 | to_log = {}
381 | to_log["global/episode_reward"] = self.reward_mean(self.log_freq)
382 | to_log["global/episode_length"] = self.length_mean(self.log_freq)
383 | self.log_t = 0
384 | else:
385 | if self.prob_goal_move > 0:
386 | self.maze.move_object_random(self.prob_goal_move, GOAL)
387 | # print("goal_moved")
388 |
389 | return self.observation(), reward, self.terminated, to_log, 1
390 |
391 | def log_episode(self, reward, length):
392 | self.reward_history.insert(0, reward)
393 | self.length_history.insert(0, length)
394 | while len(self.reward_history) > self.max_history:
395 | self.reward_history.pop()
396 | self.length_history.pop()
397 |
398 | def reward_mean(self, num):
399 | return np.asarray(self.reward_history[:num]).mean()
400 |
401 | def length_mean(self, num):
402 | return np.asarray(self.length_history[:num]).mean()
403 |
404 | def tf_visualize(self, x):
405 | colors = np.array(COLOR, np.uint8)
406 | colors = colors.astype(np.float32) / 255.0
407 | color = tf.constant(colors)
408 | obs_dim = self.observation_space.shape[-1]
409 | x = x[:, :, :, :obs_dim]
410 | xdim = x.get_shape()
411 | x = tf.clip_by_value(x, 0, 1)
412 | bg = tf.ones((tf.shape(x)[0], int(x.shape[1]), int(x.shape[2]), 3))
413 | w = tf.minimum(tf.expand_dims(tf.reduce_sum(x, axis=xdim.ndims-1), -1), 1.0)
414 | w = tf.reshape(tf.tile(w, [1, 1, 1, 3]), tf.shape(bg))
415 | fg = tf.reshape(tf.matmul(tf.reshape(x, (-1, int(xdim[-1]))),
416 | color[:xdim[-1], :]), tf.shape(bg))
417 | return bg * (1.0 - w) + fg
418 |
419 | class MazeSMDP(MazeEnv):
420 | def __init__(self, gamma=0.99, *args, **kwargs):
421 | super(MazeSMDP, self).__init__(*args, **kwargs)
422 | self.gamma = gamma
423 | self.prob_slip = float(self.config.random["p_slip"])
424 |
425 | def step(self, act):
426 | assert self.action_space.contains(act), "invalid action: %d" % act
427 | assert not self.terminated, "episode is terminated"
428 |
429 | reward = 0
430 | steps = 0
431 | time = 0
432 | gamma = 1.0
433 | self.last_observation = self.maze.observation()
434 | while not self.terminated:
435 | _, r, _, to_log, t = super(MazeSMDP, self).step(act)
436 | reward += r * gamma
437 | steps += 1
438 | time += t
439 | gamma = gamma * self.gamma
440 | if not self.rand_stopped:
441 | if self.maze.is_agent_on_end_of_corridor(act):
442 | break
443 | if self.object_reached:
444 | break
445 | if self.maze.is_agent_on_branch():
446 | if self.prob_slip > 0 and np.random.rand() < self.prob_slip:
447 | pass
448 | else:
449 | break
450 |
451 | self.last_step_reward = reward
452 | return self.observation(), reward, self.terminated, to_log, time
453 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import tensorflow.contrib.rnn as rnn
4 | import math
5 |
6 | act_fn = tf.nn.elu
7 |
8 | def normalized_columns_initializer(std=1.0):
9 | def _initializer(shape, dtype=None, partition_info=None):
10 | out = np.random.randn(*shape).astype(np.float32)
11 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
12 | return tf.constant(out)
13 | return _initializer
14 |
15 | def he_initializer(fan_in, uniform=False, factor=2, seed=None):
16 | def _initializer(shape, dtype=None, partition_info=None):
17 | n = fan_in
18 | if uniform:
19 | # To get stddev = math.sqrt(factor / n) need to adjust for uniform.
20 | limit = math.sqrt(3.0 * factor / n)
21 | return tf.random_uniform(shape, -limit, limit, dtype, seed=seed)
22 | else:
23 | # To get stddev = math.sqrt(factor / n) need to adjust for truncated.
24 | trunc_stddev = math.sqrt(1.3 * factor / n)
25 | return tf.truncated_normal(shape, 0.0, trunc_stddev, dtype, seed=seed)
26 | return _initializer
27 |
28 | def flatten(x):
29 | if x.get_shape().ndims > 2:
30 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
31 | return x
32 |
33 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, init="he"):
34 | with tf.variable_scope(name):
35 | stride_shape = [1, stride[0], stride[1], 1]
36 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
37 |
38 | if init != "he":
39 | fan_in = np.prod(filter_shape[:3])
40 | fan_out = np.prod(filter_shape[:2]) * num_filters
41 | w_bound = np.sqrt(6. / (fan_in + fan_out))
42 | w = tf.get_variable("W", filter_shape, dtype,
43 | tf.random_uniform_initializer(-w_bound, w_bound),
44 | collections=collections)
45 | else:
46 | fan_in = np.prod(filter_shape[:3])
47 | w = tf.get_variable("W", filter_shape, dtype, he_initializer(fan_in),
48 | collections=collections)
49 | b = tf.get_variable("b", [1, 1, 1, num_filters],
50 | initializer=tf.constant_initializer(0.0),
51 | collections=collections)
52 | return tf.nn.conv2d(x, w, stride_shape, pad) + b
53 |
54 | def linear(x, size, name, initializer=None, bias_init=0):
55 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer)
56 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init))
57 | return tf.matmul(x, w) + b
58 |
59 | def categorical_sample(logits, d):
60 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1])
61 | return tf.one_hot(value, d)
62 |
63 | def transform_fc(x, a, n_actions, name, bias_init=0, pad="SAME"):
64 | if x.shape.ndims > 2:
65 | x = flatten(x)
66 | xdim = int(x.get_shape()[1])
67 | w = tf.get_variable(name + "/w", [n_actions, xdim],
68 | initializer=normalized_columns_initializer(0.1))
69 | if a is not None:
70 | # Transform only for the given action
71 | mul = x * w[tf.to_int32(tf.squeeze(tf.argmax(a, axis=1))), :]
72 | else:
73 | # Enumerate all possible actions and concatenate them
74 | transformed = []
75 | for i in range(0, n_actions):
76 | transformed.append(x * w[i, :])
77 | mul = pack(tf.concat(transformed, 1), [xdim])
78 |
79 | h = linear(mul, xdim, name + "_dec",
80 | initializer=normalized_columns_initializer(0.1), bias_init=bias_init)
81 | return act_fn(h)
82 |
83 | def transform_conv_state(x, a, n_actions, filter_size=(3, 3), pad="SAME"):
84 | # 3x3 option-conv -> 3x3 conv * 1x1 mask (with residual connection)
85 | stride_shape = [1, 1, 1, 1]
86 | dec_f_size = filter_size[0]
87 | num_filters = int(x.get_shape()[3])
88 | xdim = [int(x.get_shape()[1]), int(x.get_shape()[2]), num_filters]
89 | filter_shape = [filter_size[0], filter_size[1], num_filters, n_actions, num_filters]
90 | fan_in = np.prod(filter_shape[:3])
91 | dec_filter_shape = [dec_f_size, dec_f_size, num_filters, num_filters]
92 | w = tf.get_variable("W", filter_shape, initializer=he_initializer(fan_in))
93 | b = tf.get_variable("b", [1, 1, 1, n_actions, num_filters],
94 | initializer=tf.constant_initializer(0.0))
95 | w_dec = tf.get_variable("dec1-W", dec_filter_shape,
96 | initializer=he_initializer(fan_in))
97 | b_dec = tf.get_variable("dec1-b", [1, 1, 1, num_filters],
98 | initializer=tf.constant_initializer(0.0))
99 | w_dec2 = tf.get_variable("dec2-W", dec_filter_shape,
100 | initializer=he_initializer(fan_in))
101 | b_dec2 = tf.get_variable("dec2-b", [1, 1, 1, num_filters],
102 | initializer=tf.constant_initializer(0.0))
103 | w_gate = tf.get_variable("gate-W", [1, 1, num_filters, num_filters],
104 | initializer=he_initializer(num_filters))
105 | b_gate = tf.get_variable("gate-b", [1, 1, 1, num_filters],
106 | initializer=tf.constant_initializer(0.0))
107 | if a is not None:
108 | idx = tf.to_int32(tf.squeeze(tf.argmax(a, axis=1)))
109 | conv = tf.nn.conv2d(x, w[:, :, :, idx, :], stride_shape, pad) + b[:, :, :, idx, :]
110 | conv = act_fn(conv)
111 | else:
112 | w = tf.reshape(w, [filter_size[0], filter_size[1], num_filters, n_actions * num_filters])
113 | b = tf.reshape(b, [1, 1, 1, n_actions * num_filters])
114 | conv = act_fn(tf.nn.conv2d(x, w, stride_shape, pad) + b)
115 | conv = pack(tf.transpose(tf.reshape(conv,
116 | [-1, xdim[0], xdim[1], n_actions, num_filters]),
117 | [0, 3, 1, 2, 4]), xdim)
118 | conv = act_fn(tf.nn.conv2d(conv, w_dec, stride_shape, pad) + b_dec)
119 | gate = tf.sigmoid(tf.nn.conv2d(conv, w_gate, stride_shape, pad) + b_gate)
120 | conv = tf.nn.conv2d(conv, w_dec2, stride_shape, pad) + b_dec2
121 | if a is not None:
122 | conv = conv * gate + x
123 | else:
124 | conv = tf.transpose(tf.reshape(conv, [-1, n_actions] + xdim), [1, 0, 2, 3, 4])
125 | gate = tf.transpose(tf.reshape(gate, [-1, n_actions] + xdim), [1, 0, 2, 3, 4])
126 | conv = conv * gate + x
127 | conv = pack(tf.transpose(conv, [1, 0, 2, 3, 4]), xdim)
128 | return act_fn(conv)
129 |
130 | def transform_conv_pred(x, a, n_actions, filter_size=(3, 3), pad="SAME"):
131 | # 3x3 option-conv -> 3x3 conv
132 | stride_shape = [1, 1, 1, 1]
133 | dec_f_size = filter_size[0]
134 | num_filters = int(x.get_shape()[3])
135 | xdim = [int(x.get_shape()[1]), int(x.get_shape()[2]), num_filters]
136 | filter_shape = [filter_size[0], filter_size[1], num_filters, n_actions, num_filters]
137 | fan_in = np.prod(filter_shape[:3])
138 | fan_out = np.prod(filter_shape[:2]) * num_filters
139 | w_bound = np.sqrt(6. / (fan_in + fan_out))
140 | w = tf.get_variable("W", filter_shape,
141 | initializer=tf.random_uniform_initializer(-w_bound, w_bound))
142 | b = tf.get_variable("b", [1, 1, 1, n_actions, num_filters],
143 | initializer=tf.constant_initializer(0.0))
144 | w_dec = tf.get_variable("W-dec", [dec_f_size, dec_f_size, num_filters, num_filters],
145 | initializer=tf.random_uniform_initializer(-w_bound, w_bound))
146 | b_dec = tf.get_variable("b-dec", [1, 1, 1, num_filters],
147 | initializer=tf.constant_initializer(0.0))
148 | if a is not None:
149 | idx = tf.to_int32(tf.squeeze(tf.argmax(a, axis=1)))
150 | conv = tf.nn.conv2d(x, w[:, :, :, idx, :], stride_shape, pad) + b[:, :, :, idx, :]
151 | conv = act_fn(conv)
152 | else:
153 | w = tf.reshape(w, [filter_size[0], filter_size[1], num_filters, n_actions * num_filters])
154 | b = tf.reshape(b, [1, 1, 1, n_actions * num_filters])
155 | conv = act_fn(tf.nn.conv2d(x, w, stride_shape, pad) + b)
156 | conv = pack(tf.transpose(tf.reshape(conv,
157 | [-1, xdim[0], xdim[1], n_actions, num_filters]),
158 | [0, 3, 1, 2, 4]), xdim)
159 | conv = tf.nn.conv2d(conv, w_dec, stride_shape, pad) + b_dec
160 | return act_fn(conv)
161 |
162 | def pack(x, dim):
163 | return tf.reshape(x, [-1] + dim)
164 |
165 | def to_value(x, dim=256, initializer=None, bias_init=0):
166 | if x.shape.ndims == 2: # fc layer
167 | return linear(x, 1, "v", initializer=initializer, bias_init=bias_init)
168 | else: # conv layer
169 | x = act_fn(linear(flatten(x), dim, "v1",
170 | initializer=tf.contrib.layers.xavier_initializer(), bias_init=bias_init))
171 | return linear(x, 1, "v", initializer=initializer, bias_init=bias_init)
172 |
173 | def to_pred(x, dim=256, initializer=None, bias_init=0):
174 | return linear(flatten(x), dim, "p", initializer=initializer, bias_init=bias_init)
175 |
176 | def to_reward(x, dim=256, initializer=None, bias_init=0):
177 | x = act_fn(linear(flatten(x), dim, "r1",
178 | initializer=tf.contrib.layers.xavier_initializer(), bias_init=bias_init))
179 | return linear(x, 1, "r", initializer=initializer, bias_init=bias_init)
180 |
181 | def to_steps(x, dim=256, initializer=None, bias_init=0):
182 | x = act_fn(linear(flatten(x), dim, "t1",
183 | initializer=tf.contrib.layers.xavier_initializer(), bias_init=bias_init))
184 | return linear(x, 1, "t", initializer=initializer, bias_init=bias_init)
185 |
186 | def rollout_step(x, a, n_actions, op_trans_state, op_trans_pred,
187 | op_value, op_steps, op_reward, gamma=0.98):
188 | state = op_trans_state(x, a)
189 | p = op_trans_pred(x, a)
190 | if a is not None:
191 | v_next = op_value(state)
192 | r = op_reward(p)
193 | t = op_steps(p)
194 | else:
195 | v_next = pack(op_value(state), [n_actions])
196 | r = pack(op_reward(p), [n_actions])
197 | t = pack(op_steps(p), [n_actions])
198 |
199 | t = tf.nn.relu(t) + 1
200 | g = tf.pow(tf.constant(gamma), t)
201 | return r, g, t, v_next, state
202 |
203 | def predict_over_time(x, a, n_actions, op_rollout, prediction_step=5):
204 | time_steps = tf.shape(a)[0]
205 | xdim = x.get_shape().as_list()[1:]
206 |
207 | def _create_ta(name, dtype, size, clear=True):
208 | return tf.TensorArray(dtype=dtype,
209 | size=size, tensor_array_name=name,
210 | clear_after_read=clear)
211 |
212 | v_ta = _create_ta("output_v", x.dtype, time_steps)
213 | r_ta = _create_ta("output_r", x.dtype, time_steps)
214 | g_ta = _create_ta("output_g", x.dtype, time_steps)
215 | t_ta = _create_ta("output_t", x.dtype, time_steps)
216 | q_ta = _create_ta("output_q", x.dtype, time_steps)
217 | s_ta = _create_ta("output_s", x.dtype, time_steps)
218 |
219 | x_ta = _create_ta("input_x", x.dtype, time_steps).unstack(x)
220 | a_ta = _create_ta("input_a", x.dtype, time_steps).unstack(a)
221 |
222 | time = tf.constant(0, dtype=tf.int32)
223 | roll_step = tf.minimum(prediction_step, time_steps)
224 | state = tf.zeros([roll_step] + xdim)
225 |
226 | def _time_step(time, r_ta, g_ta, t_ta, v_ta, q_ta, s_ta, state):
227 | a_t = a_ta.read(time)
228 | a_t = tf.expand_dims(a_t, 0)
229 |
230 | # stack previously generated states with the new state through batch
231 | x_t = x_ta.read(time)
232 | x_t = tf.expand_dims(x_t, 0)
233 | state = tf.concat([x_t, tf.slice(state, [0] * (len(xdim) + 1),
234 | [roll_step-1] + xdim)], 0)
235 |
236 | r, gamma, t, v_next, state = op_rollout(state, a_t)
237 | q = r + gamma * v_next
238 | r_ta = r_ta.write(time, tf.reshape(r, [-1]))
239 | g_ta = g_ta.write(time, tf.reshape(gamma, [-1]))
240 | t_ta = t_ta.write(time, tf.reshape(t, [-1]))
241 | v_ta = v_ta.write(time, tf.reshape(v_next, [-1]))
242 | q_ta = q_ta.write(time, tf.reshape(q, [-1]))
243 | s_ta = s_ta.write(time, state)
244 | return (time+1, r_ta, g_ta, t_ta, v_ta, q_ta, s_ta, state)
245 |
246 | _, r_ta, g_ta, t_ta, v_ta, q_ta, s_ta, state = tf.while_loop(
247 | cond=lambda time, *_: time < time_steps,
248 | body=_time_step,
249 | loop_vars=(time, v_ta, r_ta, g_ta, t_ta, q_ta, s_ta, state))
250 |
251 | r = r_ta.stack()
252 | g = g_ta.stack()
253 | t = t_ta.stack()
254 | v = v_ta.stack()
255 | q = q_ta.stack()
256 | s = s_ta.stack()
257 |
258 | return r, g, t, v, q, s
259 |
260 | class Model(object):
261 | def __init__(self, ob_space, n_actions, type,
262 | gamma=0.99, prediction_step=1,
263 | dim=256,
264 | f_num=[32,32,64],
265 | f_stride=[1,1,2],
266 | f_size=[3,3,4],
267 | f_pad="SAME",
268 | branch=[4,4,4],
269 | meta_dim=0):
270 | self.n_actions = n_actions
271 | self.type = type
272 | self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
273 | self.a = tf.placeholder(tf.float32, [None, n_actions])
274 | self.meta = tf.placeholder(tf.float32, [None, meta_dim]) if meta_dim > 0 else None
275 | self.state_init = []
276 | self.state_in = []
277 | self.state_out = []
278 | self.dim = dim
279 | self.f_num = f_num
280 | self.f_stride = f_stride
281 | self.f_size = f_size
282 | self.f_pad = f_pad
283 | self.meta_dim = meta_dim
284 | self.xdim = list(ob_space)
285 | self.branch = [min(n_actions, k) for k in branch]
286 |
287 | self.s, self.state_in, self.state_out = self.build_model(self.x, self.meta)
288 | self.sdim = self.s.get_shape().as_list()[1:]
289 |
290 | # output layer
291 | if self.type == 'policy':
292 | self.logits = linear(flatten(self.s), n_actions, "action",
293 | normalized_columns_initializer(0.01))
294 | self.vf = tf.reshape(linear(flatten(self.s), 1, "value",
295 | normalized_columns_initializer(1.0)), [-1])
296 | self.sample = categorical_sample(self.logits, n_actions)[0, :]
297 | elif self.type == 'q':
298 | h = transform_conv_state(self.s, None, n_actions)
299 | self.h = linear(flatten(h), self.dim, "fc",
300 | normalized_columns_initializer(0.01))
301 | self.q = pack(linear(self.h, 1, "action",
302 | normalized_columns_initializer(0.01)), [n_actions])
303 | self.sample = tf.one_hot(tf.squeeze(tf.argmax(self.q, axis=1)), n_actions)
304 | self.qmax = tf.reduce_max(self.q, axis=[1])
305 | elif self.type == 'vpn':
306 | self.op_value = tf.make_template('v', to_value, dim=self.dim,
307 | initializer=normalized_columns_initializer(0.01))
308 | self.op_reward = tf.make_template('r', to_reward, dim=self.dim,
309 | initializer=normalized_columns_initializer(0.01))
310 | self.op_steps = tf.make_template('t', to_steps, dim=self.dim,
311 | initializer=normalized_columns_initializer(0.01))
312 | self.op_trans_state = tf.make_template('trans_state', transform_conv_state,
313 | n_actions=n_actions)
314 | self.op_trans_pred = tf.make_template('trans_pred', transform_conv_pred,
315 | n_actions=n_actions)
316 | self.op_rollout = tf.make_template('rollout', rollout_step,
317 | n_actions=n_actions,
318 | op_trans_state=self.op_trans_state,
319 | op_trans_pred=self.op_trans_pred,
320 | op_value=self.op_value,
321 | op_steps=self.op_steps,
322 | op_reward=self.op_reward,
323 | gamma=gamma)
324 |
325 | # Unconditional rollout
326 | self.r, self.gamma, self.steps, self.v_next, self.state = self.op_rollout(self.s, None)
327 | self.q = self.r + self.gamma * self.v_next
328 |
329 | # Action-conditional rollout over time for training
330 | self.r_a, self.gamma_a, self.t_a, self.v_next_a, self.q_a, self.states = \
331 | predict_over_time(self.s, self.a, n_actions, self.op_rollout,
332 | prediction_step=prediction_step)
333 |
334 | # Tree expansion/backup
335 | depth = len(self.branch)
336 | q_list = []
337 | r_list = []
338 | g_list = []
339 | v_list = []
340 | idx_list = []
341 | s_list = []
342 | s = self.s
343 |
344 | # Expansion
345 | for i in range(depth):
346 | r, gamma, _, v, s = self.op_rollout(s, None)
347 | r_list.append(tf.squeeze(r))
348 | v_list.append(tf.squeeze(v))
349 | s_list.append(s)
350 | g_list.append(tf.squeeze(gamma))
351 |
352 | b = self.branch[i]
353 | q_list.append(r_list[i] + g_list[i] * v_list[i])
354 | q_list[i] = tf.reshape(q_list[i], [-1, self.n_actions])
355 | _, idx = tf.nn.top_k(q_list[i], k=b)
356 | idx_list.append(idx)
357 |
358 | l = tf.tile(tf.expand_dims(tf.range(0, tf.shape(idx)[0]), 1), [1, b])
359 | l = tf.concat([tf.reshape(l, [-1, 1]), tf.reshape(idx, [-1, 1])], axis=1)
360 | s = tf.reshape(tf.gather_nd(
361 | tf.reshape(s, [-1, self.n_actions] + self.sdim), l), [-1] + self.sdim)
362 | r_list[i] = tf.reshape(tf.gather_nd(
363 | tf.reshape(r_list[i], [-1, self.n_actions]), l), [-1])
364 | g_list[i] = tf.reshape(tf.gather_nd(
365 | tf.reshape(g_list[i], [-1, self.n_actions]), l), [-1])
366 | v_list[i] = tf.reshape(tf.gather_nd(
367 | tf.reshape(v_list[i], [-1, self.n_actions]), l), [-1])
368 |
369 | self.q_list = q_list
370 | self.r_list = r_list
371 | self.g_list = g_list
372 | self.v_list = v_list
373 | self.s_list = s_list
374 | self.idx_list = idx_list
375 |
376 | # Backup
377 | v_plan = [None] * depth
378 | q_plan = [None] * depth
379 |
380 | v_plan[-1] = v_list[-1]
381 | for i in reversed(range(0, depth)):
382 | q_plan[i] = r_list[i] + g_list[i] * v_plan[i]
383 | if i > 0:
384 | q_max = tf.reduce_max(tf.reshape(q_plan[i], [-1, self.branch[i]]), axis=1)
385 | n = float(depth - i)
386 | v_plan[i-1] = (v_list[i-1] + q_max * n) / (n + 1)
387 |
388 | idx = tf.squeeze(idx_list[0])
389 | self.q_deep = tf.squeeze(q_plan[0])
390 | self.q_plan = tf.sparse_to_dense(idx, [self.n_actions], self.q_deep,
391 | default_value=-100, validate_indices=False)
392 |
393 | self.x_off = tf.placeholder(tf.float32, [None] + list(ob_space))
394 | self.a_off = tf.placeholder(tf.float32, [None, n_actions])
395 | self.meta_off = tf.placeholder(tf.float32, [None, self.meta_dim]) \
396 | if self.meta_dim > 0 else None
397 | tf.get_variable_scope().reuse_variables()
398 | self.s_off, self.state_in_off, self.state_out_off = \
399 | self.build_model(self.x_off, self.meta_off)
400 |
401 | # Action-conditional rollout over time for training
402 | self.r_off, self.gamma_off, self.t_off, self.v_next_off, _, _ = \
403 | predict_over_time(self.s_off, self.a_off, n_actions, self.op_rollout,
404 | prediction_step=prediction_step)
405 |
406 | else:
407 | raise ValueError('Invalid model type %s' % (self.type))
408 |
409 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
410 | tf.get_variable_scope().name)
411 |
412 | self.num_param = 0
413 | for v in self.var_list:
414 | self.num_param += v.get_shape().num_elements()
415 |
416 | def is_recurrent(self):
417 | return self.state_in is not None and len(self.state_in) > 0
418 |
419 | def get_initial_features(self):
420 | return self.state_init
421 |
422 | def act(self, ob, state_in=[], meta=None):
423 | sess = tf.get_default_session()
424 | feed_dict = {self.x: [ob]}
425 | for i in range(len(state_in)):
426 | feed_dict[self.state_in[i]] = state_in[i]
427 | if self.meta_dim > 0:
428 | feed_dict[self.meta] = [meta]
429 |
430 | if self.type == 'policy':
431 | return sess.run([self.sample, self.vf] + self.state_out, feed_dict)
432 | elif self.type == 'q':
433 | return sess.run([self.sample] + self.state_out, feed_dict)
434 | elif self.type == 'vpn':
435 | out = sess.run([self.q_plan] + self.state_out, feed_dict)
436 | q = out[0]
437 | state_out = out[1:]
438 | act = np.zeros_like(q)
439 | act[q.argmax()] = 1
440 | return [act] + state_out
441 |
442 | def update_state(self, ob, state_in=[], meta=None):
443 | sess = tf.get_default_session()
444 | feed_dict = {self.x: [ob]}
445 | for i in range(len(state_in)):
446 | feed_dict[self.state_in[i]] = state_in[i]
447 | if self.meta_dim > 0:
448 | feed_dict[self.meta] = [meta]
449 | return sess.run(self.state_out, feed_dict)
450 |
451 | def value(self, ob, state_in=[], meta=None):
452 | sess = tf.get_default_session()
453 | feed_dict = {self.x: [ob]}
454 | for i in range(len(state_in)):
455 | feed_dict[self.state_in[i]] = state_in[i]
456 | if self.meta_dim > 0:
457 | feed_dict[self.meta] = [meta]
458 |
459 | if self.type == 'policy':
460 | return sess.run(self.vf, feed_dict)[0]
461 | elif self.type == 'q':
462 | return sess.run(self.qmax, feed_dict)[0]
463 | elif self.type == 'vpn':
464 | q = sess.run(self.q_plan, feed_dict)
465 | return q.max()
466 |
467 | class CNN(Model):
468 | def __init__(self, *args, **kwargs):
469 | super(CNN, self).__init__(*args, **kwargs)
470 |
471 | def build_model(self, x, meta=None):
472 | for i in range(len(self.f_num)):
473 | x = act_fn(conv2d(x, self.f_num[i], "l{}".format(i+1),
474 | [self.f_size[i], self.f_size[i]],
475 | [self.f_stride[i], self.f_stride[i]], pad=self.f_pad,
476 | init="he"))
477 | self.conv = x
478 | if meta is not None:
479 | space_dim = x.get_shape().as_list()[1:3]
480 | meta_dim = meta.get_shape().as_list()[-1]
481 | t = tf.reshape(tf.tile(meta,
482 | [1, np.prod(space_dim)]), [-1] + space_dim + [meta_dim])
483 | x = tf.concat([t, x], axis=3)
484 |
485 | return x, [], []
486 |
487 | class LSTM(Model):
488 | def __init__(self, *args, **kwargs):
489 | super(LSTM, self).__init__(*args, **kwargs)
490 |
491 | def build_model(self, x):
492 | for i in range(len(self.f_num)):
493 | x = act_fn(conv2d(x, self.f_num[i], "l{}".format(i+1),
494 | [self.f_size[i], self.f_size[i]],
495 | [self.f_stride[i], self.f_stride[i]], pad=self.f_pad))
496 | self.conv = x
497 | x = act_fn(linear(flatten(x), 256, "l{}".format(3),
498 | normalized_columns_initializer(0.01)))
499 |
500 | # introduce a "fake" batch dimension of 1 after flatten
501 | # so that we can do LSTM over time dim
502 | x = tf.expand_dims(x, [0])
503 |
504 | size = 256
505 | lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
506 | self.state_size = lstm.state_size
507 | step_size = tf.shape(self.x)[:1]
508 |
509 | c_init = np.zeros((1, lstm.state_size.c), np.float32)
510 | h_init = np.zeros((1, lstm.state_size.h), np.float32)
511 | self.state_init = [c_init, h_init]
512 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
513 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
514 |
515 | state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
516 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
517 | lstm, x, initial_state=state_in, sequence_length=step_size,
518 | time_major=False)
519 | lstm_c, lstm_h = lstm_state
520 | state_out = [lstm_c[:1, :], lstm_h[:1, :]]
521 |
522 | x = tf.reshape(lstm_outputs, [-1, size])
523 | return x, state_in, state_out
524 |
--------------------------------------------------------------------------------
/q.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import tensorflow as tf
4 | import model
5 | import util
6 | from async import AsyncSolver
7 | import logging
8 | logger = logging.getLogger(__name__)
9 | logger.setLevel(logging.INFO)
10 |
11 | class Q(AsyncSolver):
12 | def define_network(self, name):
13 | self.args.meta_dim = 0 if self.env.meta() is None else len(self.env.meta())
14 | return eval("model." + name)(self.env.observation_space.shape,
15 | self.env.action_space.n, type='q',
16 | gamma=self.args.gamma,
17 | dim=self.args.dim,
18 | f_num=self.args.f_num,
19 | f_pad=self.args.f_pad,
20 | f_stride=self.args.f_stride,
21 | f_size=self.args.f_size,
22 | meta_dim=self.args.meta_dim,
23 | )
24 |
25 | def use_target_network(self):
26 | return True
27 |
28 | def process_rollout(self, rollout, gamma, lambda_=1.0):
29 | """
30 | given a rollout, compute its returns
31 | """
32 | batch_si = np.asarray(rollout.states)
33 | batch_a = np.asarray(rollout.actions)
34 | rewards = np.asarray(rollout.rewards)
35 | time = np.asarray(rollout.time)
36 | meta = np.asarray(rollout.meta)
37 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
38 | batch_r = util.discount(rewards_plus_v, gamma, time)[:-1]
39 | features = rollout.features[0]
40 |
41 | return util.Batch(si=batch_si,
42 | a=batch_a,
43 | adv=None,
44 | r=batch_r,
45 | terminal=rollout.terminal,
46 | features=features,
47 | reward=rewards,
48 | step=time,
49 | meta=meta,
50 | )
51 |
52 | def init_variables(self):
53 | pi = self.local_network
54 |
55 | # target network is synchronized after every 10,000 steps
56 | self.local_to_target = tf.group(*[v1.assign(v2) for v1, v2 in
57 | zip(self.global_target_network.var_list, pi.var_list)])
58 | self.target_sync = tf.group(*[v1.assign(v2) for v1, v2 in
59 | zip(self.target_network.var_list, self.global_target_network.var_list)])
60 | self.sync_count = 0
61 | self.target_sync_step = 0
62 |
63 | # epsilon
64 | self.eps = [1.0]
65 | self.eps_start = [1.0]
66 | self.eps_end = [self.args.eps]
67 | self.eps_prob = [1]
68 | self.anneal_step = self.args.eps_step
69 |
70 | # batch size
71 | self.bs = tf.to_float(tf.shape(pi.x)[0])
72 |
73 | # loss function
74 | self.define_loss()
75 |
76 | def define_loss(self):
77 | pi = self.local_network
78 |
79 | # loss function
80 | self.ac = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="ac")
81 | self.r = tf.placeholder(tf.float32, [None], name="r") # target
82 |
83 | self.q_val = tf.reduce_sum(pi.q * self.ac, [1])
84 | self.delta = self.q_val - self.r
85 | # clipping gradient to [-1, 1] amounts to using Huber loss
86 | self.q_loss = tf.reduce_sum(tf.where(tf.abs(self.delta) < 1,
87 | 0.5 * tf.square(self.delta),
88 | tf.abs(self.delta) - 0.5))
89 |
90 | self.loss = self.q_loss
91 |
92 | def define_summary(self):
93 | super(Q, self).define_summary()
94 | tf.summary.scalar("loss/loss", self.loss / self.bs)
95 | if hasattr(self, "q_loss"):
96 | tf.summary.scalar("loss/q_loss", self.q_loss / self.bs)
97 | tf.summary.scalar("param/target_param_norm",
98 | tf.global_norm(self.target_network.var_list))
99 | self.summary_op = tf.summary.merge_all()
100 |
101 | def start(self, sess, summary_writer):
102 | sess.run(self.sync) # copy weights from shared to local
103 | if self.task == 0:
104 | sess.run(self.local_to_target) # copy weights from local to shared target
105 | sess.run(self.target_sync) # copy weights from global target to local target
106 | super(Q, self).start(sess, summary_writer)
107 |
108 | def prepare_input(self, batch):
109 | feed_dict = {self.local_network.x: batch.si,
110 | self.ac: batch.a,
111 | self.r: batch.r}
112 | if self.args.meta_dim > 0:
113 | feed_dict[self.local_network.meta] = batch.meta
114 | for i in range(len(self.local_network.state_in)):
115 | feed_dict[self.local_network.state_in[i]] = batch.features[i]
116 | return feed_dict
117 |
118 | def post_process(self, sess):
119 | if self.task == 0:
120 | global_step = self.last_global_step
121 | if int(global_step / self.args.sync) > self.sync_count:
122 | # copy weights from local to shared target
123 | self.sync_count = int(global_step / self.args.sync)
124 | sess.run([self.local_to_target, self.target_sync, self.update_target_step])
125 | logger.info("[Step: %d] Target network is synchronized", global_step)
126 | else:
127 | target_step = self.global_target_sync_step.eval()
128 | if target_step != self.target_sync_step:
129 | self.target_sync_step = target_step
130 | sess.run(self.target_sync)
131 | logger.info("[Step: %d] Target network is synchronized", target_step)
132 |
133 | for i in range(len(self.eps)):
134 | self.eps[i] = self.eps_start[i]
135 | self.eps[i] -= self.last_global_step * (self.eps_start[i] - self.eps_end[i])\
136 | / self.anneal_step
137 | self.eps[i] = max(self.eps[i], self.eps_end[i])
138 |
139 | def epsilon(self):
140 | return np.random.choice(self.eps, p=self.eps_prob)
141 |
142 | def write_extra_summary(self, rollout=None):
143 | summary = tf.Summary()
144 | summary.value.add(tag='model/epsilon', simple_value=float(
145 | np.sum(np.array(self.eps) * np.array(self.eps_prob))))
146 | summary.value.add(tag='model/rollout_r', simple_value=float(rollout.r))
147 | self.summary_writer.add_summary(summary, self.last_global_step)
148 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import go_vncdriver
3 | import tensorflow as tf
4 | from envs import create_env
5 | import subprocess as sp
6 | import util
7 | import model
8 | import numpy as np
9 | from worker import new_env
10 |
11 | parser = argparse.ArgumentParser(description="Run commands")
12 | parser.add_argument('-gpu', '--gpu', default=0, type=int, help='Number of GPUs')
13 | parser.add_argument('-r', '--remotes', default=None,
14 | help='The address of pre-existing VNC servers and rewarders to use'
15 | '(e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).')
16 | parser.add_argument('-e', '--env-id', type=str, default="maze",
17 | help="Environment id")
18 | parser.add_argument('-a', '--alg', type=str, default="VPN", help="Algorithm: [A3C | Q | VPN]")
19 | parser.add_argument('-mo', '--model', type=str, default="CNN", help="Name of model: [CNN | LSTM]")
20 | parser.add_argument('-ck', '--checkpoint', type=str, default="", help="Path of the checkpoint")
21 | parser.add_argument('-n', '--n-play', type=int, default=1000, help="Num of play")
22 | parser.add_argument('--eps', type=float, default=0.0, help="Epsilon-greedy")
23 | parser.add_argument('--config', type=str, default="", help="config xml file for environment")
24 | parser.add_argument('--seed', type=int, default=0, help="Random seed")
25 |
26 | # Hyperparameters
27 | parser.add_argument('-g', '--gamma', type=float, default=0.98, help="Discount factor")
28 | parser.add_argument('--dim', type=int, default=64, help="Number of final hidden units")
29 | parser.add_argument('--f-num', type=str, default='32,32,64', help="num of conv filters")
30 | parser.add_argument('--f-stride', type=str, default='1,1,2', help="stride of conv filters")
31 | parser.add_argument('--f-size', type=str, default='3,3,4', help="size of conv filters")
32 | parser.add_argument('--f-pad', type=str, default='SAME', help="padding of conv filters")
33 |
34 | # VPN parameters
35 | parser.add_argument('--branch', type=str, default="4,4,4", help="branching factor")
36 |
37 | def evaluate(env, agent, num_play=3000, eps=0.0):
38 | env.max_history = num_play
39 | for iter in range(0, num_play):
40 | last_state = env.reset()
41 | last_features = agent.get_initial_features()
42 | last_meta = env.meta()
43 | while True:
44 | # import pdb; pdb.set_trace()
45 | if eps == 0.0 or np.random.rand() > eps:
46 | fetched = agent.act(last_state, last_features,
47 | meta=last_meta)
48 | if agent.type == 'policy':
49 | action, features = fetched[0], fetched[2:]
50 | else:
51 | action, features = fetched[0], fetched[1:]
52 | else:
53 | act_idx = np.random.randint(0, env.action_space.n)
54 | action = np.zeros(env.action_space.n)
55 | action[act_idx] = 1
56 | features = []
57 |
58 | state, reward, terminal, info, _ = env.step(action.argmax())
59 | last_state = state
60 | last_features = features
61 | last_meta = env.meta()
62 | if terminal:
63 | break
64 |
65 | return env.reward_mean(num_play)
66 |
67 | def run():
68 | args = parser.parse_args()
69 | args.task = 0
70 | args.f_num = util.parse_to_num(args.f_num)
71 | args.f_stride = util.parse_to_num(args.f_stride)
72 | args.f_size = util.parse_to_num(args.f_size)
73 | args.branch = util.parse_to_num(args.branch)
74 |
75 | env = new_env(args)
76 | args.meta_dim = 0 if env.meta() is None else len(env.meta())
77 | device = '/gpu:0' if args.gpu > 0 else '/cpu:0'
78 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
79 | config = tf.ConfigProto(device_filters=device,
80 | gpu_options=gpu_options,
81 | allow_soft_placement=True)
82 | with tf.Session(config=config) as sess:
83 | if args.alg == 'A3C':
84 | model_type = 'policy'
85 | elif args.alg == 'Q':
86 | model_type = 'q'
87 | elif args.alg == 'VPN':
88 | model_type = 'vpn'
89 | else:
90 | raise ValueError('Invalid algorithm: ' + args.alg)
91 | with tf.device(device):
92 | with tf.variable_scope("local/learner"):
93 | agent = eval("model." + args.model)(env.observation_space.shape,
94 | env.action_space.n, type=model_type,
95 | gamma=args.gamma,
96 | dim=args.dim,
97 | f_num=args.f_num,
98 | f_stride=args.f_stride,
99 | f_size=args.f_size,
100 | f_pad=args.f_pad,
101 | branch=args.branch,
102 | meta_dim=args.meta_dim)
103 | print("Num parameters: %d" % agent.num_param)
104 |
105 | saver = tf.train.Saver()
106 | saver.restore(sess, args.checkpoint)
107 | np.random.seed(args.seed)
108 | reward = evaluate(env, agent, args.n_play, eps=args.eps)
109 | print("Reward: %.2f" % (reward))
110 |
111 | if __name__ == "__main__":
112 | run()
113 |
--------------------------------------------------------------------------------
/test_vpn:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | if [ -z "$1" ]
3 | then echo "[checkpoint]"; exit 0
4 | fi
5 |
6 | args="--config config/collect_deterministic.xml --branch 4,4,4,1,1 --checkpoint $1 --gpu 1 ${@:2}"
7 | echo $args
8 | CUDA_VISIBLE_DEVICES=$3 python test.py $args
9 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from six.moves import shlex_quote
5 | import util
6 |
7 | parser = argparse.ArgumentParser(description="Run commands")
8 | parser.add_argument('-gpu', '--gpu', default="", type=str, help='GPU Ids')
9 | parser.add_argument('-w', '--num-workers', type=int, default=16, help="Number of workers")
10 | parser.add_argument('-ps', '--num-ps', type=int, default=4, help="Number of parameter servers")
11 | parser.add_argument('-r', '--remotes', default=None,
12 | help='The address of pre-existing VNC servers and rewarders to use'
13 | '(e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).')
14 | parser.add_argument('-e', '--env-id', type=str, default="maze",
15 | help="Environment id")
16 | parser.add_argument('-l', '--log', type=str, default="/tmp/vpn",
17 | help="Log directory path")
18 | parser.add_argument('-d', '--dry-run', action='store_true',
19 | help="Print out commands rather than executing them")
20 | parser.add_argument('-m', '--mode', type=str, default='tmux',
21 | help="tmux: run workers in a tmux session. "
22 | "nohup: run workers with nohup. "
23 | "child: run workers as child processes")
24 | parser.add_argument('-a', '--alg', choices=['A3C', 'Q', 'VPN'], default="VPN")
25 | parser.add_argument('-mo', '--model', type=str, default="CNN", help="Name of model: [CNN | LSTM]")
26 | parser.add_argument('-ms', '--max-step', type=int, default=int(15e6), help="Max global step")
27 | parser.add_argument('--config', type=str, default="config/collect_deterministic.xml",
28 | help="config xml file for environment")
29 | parser.add_argument('--seed', type=int, default=0, help="Random seed")
30 | parser.add_argument('--eval-freq', type=int, default=250000, help="Evaluation frequency")
31 | parser.add_argument('--eval-num', type=int, default=2000, help="Evaluation frequency")
32 |
33 | # Hyperparameters
34 | parser.add_argument('-n', '--t-max', type=int, default=10, help="Number of unrolling steps")
35 | parser.add_argument('-g', '--gamma', type=float, default=0.98, help="Discount factor")
36 | parser.add_argument('-lr', '--lr', type=float, default=1e-4, help="Learning rate")
37 | parser.add_argument('--decay', type=float, default=0.95, help="Learning rate")
38 | parser.add_argument('--dim', type=int, default=64, help="Number of final hidden units")
39 | parser.add_argument('--f-num', type=str, default='32,32,64', help="num of conv filters")
40 | parser.add_argument('--f-stride', type=str, default='1,1,2', help="stride of conv filters")
41 | parser.add_argument('--f-size', type=str, default='3,3,4', help="size of conv filters")
42 | parser.add_argument('--f-pad', type=str, default='SAME', help="padding of conv filters")
43 | parser.add_argument('--h-dim', type=str, default='', help="num of hidden units")
44 |
45 | # Q-Learning parameters
46 | parser.add_argument('-s', '--sync', type=int, default=10000,
47 | help="Target network synchronization frequency")
48 | parser.add_argument('-f', '--update-freq', type=int, default=1,
49 | help="Parameter update frequency")
50 | parser.add_argument('--eps-step', type=int, default=int(1e6),
51 | help="Num of local steps for epsilon scheduling")
52 | parser.add_argument('--eps', type=float, default=0.05, help="Final epsilon")
53 | parser.add_argument('--eps-eval', type=float, default=0.0, help="Epsilon for evaluation")
54 |
55 | # VPN parameters
56 | parser.add_argument('--prediction-step', type=int, default=3, help="number of prediction steps")
57 | parser.add_argument('--branch', type=str, default="4,4,4", help="branching factor")
58 | parser.add_argument('--buf', type=int, default=10**6, help="num of steps for random buffer")
59 |
60 | def new_cmd(session, name, cmd, mode, logdir, shell):
61 | if isinstance(cmd, (list, tuple)):
62 | cmd = " ".join(shlex_quote(str(v)) for v in cmd)
63 | if mode == 'tmux':
64 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd))
65 | elif mode == 'child':
66 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir)
67 | elif mode == 'nohup':
68 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir)
69 |
70 |
71 | def create_commands(session, args, shell='bash'):
72 | # for launching the TF workers and for launching tensorboard
73 | base_cmd = [
74 | sys.executable, 'worker.py',
75 | '--log', args.log, '--env-id', args.env_id,
76 | '--num-workers', str(args.num_workers),
77 | '--num-ps', str(args.num_ps),
78 | '--alg', args.alg,
79 | '--model', args.model,
80 | '--max-step', args.max_step,
81 | '--t-max', args.t_max,
82 | '--eps-step', args.eps_step,
83 | '--eps', args.eps,
84 | '--eps-eval', args.eps_eval,
85 | '--gamma', args.gamma,
86 | '--lr', args.lr,
87 | '--decay', args.decay,
88 | '--sync', args.sync,
89 | '--update-freq', args.update_freq,
90 | '--eval-freq', args.eval_freq,
91 | '--eval-num', args.eval_num,
92 | '--prediction-step', args.prediction_step,
93 | '--dim', args.dim,
94 | '--f-num', args.f_num,
95 | '--f-pad', args.f_pad,
96 | '--f-stride', args.f_stride,
97 | '--f-size', args.f_size,
98 | '--h-dim', args.h_dim,
99 | '--branch', args.branch,
100 | '--config', args.config,
101 | '--buf', args.buf,
102 | ]
103 |
104 | if len(args.gpu) > 0:
105 | base_cmd += ['--gpu', 1]
106 |
107 | if args.remotes is None:
108 | args.remotes = ["1"] * args.num_workers
109 | else:
110 | args.remotes = args.remotes.split(',')
111 | assert len(args.remotes) == args.num_workers
112 |
113 | cmds_map = []
114 | for i in range(args.num_ps):
115 | prefix = ['CUDA_VISIBLE_DEVICES=']
116 | cmds_map += [new_cmd(session, "ps-%d" % i, prefix + base_cmd + ["--job-name", "ps",
117 | "--task", str(i)], args.mode, args.log, shell)]
118 |
119 | for i in range(args.num_workers):
120 | prefix = []
121 | if len(args.gpu) > 0:
122 | prefix = ['CUDA_VISIBLE_DEVICES=%d' % args.gpu[(i % len(args.gpu))]]
123 | else:
124 | prefix = ['CUDA_VISIBLE_DEVICES=']
125 | cmds_map += [new_cmd(session,
126 | "w-%d" % i,
127 | prefix + base_cmd + ["--job-name", "worker", "--task", str(i),
128 | "--remotes", args.remotes[i]], args.mode, args.log, shell)]
129 |
130 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", args.log,
131 | "--port", "12345"], args.mode, args.log, shell)]
132 | cmds_map += [new_cmd(session, "test", prefix + base_cmd + ["--job-name", "test",
133 | "--task", str(args.num_workers)], args.mode, args.log, shell)]
134 | windows = [v[0] for v in cmds_map]
135 |
136 | notes = []
137 | cmds = [
138 | "mkdir -p {}".format(args.log),
139 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), args.log),
140 | ]
141 | if args.mode == 'nohup' or args.mode == 'child':
142 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(args.log)]
143 | notes += ["Run `source {}/kill.sh` to kill the job".format(args.log)]
144 | if args.mode == 'tmux':
145 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)]
146 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)]
147 | else:
148 | notes += ["Use `tail -f {}/*.out` to watch process output".format(args.log)]
149 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"]
150 |
151 | if args.mode == 'tmux':
152 | cmds += [
153 | "tmux kill-session -t {}".format(session),
154 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell)
155 | ]
156 | for w in windows[1:]:
157 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)]
158 | cmds += ["sleep 1"]
159 | for window, cmd in cmds_map:
160 | cmds += [cmd]
161 |
162 | return cmds, notes
163 |
164 |
165 | def run():
166 | args = parser.parse_args()
167 | args.gpu = util.parse_to_num(args.gpu)
168 | cmds, notes = create_commands("e", args)
169 | if args.dry_run:
170 | print("Dry-run mode due to -d flag, otherwise the following commands would be executed:")
171 | else:
172 | print("Executing the following commands:")
173 | print("\n".join(cmds))
174 | print("")
175 | if not args.dry_run:
176 | if args.mode == "tmux":
177 | os.environ["TMUX"] = ""
178 | path = os.path.join(os.getcwd(), args.log)
179 | if os.path.exists(path):
180 | key = raw_input("%s exists. Do you want to delete it? (y/n): " % path)
181 | if key != 'n':
182 | os.system("rm -rf %s" % path)
183 | os.system("\n".join(cmds))
184 | print('\n'.join(notes))
185 | else:
186 | os.system("\n".join(cmds))
187 | print('\n'.join(notes))
188 |
189 |
190 | if __name__ == "__main__":
191 | run()
192 |
--------------------------------------------------------------------------------
/train_vpn:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Deterministic Collect
4 | # python train.py --config config/collect_deterministic.xml --prediction-step 1 --branch 4 --decay 0.9 --alg VPN --log result/vpn1 ${@:1}
5 | # python train.py --config config/collect_deterministic.xml --prediction-step 2 --branch 4,4 --decay 0.9 --alg VPN --log result/vpn2 ${@:1}
6 | # python train.py --config config/collect_deterministic.xml --prediction-step 3 --branch 4,4,4 --decay 0.9 --alg VPN --log result/vpn3 ${@:1}
7 | # python train.py --config config/collect_deterministic.xml --prediction-step 5 --branch 4,4,4,1,1 --decay 0.8 --alg VPN --log result/vpn5 ${@:1}
8 |
9 | # Stochastic Collect
10 | # python train.py --config config/collect_stochastic.xml --prediction-step 1 --branch 4 --decay 0.9 --alg VPN --log result/vpn1 ${@:1}
11 | # python train.py --config config/collect_stochastic.xml --prediction-step 2 --branch 4,4 --decay 0.9 --alg VPN --log result/vpn2 ${@:1}
12 | # python train.py --config config/collect_stochastic.xml --prediction-step 3 --branch 4,4,4 --decay 0.9 --alg VPN --log result/vpn3 ${@:1}
13 | # python train.py --config config/collect_stochastic.xml --prediction-step 5 --branch 4,4,4,1,1 --decay 0.8 --alg VPN --log result/vpn5 ${@:1}
14 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import scipy.signal
2 | import numpy as np
3 | import tensorflow as tf
4 | from collections import namedtuple
5 |
6 | def discount(x, gamma, time=None):
7 | if time is not None and time.size > 0:
8 | y = np.array(x, copy=True)
9 | for i in reversed(range(y.size-1)):
10 | y[i] += (gamma ** time[i]) * y[i+1]
11 | return y
12 | else:
13 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
14 |
15 | Batch = namedtuple("Batch", ["si", "a", "adv", "r",
16 | "terminal", "features", "reward", "step", "meta"])
17 |
18 | def huber_loss(delta, sum=True):
19 | if sum:
20 | return tf.reduce_sum(tf.where(tf.abs(delta) < 1,
21 | 0.5 * tf.square(delta),
22 | tf.abs(delta) - 0.5))
23 | else:
24 | return tf.where(tf.abs(delta) < 1,
25 | 0.5 * tf.square(delta),
26 | tf.abs(delta) - 0.5)
27 |
28 | def lower_triangular(x):
29 | return tf.matrix_band_part(x, -1, 0)
30 |
31 | def to_bool(x):
32 | return x == 1
33 |
34 | def parse_to_num(s):
35 | l = s.split(',')
36 | for i in range(0, len(l)):
37 | try:
38 | l[i] = int(l[i])
39 | except ValueError:
40 | l = []
41 | break
42 | return l
43 |
--------------------------------------------------------------------------------
/vpn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import tensorflow as tf
4 | import model # NOQA
5 | import util
6 | from q import Q
7 | import logging
8 | logger = logging.getLogger(__name__)
9 | logger.setLevel(logging.INFO)
10 |
11 | class RolloutMemory(object):
12 | def __init__(self, max_size, sampling='rand'):
13 | self.max_size = max_size
14 | self.s = []
15 | self.a = []
16 | self.r = []
17 | self.t = []
18 | self.r_t = []
19 | self.term = []
20 | self.sampling = sampling
21 | self.sample_idx = 0
22 |
23 | def add(self, s, a, r, t, r_t, term):
24 | self.s.append(s)
25 | self.a.append(a)
26 | self.r.append(r)
27 | self.t.append(t)
28 | self.r_t.append(r_t)
29 | self.term.append(term)
30 |
31 | def size(self):
32 | return len(self.s)
33 |
34 | def is_full(self):
35 | return len(self.s) >= self.max_size
36 |
37 | def sample(self, length):
38 | size = len(self.s)
39 | is_initial_state = False
40 | if self.sampling == 'rand':
41 | idx = np.random.randint(0, size-1)
42 | if self.term[idx]:
43 | return self.sample(length)
44 | for end_idx in range(idx, idx + length):
45 | if self.term[end_idx] or end_idx == size-1:
46 | break
47 | is_initial_state = (idx > 0 and self.term[idx-1]) or idx == 0
48 | else:
49 | idx = self.sample_idx
50 | if self.term[idx]:
51 | idx = idx + 1
52 | for end_idx in range(idx, idx + length):
53 | if self.term[end_idx] or end_idx == size-1:
54 | break
55 | self.sample_idx = end_idx + 1 if end_idx < size-1 else 0
56 | is_initial_state = (idx > 0 and self.term[idx-1]) or idx == 0
57 |
58 | assert end_idx == idx + length - 1 or self.term[end_idx] or end_idx == size-1
59 | return util.Batch(si=np.asarray(self.s[idx:end_idx+1]),
60 | a=np.asarray(self.a[idx:end_idx+1]),
61 | adv=None,
62 | r=None,
63 | terminal=self.term,
64 | features=[],
65 | reward=np.asarray(self.r[idx:end_idx+1]),
66 | step=np.asarray(self.t[idx:end_idx+1]),
67 | meta=np.asarray(self.r_t[idx:end_idx+1])), is_initial_state
68 |
69 | class VPN(Q):
70 | def define_network(self, name):
71 | self.state_off = None
72 | self.args.meta_dim = 0 if self.env.meta() is None else len(self.env.meta())
73 | m = eval("model." + name)(self.env.observation_space.shape,
74 | self.env.action_space.n, type='vpn',
75 | gamma=self.args.gamma,
76 | prediction_step=self.args.prediction_step,
77 | dim=self.args.dim,
78 | f_num=self.args.f_num,
79 | f_pad=self.args.f_pad,
80 | f_stride=self.args.f_stride,
81 | f_size=self.args.f_size,
82 | branch=self.args.branch,
83 | meta_dim=self.args.meta_dim,
84 | )
85 |
86 | return m
87 |
88 | def process_rollout(self, rollout, gamma, lambda_=1.0):
89 | """
90 | given a rollout, compute its returns
91 | """
92 | batch_si = np.asarray(rollout.states)
93 | batch_a = np.asarray(rollout.actions)
94 | rewards = np.asarray(rollout.rewards)
95 | time = np.asarray(rollout.time)
96 | meta = np.asarray(rollout.meta)
97 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
98 | batch_r = util.discount(rewards_plus_v, gamma, time)
99 | features = rollout.features[0]
100 |
101 | return util.Batch(si=batch_si,
102 | a=batch_a,
103 | adv=None,
104 | r=batch_r,
105 | terminal=rollout.terminal,
106 | features=features,
107 | reward=rewards,
108 | step=time,
109 | meta=meta,
110 | )
111 |
112 | def define_loss(self):
113 | pi = self.local_network
114 | if self.args.buf > 0:
115 | if pi.is_recurrent():
116 | self.rand_rollouts = RolloutMemory(int(self.args.buf / self.args.num_workers),
117 | sampling='seq')
118 | self.off_state = pi.get_initial_features()
119 | else:
120 | self.rand_rollouts = RolloutMemory(int(self.args.buf / self.args.num_workers))
121 |
122 | # loss function
123 | self.ac = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="ac")
124 | self.v_target = tf.placeholder(tf.float32, [None], name="v_target") # target
125 | self.reward = tf.placeholder(tf.float32, [None], name="reward") # immediate reward
126 | self.step = tf.placeholder(tf.float32, [None], name="step") # num of steps
127 | self.terminal = tf.placeholder(tf.float32, (), name="terminal")
128 |
129 | time = tf.shape(pi.x)[0]
130 | steps = tf.minimum(self.args.prediction_step, time)
131 | self.rollout_num = tf.to_float(time * steps - steps * (steps - 1) / 2)
132 |
133 | # reward/gamma/value prediction
134 | self.r_delta = util.lower_triangular(
135 | pi.r_a - tf.reshape(self.reward, [-1, 1]))
136 | self.r_loss_mat = util.huber_loss(self.r_delta, sum=False)
137 | self.r_loss = tf.reduce_sum(self.r_loss_mat)
138 |
139 | self.gamma_loss_mat = util.huber_loss(util.lower_triangular(
140 | pi.t_a - tf.reshape(self.step, [-1, 1])), sum=False)
141 | self.gamma_loss = tf.reduce_sum(self.gamma_loss_mat)
142 |
143 | self.v_next_loss_mat = util.huber_loss(util.lower_triangular(
144 | pi.v_next_a - tf.reshape(self.v_target[1:], [-1, 1])), sum=False)
145 | self.v_next_loss = tf.reduce_sum(self.v_next_loss_mat)
146 | self.loss = self.r_loss + self.gamma_loss + self.v_next_loss
147 |
148 | # reward/gamma prediction for off-policy data (optional)
149 | self.a_off = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="a_off")
150 | self.r_off = tf.placeholder(tf.float32, [None], name="r_off") # immediate reward
151 | self.step_off = tf.placeholder(tf.float32, [None], name="step_off") # num of steps
152 |
153 | self.r_delta_off = util.lower_triangular(
154 | pi.r_off - tf.reshape(self.r_off, [-1, 1]))
155 | self.r_loss_mat_off = util.huber_loss(self.r_delta_off, sum=False)
156 | self.r_loss_off = tf.reduce_sum(self.r_loss_mat_off)
157 |
158 | self.gamma_loss_mat_off = util.huber_loss(util.lower_triangular(
159 | pi.t_off - tf.reshape(self.step_off, [-1, 1])), sum=False)
160 |
161 | self.gamma_loss_off = tf.reduce_sum(self.gamma_loss_mat_off)
162 | self.loss += self.r_loss_off + self.gamma_loss_off
163 |
164 | def prepare_input(self, batch):
165 | feed_dict = {self.local_network.x: batch.si,
166 | self.local_network.a: batch.a,
167 | self.ac: batch.a,
168 | self.reward: batch.reward,
169 | self.step: batch.step,
170 | self.target_network.x: batch.si,
171 | self.terminal: float(batch.terminal),
172 | self.v_target: batch.r}
173 |
174 | for i in range(len(self.local_network.state_in)):
175 | feed_dict[self.local_network.state_in[i]] = batch.features[i]
176 |
177 | if self.args.meta_dim > 0:
178 | feed_dict[self.local_network.meta] = batch.meta
179 |
180 | traj, initial = self.random_trajectory()
181 | feed_dict[self.local_network.x_off] = traj.si
182 | feed_dict[self.local_network.a_off] = traj.a
183 | feed_dict[self.a_off] = traj.a
184 | feed_dict[self.r_off] = traj.reward
185 | feed_dict[self.step_off] = traj.step
186 |
187 | if self.local_network.is_recurrent():
188 | if initial:
189 | state_in = self.local_network.get_initial_features()
190 | else:
191 | state_in = self.off_state
192 | for i in range(len(self.local_network.state_in_off)):
193 | feed_dict[self.local_network.state_in_off[i]] = state_in[i]
194 |
195 | if self.args.meta_dim > 0:
196 | feed_dict[self.local_network.meta_off] = traj.meta
197 |
198 | return feed_dict
199 |
200 | def random_trajectory(self):
201 | if not self.rand_rollouts.is_full():
202 | env = self.env_off
203 | state_off = env.reset()
204 | meta_off = env.meta()
205 | print("Generating random rollouts: %d steps" % self.rand_rollouts.max_size)
206 | while not self.rand_rollouts.is_full():
207 | act_idx = np.random.randint(0, env.action_space.n)
208 | action = np.zeros(env.action_space.n)
209 | action[act_idx] = 1
210 | state, reward, terminal, _, time = env.step(action.argmax())
211 | self.rand_rollouts.add(state_off, action, reward, time,
212 | meta_off, terminal)
213 | state_off = state
214 | meta_off = env.meta()
215 | if terminal:
216 | state_off = env.reset()
217 | meta_off = env.meta()
218 | return self.rand_rollouts.sample(self.args.t_max)
219 |
220 | def extra_fetches(self):
221 | if self.local_network.is_recurrent():
222 | return self.local_network.state_out_off
223 | return []
224 |
225 | def handle_extra_fetches(self, fetches):
226 | if self.local_network.is_recurrent():
227 | self.off_state = fetches[:len(self.off_state)]
228 |
229 | def compute_depth(self, steps):
230 | return self.args.depth
231 |
232 | def write_extra_summary(self, rollout=None):
233 | super(VPN, self).write_extra_summary(rollout)
234 |
235 | def define_summary(self):
236 | super(VPN, self).define_summary()
237 | tf.summary.scalar("loss/r_loss", self.r_loss / self.rollout_num)
238 | tf.summary.scalar("loss/gamma_loss", self.gamma_loss / self.rollout_num)
239 | tf.summary.scalar("model/r", tf.reduce_mean(self.local_network.r))
240 | tf.summary.scalar("model/v_next", tf.reduce_mean(self.local_network.v_next))
241 | tf.summary.scalar("model/gamma", tf.reduce_mean(self.local_network.gamma))
242 | self.summary_op = tf.summary.merge_all()
243 |
--------------------------------------------------------------------------------
/worker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import go_vncdriver
3 | import tensorflow as tf
4 | import argparse
5 | import logging
6 | import sys, signal
7 | import time
8 | import os
9 | from a3c import A3C
10 | from q import Q
11 | from vpn import VPN
12 | from envs import create_env
13 | import util
14 | import numpy as np
15 |
16 | logger = logging.getLogger(__name__)
17 | logger.setLevel(logging.INFO)
18 |
19 | def new_env(args):
20 | config = open(args.config) if args.config != "" else None
21 | env = create_env(args.env_id,
22 | str(args.task),
23 | args.remotes,
24 | config=config)
25 | return env
26 |
27 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless.
28 | class FastSaver(tf.train.Saver):
29 | def save(self, sess, save_path, global_step=None, latest_filename=None,
30 | meta_graph_suffix="meta", write_meta_graph=True):
31 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
32 | meta_graph_suffix, False)
33 |
34 | def run(args, server):
35 | env = new_env(args)
36 | if args.alg == 'A3C':
37 | trainer = A3C(env, args)
38 | elif args.alg == 'Q':
39 | trainer = Q(env, args)
40 | elif args.alg == 'VPN':
41 | env_off = new_env(args)
42 | env_off.verbose = 0
43 | env_off.reset()
44 | trainer = VPN(env, args, env_off=env_off)
45 | else:
46 | raise ValueError('Invalid algorithm: ' + args.alg)
47 |
48 | # Variable names that start with "local" are not saved in checkpoints.
49 | variables_to_save = [v for v in tf.global_variables() if \
50 | not v.name.startswith("global") and not v.name.startswith("local/target/")]
51 | global_variables = [v for v in tf.global_variables() if not v.name.startswith("local")]
52 |
53 | init_op = tf.variables_initializer(global_variables)
54 | init_all_op = tf.global_variables_initializer()
55 | saver = FastSaver(variables_to_save, max_to_keep=0)
56 |
57 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
58 | logger.info('Trainable vars:')
59 | for v in var_list:
60 | logger.info(' %s %s', v.name, v.get_shape())
61 | logger.info("Num parameters: %d", trainer.local_network.num_param)
62 |
63 | def init_fn(ses):
64 | logger.info("Initializing all parameters.")
65 | ses.run(init_all_op)
66 |
67 | device = 'gpu' if args.gpu > 0 else 'cpu'
68 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15)
69 | config = tf.ConfigProto(device_filters=["/job:ps",
70 | "/job:worker/task:{}/{}:0".format(args.task, device)],
71 | gpu_options=gpu_options,
72 | allow_soft_placement=True)
73 | logdir = os.path.join(args.log, 'train')
74 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
75 |
76 | logger.info("Events directory: %s_%s", logdir, args.task)
77 | sv = tf.train.Supervisor(is_chief=(args.task == 0),
78 | logdir=logdir,
79 | saver=saver,
80 | summary_op=None,
81 | init_op=init_op,
82 | init_fn=init_fn,
83 | summary_writer=summary_writer,
84 | ready_op=tf.report_uninitialized_variables(global_variables),
85 | global_step=trainer.global_step,
86 | save_model_secs=0,
87 | save_summaries_secs=30)
88 |
89 |
90 | logger.info(
91 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
92 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
93 | with sv.managed_session(server.target, config=config) as sess, sess.as_default():
94 | sess.run(trainer.sync)
95 | trainer.start(sess, summary_writer)
96 | global_step = sess.run(trainer.global_step)
97 | epoch = -1
98 | logger.info("Starting training at step=%d", global_step)
99 | while not sv.should_stop() and (not args.max_step or global_step < args.max_step):
100 | if args.task == 0 and int(global_step / args.eval_freq) > epoch:
101 | epoch = int(global_step / args.eval_freq)
102 | filename = os.path.join(args.log, 'e%d' % (epoch))
103 | sv.saver.save(sess, filename)
104 | sv.saver.save(sess, os.path.join(args.log, 'latest'))
105 | print("Saved to: %s" % filename)
106 | trainer.process(sess)
107 | global_step = sess.run(trainer.global_step)
108 |
109 | if args.task == 0 and int(global_step / args.eval_freq) > epoch:
110 | epoch = int(global_step / args.eval_freq)
111 | filename = os.path.join(args.log, 'e%d' % (epoch))
112 | sv.saver.save(sess, filename)
113 | sv.saver.save(sess, os.path.join(args.log, 'latest'))
114 | print("Saved to: %s" % filename)
115 | # Ask for all the services to stop.
116 | sv.stop()
117 | logger.info('reached %s steps. worker stopped.', global_step)
118 |
119 | def cluster_spec(num_workers, num_ps):
120 | """
121 | More tensorflow setup for data parallelism
122 | """
123 | cluster = {}
124 | port = 12222
125 |
126 | all_ps = []
127 | host = '127.0.0.1'
128 | for _ in range(num_ps):
129 | all_ps.append('{}:{}'.format(host, port))
130 | port += 1
131 | cluster['ps'] = all_ps
132 |
133 | all_workers = []
134 | for _ in range(num_workers + 1):
135 | all_workers.append('{}:{}'.format(host, port))
136 | port += 1
137 | cluster['worker'] = all_workers
138 | port += 1
139 | return cluster
140 |
141 | def evaluate(env, network, num_play=3000, eps=0.0):
142 | for iter in range(0, num_play):
143 | last_state = env.reset()
144 | last_features = network.get_initial_features()
145 | last_meta = env.meta()
146 | while True:
147 | if eps == 0.0 or np.random.rand() > eps:
148 | fetched = network.act(last_state, last_features,
149 | meta=last_meta)
150 | if network.type == 'policy':
151 | action, features = fetched[0], fetched[2:]
152 | else:
153 | action, features = fetched[0], fetched[1:]
154 | else:
155 | act_idx = np.random.randint(0, env.action_space.n)
156 | action = np.zeros(env.action_space.n)
157 | action[act_idx] = 1
158 | features = []
159 |
160 | state, reward, terminal, info, time = env.step(action.argmax())
161 | last_state = state
162 | last_features = features
163 | last_meta = env.meta()
164 |
165 | if terminal:
166 | break
167 |
168 | return env.reward_mean(num_play)
169 |
170 | def run_tester(args, server):
171 | env = new_env(args)
172 | env.reset()
173 | env.max_history = args.eval_num
174 | if args.alg == 'A3C':
175 | agent = A3C(env, args)
176 | elif args.alg == 'Q':
177 | agent = Q(env, args)
178 | elif args.alg == 'VPN':
179 | agent = VPN(env, args)
180 | else:
181 | raise ValueError('Invalid algorithm: ' + args.alg)
182 |
183 | device = 'gpu' if args.gpu > 0 else 'cpu'
184 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15)
185 | config = tf.ConfigProto(device_filters=["/job:ps",
186 | "/job:worker/task:{}/{}:0".format(args.task, device)],
187 | gpu_options=gpu_options,
188 | allow_soft_placement=True)
189 | variables_to_save = [v for v in tf.global_variables() if \
190 | not v.name.startswith("global") and not v.name.startswith("local/target/")]
191 | global_variables = [v for v in tf.global_variables() if not v.name.startswith("local")]
192 |
193 | init_op = tf.variables_initializer(global_variables)
194 | init_all_op = tf.global_variables_initializer()
195 |
196 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
197 | logger.info('Trainable vars:')
198 | for v in var_list:
199 | logger.info(' %s %s', v.name, v.get_shape())
200 | logger.info("Num parameters: %d", agent.local_network.num_param)
201 |
202 | def init_fn(ses):
203 | logger.info("Initializing all parameters.")
204 | ses.run(init_all_op)
205 |
206 | saver = FastSaver(variables_to_save, max_to_keep=0)
207 | sv = tf.train.Supervisor(is_chief=False,
208 | global_step=agent.global_step,
209 | summary_op=None,
210 | init_op=init_op,
211 | init_fn=init_fn,
212 | ready_op=tf.report_uninitialized_variables(global_variables),
213 | saver=saver,
214 | save_model_secs=0,
215 | save_summaries_secs=0)
216 |
217 | best_reward = -10000
218 | with sv.managed_session(server.target, config=config) as sess, sess.as_default():
219 | epoch = args.eval_epoch
220 | while args.eval_freq * epoch <= args.max_step:
221 | path = os.path.join(args.log, "e%d" % epoch)
222 | if not os.path.exists(path + ".index"):
223 | time.sleep(10)
224 | continue
225 | logger.info("Start evaluation (Epoch %d)", epoch)
226 | saver.restore(sess, path)
227 | np.random.seed(args.seed)
228 | reward = evaluate(env, agent.local_network, args.eval_num, eps=args.eps_eval)
229 |
230 | logfile = open(os.path.join(args.log, "eval.csv"), "a")
231 | print("Epoch: %d, Reward: %.2f" % (epoch, reward))
232 | logfile.write("%d, %.3f\n" % (epoch, reward))
233 | logfile.close()
234 | if reward > best_reward:
235 | best_reward = reward
236 | sv.saver.save(sess, os.path.join(args.log, 'best'))
237 | print("Saved to: %s" % os.path.join(args.log, 'best'))
238 |
239 | epoch += 1
240 |
241 | logger.info('tester stopped.')
242 |
243 | def main(_):
244 | """
245 | Setting up Tensorflow for data parallel work
246 | """
247 |
248 | parser = argparse.ArgumentParser(description=None)
249 | parser.add_argument('-gpu', '--gpu', default=0, type=int, help='Number of GPUs')
250 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.')
251 | parser.add_argument('--task', default=0, type=int, help='Task index')
252 | parser.add_argument('--job-name', default="worker", help='worker or ps')
253 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers')
254 | parser.add_argument('--num-ps', type=int, default=1, help="Number of parameter servers")
255 | parser.add_argument('--log', default="/tmp/vpn", help='Log directory path')
256 | parser.add_argument('--env-id', default="maze", help='Environment id')
257 | parser.add_argument('-r', '--remotes', default=None,
258 | help='References to environments to create (e.g. -r 20), '
259 | 'or the address of pre-existing VNC servers and '
260 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901)')
261 | parser.add_argument('-a', '--alg', choices=['A3C', 'Q', 'VPN'], default="A3C")
262 | parser.add_argument('-mo', '--model', type=str, default="LSTM", help="Name of model: [CNN | LSTM]")
263 | parser.add_argument('--eval-freq', type=int, default=250000, help="Evaluation frequency")
264 | parser.add_argument('--eval-num', type=int, default=500, help="Evaluation frequency")
265 | parser.add_argument('--eval-epoch', type=int, default=0, help="Evaluation epoch")
266 | parser.add_argument('--seed', type=int, default=0, help="Random seed")
267 | parser.add_argument('--config', type=str, default="config/collect_deterministic.xml",
268 | help="config xml file for environment")
269 |
270 | # Hyperparameters
271 | parser.add_argument('-n', '--t-max', type=int, default=10, help="Number of unrolling steps")
272 | parser.add_argument('-g', '--gamma', type=float, default=0.98, help="Discount factor")
273 | parser.add_argument('-ld', '--ld', type=float, default=1, help="Lambda for GAE")
274 | parser.add_argument('-lr', '--lr', type=float, default=1e-4, help="Learning rate")
275 | parser.add_argument('--decay', type=float, default=0.95, help="Learning decay")
276 | parser.add_argument('-ms', '--max-step', type=int, default=int(15e6), help="Max global step")
277 | parser.add_argument('--dim', type=int, default=0, help="Number of final hidden units")
278 | parser.add_argument('--f-num', type=str, default='32,32,64', help="num of conv filters")
279 | parser.add_argument('--f-pad', type=str, default='SAME', help="padding of conv filters")
280 | parser.add_argument('--f-stride', type=str, default='1,1,2', help="stride of conv filters")
281 | parser.add_argument('--f-size', type=str, default='3,3,4', help="size of conv filters")
282 | parser.add_argument('--h-dim', type=str, default='', help="num of hidden units")
283 |
284 | # Q-Learning parameters
285 | parser.add_argument('-s', '--sync', type=int, default=10000,
286 | help="Target network synchronization frequency")
287 | parser.add_argument('-f', '--update-freq', type=int, default=1,
288 | help="Parameter update frequency")
289 | parser.add_argument('--eps-step', type=int, default=int(1e6),
290 | help="Num of local steps for epsilon scheduling")
291 | parser.add_argument('--eps', type=float, default=0.05, help="Final epsilon value")
292 | parser.add_argument('--eps-eval', type=float, default=0.0, help="Epsilon for evaluation")
293 |
294 | # VPN parameters
295 | parser.add_argument('--prediction-step', type=int, default=3, help="number of prediction steps")
296 | parser.add_argument('--branch', type=str, default="4,4,4", help="branching factor")
297 | parser.add_argument('--buf', type=int, default=10**6, help="num of steps for random buffer")
298 |
299 | args = parser.parse_args()
300 | args.f_num = util.parse_to_num(args.f_num)
301 | args.f_stride = util.parse_to_num(args.f_stride)
302 | args.f_size = util.parse_to_num(args.f_size)
303 | args.h_dim = util.parse_to_num(args.h_dim)
304 | args.eps_eval = min(args.eps, args.eps_eval)
305 | args.branch = util.parse_to_num(args.branch)
306 | spec = cluster_spec(args.num_workers, args.num_ps)
307 | cluster = tf.train.ClusterSpec(spec).as_cluster_def()
308 |
309 | def shutdown(signal, frame):
310 | logger.warn('Received signal %s: exiting', signal)
311 | sys.exit(128+signal)
312 | signal.signal(signal.SIGHUP, shutdown)
313 | signal.signal(signal.SIGINT, shutdown)
314 | signal.signal(signal.SIGTERM, shutdown)
315 |
316 | gpu_options = None
317 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15)
318 |
319 | if args.job_name == "worker":
320 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task,
321 | config=tf.ConfigProto(intra_op_parallelism_threads=1,
322 | inter_op_parallelism_threads=1,
323 | gpu_options=gpu_options))
324 | run(args, server)
325 | elif args.job_name == "test":
326 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task,
327 | config=tf.ConfigProto(intra_op_parallelism_threads=1,
328 | inter_op_parallelism_threads=1,
329 | gpu_options=gpu_options))
330 | run_tester(args, server)
331 | elif args.job_name == "ps":
332 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
333 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task,
334 | config=tf.ConfigProto(device_filters=["/job:ps"],
335 | gpu_options=gpu_options))
336 | while True:
337 | time.sleep(1000)
338 |
339 | if __name__ == "__main__":
340 | tf.app.run()
341 |
--------------------------------------------------------------------------------