├── .gitignore ├── README.md ├── a3c.py ├── async.py ├── config ├── collect_deterministic.xml └── collect_stochastic.xml ├── envs.py ├── maze.py ├── model.py ├── q.py ├── test.py ├── test_vpn ├── train.py ├── train_vpn ├── util.py ├── vpn.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | *.swp 91 | 92 | result/ 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This repository implements **[NIPS 2017 Value Prediction Network (Oh et al.)](https://arxiv.org/abs/1707.03497)** in Tensorflow. 3 | ``` 4 | @inproceedings{Oh2017VPN, 5 | title={Value Prediction Network}, 6 | author={Junhyuk Oh and Satinder Singh and Honglak Lee}, 7 | booktitle={NIPS}, 8 | year={2017} 9 | } 10 | ``` 11 | Our code is based on [OpenAI's A3C implemenation](https://github.com/openai/universe-starter-agent). 12 | 13 | # Dependencies 14 | * [Tensorflow](https://www.tensorflow.org/install/) 15 | * [Beutiful Soup](https://www.crummy.com/software/BeautifulSoup/bs4/doc/) 16 | * [Golang](https://golang.org/doc/install) 17 | * [six](https://pypi.python.org/pypi/six) (for py2/3 compatibility) 18 | * [tmux](https://tmux.github.io/) (the start script opens up a tmux session with multiple windows) 19 | * [htop](https://hisham.hm/htop/) (shown in one of the tmux windows) 20 | * [gym](https://pypi.python.org/pypi/gym) 21 | * gym[atari] 22 | * [universe](https://pypi.python.org/pypi/universe) 23 | * [opencv-python](https://pypi.python.org/pypi/opencv-python) 24 | * [numpy](https://pypi.python.org/pypi/numpy) 25 | * [scipy](https://pypi.python.org/pypi/scipy) 26 | 27 | # Training 28 | The following command trains a value prediction network (VPN) with plan depth of 3 on stochastic Collect domain: 29 | ``` 30 | python train.py --config config/collect_deterministic.xml --branch 4,4,4 --alg VPN 31 | ``` 32 | `train_vpn` script contains commands for reproducing the main result of the paper. 33 | 34 | # Notes 35 | * Tensorboard shows the performance of the epsilon-greedy policy. This is NOT the learning curve in the paper, because epsilon decreases from 1.0 to 0.05 for the first 1e6 steps. Instead, `[logdir]/eval.csv` shows the performance of the agent using greedy-policy. 36 | * Our code supports multi-gpu training. You can specify GPU IDs in `--gpu` option (e.g., `--gpu 0,1,2,3`). 37 | -------------------------------------------------------------------------------- /a3c.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | import util 5 | from async import AsyncSolver 6 | import model 7 | 8 | class A3C(AsyncSolver): 9 | def define_network(self, name): 10 | self.args.meta_dim = 0 if self.env.meta() is None else len(self.env.meta()) 11 | return eval("model." + name)(self.env.observation_space.shape, 12 | self.env.action_space.n, type='policy', 13 | gamma=self.args.gamma, 14 | dim=self.args.dim, 15 | f_num=self.args.f_num, 16 | f_pad=self.args.f_pad, 17 | f_stride=self.args.f_stride, 18 | f_size=self.args.f_size, 19 | meta_dim=self.args.meta_dim, 20 | ) 21 | 22 | def process_rollout(self, rollout, gamma, lambda_=1.0): 23 | """ 24 | given a rollout, compute its returns and the advantage 25 | """ 26 | batch_si = np.asarray(rollout.states) 27 | batch_a = np.asarray(rollout.actions) 28 | rewards = np.asarray(rollout.rewards) 29 | time = np.asarray(rollout.time) 30 | meta = np.asarray(rollout.meta) 31 | vpred_t = np.asarray(rollout.values + [rollout.r]) 32 | 33 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) 34 | batch_r = util.discount(rewards_plus_v, gamma)[:-1] 35 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] 36 | # this formula for the advantage comes "Generalized Advantage Estimation": 37 | # https://arxiv.org/abs/1506.02438 38 | batch_adv = util.discount(delta_t, gamma * lambda_) 39 | 40 | features = rollout.features[0] 41 | return util.Batch(si=batch_si, 42 | a=batch_a, 43 | adv=batch_adv, 44 | r=batch_r, 45 | terminal=rollout.terminal, 46 | features=features, 47 | reward=rewards, 48 | step=time, 49 | meta=meta) 50 | 51 | def init_variables(self): 52 | pi = self.local_network 53 | self.ac = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="ac") 54 | self.adv = tf.placeholder(tf.float32, [None], name="adv") 55 | self.r = tf.placeholder(tf.float32, [None], name="r") 56 | 57 | log_prob_tf = tf.nn.log_softmax(pi.logits) 58 | prob_tf = tf.nn.softmax(pi.logits) 59 | 60 | # the "policy gradients" loss: its derivative is precisely the policy gradient 61 | # notice that self.ac is a placeholder that is provided externally. 62 | # adv will contain the advantages, as calculated in process_rollout 63 | self.pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, [1]) * self.adv) 64 | 65 | # loss of value function 66 | self.vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.vf - self.r)) 67 | self.entropy = - tf.reduce_sum(prob_tf * log_prob_tf) 68 | 69 | self.bs = tf.to_float(tf.shape(pi.x)[0]) 70 | self.loss = self.pi_loss + 0.5 * self.vf_loss - self.entropy * 0.01 71 | 72 | def define_summary(self): 73 | super(A3C, self).define_summary() 74 | tf.summary.scalar("model/policy_loss", self.pi_loss / self.bs) 75 | tf.summary.scalar("model/value_loss", self.vf_loss / self.bs) 76 | tf.summary.scalar("model/entropy", self.entropy / self.bs) 77 | self.summary_op = tf.summary.merge_all() 78 | 79 | def prepare_input(self, batch): 80 | feed_dict = {self.local_network.x: batch.si, 81 | self.ac: batch.a, 82 | self.adv: batch.adv, 83 | self.r: batch.r} 84 | if self.args.meta_dim > 0: 85 | feed_dict[self.local_network.meta] = batch.meta 86 | for i in range(len(self.local_network.state_in)): 87 | feed_dict[self.local_network.state_in[i]] = batch.features[i] 88 | return feed_dict 89 | -------------------------------------------------------------------------------- /async.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | import six.moves.queue as queue 6 | import threading 7 | import distutils.version 8 | 9 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0') 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | 13 | class PartialRollout(object): 14 | """ 15 | a piece of a complete rollout. We run our agent, and process its experience 16 | once it has processed enough steps. 17 | """ 18 | def __init__(self): 19 | self.states = [] 20 | self.actions = [] 21 | self.rewards = [] 22 | self.values = [] 23 | self.r = 0.0 24 | self.terminal = False 25 | self.features = [] 26 | self.time = [] 27 | self.meta = [] 28 | 29 | def add(self, state, action, reward, terminal, features, 30 | value = None, time = None, meta=None): 31 | self.states += [state] 32 | self.actions += [action] 33 | self.rewards += [reward] 34 | self.terminal = terminal 35 | self.features += [features] 36 | if value is not None: 37 | self.values += [value] 38 | if time is not None: 39 | self.time += [time] 40 | if meta is not None: 41 | self.meta += [meta] 42 | 43 | def extend(self, other): 44 | assert not self.terminal 45 | self.states.extend(other.states) 46 | self.actions.extend(other.actions) 47 | self.rewards.extend(other.rewards) 48 | self.r = other.r 49 | self.terminal = other.terminal 50 | self.features.extend(other.features) 51 | if other.values is not None: 52 | self.values.extend(other.values) 53 | if other.time is not None: 54 | self.time.extend(other.time) 55 | if other.meta is not None: 56 | self.meta.extend(other.meta) 57 | 58 | class RunnerThread(threading.Thread): 59 | """ 60 | One of the key distinctions between a normal environment and a universe environment 61 | is that a universe environment is _real time_. This means that there should be a thread 62 | that would constantly interact with the environment and tell it what to do. This thread is here. 63 | """ 64 | def __init__(self, solver): 65 | threading.Thread.__init__(self) 66 | self.queue = queue.Queue(5) 67 | self.solver = solver 68 | self.num_local_steps = solver.t_max 69 | self.env = solver.env 70 | self.last_features = None 71 | self.network = solver.local_network 72 | self.daemon = True 73 | self.sess = None 74 | self.summary_writer = None 75 | 76 | def start_runner(self, sess, summary_writer): 77 | self.sess = sess 78 | self.summary_writer = summary_writer 79 | self.start() 80 | 81 | def run(self): 82 | with self.sess.as_default(): 83 | self._run() 84 | 85 | def _run(self): 86 | rollout_provider = env_runner(self.env, self.network, self.num_local_steps, 87 | self.summary_writer, solver=self.solver) 88 | while True: 89 | # the timeout variable exists because apparently, if one worker dies, the other workers 90 | # won't die with it, unless the timeout is set to some large number. This is an empirical 91 | # observation. 92 | 93 | self.queue.put(next(rollout_provider), timeout=600.0) 94 | 95 | 96 | def env_runner(env, network, num_local_steps, summary_writer, solver=None): 97 | """ 98 | The logic of the thread runner. In brief, it constantly keeps on running 99 | the policy, and as long as the rollout exceeds a certain length, the thread 100 | runner appends the policy to the queue. 101 | """ 102 | last_state = env.reset() 103 | last_features = network.get_initial_features() 104 | last_meta = env.meta() 105 | if solver.use_target_network(): 106 | last_target_features = solver.target_network.get_initial_features() 107 | 108 | while True: 109 | terminal_end = False 110 | rollout = PartialRollout() 111 | 112 | for _ in range(num_local_steps): 113 | value = None 114 | 115 | # choose an action from the policy 116 | if not hasattr(solver, 'epsilon') or solver.epsilon() < np.random.uniform(): 117 | fetched = network.act(last_state, last_features, 118 | meta=last_meta) 119 | if network.type == 'policy': 120 | action, value, features = fetched[0], fetched[1], fetched[2:] 121 | else: 122 | action, features = fetched[0], fetched[1:] 123 | else: 124 | # choose a random action 125 | assert network.type != 'policy' 126 | act_idx = np.random.randint(0, env.action_space.n) 127 | action = np.zeros(env.action_space.n) 128 | action[act_idx] = 1 129 | if network.is_recurrent(): 130 | features = network.update_state(last_state, last_features, 131 | meta=last_meta) 132 | else: 133 | features = [] 134 | 135 | # argmax to convert from one-hot 136 | state, reward, terminal, info, time = env.step(action.argmax()) 137 | if hasattr(env, 'atari'): 138 | reward = np.clip(reward, -1, 1) 139 | 140 | # collect the experience 141 | rollout.add(last_state, action, reward, terminal, last_features, 142 | value = value, time = time, meta=last_meta) 143 | 144 | last_state = state 145 | last_features = features 146 | last_meta = env.meta() 147 | 148 | if info: 149 | summary = tf.Summary() 150 | for k, v in info.items(): 151 | summary.value.add(tag=k, simple_value=float(v)) 152 | summary_writer.add_summary(summary, network.global_step.eval()) 153 | summary_writer.flush() 154 | 155 | if terminal: 156 | terminal_end = True 157 | last_state = env.reset() 158 | last_features = network.get_initial_features() 159 | last_meta = env.meta() 160 | break 161 | 162 | if not terminal_end: 163 | if solver.use_target_network(): 164 | rollout.r = solver.target_network.value(last_state, 165 | last_features, 166 | meta=last_meta) 167 | else: 168 | rollout.r = network.value(last_state, last_features, 169 | meta=last_meta) 170 | 171 | # once we have enough experience, yield it, and have the ThreadRunner place it on a queue 172 | yield rollout 173 | 174 | class AsyncSolver(object): 175 | def __init__(self, env, args, env_off=None): 176 | self.env = env 177 | self.args = args 178 | self.task = args.task 179 | self.t_max = args.t_max 180 | self.ld = args.ld 181 | self.lr = args.lr 182 | self.model = args.model 183 | self.env_off = env_off 184 | self.last_global_step = 0 185 | 186 | device = 'gpu' if self.args.gpu > 0 else 'cpu' 187 | worker_device = "/job:worker/task:{}/{}:0".format(self.task, device) 188 | def _load_fn(unused_op): 189 | return 1 190 | with tf.device(tf.train.replica_device_setter(self.args.num_ps, 191 | worker_device=worker_device, 192 | ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy( 193 | self.args.num_ps, _load_fn))): 194 | with tf.variable_scope("global"): 195 | with tf.variable_scope("learner"): 196 | self.network = self.define_network(self.model) 197 | if self.use_target_network(): 198 | with tf.variable_scope("target"): 199 | self.global_target_network = self.define_network(self.model) 200 | self.global_target_sync_step = tf.get_variable("target_sync_step", [], 201 | tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), 202 | trainable=False) 203 | self.global_step = tf.get_variable("global_step", [], tf.int32, 204 | initializer=tf.constant_initializer(0, dtype=tf.int32), 205 | trainable=False) 206 | 207 | 208 | with tf.device(worker_device): 209 | with tf.variable_scope("local"): 210 | with tf.variable_scope("learner"): 211 | self.local_network = pi = self.define_network(self.model) 212 | pi.global_step = self.global_step 213 | if self.use_target_network(): 214 | with tf.variable_scope("target"): 215 | self.target_network = self.define_network(self.model) 216 | 217 | self.init_variables() 218 | 219 | # 20 represents the number of "local steps": the number of timesteps 220 | # we run the policy before we update the parameters. 221 | # The larger local steps is, the lower is the variance in our policy gradients estimate 222 | # on the one hand; but on the other hand, we get less frequent parameter updates, which 223 | # slows down learning. In this code, we found that making local steps be much 224 | # smaller than 20 makes the algorithm more difficult to tune and to get to work. 225 | self.runner = RunnerThread(self) 226 | 227 | self.grads = tf.gradients(self.loss, pi.var_list) 228 | self.grads, _ = tf.clip_by_global_norm(self.grads, 40.0) 229 | 230 | # copy weights from the parameter server to the local model 231 | self.sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(pi.var_list, 232 | self.network.var_list)]) 233 | 234 | self.grads_and_vars = list(zip(self.grads, self.network.var_list)) 235 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0]) 236 | 237 | self.learning_rate = tf.placeholder(tf.float32, shape=[]) 238 | # each worker has a different set of adam optimizer parameters 239 | opt = tf.train.AdamOptimizer(self.learning_rate) 240 | self.train_op = tf.group(opt.apply_gradients(self.grads_and_vars), inc_step) 241 | if self.use_target_network(): 242 | self.update_target_step = self.global_target_sync_step.assign(self.global_step) 243 | 244 | with tf.device(None): 245 | self.define_summary() 246 | self.summary_writer = None 247 | self.local_steps = 0 248 | 249 | def define_summary(self): 250 | tf.summary.scalar("model/lr", self.learning_rate) 251 | tf.summary.image("model/state", self.env.tf_visualize(self.local_network.x), max_outputs=10) 252 | tf.summary.scalar("gradient/grad_norm", tf.global_norm(self.grads)) 253 | tf.summary.scalar("param/param_norm", tf.global_norm(self.local_network.var_list)) 254 | for grad_var in self.grads_and_vars: 255 | grad = grad_var[0] 256 | var = grad_var[1] 257 | if var.name.find('/W:') >= 0 or var.name.find('/w:') >= 0: 258 | if grad is None: 259 | raise ValueError(var.name + " grads are missing") 260 | tf.summary.scalar("gradient/%s" % var.name, tf.norm(grad)) 261 | tf.summary.scalar("param/%s" % var.name, tf.norm(var)) 262 | 263 | self.summary_op = tf.summary.merge_all() 264 | 265 | def use_target_network(self): 266 | return False 267 | 268 | def start(self, sess, summary_writer): 269 | self.runner.start_runner(sess, summary_writer) 270 | self.summary_writer = summary_writer 271 | 272 | def pull_batch_from_queue(self): 273 | """ 274 | self explanatory: take a rollout from the queue of the thread runner. 275 | """ 276 | rollout = self.runner.queue.get(timeout=600.0) 277 | ''' 278 | while not rollout.terminal: 279 | try: 280 | rollout.extend(self.runner.queue.get_nowait()) 281 | except queue.Empty: 282 | break 283 | ''' 284 | return rollout 285 | 286 | def process(self, sess): 287 | """ 288 | process grabs a rollout that's been produced by the thread runner, 289 | and updates the parameters. The update is then sent to the parameter 290 | server. 291 | """ 292 | sess.run(self.sync) # copy weights from shared to local 293 | rollout = self.pull_batch_from_queue() 294 | should_compute_summary = self.task == 0 and self.local_steps % 101 == 0 295 | 296 | if self.local_steps % self.args.update_freq == 0: 297 | batch = self.process_rollout(rollout, gamma=self.args.gamma, lambda_=self.ld) 298 | extra_fetches = self.extra_fetches() 299 | if should_compute_summary: 300 | fetches = [self.train_op, self.summary_op, self.global_step] 301 | else: 302 | fetches = [self.train_op, self.global_step] 303 | 304 | feed_dict = self.prepare_input(batch) 305 | feed_dict[self.learning_rate] = \ 306 | self.args.lr * self.args.decay ** (self.last_global_step/float(10**6)) 307 | fetched = sess.run(extra_fetches + fetches, feed_dict=feed_dict) 308 | if should_compute_summary: 309 | self.summary_writer.add_summary(tf.Summary.FromString(fetched[-2]), fetched[-1]) 310 | self.write_extra_summary(rollout=rollout) 311 | self.summary_writer.flush() 312 | self.last_global_step = fetched[-1] 313 | self.handle_extra_fetches(fetched[:len(extra_fetches)]) 314 | 315 | self.local_steps += 1 316 | self.post_process(sess) 317 | 318 | def extra_fetches(self): 319 | return [] 320 | 321 | def handle_extra_fetches(self, fetches): 322 | return None 323 | 324 | def post_process(self, sess): 325 | return None 326 | 327 | def write_extra_summary(self, rollout=None): 328 | return None 329 | -------------------------------------------------------------------------------- /config/collect_deterministic.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /config/collect_stochastic.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from gym.spaces.box import Box 3 | import numpy as np 4 | import gym 5 | from gym import spaces 6 | import logging 7 | import universe 8 | from universe import vectorized 9 | from universe.wrappers import BlockingReset, GymCoreAction, EpisodeID, Unvectorize, Vectorize, Vision, Logger 10 | from universe import spaces as vnc_spaces 11 | from universe.spaces.vnc_event import keycode 12 | import time 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | universe.configure_logging() 16 | import maze 17 | 18 | def create_env(env_id, client_id, remotes, **kwargs): 19 | if env_id == 'maze': 20 | return maze.MazeSMDP(**kwargs) 21 | 22 | spec = gym.spec(env_id) 23 | 24 | if spec.tags.get('flashgames', False): 25 | return create_flash_env(env_id, client_id, remotes, **kwargs) 26 | elif spec.tags.get('atari', False) and spec.tags.get('vnc', False): 27 | return create_vncatari_env(env_id, client_id, remotes, **kwargs) 28 | else: 29 | # Assume atari. 30 | assert "." not in env_id # universe environments have dots in names. 31 | return create_atari_env(env_id, **kwargs) 32 | 33 | def create_flash_env(env_id, client_id, remotes, **_): 34 | env = gym.make(env_id) 35 | env = Vision(env) 36 | env = Logger(env) 37 | env = BlockingReset(env) 38 | 39 | reg = universe.runtime_spec('flashgames').server_registry 40 | height = reg[env_id]["height"] 41 | width = reg[env_id]["width"] 42 | env = CropScreen(env, height, width, 84, 18) 43 | env = FlashRescale(env) 44 | 45 | keys = ['left', 'right', 'up', 'down', 'x'] 46 | if env_id == 'flashgames.NeonRace-v0': 47 | # Better key space for this game. 48 | keys = ['left', 'right', 'up', 'left up', 'right up', 'down', 'up x'] 49 | logger.info('create_flash_env(%s): keys=%s', env_id, keys) 50 | 51 | env = DiscreteToFixedKeysVNCActions(env, keys) 52 | env = EpisodeID(env) 53 | env = DiagnosticsInfo(env) 54 | env = Unvectorize(env) 55 | env.configure(fps=5.0, remotes=remotes, start_timeout=15 * 60, client_id=client_id, 56 | vnc_driver='go', vnc_kwargs={ 57 | 'encoding': 'tight', 'compress_level': 0, 58 | 'fine_quality_level': 50, 'subsample_level': 3}) 59 | return env 60 | 61 | def create_vncatari_env(env_id, client_id, remotes, **_): 62 | env = gym.make(env_id) 63 | env = Vision(env) 64 | env = Logger(env) 65 | env = BlockingReset(env) 66 | env = GymCoreAction(env) 67 | env = AtariRescale84x84(env) 68 | env = EpisodeID(env) 69 | env = DiagnosticsInfo(env) 70 | env = Unvectorize(env) 71 | 72 | logger.info('Connecting to remotes: %s', remotes) 73 | fps = env.metadata['video.frames_per_second'] 74 | env.configure(remotes=remotes, start_timeout=15 * 60, fps=fps, client_id=client_id) 75 | env.atari = True 76 | return env 77 | 78 | def create_atari_env(env_id, **kwargs): 79 | env = gym.make(env_id) 80 | env = Vectorize(env) 81 | env = AtariRescale84x84(env) 82 | env = DiagnosticsInfo(env) 83 | env = Unvectorize(env) 84 | env.atari = True 85 | return env 86 | 87 | def DiagnosticsInfo(env, *args, **kwargs): 88 | return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs) 89 | 90 | class DiagnosticsInfoI(vectorized.Filter): 91 | def __init__(self, log_interval=503): 92 | super(DiagnosticsInfoI, self).__init__() 93 | 94 | self._episode_time = time.time() 95 | self._last_time = time.time() 96 | self._local_t = 0 97 | self._log_interval = log_interval 98 | self._episode_reward = 0 99 | self._episode_length = 0 100 | self._all_rewards = [] 101 | self._num_vnc_updates = 0 102 | self._last_episode_id = -1 103 | 104 | def _after_reset(self, observation): 105 | # logger.info('Resetting environment') 106 | self._episode_reward = 0 107 | self._episode_length = 0 108 | self._all_rewards = [] 109 | return observation 110 | 111 | def _after_step(self, observation, reward, done, info): 112 | to_log = {} 113 | if self._episode_length == 0: 114 | self._episode_time = time.time() 115 | 116 | self._local_t += 1 117 | if info.get("stats.vnc.updates.n") is not None: 118 | self._num_vnc_updates += info.get("stats.vnc.updates.n") 119 | 120 | if self._local_t % self._log_interval == 0: 121 | cur_time = time.time() 122 | elapsed = cur_time - self._last_time 123 | fps = self._log_interval / elapsed 124 | self._last_time = cur_time 125 | cur_episode_id = info.get('vectorized.episode_id', 0) 126 | to_log["diagnostics/fps"] = fps 127 | if self._last_episode_id == cur_episode_id: 128 | to_log["diagnostics/fps_within_episode"] = fps 129 | self._last_episode_id = cur_episode_id 130 | if info.get("stats.gauges.diagnostics.lag.action") is not None: 131 | to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0] 132 | to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1] 133 | if info.get("reward.count") is not None: 134 | to_log["diagnostics/reward_count"] = info["reward.count"] 135 | if info.get("stats.gauges.diagnostics.clock_skew") is not None: 136 | to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0] 137 | to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1] 138 | if info.get("stats.gauges.diagnostics.lag.observation") is not None: 139 | to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0] 140 | to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1] 141 | 142 | if info.get("stats.vnc.updates.n") is not None: 143 | to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"] 144 | to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed 145 | self._num_vnc_updates = 0 146 | if info.get("stats.vnc.updates.bytes") is not None: 147 | to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"] 148 | if info.get("stats.vnc.updates.pixels") is not None: 149 | to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"] 150 | if info.get("stats.vnc.updates.rectangles") is not None: 151 | to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"] 152 | if info.get("env_status.state_id") is not None: 153 | to_log["diagnostics/env_state_id"] = info["env_status.state_id"] 154 | 155 | if reward is not None: 156 | self._episode_reward += reward 157 | if observation is not None: 158 | self._episode_length += 1 159 | self._all_rewards.append(reward) 160 | 161 | if done: 162 | logger.info('Episode terminating: episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length) 163 | total_time = time.time() - self._episode_time 164 | to_log["global/episode_reward"] = self._episode_reward 165 | to_log["global/episode_length"] = self._episode_length 166 | to_log["global/episode_time"] = total_time 167 | to_log["global/reward_per_time"] = self._episode_reward / total_time 168 | self._episode_reward = 0 169 | self._episode_length = 0 170 | self._all_rewards = [] 171 | 172 | return observation, reward, done, to_log 173 | 174 | def _process_frame84gray(frame): 175 | frame = cv2.resize(frame, (84, 84)) 176 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 177 | frame = frame.astype(np.float32) 178 | frame *= (1.0 / 255.0) 179 | frame = np.reshape(frame, [84, 84, 1]) 180 | return frame 181 | 182 | def _process_frame42(frame): 183 | frame = frame[34:34+160, :160] 184 | # Resize by half, then down to 42x42 (essentially mipmapping). If 185 | # we resize directly we lose pixels that, when mapped to 42x42, 186 | # aren't close enough to the pixel boundary. 187 | frame = cv2.resize(frame, (84, 84)) 188 | frame = cv2.resize(frame, (42, 42)) 189 | frame = frame.mean(2) 190 | frame = frame.astype(np.float32) 191 | frame *= (1.0 / 255.0) 192 | frame = np.reshape(frame, [42, 42, 1]) 193 | return frame 194 | 195 | class AtariRescale84x84(vectorized.ObservationWrapper): 196 | def __init__(self, env=None): 197 | super(AtariRescale84x84, self).__init__(env) 198 | self.observation_space = Box(0.0, 1.0, [84, 84, 1]) 199 | 200 | def _observation(self, observation_n): 201 | return [_process_frame84gray(observation) for observation in observation_n] 202 | 203 | class AtariRescale42x42(vectorized.ObservationWrapper): 204 | def __init__(self, env=None): 205 | super(AtariRescale42x42, self).__init__(env) 206 | self.observation_space = Box(0.0, 1.0, [42, 42, 1]) 207 | 208 | def _observation(self, observation_n): 209 | return [_process_frame42(observation) for observation in observation_n] 210 | 211 | class FixedKeyState(object): 212 | def __init__(self, keys): 213 | self._keys = [keycode(key) for key in keys] 214 | self._down_keysyms = set() 215 | 216 | def apply_vnc_actions(self, vnc_actions): 217 | for event in vnc_actions: 218 | if isinstance(event, vnc_spaces.KeyEvent): 219 | if event.down: 220 | self._down_keysyms.add(event.key) 221 | else: 222 | self._down_keysyms.discard(event.key) 223 | 224 | def to_index(self): 225 | action_n = 0 226 | for key in self._down_keysyms: 227 | if key in self._keys: 228 | # If multiple keys are pressed, just use the first one 229 | action_n = self._keys.index(key) + 1 230 | break 231 | return action_n 232 | 233 | class DiscreteToFixedKeysVNCActions(vectorized.ActionWrapper): 234 | """ 235 | Define a fixed action space. Action 0 is all keys up. Each element of keys can be a single key or a space-separated list of keys 236 | 237 | For example, 238 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right']) 239 | will have 3 actions: [none, left, right] 240 | 241 | You can define a state with more than one key down by separating with spaces. For example, 242 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right', 'space', 'left space', 'right space']) 243 | will have 6 actions: [none, left, right, space, left space, right space] 244 | """ 245 | def __init__(self, env, keys): 246 | super(DiscreteToFixedKeysVNCActions, self).__init__(env) 247 | 248 | self._keys = keys 249 | self._generate_actions() 250 | self.action_space = spaces.Discrete(len(self._actions)) 251 | 252 | def _generate_actions(self): 253 | self._actions = [] 254 | uniq_keys = set() 255 | for key in self._keys: 256 | for cur_key in key.split(' '): 257 | uniq_keys.add(cur_key) 258 | 259 | for key in [''] + self._keys: 260 | split_keys = key.split(' ') 261 | cur_action = [] 262 | for cur_key in uniq_keys: 263 | cur_action.append(vnc_spaces.KeyEvent.by_name(cur_key, down=(cur_key in split_keys))) 264 | self._actions.append(cur_action) 265 | self.key_state = FixedKeyState(uniq_keys) 266 | 267 | def _action(self, action_n): 268 | # Each action might be a length-1 np.array. Cast to int to 269 | # avoid warnings. 270 | return [self._actions[int(action)] for action in action_n] 271 | 272 | class CropScreen(vectorized.ObservationWrapper): 273 | """Crops out a [height]x[width] area starting from (top,left) """ 274 | def __init__(self, env, height, width, top=0, left=0): 275 | super(CropScreen, self).__init__(env) 276 | self.height = height 277 | self.width = width 278 | self.top = top 279 | self.left = left 280 | self.observation_space = Box(0, 255, shape=(height, width, 3)) 281 | 282 | def _observation(self, observation_n): 283 | return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None 284 | for ob in observation_n] 285 | 286 | def _process_frame_flash(frame): 287 | frame = cv2.resize(frame, (200, 128)) 288 | frame = frame.mean(2).astype(np.float32) 289 | frame *= (1.0 / 255.0) 290 | frame = np.reshape(frame, [128, 200, 1]) 291 | return frame 292 | 293 | class FlashRescale(vectorized.ObservationWrapper): 294 | def __init__(self, env=None): 295 | super(FlashRescale, self).__init__(env) 296 | self.observation_space = Box(0.0, 1.0, [128, 200, 1]) 297 | 298 | def _observation(self, observation_n): 299 | return [_process_frame_flash(observation) for observation in observation_n] 300 | -------------------------------------------------------------------------------- /maze.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import universe 4 | import gym 5 | import logging 6 | import copy 7 | from bs4 import BeautifulSoup 8 | import tensorflow as tf 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | universe.configure_logging() 12 | 13 | BLOCK = 0 14 | AGENT = 1 15 | GOAL = 2 16 | DX = [0, 1, 0, -1] 17 | DY = [-1, 0, 1, 0] 18 | 19 | COLOR = [[44, 42, 60], # block 20 | [91, 255, 123], # agent 21 | [52, 152, 219], # goal 22 | ] 23 | 24 | def str2bool(v): 25 | return v.lower() in ("yes", "true", "t", "1") 26 | 27 | def generate_maze(size, holes=0): 28 | # Source: http://code.activestate.com/recipes/578356-random-maze-generator/ 29 | # Random Maze Generator using Depth-first Search 30 | # http://en.wikipedia.org/wiki/Maze_generation_algorithm 31 | mx = size-2; my = size-2 # width and height of the maze 32 | maze = np.ones((my, mx)) 33 | dx = [0, 1, 0, -1]; dy = [-1, 0, 1, 0] # 4 directions to move in the maze 34 | # start the maze from a random cell 35 | start_x = np.random.randint(0, mx); start_y = np.random.randint(0, my) 36 | cx, cy = 0, 0 37 | # stack element: (x, y, direction) 38 | maze[start_y][start_x] = 0; stack = [(start_x, start_y, 0)] 39 | while len(stack) > 0: 40 | (cx, cy, cd) = stack[-1] 41 | # to prevent zigzags: 42 | # if changed direction in the last move then cannot change again 43 | if len(stack) > 2: 44 | if cd != stack[-2][2]: dirRange = [cd] 45 | else: dirRange = range(4) 46 | else: dirRange = range(4) 47 | 48 | # find a new cell to add 49 | nlst = [] # list of available neighbors 50 | for i in dirRange: 51 | nx = cx + dx[i]; ny = cy + dy[i] 52 | if nx >= 0 and nx < mx and ny >= 0 and ny < my: 53 | if maze[ny][nx] == 1: 54 | ctr = 0 # of occupied neighbors must be 1 55 | for j in range(4): 56 | ex = nx + dx[j]; ey = ny + dy[j] 57 | if ex >= 0 and ex < mx and ey >= 0 and ey < my: 58 | if maze[ey][ex] == 0: ctr += 1 59 | if ctr == 1: nlst.append(i) 60 | 61 | # if 1 or more neighbors available then randomly select one and move 62 | if len(nlst) > 0: 63 | ir = nlst[np.random.randint(0, len(nlst))] 64 | cx += dx[ir]; cy += dy[ir]; maze[cy][cx] = 0 65 | stack.append((cx, cy, ir)) 66 | else: stack.pop() 67 | 68 | maze_tensor = np.zeros((size, size, 3)) 69 | maze_tensor[:,:,BLOCK] = 1 70 | maze_tensor[1:-1, 1:-1, BLOCK] = maze 71 | maze_tensor[start_y+1][start_x+1][AGENT] = 1 72 | 73 | while holes > 0: 74 | removable = [] 75 | for y in range(0, my): 76 | for x in range(0, mx): 77 | if maze_tensor[y+1][x+1][BLOCK] == 1: 78 | if maze_tensor[y][x+1][BLOCK] == 1 and maze_tensor[y+2][x+1][BLOCK] == 1 and \ 79 | maze_tensor[y+1][x][BLOCK] == 0 and maze_tensor[y+1][x+2][BLOCK] == 0: 80 | removable.append((y+1, x+1)) 81 | elif maze_tensor[y][x+1][BLOCK] == 0 and maze_tensor[y+2][x+1][BLOCK] == 0 and \ 82 | maze_tensor[y+1][x][BLOCK] == 1 and maze_tensor[y+1][x+2][BLOCK] == 1: 83 | removable.append((y+1, x+1)) 84 | 85 | if len(removable) == 0: 86 | break 87 | 88 | idx = np.random.randint(0, len(removable)) 89 | maze_tensor[removable[idx][0]][removable[idx][1]][BLOCK] = 0 90 | holes -= 1 91 | 92 | return maze_tensor, start_y+1, start_x+1 93 | 94 | def find_empty_loc(maze): 95 | size = maze.shape[0] 96 | # Randomly determine a goal position 97 | for i in range(300): 98 | y = np.random.randint(0, size-2) + 1 99 | x = np.random.randint(0, size-2) + 1 100 | if np.sum(maze[y][x]) == 0: 101 | return [y, x] 102 | 103 | raise AttributeError("Cannot find an empty location in 300 trials") 104 | 105 | def generate_maze_with_multiple_goal(size, num_goal=1, holes=3): 106 | maze, start_y, start_x = generate_maze(size, holes=holes) 107 | 108 | # Randomly determine agent position 109 | maze[start_y][start_x][AGENT] = 0 110 | agent_pos = find_empty_loc(maze) 111 | maze[agent_pos[0]][agent_pos[1]][AGENT] = 1 112 | 113 | object_pos = [[],[],[],[]] 114 | for i in range(num_goal): 115 | pos = find_empty_loc(maze) 116 | maze[pos[0]][pos[1]][GOAL] = 1 117 | object_pos[GOAL].append(pos) 118 | 119 | return maze, agent_pos, object_pos 120 | 121 | def visualize_maze(maze, img_size=320): 122 | my = maze.shape[0] 123 | mx = maze.shape[1] 124 | colors = np.array(COLOR, np.uint8) 125 | num_channel = maze.shape[2] 126 | vis_maze = np.matmul(maze, colors[:num_channel]) 127 | vis_maze = vis_maze.astype(np.uint8) 128 | for i in range(vis_maze.shape[0]): 129 | for j in range(vis_maze.shape[1]): 130 | if maze[i][j].sum() == 0.0: 131 | vis_maze[i][j][:] = int(255) 132 | image = Image.fromarray(vis_maze) 133 | return image.resize((int(float(img_size) * mx / my), img_size), Image.NEAREST) 134 | 135 | def visualize_mazes(maze, img_size=320): 136 | if maze.ndim == 3: 137 | return visualize_maze(maze, img_size=img_size) 138 | elif maze.ndim == 4: 139 | n = maze.shape[0] 140 | size = maze.shape[1] 141 | dim = maze.shape[-1] 142 | concat_m = maze.transpose((1,0,2,3)).reshape((size, n * size, dim)) 143 | return visualize_maze(concat_m, img_size=img_size) 144 | else: 145 | raise ValueError("maze should be 3d or 4d tensor") 146 | 147 | def to_string(maze): 148 | my = maze.shape[0] 149 | mx = maze.shape[1] 150 | str = '' 151 | for y in range(my): 152 | for x in range(mx): 153 | if maze[y][x][BLOCK] == 1: 154 | str += '#' 155 | elif maze[y][x][AGENT] == 1: 156 | str += 'o' 157 | elif maze[y][x][GOAL] == 1: 158 | str += 'x' 159 | else: 160 | str += ' ' 161 | str += '\n' 162 | return str 163 | 164 | class Maze(object): 165 | def __init__(self, size=10, num_goal=1, holes=0): 166 | self.size = size 167 | self.dx = [0, 1, 0, -1] 168 | self.dy = [-1, 0, 1, 0] 169 | self.num_goal = num_goal 170 | self.holes = holes 171 | self.reset() 172 | 173 | def reset(self): 174 | self.maze, self.agent_pos, self.obj_pos = \ 175 | generate_maze_with_multiple_goal(self.size, num_goal=self.num_goal, 176 | holes=self.holes) 177 | 178 | def is_reachable(self, y, x): 179 | return self.maze[y][x][BLOCK] == 0 180 | 181 | def is_branch(self, y, x): 182 | if self.maze[y][x][BLOCK] == 1: 183 | return False 184 | neighbor_count = 0 185 | for i in range(4): 186 | new_y = y + self.dy[i] 187 | new_x = x + self.dx[i] 188 | if self.maze[new_y][new_x][BLOCK] == 0: 189 | neighbor_count += 1 190 | return neighbor_count > 2 191 | 192 | def is_agent_on_branch(self): 193 | return self.is_branch(self.agent_pos[0], self.agent_pos[1]) 194 | 195 | def is_end_of_corridor(self, y, x, direction): 196 | return self.maze[y + self.dy[direction]][x + self.dx[direction]][BLOCK] == 1 197 | 198 | def is_agent_on_end_of_corridor(self, direction): 199 | return self.is_end_of_corridor(self.agent_pos[0], self.agent_pos[1], direction) 200 | 201 | def move_agent(self, direction): 202 | y = self.agent_pos[0] + self.dy[direction] 203 | x = self.agent_pos[1] + self.dx[direction] 204 | if not self.is_reachable(y, x): 205 | return False 206 | self.maze[self.agent_pos[0]][self.agent_pos[1]][AGENT] = 0 207 | self.maze[y][x][AGENT] = 1 208 | self.agent_pos = [y, x] 209 | return True 210 | 211 | def is_object_reached(self, obj_idx): 212 | if self.maze.shape[2] <= obj_idx: 213 | return False 214 | return self.maze[self.agent_pos[0]][self.agent_pos[1]][obj_idx] == 1 215 | 216 | def is_empty(self, y, x): 217 | return np.sum(self.maze[y][x]) == 0 218 | 219 | def remove_object(self, y, x, obj_idx): 220 | removed = self.maze[y][x][obj_idx] == 1 221 | self.maze[y][x][obj_idx] = 0 222 | self.obj_pos[obj_idx].remove([y, x]) 223 | return removed 224 | 225 | def remaining_goal(self): 226 | return self.remaining_object(GOAL) 227 | 228 | def remaining_object(self, obj_idx): 229 | return len(self.obj_pos[obj_idx]) 230 | 231 | def add_object(self, y, x, obj_idx): 232 | if self.is_empty(y, x): 233 | self.maze[y][x][obj_idx] = 1 234 | self.obj_pos[obj_idx].append([y, x]) 235 | else: 236 | ValueError("%d, %d is not empty" % (y, x)) 237 | 238 | def move_object_random(self, prob, obj_idx): 239 | pos_copy = copy.deepcopy(self.obj_pos[obj_idx]) 240 | for pos in pos_copy: 241 | if not hasattr(self, "goal_move_prob"): 242 | self.goal_move_prob = np.random.rand(1000) 243 | self.goal_move_idx = 0 244 | else: 245 | self.goal_move_idx = (self.goal_move_idx + 1) \ 246 | % self.goal_move_prob.size 247 | if self.goal_move_prob[self.goal_move_idx] < prob: 248 | possible_moves = [] 249 | for i in range(4): 250 | y = pos[0] + DY[i] 251 | x = pos[1] + DX[i] 252 | if self.is_empty(y, x): 253 | possible_moves.append(i) 254 | if len(possible_moves) > 0: 255 | self.move_object(pos, obj_idx, 256 | possible_moves[np.random.randint(len(possible_moves))]) 257 | 258 | def move_object(self, pos, obj_idx, direction): 259 | y = pos[0] + self.dy[direction] 260 | x = pos[1] + self.dx[direction] 261 | if not self.is_reachable(y, x): 262 | return False 263 | self.remove_object(pos[0], pos[1], obj_idx) 264 | self.add_object(y, x, obj_idx) 265 | return True 266 | 267 | def observation(self, clone=True): 268 | return np.array(self.maze, copy=clone) 269 | 270 | def visualize(self): 271 | return visualize_maze(self.maze) 272 | 273 | def to_string(self): 274 | return to_string(self.maze) 275 | 276 | class MazeEnv(object): 277 | def __init__(self, config="", verbose=1): 278 | self.config = BeautifulSoup(config, "lxml") 279 | # map 280 | self.size = int(self.config.maze["size"]) 281 | self.max_step = int(self.config.maze["time"]) 282 | self.holes = int(self.config.maze["holes"]) 283 | self.num_goal = int(self.config.object["num_goal"]) 284 | # reward 285 | self.default_reward = float(self.config.reward["default"]) 286 | self.goal_reward = float(self.config.reward["goal"]) 287 | self.lazy_reward = float(self.config.reward["lazy"]) 288 | # randomness 289 | self.prob_stop = float(self.config.random["p_stop"]) 290 | self.prob_goal_move = float(self.config.random["p_goal"]) 291 | # meta 292 | self.meta_remaining_time = str2bool(self.config.meta["remaining_time"]) if \ 293 | self.config.meta.has_attr("remaining_time") else False 294 | 295 | # log 296 | self.log_freq = 100 297 | self.log_t = 0 298 | self.max_history = 1000 299 | self.reward_history = [] 300 | self.length_history = [] 301 | self.verbose = verbose 302 | 303 | self.reset() 304 | self.action_space = gym.spaces.discrete.Discrete(4) 305 | self.observation_space = gym.spaces.box.Box(0, 1, self.observation().shape) 306 | 307 | 308 | def observation(self, clone=True): 309 | return self.maze.observation(clone=clone) 310 | 311 | def reset(self, reset_episode=True, holes=None): 312 | if reset_episode: 313 | self.t = 0 314 | self.episode_reward = 0 315 | self.last_step_reward = 0.0 316 | self.terminated = False 317 | 318 | holes = self.holes if holes is None else holes 319 | self.maze = Maze(self.size, num_goal=self.num_goal, 320 | holes=holes) 321 | 322 | return self.observation() 323 | 324 | def remaining_time(self, normalized=True): 325 | return float(self.max_step - self.t) / float(self.max_step) 326 | 327 | def last_reward(self): 328 | return self.last_step_reward 329 | 330 | def meta(self): 331 | meta = [] 332 | if self.meta_remaining_time: 333 | meta.append(self.remaining_time()) 334 | if len(meta) == 0: 335 | return None 336 | return meta 337 | 338 | def visualize(self): 339 | return self.maze.visualize() 340 | 341 | def to_string(self): 342 | return self.maze.to_string() 343 | 344 | def step(self, act): 345 | assert self.action_space.contains(act), "invalid action: %d" % act 346 | assert not self.terminated, "episode is terminated" 347 | self.t += 1 348 | 349 | self.object_reached = False 350 | self.rand_stopped = False 351 | if self.prob_stop > 0 and np.random.rand() < self.prob_stop: 352 | reward = self.default_reward 353 | self.rand_stopped = True 354 | else: 355 | moved = self.maze.move_agent(act) 356 | reward = self.default_reward if moved else self.lazy_reward 357 | 358 | if self.maze.is_object_reached(GOAL): 359 | self.object_reached = True 360 | reward = self.goal_reward 361 | self.maze.remove_object(self.maze.agent_pos[0], self.maze.agent_pos[1], GOAL) 362 | if self.maze.remaining_goal() == 0: 363 | self.terminated = True 364 | 365 | if self.t >= self.max_step: 366 | self.terminated = True 367 | 368 | self.episode_reward += reward 369 | self.last_step_reward = reward 370 | 371 | to_log = None 372 | if self.terminated: 373 | if self.verbose > 0: 374 | logger.info('Episode terminating: episode_reward=%s episode_length=%s', 375 | self.episode_reward, self.t) 376 | self.log_episode(self.episode_reward, self.t) 377 | if self.log_t < self.log_freq: 378 | self.log_t += 1 379 | else: 380 | to_log = {} 381 | to_log["global/episode_reward"] = self.reward_mean(self.log_freq) 382 | to_log["global/episode_length"] = self.length_mean(self.log_freq) 383 | self.log_t = 0 384 | else: 385 | if self.prob_goal_move > 0: 386 | self.maze.move_object_random(self.prob_goal_move, GOAL) 387 | # print("goal_moved") 388 | 389 | return self.observation(), reward, self.terminated, to_log, 1 390 | 391 | def log_episode(self, reward, length): 392 | self.reward_history.insert(0, reward) 393 | self.length_history.insert(0, length) 394 | while len(self.reward_history) > self.max_history: 395 | self.reward_history.pop() 396 | self.length_history.pop() 397 | 398 | def reward_mean(self, num): 399 | return np.asarray(self.reward_history[:num]).mean() 400 | 401 | def length_mean(self, num): 402 | return np.asarray(self.length_history[:num]).mean() 403 | 404 | def tf_visualize(self, x): 405 | colors = np.array(COLOR, np.uint8) 406 | colors = colors.astype(np.float32) / 255.0 407 | color = tf.constant(colors) 408 | obs_dim = self.observation_space.shape[-1] 409 | x = x[:, :, :, :obs_dim] 410 | xdim = x.get_shape() 411 | x = tf.clip_by_value(x, 0, 1) 412 | bg = tf.ones((tf.shape(x)[0], int(x.shape[1]), int(x.shape[2]), 3)) 413 | w = tf.minimum(tf.expand_dims(tf.reduce_sum(x, axis=xdim.ndims-1), -1), 1.0) 414 | w = tf.reshape(tf.tile(w, [1, 1, 1, 3]), tf.shape(bg)) 415 | fg = tf.reshape(tf.matmul(tf.reshape(x, (-1, int(xdim[-1]))), 416 | color[:xdim[-1], :]), tf.shape(bg)) 417 | return bg * (1.0 - w) + fg 418 | 419 | class MazeSMDP(MazeEnv): 420 | def __init__(self, gamma=0.99, *args, **kwargs): 421 | super(MazeSMDP, self).__init__(*args, **kwargs) 422 | self.gamma = gamma 423 | self.prob_slip = float(self.config.random["p_slip"]) 424 | 425 | def step(self, act): 426 | assert self.action_space.contains(act), "invalid action: %d" % act 427 | assert not self.terminated, "episode is terminated" 428 | 429 | reward = 0 430 | steps = 0 431 | time = 0 432 | gamma = 1.0 433 | self.last_observation = self.maze.observation() 434 | while not self.terminated: 435 | _, r, _, to_log, t = super(MazeSMDP, self).step(act) 436 | reward += r * gamma 437 | steps += 1 438 | time += t 439 | gamma = gamma * self.gamma 440 | if not self.rand_stopped: 441 | if self.maze.is_agent_on_end_of_corridor(act): 442 | break 443 | if self.object_reached: 444 | break 445 | if self.maze.is_agent_on_branch(): 446 | if self.prob_slip > 0 and np.random.rand() < self.prob_slip: 447 | pass 448 | else: 449 | break 450 | 451 | self.last_step_reward = reward 452 | return self.observation(), reward, self.terminated, to_log, time 453 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.rnn as rnn 4 | import math 5 | 6 | act_fn = tf.nn.elu 7 | 8 | def normalized_columns_initializer(std=1.0): 9 | def _initializer(shape, dtype=None, partition_info=None): 10 | out = np.random.randn(*shape).astype(np.float32) 11 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) 12 | return tf.constant(out) 13 | return _initializer 14 | 15 | def he_initializer(fan_in, uniform=False, factor=2, seed=None): 16 | def _initializer(shape, dtype=None, partition_info=None): 17 | n = fan_in 18 | if uniform: 19 | # To get stddev = math.sqrt(factor / n) need to adjust for uniform. 20 | limit = math.sqrt(3.0 * factor / n) 21 | return tf.random_uniform(shape, -limit, limit, dtype, seed=seed) 22 | else: 23 | # To get stddev = math.sqrt(factor / n) need to adjust for truncated. 24 | trunc_stddev = math.sqrt(1.3 * factor / n) 25 | return tf.truncated_normal(shape, 0.0, trunc_stddev, dtype, seed=seed) 26 | return _initializer 27 | 28 | def flatten(x): 29 | if x.get_shape().ndims > 2: 30 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) 31 | return x 32 | 33 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, init="he"): 34 | with tf.variable_scope(name): 35 | stride_shape = [1, stride[0], stride[1], 1] 36 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] 37 | 38 | if init != "he": 39 | fan_in = np.prod(filter_shape[:3]) 40 | fan_out = np.prod(filter_shape[:2]) * num_filters 41 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 42 | w = tf.get_variable("W", filter_shape, dtype, 43 | tf.random_uniform_initializer(-w_bound, w_bound), 44 | collections=collections) 45 | else: 46 | fan_in = np.prod(filter_shape[:3]) 47 | w = tf.get_variable("W", filter_shape, dtype, he_initializer(fan_in), 48 | collections=collections) 49 | b = tf.get_variable("b", [1, 1, 1, num_filters], 50 | initializer=tf.constant_initializer(0.0), 51 | collections=collections) 52 | return tf.nn.conv2d(x, w, stride_shape, pad) + b 53 | 54 | def linear(x, size, name, initializer=None, bias_init=0): 55 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer) 56 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init)) 57 | return tf.matmul(x, w) + b 58 | 59 | def categorical_sample(logits, d): 60 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1]) 61 | return tf.one_hot(value, d) 62 | 63 | def transform_fc(x, a, n_actions, name, bias_init=0, pad="SAME"): 64 | if x.shape.ndims > 2: 65 | x = flatten(x) 66 | xdim = int(x.get_shape()[1]) 67 | w = tf.get_variable(name + "/w", [n_actions, xdim], 68 | initializer=normalized_columns_initializer(0.1)) 69 | if a is not None: 70 | # Transform only for the given action 71 | mul = x * w[tf.to_int32(tf.squeeze(tf.argmax(a, axis=1))), :] 72 | else: 73 | # Enumerate all possible actions and concatenate them 74 | transformed = [] 75 | for i in range(0, n_actions): 76 | transformed.append(x * w[i, :]) 77 | mul = pack(tf.concat(transformed, 1), [xdim]) 78 | 79 | h = linear(mul, xdim, name + "_dec", 80 | initializer=normalized_columns_initializer(0.1), bias_init=bias_init) 81 | return act_fn(h) 82 | 83 | def transform_conv_state(x, a, n_actions, filter_size=(3, 3), pad="SAME"): 84 | # 3x3 option-conv -> 3x3 conv * 1x1 mask (with residual connection) 85 | stride_shape = [1, 1, 1, 1] 86 | dec_f_size = filter_size[0] 87 | num_filters = int(x.get_shape()[3]) 88 | xdim = [int(x.get_shape()[1]), int(x.get_shape()[2]), num_filters] 89 | filter_shape = [filter_size[0], filter_size[1], num_filters, n_actions, num_filters] 90 | fan_in = np.prod(filter_shape[:3]) 91 | dec_filter_shape = [dec_f_size, dec_f_size, num_filters, num_filters] 92 | w = tf.get_variable("W", filter_shape, initializer=he_initializer(fan_in)) 93 | b = tf.get_variable("b", [1, 1, 1, n_actions, num_filters], 94 | initializer=tf.constant_initializer(0.0)) 95 | w_dec = tf.get_variable("dec1-W", dec_filter_shape, 96 | initializer=he_initializer(fan_in)) 97 | b_dec = tf.get_variable("dec1-b", [1, 1, 1, num_filters], 98 | initializer=tf.constant_initializer(0.0)) 99 | w_dec2 = tf.get_variable("dec2-W", dec_filter_shape, 100 | initializer=he_initializer(fan_in)) 101 | b_dec2 = tf.get_variable("dec2-b", [1, 1, 1, num_filters], 102 | initializer=tf.constant_initializer(0.0)) 103 | w_gate = tf.get_variable("gate-W", [1, 1, num_filters, num_filters], 104 | initializer=he_initializer(num_filters)) 105 | b_gate = tf.get_variable("gate-b", [1, 1, 1, num_filters], 106 | initializer=tf.constant_initializer(0.0)) 107 | if a is not None: 108 | idx = tf.to_int32(tf.squeeze(tf.argmax(a, axis=1))) 109 | conv = tf.nn.conv2d(x, w[:, :, :, idx, :], stride_shape, pad) + b[:, :, :, idx, :] 110 | conv = act_fn(conv) 111 | else: 112 | w = tf.reshape(w, [filter_size[0], filter_size[1], num_filters, n_actions * num_filters]) 113 | b = tf.reshape(b, [1, 1, 1, n_actions * num_filters]) 114 | conv = act_fn(tf.nn.conv2d(x, w, stride_shape, pad) + b) 115 | conv = pack(tf.transpose(tf.reshape(conv, 116 | [-1, xdim[0], xdim[1], n_actions, num_filters]), 117 | [0, 3, 1, 2, 4]), xdim) 118 | conv = act_fn(tf.nn.conv2d(conv, w_dec, stride_shape, pad) + b_dec) 119 | gate = tf.sigmoid(tf.nn.conv2d(conv, w_gate, stride_shape, pad) + b_gate) 120 | conv = tf.nn.conv2d(conv, w_dec2, stride_shape, pad) + b_dec2 121 | if a is not None: 122 | conv = conv * gate + x 123 | else: 124 | conv = tf.transpose(tf.reshape(conv, [-1, n_actions] + xdim), [1, 0, 2, 3, 4]) 125 | gate = tf.transpose(tf.reshape(gate, [-1, n_actions] + xdim), [1, 0, 2, 3, 4]) 126 | conv = conv * gate + x 127 | conv = pack(tf.transpose(conv, [1, 0, 2, 3, 4]), xdim) 128 | return act_fn(conv) 129 | 130 | def transform_conv_pred(x, a, n_actions, filter_size=(3, 3), pad="SAME"): 131 | # 3x3 option-conv -> 3x3 conv 132 | stride_shape = [1, 1, 1, 1] 133 | dec_f_size = filter_size[0] 134 | num_filters = int(x.get_shape()[3]) 135 | xdim = [int(x.get_shape()[1]), int(x.get_shape()[2]), num_filters] 136 | filter_shape = [filter_size[0], filter_size[1], num_filters, n_actions, num_filters] 137 | fan_in = np.prod(filter_shape[:3]) 138 | fan_out = np.prod(filter_shape[:2]) * num_filters 139 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 140 | w = tf.get_variable("W", filter_shape, 141 | initializer=tf.random_uniform_initializer(-w_bound, w_bound)) 142 | b = tf.get_variable("b", [1, 1, 1, n_actions, num_filters], 143 | initializer=tf.constant_initializer(0.0)) 144 | w_dec = tf.get_variable("W-dec", [dec_f_size, dec_f_size, num_filters, num_filters], 145 | initializer=tf.random_uniform_initializer(-w_bound, w_bound)) 146 | b_dec = tf.get_variable("b-dec", [1, 1, 1, num_filters], 147 | initializer=tf.constant_initializer(0.0)) 148 | if a is not None: 149 | idx = tf.to_int32(tf.squeeze(tf.argmax(a, axis=1))) 150 | conv = tf.nn.conv2d(x, w[:, :, :, idx, :], stride_shape, pad) + b[:, :, :, idx, :] 151 | conv = act_fn(conv) 152 | else: 153 | w = tf.reshape(w, [filter_size[0], filter_size[1], num_filters, n_actions * num_filters]) 154 | b = tf.reshape(b, [1, 1, 1, n_actions * num_filters]) 155 | conv = act_fn(tf.nn.conv2d(x, w, stride_shape, pad) + b) 156 | conv = pack(tf.transpose(tf.reshape(conv, 157 | [-1, xdim[0], xdim[1], n_actions, num_filters]), 158 | [0, 3, 1, 2, 4]), xdim) 159 | conv = tf.nn.conv2d(conv, w_dec, stride_shape, pad) + b_dec 160 | return act_fn(conv) 161 | 162 | def pack(x, dim): 163 | return tf.reshape(x, [-1] + dim) 164 | 165 | def to_value(x, dim=256, initializer=None, bias_init=0): 166 | if x.shape.ndims == 2: # fc layer 167 | return linear(x, 1, "v", initializer=initializer, bias_init=bias_init) 168 | else: # conv layer 169 | x = act_fn(linear(flatten(x), dim, "v1", 170 | initializer=tf.contrib.layers.xavier_initializer(), bias_init=bias_init)) 171 | return linear(x, 1, "v", initializer=initializer, bias_init=bias_init) 172 | 173 | def to_pred(x, dim=256, initializer=None, bias_init=0): 174 | return linear(flatten(x), dim, "p", initializer=initializer, bias_init=bias_init) 175 | 176 | def to_reward(x, dim=256, initializer=None, bias_init=0): 177 | x = act_fn(linear(flatten(x), dim, "r1", 178 | initializer=tf.contrib.layers.xavier_initializer(), bias_init=bias_init)) 179 | return linear(x, 1, "r", initializer=initializer, bias_init=bias_init) 180 | 181 | def to_steps(x, dim=256, initializer=None, bias_init=0): 182 | x = act_fn(linear(flatten(x), dim, "t1", 183 | initializer=tf.contrib.layers.xavier_initializer(), bias_init=bias_init)) 184 | return linear(x, 1, "t", initializer=initializer, bias_init=bias_init) 185 | 186 | def rollout_step(x, a, n_actions, op_trans_state, op_trans_pred, 187 | op_value, op_steps, op_reward, gamma=0.98): 188 | state = op_trans_state(x, a) 189 | p = op_trans_pred(x, a) 190 | if a is not None: 191 | v_next = op_value(state) 192 | r = op_reward(p) 193 | t = op_steps(p) 194 | else: 195 | v_next = pack(op_value(state), [n_actions]) 196 | r = pack(op_reward(p), [n_actions]) 197 | t = pack(op_steps(p), [n_actions]) 198 | 199 | t = tf.nn.relu(t) + 1 200 | g = tf.pow(tf.constant(gamma), t) 201 | return r, g, t, v_next, state 202 | 203 | def predict_over_time(x, a, n_actions, op_rollout, prediction_step=5): 204 | time_steps = tf.shape(a)[0] 205 | xdim = x.get_shape().as_list()[1:] 206 | 207 | def _create_ta(name, dtype, size, clear=True): 208 | return tf.TensorArray(dtype=dtype, 209 | size=size, tensor_array_name=name, 210 | clear_after_read=clear) 211 | 212 | v_ta = _create_ta("output_v", x.dtype, time_steps) 213 | r_ta = _create_ta("output_r", x.dtype, time_steps) 214 | g_ta = _create_ta("output_g", x.dtype, time_steps) 215 | t_ta = _create_ta("output_t", x.dtype, time_steps) 216 | q_ta = _create_ta("output_q", x.dtype, time_steps) 217 | s_ta = _create_ta("output_s", x.dtype, time_steps) 218 | 219 | x_ta = _create_ta("input_x", x.dtype, time_steps).unstack(x) 220 | a_ta = _create_ta("input_a", x.dtype, time_steps).unstack(a) 221 | 222 | time = tf.constant(0, dtype=tf.int32) 223 | roll_step = tf.minimum(prediction_step, time_steps) 224 | state = tf.zeros([roll_step] + xdim) 225 | 226 | def _time_step(time, r_ta, g_ta, t_ta, v_ta, q_ta, s_ta, state): 227 | a_t = a_ta.read(time) 228 | a_t = tf.expand_dims(a_t, 0) 229 | 230 | # stack previously generated states with the new state through batch 231 | x_t = x_ta.read(time) 232 | x_t = tf.expand_dims(x_t, 0) 233 | state = tf.concat([x_t, tf.slice(state, [0] * (len(xdim) + 1), 234 | [roll_step-1] + xdim)], 0) 235 | 236 | r, gamma, t, v_next, state = op_rollout(state, a_t) 237 | q = r + gamma * v_next 238 | r_ta = r_ta.write(time, tf.reshape(r, [-1])) 239 | g_ta = g_ta.write(time, tf.reshape(gamma, [-1])) 240 | t_ta = t_ta.write(time, tf.reshape(t, [-1])) 241 | v_ta = v_ta.write(time, tf.reshape(v_next, [-1])) 242 | q_ta = q_ta.write(time, tf.reshape(q, [-1])) 243 | s_ta = s_ta.write(time, state) 244 | return (time+1, r_ta, g_ta, t_ta, v_ta, q_ta, s_ta, state) 245 | 246 | _, r_ta, g_ta, t_ta, v_ta, q_ta, s_ta, state = tf.while_loop( 247 | cond=lambda time, *_: time < time_steps, 248 | body=_time_step, 249 | loop_vars=(time, v_ta, r_ta, g_ta, t_ta, q_ta, s_ta, state)) 250 | 251 | r = r_ta.stack() 252 | g = g_ta.stack() 253 | t = t_ta.stack() 254 | v = v_ta.stack() 255 | q = q_ta.stack() 256 | s = s_ta.stack() 257 | 258 | return r, g, t, v, q, s 259 | 260 | class Model(object): 261 | def __init__(self, ob_space, n_actions, type, 262 | gamma=0.99, prediction_step=1, 263 | dim=256, 264 | f_num=[32,32,64], 265 | f_stride=[1,1,2], 266 | f_size=[3,3,4], 267 | f_pad="SAME", 268 | branch=[4,4,4], 269 | meta_dim=0): 270 | self.n_actions = n_actions 271 | self.type = type 272 | self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) 273 | self.a = tf.placeholder(tf.float32, [None, n_actions]) 274 | self.meta = tf.placeholder(tf.float32, [None, meta_dim]) if meta_dim > 0 else None 275 | self.state_init = [] 276 | self.state_in = [] 277 | self.state_out = [] 278 | self.dim = dim 279 | self.f_num = f_num 280 | self.f_stride = f_stride 281 | self.f_size = f_size 282 | self.f_pad = f_pad 283 | self.meta_dim = meta_dim 284 | self.xdim = list(ob_space) 285 | self.branch = [min(n_actions, k) for k in branch] 286 | 287 | self.s, self.state_in, self.state_out = self.build_model(self.x, self.meta) 288 | self.sdim = self.s.get_shape().as_list()[1:] 289 | 290 | # output layer 291 | if self.type == 'policy': 292 | self.logits = linear(flatten(self.s), n_actions, "action", 293 | normalized_columns_initializer(0.01)) 294 | self.vf = tf.reshape(linear(flatten(self.s), 1, "value", 295 | normalized_columns_initializer(1.0)), [-1]) 296 | self.sample = categorical_sample(self.logits, n_actions)[0, :] 297 | elif self.type == 'q': 298 | h = transform_conv_state(self.s, None, n_actions) 299 | self.h = linear(flatten(h), self.dim, "fc", 300 | normalized_columns_initializer(0.01)) 301 | self.q = pack(linear(self.h, 1, "action", 302 | normalized_columns_initializer(0.01)), [n_actions]) 303 | self.sample = tf.one_hot(tf.squeeze(tf.argmax(self.q, axis=1)), n_actions) 304 | self.qmax = tf.reduce_max(self.q, axis=[1]) 305 | elif self.type == 'vpn': 306 | self.op_value = tf.make_template('v', to_value, dim=self.dim, 307 | initializer=normalized_columns_initializer(0.01)) 308 | self.op_reward = tf.make_template('r', to_reward, dim=self.dim, 309 | initializer=normalized_columns_initializer(0.01)) 310 | self.op_steps = tf.make_template('t', to_steps, dim=self.dim, 311 | initializer=normalized_columns_initializer(0.01)) 312 | self.op_trans_state = tf.make_template('trans_state', transform_conv_state, 313 | n_actions=n_actions) 314 | self.op_trans_pred = tf.make_template('trans_pred', transform_conv_pred, 315 | n_actions=n_actions) 316 | self.op_rollout = tf.make_template('rollout', rollout_step, 317 | n_actions=n_actions, 318 | op_trans_state=self.op_trans_state, 319 | op_trans_pred=self.op_trans_pred, 320 | op_value=self.op_value, 321 | op_steps=self.op_steps, 322 | op_reward=self.op_reward, 323 | gamma=gamma) 324 | 325 | # Unconditional rollout 326 | self.r, self.gamma, self.steps, self.v_next, self.state = self.op_rollout(self.s, None) 327 | self.q = self.r + self.gamma * self.v_next 328 | 329 | # Action-conditional rollout over time for training 330 | self.r_a, self.gamma_a, self.t_a, self.v_next_a, self.q_a, self.states = \ 331 | predict_over_time(self.s, self.a, n_actions, self.op_rollout, 332 | prediction_step=prediction_step) 333 | 334 | # Tree expansion/backup 335 | depth = len(self.branch) 336 | q_list = [] 337 | r_list = [] 338 | g_list = [] 339 | v_list = [] 340 | idx_list = [] 341 | s_list = [] 342 | s = self.s 343 | 344 | # Expansion 345 | for i in range(depth): 346 | r, gamma, _, v, s = self.op_rollout(s, None) 347 | r_list.append(tf.squeeze(r)) 348 | v_list.append(tf.squeeze(v)) 349 | s_list.append(s) 350 | g_list.append(tf.squeeze(gamma)) 351 | 352 | b = self.branch[i] 353 | q_list.append(r_list[i] + g_list[i] * v_list[i]) 354 | q_list[i] = tf.reshape(q_list[i], [-1, self.n_actions]) 355 | _, idx = tf.nn.top_k(q_list[i], k=b) 356 | idx_list.append(idx) 357 | 358 | l = tf.tile(tf.expand_dims(tf.range(0, tf.shape(idx)[0]), 1), [1, b]) 359 | l = tf.concat([tf.reshape(l, [-1, 1]), tf.reshape(idx, [-1, 1])], axis=1) 360 | s = tf.reshape(tf.gather_nd( 361 | tf.reshape(s, [-1, self.n_actions] + self.sdim), l), [-1] + self.sdim) 362 | r_list[i] = tf.reshape(tf.gather_nd( 363 | tf.reshape(r_list[i], [-1, self.n_actions]), l), [-1]) 364 | g_list[i] = tf.reshape(tf.gather_nd( 365 | tf.reshape(g_list[i], [-1, self.n_actions]), l), [-1]) 366 | v_list[i] = tf.reshape(tf.gather_nd( 367 | tf.reshape(v_list[i], [-1, self.n_actions]), l), [-1]) 368 | 369 | self.q_list = q_list 370 | self.r_list = r_list 371 | self.g_list = g_list 372 | self.v_list = v_list 373 | self.s_list = s_list 374 | self.idx_list = idx_list 375 | 376 | # Backup 377 | v_plan = [None] * depth 378 | q_plan = [None] * depth 379 | 380 | v_plan[-1] = v_list[-1] 381 | for i in reversed(range(0, depth)): 382 | q_plan[i] = r_list[i] + g_list[i] * v_plan[i] 383 | if i > 0: 384 | q_max = tf.reduce_max(tf.reshape(q_plan[i], [-1, self.branch[i]]), axis=1) 385 | n = float(depth - i) 386 | v_plan[i-1] = (v_list[i-1] + q_max * n) / (n + 1) 387 | 388 | idx = tf.squeeze(idx_list[0]) 389 | self.q_deep = tf.squeeze(q_plan[0]) 390 | self.q_plan = tf.sparse_to_dense(idx, [self.n_actions], self.q_deep, 391 | default_value=-100, validate_indices=False) 392 | 393 | self.x_off = tf.placeholder(tf.float32, [None] + list(ob_space)) 394 | self.a_off = tf.placeholder(tf.float32, [None, n_actions]) 395 | self.meta_off = tf.placeholder(tf.float32, [None, self.meta_dim]) \ 396 | if self.meta_dim > 0 else None 397 | tf.get_variable_scope().reuse_variables() 398 | self.s_off, self.state_in_off, self.state_out_off = \ 399 | self.build_model(self.x_off, self.meta_off) 400 | 401 | # Action-conditional rollout over time for training 402 | self.r_off, self.gamma_off, self.t_off, self.v_next_off, _, _ = \ 403 | predict_over_time(self.s_off, self.a_off, n_actions, self.op_rollout, 404 | prediction_step=prediction_step) 405 | 406 | else: 407 | raise ValueError('Invalid model type %s' % (self.type)) 408 | 409 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 410 | tf.get_variable_scope().name) 411 | 412 | self.num_param = 0 413 | for v in self.var_list: 414 | self.num_param += v.get_shape().num_elements() 415 | 416 | def is_recurrent(self): 417 | return self.state_in is not None and len(self.state_in) > 0 418 | 419 | def get_initial_features(self): 420 | return self.state_init 421 | 422 | def act(self, ob, state_in=[], meta=None): 423 | sess = tf.get_default_session() 424 | feed_dict = {self.x: [ob]} 425 | for i in range(len(state_in)): 426 | feed_dict[self.state_in[i]] = state_in[i] 427 | if self.meta_dim > 0: 428 | feed_dict[self.meta] = [meta] 429 | 430 | if self.type == 'policy': 431 | return sess.run([self.sample, self.vf] + self.state_out, feed_dict) 432 | elif self.type == 'q': 433 | return sess.run([self.sample] + self.state_out, feed_dict) 434 | elif self.type == 'vpn': 435 | out = sess.run([self.q_plan] + self.state_out, feed_dict) 436 | q = out[0] 437 | state_out = out[1:] 438 | act = np.zeros_like(q) 439 | act[q.argmax()] = 1 440 | return [act] + state_out 441 | 442 | def update_state(self, ob, state_in=[], meta=None): 443 | sess = tf.get_default_session() 444 | feed_dict = {self.x: [ob]} 445 | for i in range(len(state_in)): 446 | feed_dict[self.state_in[i]] = state_in[i] 447 | if self.meta_dim > 0: 448 | feed_dict[self.meta] = [meta] 449 | return sess.run(self.state_out, feed_dict) 450 | 451 | def value(self, ob, state_in=[], meta=None): 452 | sess = tf.get_default_session() 453 | feed_dict = {self.x: [ob]} 454 | for i in range(len(state_in)): 455 | feed_dict[self.state_in[i]] = state_in[i] 456 | if self.meta_dim > 0: 457 | feed_dict[self.meta] = [meta] 458 | 459 | if self.type == 'policy': 460 | return sess.run(self.vf, feed_dict)[0] 461 | elif self.type == 'q': 462 | return sess.run(self.qmax, feed_dict)[0] 463 | elif self.type == 'vpn': 464 | q = sess.run(self.q_plan, feed_dict) 465 | return q.max() 466 | 467 | class CNN(Model): 468 | def __init__(self, *args, **kwargs): 469 | super(CNN, self).__init__(*args, **kwargs) 470 | 471 | def build_model(self, x, meta=None): 472 | for i in range(len(self.f_num)): 473 | x = act_fn(conv2d(x, self.f_num[i], "l{}".format(i+1), 474 | [self.f_size[i], self.f_size[i]], 475 | [self.f_stride[i], self.f_stride[i]], pad=self.f_pad, 476 | init="he")) 477 | self.conv = x 478 | if meta is not None: 479 | space_dim = x.get_shape().as_list()[1:3] 480 | meta_dim = meta.get_shape().as_list()[-1] 481 | t = tf.reshape(tf.tile(meta, 482 | [1, np.prod(space_dim)]), [-1] + space_dim + [meta_dim]) 483 | x = tf.concat([t, x], axis=3) 484 | 485 | return x, [], [] 486 | 487 | class LSTM(Model): 488 | def __init__(self, *args, **kwargs): 489 | super(LSTM, self).__init__(*args, **kwargs) 490 | 491 | def build_model(self, x): 492 | for i in range(len(self.f_num)): 493 | x = act_fn(conv2d(x, self.f_num[i], "l{}".format(i+1), 494 | [self.f_size[i], self.f_size[i]], 495 | [self.f_stride[i], self.f_stride[i]], pad=self.f_pad)) 496 | self.conv = x 497 | x = act_fn(linear(flatten(x), 256, "l{}".format(3), 498 | normalized_columns_initializer(0.01))) 499 | 500 | # introduce a "fake" batch dimension of 1 after flatten 501 | # so that we can do LSTM over time dim 502 | x = tf.expand_dims(x, [0]) 503 | 504 | size = 256 505 | lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True) 506 | self.state_size = lstm.state_size 507 | step_size = tf.shape(self.x)[:1] 508 | 509 | c_init = np.zeros((1, lstm.state_size.c), np.float32) 510 | h_init = np.zeros((1, lstm.state_size.h), np.float32) 511 | self.state_init = [c_init, h_init] 512 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) 513 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) 514 | 515 | state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) 516 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn( 517 | lstm, x, initial_state=state_in, sequence_length=step_size, 518 | time_major=False) 519 | lstm_c, lstm_h = lstm_state 520 | state_out = [lstm_c[:1, :], lstm_h[:1, :]] 521 | 522 | x = tf.reshape(lstm_outputs, [-1, size]) 523 | return x, state_in, state_out 524 | -------------------------------------------------------------------------------- /q.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | import model 5 | import util 6 | from async import AsyncSolver 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.INFO) 10 | 11 | class Q(AsyncSolver): 12 | def define_network(self, name): 13 | self.args.meta_dim = 0 if self.env.meta() is None else len(self.env.meta()) 14 | return eval("model." + name)(self.env.observation_space.shape, 15 | self.env.action_space.n, type='q', 16 | gamma=self.args.gamma, 17 | dim=self.args.dim, 18 | f_num=self.args.f_num, 19 | f_pad=self.args.f_pad, 20 | f_stride=self.args.f_stride, 21 | f_size=self.args.f_size, 22 | meta_dim=self.args.meta_dim, 23 | ) 24 | 25 | def use_target_network(self): 26 | return True 27 | 28 | def process_rollout(self, rollout, gamma, lambda_=1.0): 29 | """ 30 | given a rollout, compute its returns 31 | """ 32 | batch_si = np.asarray(rollout.states) 33 | batch_a = np.asarray(rollout.actions) 34 | rewards = np.asarray(rollout.rewards) 35 | time = np.asarray(rollout.time) 36 | meta = np.asarray(rollout.meta) 37 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) 38 | batch_r = util.discount(rewards_plus_v, gamma, time)[:-1] 39 | features = rollout.features[0] 40 | 41 | return util.Batch(si=batch_si, 42 | a=batch_a, 43 | adv=None, 44 | r=batch_r, 45 | terminal=rollout.terminal, 46 | features=features, 47 | reward=rewards, 48 | step=time, 49 | meta=meta, 50 | ) 51 | 52 | def init_variables(self): 53 | pi = self.local_network 54 | 55 | # target network is synchronized after every 10,000 steps 56 | self.local_to_target = tf.group(*[v1.assign(v2) for v1, v2 in 57 | zip(self.global_target_network.var_list, pi.var_list)]) 58 | self.target_sync = tf.group(*[v1.assign(v2) for v1, v2 in 59 | zip(self.target_network.var_list, self.global_target_network.var_list)]) 60 | self.sync_count = 0 61 | self.target_sync_step = 0 62 | 63 | # epsilon 64 | self.eps = [1.0] 65 | self.eps_start = [1.0] 66 | self.eps_end = [self.args.eps] 67 | self.eps_prob = [1] 68 | self.anneal_step = self.args.eps_step 69 | 70 | # batch size 71 | self.bs = tf.to_float(tf.shape(pi.x)[0]) 72 | 73 | # loss function 74 | self.define_loss() 75 | 76 | def define_loss(self): 77 | pi = self.local_network 78 | 79 | # loss function 80 | self.ac = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="ac") 81 | self.r = tf.placeholder(tf.float32, [None], name="r") # target 82 | 83 | self.q_val = tf.reduce_sum(pi.q * self.ac, [1]) 84 | self.delta = self.q_val - self.r 85 | # clipping gradient to [-1, 1] amounts to using Huber loss 86 | self.q_loss = tf.reduce_sum(tf.where(tf.abs(self.delta) < 1, 87 | 0.5 * tf.square(self.delta), 88 | tf.abs(self.delta) - 0.5)) 89 | 90 | self.loss = self.q_loss 91 | 92 | def define_summary(self): 93 | super(Q, self).define_summary() 94 | tf.summary.scalar("loss/loss", self.loss / self.bs) 95 | if hasattr(self, "q_loss"): 96 | tf.summary.scalar("loss/q_loss", self.q_loss / self.bs) 97 | tf.summary.scalar("param/target_param_norm", 98 | tf.global_norm(self.target_network.var_list)) 99 | self.summary_op = tf.summary.merge_all() 100 | 101 | def start(self, sess, summary_writer): 102 | sess.run(self.sync) # copy weights from shared to local 103 | if self.task == 0: 104 | sess.run(self.local_to_target) # copy weights from local to shared target 105 | sess.run(self.target_sync) # copy weights from global target to local target 106 | super(Q, self).start(sess, summary_writer) 107 | 108 | def prepare_input(self, batch): 109 | feed_dict = {self.local_network.x: batch.si, 110 | self.ac: batch.a, 111 | self.r: batch.r} 112 | if self.args.meta_dim > 0: 113 | feed_dict[self.local_network.meta] = batch.meta 114 | for i in range(len(self.local_network.state_in)): 115 | feed_dict[self.local_network.state_in[i]] = batch.features[i] 116 | return feed_dict 117 | 118 | def post_process(self, sess): 119 | if self.task == 0: 120 | global_step = self.last_global_step 121 | if int(global_step / self.args.sync) > self.sync_count: 122 | # copy weights from local to shared target 123 | self.sync_count = int(global_step / self.args.sync) 124 | sess.run([self.local_to_target, self.target_sync, self.update_target_step]) 125 | logger.info("[Step: %d] Target network is synchronized", global_step) 126 | else: 127 | target_step = self.global_target_sync_step.eval() 128 | if target_step != self.target_sync_step: 129 | self.target_sync_step = target_step 130 | sess.run(self.target_sync) 131 | logger.info("[Step: %d] Target network is synchronized", target_step) 132 | 133 | for i in range(len(self.eps)): 134 | self.eps[i] = self.eps_start[i] 135 | self.eps[i] -= self.last_global_step * (self.eps_start[i] - self.eps_end[i])\ 136 | / self.anneal_step 137 | self.eps[i] = max(self.eps[i], self.eps_end[i]) 138 | 139 | def epsilon(self): 140 | return np.random.choice(self.eps, p=self.eps_prob) 141 | 142 | def write_extra_summary(self, rollout=None): 143 | summary = tf.Summary() 144 | summary.value.add(tag='model/epsilon', simple_value=float( 145 | np.sum(np.array(self.eps) * np.array(self.eps_prob)))) 146 | summary.value.add(tag='model/rollout_r', simple_value=float(rollout.r)) 147 | self.summary_writer.add_summary(summary, self.last_global_step) 148 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import go_vncdriver 3 | import tensorflow as tf 4 | from envs import create_env 5 | import subprocess as sp 6 | import util 7 | import model 8 | import numpy as np 9 | from worker import new_env 10 | 11 | parser = argparse.ArgumentParser(description="Run commands") 12 | parser.add_argument('-gpu', '--gpu', default=0, type=int, help='Number of GPUs') 13 | parser.add_argument('-r', '--remotes', default=None, 14 | help='The address of pre-existing VNC servers and rewarders to use' 15 | '(e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).') 16 | parser.add_argument('-e', '--env-id', type=str, default="maze", 17 | help="Environment id") 18 | parser.add_argument('-a', '--alg', type=str, default="VPN", help="Algorithm: [A3C | Q | VPN]") 19 | parser.add_argument('-mo', '--model', type=str, default="CNN", help="Name of model: [CNN | LSTM]") 20 | parser.add_argument('-ck', '--checkpoint', type=str, default="", help="Path of the checkpoint") 21 | parser.add_argument('-n', '--n-play', type=int, default=1000, help="Num of play") 22 | parser.add_argument('--eps', type=float, default=0.0, help="Epsilon-greedy") 23 | parser.add_argument('--config', type=str, default="", help="config xml file for environment") 24 | parser.add_argument('--seed', type=int, default=0, help="Random seed") 25 | 26 | # Hyperparameters 27 | parser.add_argument('-g', '--gamma', type=float, default=0.98, help="Discount factor") 28 | parser.add_argument('--dim', type=int, default=64, help="Number of final hidden units") 29 | parser.add_argument('--f-num', type=str, default='32,32,64', help="num of conv filters") 30 | parser.add_argument('--f-stride', type=str, default='1,1,2', help="stride of conv filters") 31 | parser.add_argument('--f-size', type=str, default='3,3,4', help="size of conv filters") 32 | parser.add_argument('--f-pad', type=str, default='SAME', help="padding of conv filters") 33 | 34 | # VPN parameters 35 | parser.add_argument('--branch', type=str, default="4,4,4", help="branching factor") 36 | 37 | def evaluate(env, agent, num_play=3000, eps=0.0): 38 | env.max_history = num_play 39 | for iter in range(0, num_play): 40 | last_state = env.reset() 41 | last_features = agent.get_initial_features() 42 | last_meta = env.meta() 43 | while True: 44 | # import pdb; pdb.set_trace() 45 | if eps == 0.0 or np.random.rand() > eps: 46 | fetched = agent.act(last_state, last_features, 47 | meta=last_meta) 48 | if agent.type == 'policy': 49 | action, features = fetched[0], fetched[2:] 50 | else: 51 | action, features = fetched[0], fetched[1:] 52 | else: 53 | act_idx = np.random.randint(0, env.action_space.n) 54 | action = np.zeros(env.action_space.n) 55 | action[act_idx] = 1 56 | features = [] 57 | 58 | state, reward, terminal, info, _ = env.step(action.argmax()) 59 | last_state = state 60 | last_features = features 61 | last_meta = env.meta() 62 | if terminal: 63 | break 64 | 65 | return env.reward_mean(num_play) 66 | 67 | def run(): 68 | args = parser.parse_args() 69 | args.task = 0 70 | args.f_num = util.parse_to_num(args.f_num) 71 | args.f_stride = util.parse_to_num(args.f_stride) 72 | args.f_size = util.parse_to_num(args.f_size) 73 | args.branch = util.parse_to_num(args.branch) 74 | 75 | env = new_env(args) 76 | args.meta_dim = 0 if env.meta() is None else len(env.meta()) 77 | device = '/gpu:0' if args.gpu > 0 else '/cpu:0' 78 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) 79 | config = tf.ConfigProto(device_filters=device, 80 | gpu_options=gpu_options, 81 | allow_soft_placement=True) 82 | with tf.Session(config=config) as sess: 83 | if args.alg == 'A3C': 84 | model_type = 'policy' 85 | elif args.alg == 'Q': 86 | model_type = 'q' 87 | elif args.alg == 'VPN': 88 | model_type = 'vpn' 89 | else: 90 | raise ValueError('Invalid algorithm: ' + args.alg) 91 | with tf.device(device): 92 | with tf.variable_scope("local/learner"): 93 | agent = eval("model." + args.model)(env.observation_space.shape, 94 | env.action_space.n, type=model_type, 95 | gamma=args.gamma, 96 | dim=args.dim, 97 | f_num=args.f_num, 98 | f_stride=args.f_stride, 99 | f_size=args.f_size, 100 | f_pad=args.f_pad, 101 | branch=args.branch, 102 | meta_dim=args.meta_dim) 103 | print("Num parameters: %d" % agent.num_param) 104 | 105 | saver = tf.train.Saver() 106 | saver.restore(sess, args.checkpoint) 107 | np.random.seed(args.seed) 108 | reward = evaluate(env, agent, args.n_play, eps=args.eps) 109 | print("Reward: %.2f" % (reward)) 110 | 111 | if __name__ == "__main__": 112 | run() 113 | -------------------------------------------------------------------------------- /test_vpn: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ -z "$1" ] 3 | then echo "[checkpoint]"; exit 0 4 | fi 5 | 6 | args="--config config/collect_deterministic.xml --branch 4,4,4,1,1 --checkpoint $1 --gpu 1 ${@:2}" 7 | echo $args 8 | CUDA_VISIBLE_DEVICES=$3 python test.py $args 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from six.moves import shlex_quote 5 | import util 6 | 7 | parser = argparse.ArgumentParser(description="Run commands") 8 | parser.add_argument('-gpu', '--gpu', default="", type=str, help='GPU Ids') 9 | parser.add_argument('-w', '--num-workers', type=int, default=16, help="Number of workers") 10 | parser.add_argument('-ps', '--num-ps', type=int, default=4, help="Number of parameter servers") 11 | parser.add_argument('-r', '--remotes', default=None, 12 | help='The address of pre-existing VNC servers and rewarders to use' 13 | '(e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).') 14 | parser.add_argument('-e', '--env-id', type=str, default="maze", 15 | help="Environment id") 16 | parser.add_argument('-l', '--log', type=str, default="/tmp/vpn", 17 | help="Log directory path") 18 | parser.add_argument('-d', '--dry-run', action='store_true', 19 | help="Print out commands rather than executing them") 20 | parser.add_argument('-m', '--mode', type=str, default='tmux', 21 | help="tmux: run workers in a tmux session. " 22 | "nohup: run workers with nohup. " 23 | "child: run workers as child processes") 24 | parser.add_argument('-a', '--alg', choices=['A3C', 'Q', 'VPN'], default="VPN") 25 | parser.add_argument('-mo', '--model', type=str, default="CNN", help="Name of model: [CNN | LSTM]") 26 | parser.add_argument('-ms', '--max-step', type=int, default=int(15e6), help="Max global step") 27 | parser.add_argument('--config', type=str, default="config/collect_deterministic.xml", 28 | help="config xml file for environment") 29 | parser.add_argument('--seed', type=int, default=0, help="Random seed") 30 | parser.add_argument('--eval-freq', type=int, default=250000, help="Evaluation frequency") 31 | parser.add_argument('--eval-num', type=int, default=2000, help="Evaluation frequency") 32 | 33 | # Hyperparameters 34 | parser.add_argument('-n', '--t-max', type=int, default=10, help="Number of unrolling steps") 35 | parser.add_argument('-g', '--gamma', type=float, default=0.98, help="Discount factor") 36 | parser.add_argument('-lr', '--lr', type=float, default=1e-4, help="Learning rate") 37 | parser.add_argument('--decay', type=float, default=0.95, help="Learning rate") 38 | parser.add_argument('--dim', type=int, default=64, help="Number of final hidden units") 39 | parser.add_argument('--f-num', type=str, default='32,32,64', help="num of conv filters") 40 | parser.add_argument('--f-stride', type=str, default='1,1,2', help="stride of conv filters") 41 | parser.add_argument('--f-size', type=str, default='3,3,4', help="size of conv filters") 42 | parser.add_argument('--f-pad', type=str, default='SAME', help="padding of conv filters") 43 | parser.add_argument('--h-dim', type=str, default='', help="num of hidden units") 44 | 45 | # Q-Learning parameters 46 | parser.add_argument('-s', '--sync', type=int, default=10000, 47 | help="Target network synchronization frequency") 48 | parser.add_argument('-f', '--update-freq', type=int, default=1, 49 | help="Parameter update frequency") 50 | parser.add_argument('--eps-step', type=int, default=int(1e6), 51 | help="Num of local steps for epsilon scheduling") 52 | parser.add_argument('--eps', type=float, default=0.05, help="Final epsilon") 53 | parser.add_argument('--eps-eval', type=float, default=0.0, help="Epsilon for evaluation") 54 | 55 | # VPN parameters 56 | parser.add_argument('--prediction-step', type=int, default=3, help="number of prediction steps") 57 | parser.add_argument('--branch', type=str, default="4,4,4", help="branching factor") 58 | parser.add_argument('--buf', type=int, default=10**6, help="num of steps for random buffer") 59 | 60 | def new_cmd(session, name, cmd, mode, logdir, shell): 61 | if isinstance(cmd, (list, tuple)): 62 | cmd = " ".join(shlex_quote(str(v)) for v in cmd) 63 | if mode == 'tmux': 64 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd)) 65 | elif mode == 'child': 66 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir) 67 | elif mode == 'nohup': 68 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir) 69 | 70 | 71 | def create_commands(session, args, shell='bash'): 72 | # for launching the TF workers and for launching tensorboard 73 | base_cmd = [ 74 | sys.executable, 'worker.py', 75 | '--log', args.log, '--env-id', args.env_id, 76 | '--num-workers', str(args.num_workers), 77 | '--num-ps', str(args.num_ps), 78 | '--alg', args.alg, 79 | '--model', args.model, 80 | '--max-step', args.max_step, 81 | '--t-max', args.t_max, 82 | '--eps-step', args.eps_step, 83 | '--eps', args.eps, 84 | '--eps-eval', args.eps_eval, 85 | '--gamma', args.gamma, 86 | '--lr', args.lr, 87 | '--decay', args.decay, 88 | '--sync', args.sync, 89 | '--update-freq', args.update_freq, 90 | '--eval-freq', args.eval_freq, 91 | '--eval-num', args.eval_num, 92 | '--prediction-step', args.prediction_step, 93 | '--dim', args.dim, 94 | '--f-num', args.f_num, 95 | '--f-pad', args.f_pad, 96 | '--f-stride', args.f_stride, 97 | '--f-size', args.f_size, 98 | '--h-dim', args.h_dim, 99 | '--branch', args.branch, 100 | '--config', args.config, 101 | '--buf', args.buf, 102 | ] 103 | 104 | if len(args.gpu) > 0: 105 | base_cmd += ['--gpu', 1] 106 | 107 | if args.remotes is None: 108 | args.remotes = ["1"] * args.num_workers 109 | else: 110 | args.remotes = args.remotes.split(',') 111 | assert len(args.remotes) == args.num_workers 112 | 113 | cmds_map = [] 114 | for i in range(args.num_ps): 115 | prefix = ['CUDA_VISIBLE_DEVICES='] 116 | cmds_map += [new_cmd(session, "ps-%d" % i, prefix + base_cmd + ["--job-name", "ps", 117 | "--task", str(i)], args.mode, args.log, shell)] 118 | 119 | for i in range(args.num_workers): 120 | prefix = [] 121 | if len(args.gpu) > 0: 122 | prefix = ['CUDA_VISIBLE_DEVICES=%d' % args.gpu[(i % len(args.gpu))]] 123 | else: 124 | prefix = ['CUDA_VISIBLE_DEVICES='] 125 | cmds_map += [new_cmd(session, 126 | "w-%d" % i, 127 | prefix + base_cmd + ["--job-name", "worker", "--task", str(i), 128 | "--remotes", args.remotes[i]], args.mode, args.log, shell)] 129 | 130 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", args.log, 131 | "--port", "12345"], args.mode, args.log, shell)] 132 | cmds_map += [new_cmd(session, "test", prefix + base_cmd + ["--job-name", "test", 133 | "--task", str(args.num_workers)], args.mode, args.log, shell)] 134 | windows = [v[0] for v in cmds_map] 135 | 136 | notes = [] 137 | cmds = [ 138 | "mkdir -p {}".format(args.log), 139 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), args.log), 140 | ] 141 | if args.mode == 'nohup' or args.mode == 'child': 142 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(args.log)] 143 | notes += ["Run `source {}/kill.sh` to kill the job".format(args.log)] 144 | if args.mode == 'tmux': 145 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)] 146 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)] 147 | else: 148 | notes += ["Use `tail -f {}/*.out` to watch process output".format(args.log)] 149 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"] 150 | 151 | if args.mode == 'tmux': 152 | cmds += [ 153 | "tmux kill-session -t {}".format(session), 154 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell) 155 | ] 156 | for w in windows[1:]: 157 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)] 158 | cmds += ["sleep 1"] 159 | for window, cmd in cmds_map: 160 | cmds += [cmd] 161 | 162 | return cmds, notes 163 | 164 | 165 | def run(): 166 | args = parser.parse_args() 167 | args.gpu = util.parse_to_num(args.gpu) 168 | cmds, notes = create_commands("e", args) 169 | if args.dry_run: 170 | print("Dry-run mode due to -d flag, otherwise the following commands would be executed:") 171 | else: 172 | print("Executing the following commands:") 173 | print("\n".join(cmds)) 174 | print("") 175 | if not args.dry_run: 176 | if args.mode == "tmux": 177 | os.environ["TMUX"] = "" 178 | path = os.path.join(os.getcwd(), args.log) 179 | if os.path.exists(path): 180 | key = raw_input("%s exists. Do you want to delete it? (y/n): " % path) 181 | if key != 'n': 182 | os.system("rm -rf %s" % path) 183 | os.system("\n".join(cmds)) 184 | print('\n'.join(notes)) 185 | else: 186 | os.system("\n".join(cmds)) 187 | print('\n'.join(notes)) 188 | 189 | 190 | if __name__ == "__main__": 191 | run() 192 | -------------------------------------------------------------------------------- /train_vpn: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Deterministic Collect 4 | # python train.py --config config/collect_deterministic.xml --prediction-step 1 --branch 4 --decay 0.9 --alg VPN --log result/vpn1 ${@:1} 5 | # python train.py --config config/collect_deterministic.xml --prediction-step 2 --branch 4,4 --decay 0.9 --alg VPN --log result/vpn2 ${@:1} 6 | # python train.py --config config/collect_deterministic.xml --prediction-step 3 --branch 4,4,4 --decay 0.9 --alg VPN --log result/vpn3 ${@:1} 7 | # python train.py --config config/collect_deterministic.xml --prediction-step 5 --branch 4,4,4,1,1 --decay 0.8 --alg VPN --log result/vpn5 ${@:1} 8 | 9 | # Stochastic Collect 10 | # python train.py --config config/collect_stochastic.xml --prediction-step 1 --branch 4 --decay 0.9 --alg VPN --log result/vpn1 ${@:1} 11 | # python train.py --config config/collect_stochastic.xml --prediction-step 2 --branch 4,4 --decay 0.9 --alg VPN --log result/vpn2 ${@:1} 12 | # python train.py --config config/collect_stochastic.xml --prediction-step 3 --branch 4,4,4 --decay 0.9 --alg VPN --log result/vpn3 ${@:1} 13 | # python train.py --config config/collect_stochastic.xml --prediction-step 5 --branch 4,4,4,1,1 --decay 0.8 --alg VPN --log result/vpn5 ${@:1} 14 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import scipy.signal 2 | import numpy as np 3 | import tensorflow as tf 4 | from collections import namedtuple 5 | 6 | def discount(x, gamma, time=None): 7 | if time is not None and time.size > 0: 8 | y = np.array(x, copy=True) 9 | for i in reversed(range(y.size-1)): 10 | y[i] += (gamma ** time[i]) * y[i+1] 11 | return y 12 | else: 13 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] 14 | 15 | Batch = namedtuple("Batch", ["si", "a", "adv", "r", 16 | "terminal", "features", "reward", "step", "meta"]) 17 | 18 | def huber_loss(delta, sum=True): 19 | if sum: 20 | return tf.reduce_sum(tf.where(tf.abs(delta) < 1, 21 | 0.5 * tf.square(delta), 22 | tf.abs(delta) - 0.5)) 23 | else: 24 | return tf.where(tf.abs(delta) < 1, 25 | 0.5 * tf.square(delta), 26 | tf.abs(delta) - 0.5) 27 | 28 | def lower_triangular(x): 29 | return tf.matrix_band_part(x, -1, 0) 30 | 31 | def to_bool(x): 32 | return x == 1 33 | 34 | def parse_to_num(s): 35 | l = s.split(',') 36 | for i in range(0, len(l)): 37 | try: 38 | l[i] = int(l[i]) 39 | except ValueError: 40 | l = [] 41 | break 42 | return l 43 | -------------------------------------------------------------------------------- /vpn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | import model # NOQA 5 | import util 6 | from q import Q 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.INFO) 10 | 11 | class RolloutMemory(object): 12 | def __init__(self, max_size, sampling='rand'): 13 | self.max_size = max_size 14 | self.s = [] 15 | self.a = [] 16 | self.r = [] 17 | self.t = [] 18 | self.r_t = [] 19 | self.term = [] 20 | self.sampling = sampling 21 | self.sample_idx = 0 22 | 23 | def add(self, s, a, r, t, r_t, term): 24 | self.s.append(s) 25 | self.a.append(a) 26 | self.r.append(r) 27 | self.t.append(t) 28 | self.r_t.append(r_t) 29 | self.term.append(term) 30 | 31 | def size(self): 32 | return len(self.s) 33 | 34 | def is_full(self): 35 | return len(self.s) >= self.max_size 36 | 37 | def sample(self, length): 38 | size = len(self.s) 39 | is_initial_state = False 40 | if self.sampling == 'rand': 41 | idx = np.random.randint(0, size-1) 42 | if self.term[idx]: 43 | return self.sample(length) 44 | for end_idx in range(idx, idx + length): 45 | if self.term[end_idx] or end_idx == size-1: 46 | break 47 | is_initial_state = (idx > 0 and self.term[idx-1]) or idx == 0 48 | else: 49 | idx = self.sample_idx 50 | if self.term[idx]: 51 | idx = idx + 1 52 | for end_idx in range(idx, idx + length): 53 | if self.term[end_idx] or end_idx == size-1: 54 | break 55 | self.sample_idx = end_idx + 1 if end_idx < size-1 else 0 56 | is_initial_state = (idx > 0 and self.term[idx-1]) or idx == 0 57 | 58 | assert end_idx == idx + length - 1 or self.term[end_idx] or end_idx == size-1 59 | return util.Batch(si=np.asarray(self.s[idx:end_idx+1]), 60 | a=np.asarray(self.a[idx:end_idx+1]), 61 | adv=None, 62 | r=None, 63 | terminal=self.term, 64 | features=[], 65 | reward=np.asarray(self.r[idx:end_idx+1]), 66 | step=np.asarray(self.t[idx:end_idx+1]), 67 | meta=np.asarray(self.r_t[idx:end_idx+1])), is_initial_state 68 | 69 | class VPN(Q): 70 | def define_network(self, name): 71 | self.state_off = None 72 | self.args.meta_dim = 0 if self.env.meta() is None else len(self.env.meta()) 73 | m = eval("model." + name)(self.env.observation_space.shape, 74 | self.env.action_space.n, type='vpn', 75 | gamma=self.args.gamma, 76 | prediction_step=self.args.prediction_step, 77 | dim=self.args.dim, 78 | f_num=self.args.f_num, 79 | f_pad=self.args.f_pad, 80 | f_stride=self.args.f_stride, 81 | f_size=self.args.f_size, 82 | branch=self.args.branch, 83 | meta_dim=self.args.meta_dim, 84 | ) 85 | 86 | return m 87 | 88 | def process_rollout(self, rollout, gamma, lambda_=1.0): 89 | """ 90 | given a rollout, compute its returns 91 | """ 92 | batch_si = np.asarray(rollout.states) 93 | batch_a = np.asarray(rollout.actions) 94 | rewards = np.asarray(rollout.rewards) 95 | time = np.asarray(rollout.time) 96 | meta = np.asarray(rollout.meta) 97 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) 98 | batch_r = util.discount(rewards_plus_v, gamma, time) 99 | features = rollout.features[0] 100 | 101 | return util.Batch(si=batch_si, 102 | a=batch_a, 103 | adv=None, 104 | r=batch_r, 105 | terminal=rollout.terminal, 106 | features=features, 107 | reward=rewards, 108 | step=time, 109 | meta=meta, 110 | ) 111 | 112 | def define_loss(self): 113 | pi = self.local_network 114 | if self.args.buf > 0: 115 | if pi.is_recurrent(): 116 | self.rand_rollouts = RolloutMemory(int(self.args.buf / self.args.num_workers), 117 | sampling='seq') 118 | self.off_state = pi.get_initial_features() 119 | else: 120 | self.rand_rollouts = RolloutMemory(int(self.args.buf / self.args.num_workers)) 121 | 122 | # loss function 123 | self.ac = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="ac") 124 | self.v_target = tf.placeholder(tf.float32, [None], name="v_target") # target 125 | self.reward = tf.placeholder(tf.float32, [None], name="reward") # immediate reward 126 | self.step = tf.placeholder(tf.float32, [None], name="step") # num of steps 127 | self.terminal = tf.placeholder(tf.float32, (), name="terminal") 128 | 129 | time = tf.shape(pi.x)[0] 130 | steps = tf.minimum(self.args.prediction_step, time) 131 | self.rollout_num = tf.to_float(time * steps - steps * (steps - 1) / 2) 132 | 133 | # reward/gamma/value prediction 134 | self.r_delta = util.lower_triangular( 135 | pi.r_a - tf.reshape(self.reward, [-1, 1])) 136 | self.r_loss_mat = util.huber_loss(self.r_delta, sum=False) 137 | self.r_loss = tf.reduce_sum(self.r_loss_mat) 138 | 139 | self.gamma_loss_mat = util.huber_loss(util.lower_triangular( 140 | pi.t_a - tf.reshape(self.step, [-1, 1])), sum=False) 141 | self.gamma_loss = tf.reduce_sum(self.gamma_loss_mat) 142 | 143 | self.v_next_loss_mat = util.huber_loss(util.lower_triangular( 144 | pi.v_next_a - tf.reshape(self.v_target[1:], [-1, 1])), sum=False) 145 | self.v_next_loss = tf.reduce_sum(self.v_next_loss_mat) 146 | self.loss = self.r_loss + self.gamma_loss + self.v_next_loss 147 | 148 | # reward/gamma prediction for off-policy data (optional) 149 | self.a_off = tf.placeholder(tf.float32, [None, self.env.action_space.n], name="a_off") 150 | self.r_off = tf.placeholder(tf.float32, [None], name="r_off") # immediate reward 151 | self.step_off = tf.placeholder(tf.float32, [None], name="step_off") # num of steps 152 | 153 | self.r_delta_off = util.lower_triangular( 154 | pi.r_off - tf.reshape(self.r_off, [-1, 1])) 155 | self.r_loss_mat_off = util.huber_loss(self.r_delta_off, sum=False) 156 | self.r_loss_off = tf.reduce_sum(self.r_loss_mat_off) 157 | 158 | self.gamma_loss_mat_off = util.huber_loss(util.lower_triangular( 159 | pi.t_off - tf.reshape(self.step_off, [-1, 1])), sum=False) 160 | 161 | self.gamma_loss_off = tf.reduce_sum(self.gamma_loss_mat_off) 162 | self.loss += self.r_loss_off + self.gamma_loss_off 163 | 164 | def prepare_input(self, batch): 165 | feed_dict = {self.local_network.x: batch.si, 166 | self.local_network.a: batch.a, 167 | self.ac: batch.a, 168 | self.reward: batch.reward, 169 | self.step: batch.step, 170 | self.target_network.x: batch.si, 171 | self.terminal: float(batch.terminal), 172 | self.v_target: batch.r} 173 | 174 | for i in range(len(self.local_network.state_in)): 175 | feed_dict[self.local_network.state_in[i]] = batch.features[i] 176 | 177 | if self.args.meta_dim > 0: 178 | feed_dict[self.local_network.meta] = batch.meta 179 | 180 | traj, initial = self.random_trajectory() 181 | feed_dict[self.local_network.x_off] = traj.si 182 | feed_dict[self.local_network.a_off] = traj.a 183 | feed_dict[self.a_off] = traj.a 184 | feed_dict[self.r_off] = traj.reward 185 | feed_dict[self.step_off] = traj.step 186 | 187 | if self.local_network.is_recurrent(): 188 | if initial: 189 | state_in = self.local_network.get_initial_features() 190 | else: 191 | state_in = self.off_state 192 | for i in range(len(self.local_network.state_in_off)): 193 | feed_dict[self.local_network.state_in_off[i]] = state_in[i] 194 | 195 | if self.args.meta_dim > 0: 196 | feed_dict[self.local_network.meta_off] = traj.meta 197 | 198 | return feed_dict 199 | 200 | def random_trajectory(self): 201 | if not self.rand_rollouts.is_full(): 202 | env = self.env_off 203 | state_off = env.reset() 204 | meta_off = env.meta() 205 | print("Generating random rollouts: %d steps" % self.rand_rollouts.max_size) 206 | while not self.rand_rollouts.is_full(): 207 | act_idx = np.random.randint(0, env.action_space.n) 208 | action = np.zeros(env.action_space.n) 209 | action[act_idx] = 1 210 | state, reward, terminal, _, time = env.step(action.argmax()) 211 | self.rand_rollouts.add(state_off, action, reward, time, 212 | meta_off, terminal) 213 | state_off = state 214 | meta_off = env.meta() 215 | if terminal: 216 | state_off = env.reset() 217 | meta_off = env.meta() 218 | return self.rand_rollouts.sample(self.args.t_max) 219 | 220 | def extra_fetches(self): 221 | if self.local_network.is_recurrent(): 222 | return self.local_network.state_out_off 223 | return [] 224 | 225 | def handle_extra_fetches(self, fetches): 226 | if self.local_network.is_recurrent(): 227 | self.off_state = fetches[:len(self.off_state)] 228 | 229 | def compute_depth(self, steps): 230 | return self.args.depth 231 | 232 | def write_extra_summary(self, rollout=None): 233 | super(VPN, self).write_extra_summary(rollout) 234 | 235 | def define_summary(self): 236 | super(VPN, self).define_summary() 237 | tf.summary.scalar("loss/r_loss", self.r_loss / self.rollout_num) 238 | tf.summary.scalar("loss/gamma_loss", self.gamma_loss / self.rollout_num) 239 | tf.summary.scalar("model/r", tf.reduce_mean(self.local_network.r)) 240 | tf.summary.scalar("model/v_next", tf.reduce_mean(self.local_network.v_next)) 241 | tf.summary.scalar("model/gamma", tf.reduce_mean(self.local_network.gamma)) 242 | self.summary_op = tf.summary.merge_all() 243 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import go_vncdriver 3 | import tensorflow as tf 4 | import argparse 5 | import logging 6 | import sys, signal 7 | import time 8 | import os 9 | from a3c import A3C 10 | from q import Q 11 | from vpn import VPN 12 | from envs import create_env 13 | import util 14 | import numpy as np 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | def new_env(args): 20 | config = open(args.config) if args.config != "" else None 21 | env = create_env(args.env_id, 22 | str(args.task), 23 | args.remotes, 24 | config=config) 25 | return env 26 | 27 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless. 28 | class FastSaver(tf.train.Saver): 29 | def save(self, sess, save_path, global_step=None, latest_filename=None, 30 | meta_graph_suffix="meta", write_meta_graph=True): 31 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 32 | meta_graph_suffix, False) 33 | 34 | def run(args, server): 35 | env = new_env(args) 36 | if args.alg == 'A3C': 37 | trainer = A3C(env, args) 38 | elif args.alg == 'Q': 39 | trainer = Q(env, args) 40 | elif args.alg == 'VPN': 41 | env_off = new_env(args) 42 | env_off.verbose = 0 43 | env_off.reset() 44 | trainer = VPN(env, args, env_off=env_off) 45 | else: 46 | raise ValueError('Invalid algorithm: ' + args.alg) 47 | 48 | # Variable names that start with "local" are not saved in checkpoints. 49 | variables_to_save = [v for v in tf.global_variables() if \ 50 | not v.name.startswith("global") and not v.name.startswith("local/target/")] 51 | global_variables = [v for v in tf.global_variables() if not v.name.startswith("local")] 52 | 53 | init_op = tf.variables_initializer(global_variables) 54 | init_all_op = tf.global_variables_initializer() 55 | saver = FastSaver(variables_to_save, max_to_keep=0) 56 | 57 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 58 | logger.info('Trainable vars:') 59 | for v in var_list: 60 | logger.info(' %s %s', v.name, v.get_shape()) 61 | logger.info("Num parameters: %d", trainer.local_network.num_param) 62 | 63 | def init_fn(ses): 64 | logger.info("Initializing all parameters.") 65 | ses.run(init_all_op) 66 | 67 | device = 'gpu' if args.gpu > 0 else 'cpu' 68 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15) 69 | config = tf.ConfigProto(device_filters=["/job:ps", 70 | "/job:worker/task:{}/{}:0".format(args.task, device)], 71 | gpu_options=gpu_options, 72 | allow_soft_placement=True) 73 | logdir = os.path.join(args.log, 'train') 74 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task) 75 | 76 | logger.info("Events directory: %s_%s", logdir, args.task) 77 | sv = tf.train.Supervisor(is_chief=(args.task == 0), 78 | logdir=logdir, 79 | saver=saver, 80 | summary_op=None, 81 | init_op=init_op, 82 | init_fn=init_fn, 83 | summary_writer=summary_writer, 84 | ready_op=tf.report_uninitialized_variables(global_variables), 85 | global_step=trainer.global_step, 86 | save_model_secs=0, 87 | save_summaries_secs=30) 88 | 89 | 90 | logger.info( 91 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " + 92 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.") 93 | with sv.managed_session(server.target, config=config) as sess, sess.as_default(): 94 | sess.run(trainer.sync) 95 | trainer.start(sess, summary_writer) 96 | global_step = sess.run(trainer.global_step) 97 | epoch = -1 98 | logger.info("Starting training at step=%d", global_step) 99 | while not sv.should_stop() and (not args.max_step or global_step < args.max_step): 100 | if args.task == 0 and int(global_step / args.eval_freq) > epoch: 101 | epoch = int(global_step / args.eval_freq) 102 | filename = os.path.join(args.log, 'e%d' % (epoch)) 103 | sv.saver.save(sess, filename) 104 | sv.saver.save(sess, os.path.join(args.log, 'latest')) 105 | print("Saved to: %s" % filename) 106 | trainer.process(sess) 107 | global_step = sess.run(trainer.global_step) 108 | 109 | if args.task == 0 and int(global_step / args.eval_freq) > epoch: 110 | epoch = int(global_step / args.eval_freq) 111 | filename = os.path.join(args.log, 'e%d' % (epoch)) 112 | sv.saver.save(sess, filename) 113 | sv.saver.save(sess, os.path.join(args.log, 'latest')) 114 | print("Saved to: %s" % filename) 115 | # Ask for all the services to stop. 116 | sv.stop() 117 | logger.info('reached %s steps. worker stopped.', global_step) 118 | 119 | def cluster_spec(num_workers, num_ps): 120 | """ 121 | More tensorflow setup for data parallelism 122 | """ 123 | cluster = {} 124 | port = 12222 125 | 126 | all_ps = [] 127 | host = '127.0.0.1' 128 | for _ in range(num_ps): 129 | all_ps.append('{}:{}'.format(host, port)) 130 | port += 1 131 | cluster['ps'] = all_ps 132 | 133 | all_workers = [] 134 | for _ in range(num_workers + 1): 135 | all_workers.append('{}:{}'.format(host, port)) 136 | port += 1 137 | cluster['worker'] = all_workers 138 | port += 1 139 | return cluster 140 | 141 | def evaluate(env, network, num_play=3000, eps=0.0): 142 | for iter in range(0, num_play): 143 | last_state = env.reset() 144 | last_features = network.get_initial_features() 145 | last_meta = env.meta() 146 | while True: 147 | if eps == 0.0 or np.random.rand() > eps: 148 | fetched = network.act(last_state, last_features, 149 | meta=last_meta) 150 | if network.type == 'policy': 151 | action, features = fetched[0], fetched[2:] 152 | else: 153 | action, features = fetched[0], fetched[1:] 154 | else: 155 | act_idx = np.random.randint(0, env.action_space.n) 156 | action = np.zeros(env.action_space.n) 157 | action[act_idx] = 1 158 | features = [] 159 | 160 | state, reward, terminal, info, time = env.step(action.argmax()) 161 | last_state = state 162 | last_features = features 163 | last_meta = env.meta() 164 | 165 | if terminal: 166 | break 167 | 168 | return env.reward_mean(num_play) 169 | 170 | def run_tester(args, server): 171 | env = new_env(args) 172 | env.reset() 173 | env.max_history = args.eval_num 174 | if args.alg == 'A3C': 175 | agent = A3C(env, args) 176 | elif args.alg == 'Q': 177 | agent = Q(env, args) 178 | elif args.alg == 'VPN': 179 | agent = VPN(env, args) 180 | else: 181 | raise ValueError('Invalid algorithm: ' + args.alg) 182 | 183 | device = 'gpu' if args.gpu > 0 else 'cpu' 184 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15) 185 | config = tf.ConfigProto(device_filters=["/job:ps", 186 | "/job:worker/task:{}/{}:0".format(args.task, device)], 187 | gpu_options=gpu_options, 188 | allow_soft_placement=True) 189 | variables_to_save = [v for v in tf.global_variables() if \ 190 | not v.name.startswith("global") and not v.name.startswith("local/target/")] 191 | global_variables = [v for v in tf.global_variables() if not v.name.startswith("local")] 192 | 193 | init_op = tf.variables_initializer(global_variables) 194 | init_all_op = tf.global_variables_initializer() 195 | 196 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 197 | logger.info('Trainable vars:') 198 | for v in var_list: 199 | logger.info(' %s %s', v.name, v.get_shape()) 200 | logger.info("Num parameters: %d", agent.local_network.num_param) 201 | 202 | def init_fn(ses): 203 | logger.info("Initializing all parameters.") 204 | ses.run(init_all_op) 205 | 206 | saver = FastSaver(variables_to_save, max_to_keep=0) 207 | sv = tf.train.Supervisor(is_chief=False, 208 | global_step=agent.global_step, 209 | summary_op=None, 210 | init_op=init_op, 211 | init_fn=init_fn, 212 | ready_op=tf.report_uninitialized_variables(global_variables), 213 | saver=saver, 214 | save_model_secs=0, 215 | save_summaries_secs=0) 216 | 217 | best_reward = -10000 218 | with sv.managed_session(server.target, config=config) as sess, sess.as_default(): 219 | epoch = args.eval_epoch 220 | while args.eval_freq * epoch <= args.max_step: 221 | path = os.path.join(args.log, "e%d" % epoch) 222 | if not os.path.exists(path + ".index"): 223 | time.sleep(10) 224 | continue 225 | logger.info("Start evaluation (Epoch %d)", epoch) 226 | saver.restore(sess, path) 227 | np.random.seed(args.seed) 228 | reward = evaluate(env, agent.local_network, args.eval_num, eps=args.eps_eval) 229 | 230 | logfile = open(os.path.join(args.log, "eval.csv"), "a") 231 | print("Epoch: %d, Reward: %.2f" % (epoch, reward)) 232 | logfile.write("%d, %.3f\n" % (epoch, reward)) 233 | logfile.close() 234 | if reward > best_reward: 235 | best_reward = reward 236 | sv.saver.save(sess, os.path.join(args.log, 'best')) 237 | print("Saved to: %s" % os.path.join(args.log, 'best')) 238 | 239 | epoch += 1 240 | 241 | logger.info('tester stopped.') 242 | 243 | def main(_): 244 | """ 245 | Setting up Tensorflow for data parallel work 246 | """ 247 | 248 | parser = argparse.ArgumentParser(description=None) 249 | parser.add_argument('-gpu', '--gpu', default=0, type=int, help='Number of GPUs') 250 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.') 251 | parser.add_argument('--task', default=0, type=int, help='Task index') 252 | parser.add_argument('--job-name', default="worker", help='worker or ps') 253 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers') 254 | parser.add_argument('--num-ps', type=int, default=1, help="Number of parameter servers") 255 | parser.add_argument('--log', default="/tmp/vpn", help='Log directory path') 256 | parser.add_argument('--env-id', default="maze", help='Environment id') 257 | parser.add_argument('-r', '--remotes', default=None, 258 | help='References to environments to create (e.g. -r 20), ' 259 | 'or the address of pre-existing VNC servers and ' 260 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901)') 261 | parser.add_argument('-a', '--alg', choices=['A3C', 'Q', 'VPN'], default="A3C") 262 | parser.add_argument('-mo', '--model', type=str, default="LSTM", help="Name of model: [CNN | LSTM]") 263 | parser.add_argument('--eval-freq', type=int, default=250000, help="Evaluation frequency") 264 | parser.add_argument('--eval-num', type=int, default=500, help="Evaluation frequency") 265 | parser.add_argument('--eval-epoch', type=int, default=0, help="Evaluation epoch") 266 | parser.add_argument('--seed', type=int, default=0, help="Random seed") 267 | parser.add_argument('--config', type=str, default="config/collect_deterministic.xml", 268 | help="config xml file for environment") 269 | 270 | # Hyperparameters 271 | parser.add_argument('-n', '--t-max', type=int, default=10, help="Number of unrolling steps") 272 | parser.add_argument('-g', '--gamma', type=float, default=0.98, help="Discount factor") 273 | parser.add_argument('-ld', '--ld', type=float, default=1, help="Lambda for GAE") 274 | parser.add_argument('-lr', '--lr', type=float, default=1e-4, help="Learning rate") 275 | parser.add_argument('--decay', type=float, default=0.95, help="Learning decay") 276 | parser.add_argument('-ms', '--max-step', type=int, default=int(15e6), help="Max global step") 277 | parser.add_argument('--dim', type=int, default=0, help="Number of final hidden units") 278 | parser.add_argument('--f-num', type=str, default='32,32,64', help="num of conv filters") 279 | parser.add_argument('--f-pad', type=str, default='SAME', help="padding of conv filters") 280 | parser.add_argument('--f-stride', type=str, default='1,1,2', help="stride of conv filters") 281 | parser.add_argument('--f-size', type=str, default='3,3,4', help="size of conv filters") 282 | parser.add_argument('--h-dim', type=str, default='', help="num of hidden units") 283 | 284 | # Q-Learning parameters 285 | parser.add_argument('-s', '--sync', type=int, default=10000, 286 | help="Target network synchronization frequency") 287 | parser.add_argument('-f', '--update-freq', type=int, default=1, 288 | help="Parameter update frequency") 289 | parser.add_argument('--eps-step', type=int, default=int(1e6), 290 | help="Num of local steps for epsilon scheduling") 291 | parser.add_argument('--eps', type=float, default=0.05, help="Final epsilon value") 292 | parser.add_argument('--eps-eval', type=float, default=0.0, help="Epsilon for evaluation") 293 | 294 | # VPN parameters 295 | parser.add_argument('--prediction-step', type=int, default=3, help="number of prediction steps") 296 | parser.add_argument('--branch', type=str, default="4,4,4", help="branching factor") 297 | parser.add_argument('--buf', type=int, default=10**6, help="num of steps for random buffer") 298 | 299 | args = parser.parse_args() 300 | args.f_num = util.parse_to_num(args.f_num) 301 | args.f_stride = util.parse_to_num(args.f_stride) 302 | args.f_size = util.parse_to_num(args.f_size) 303 | args.h_dim = util.parse_to_num(args.h_dim) 304 | args.eps_eval = min(args.eps, args.eps_eval) 305 | args.branch = util.parse_to_num(args.branch) 306 | spec = cluster_spec(args.num_workers, args.num_ps) 307 | cluster = tf.train.ClusterSpec(spec).as_cluster_def() 308 | 309 | def shutdown(signal, frame): 310 | logger.warn('Received signal %s: exiting', signal) 311 | sys.exit(128+signal) 312 | signal.signal(signal.SIGHUP, shutdown) 313 | signal.signal(signal.SIGINT, shutdown) 314 | signal.signal(signal.SIGTERM, shutdown) 315 | 316 | gpu_options = None 317 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.15) 318 | 319 | if args.job_name == "worker": 320 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task, 321 | config=tf.ConfigProto(intra_op_parallelism_threads=1, 322 | inter_op_parallelism_threads=1, 323 | gpu_options=gpu_options)) 324 | run(args, server) 325 | elif args.job_name == "test": 326 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task, 327 | config=tf.ConfigProto(intra_op_parallelism_threads=1, 328 | inter_op_parallelism_threads=1, 329 | gpu_options=gpu_options)) 330 | run_tester(args, server) 331 | elif args.job_name == "ps": 332 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05) 333 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task, 334 | config=tf.ConfigProto(device_filters=["/job:ps"], 335 | gpu_options=gpu_options)) 336 | while True: 337 | time.sleep(1000) 338 | 339 | if __name__ == "__main__": 340 | tf.app.run() 341 | --------------------------------------------------------------------------------