├── .gitignore ├── LICENSE ├── README.md ├── a3c.py ├── envs.py ├── imgs ├── dusk_drive.png ├── neon_race.png ├── tb_pong.png └── vnc_pong.png ├── model.py ├── train.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 openai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **This repository has been deprecated in favor of the Retro (https://github.com/openai/retro) library. See our Retro Contest (https://blog.openai.com/retro-contest) blog post for detalis.** 2 | 3 | # universe-starter-agent 4 | 5 | The codebase implements a starter agent that can solve a number of `universe` environments. 6 | It contains a basic implementation of the [A3C algorithm](https://arxiv.org/abs/1602.01783), adapted for real-time environments. 7 | 8 | # Dependencies 9 | 10 | * Python 2.7 or 3.5 11 | * [Golang](https://golang.org/doc/install) 12 | * [six](https://pypi.python.org/pypi/six) (for py2/3 compatibility) 13 | * [TensorFlow](https://www.tensorflow.org/) 0.12 14 | * [tmux](https://tmux.github.io/) (the start script opens up a tmux session with multiple windows) 15 | * [htop](https://hisham.hm/htop/) (shown in one of the tmux windows) 16 | * [gym](https://pypi.python.org/pypi/gym) 17 | * gym[atari] 18 | * libjpeg-turbo (`brew install libjpeg-turbo`) 19 | * [universe](https://pypi.python.org/pypi/universe) 20 | * [opencv-python](https://pypi.python.org/pypi/opencv-python) 21 | * [numpy](https://pypi.python.org/pypi/numpy) 22 | * [scipy](https://pypi.python.org/pypi/scipy) 23 | 24 | # Getting Started 25 | 26 | ``` 27 | conda create --name universe-starter-agent python=3.5 28 | source activate universe-starter-agent 29 | 30 | brew install tmux htop cmake golang libjpeg-turbo # On Linux use sudo apt-get install -y tmux htop cmake golang libjpeg-dev 31 | 32 | pip install "gym[atari]" 33 | pip install universe 34 | pip install six 35 | pip install tensorflow 36 | conda install -y -c https://conda.binstar.org/menpo opencv3 37 | conda install -y numpy 38 | conda install -y scipy 39 | ``` 40 | 41 | 42 | Add the following to your `.bashrc` so that you'll have the correct environment when the `train.py` script spawns new bash shells 43 | ```source activate universe-starter-agent``` 44 | 45 | ## Atari Pong 46 | 47 | `python train.py --num-workers 2 --env-id PongDeterministic-v3 --log-dir /tmp/pong` 48 | 49 | The command above will train an agent on Atari Pong using ALE simulator. 50 | It will see two workers that will be learning in parallel (`--num-workers` flag) and will output intermediate results into given directory. 51 | 52 | The code will launch the following processes: 53 | * worker-0 - a process that runs policy gradient 54 | * worker-1 - a process identical to process-1, that uses different random noise from the environment 55 | * ps - the parameter server, which synchronizes the parameters among the different workers 56 | * tb - a tensorboard process for convenient display of the statistics of learning 57 | 58 | Once you start the training process, it will create a tmux session with a window for each of these processes. You can connect to them by typing `tmux a` in the console. 59 | Once in the tmux session, you can see all your windows with `ctrl-b w`. 60 | To switch to window number 0, type: `ctrl-b 0`. Look up tmux documentation for more commands. 61 | 62 | To access TensorBoard to see various monitoring metrics of the agent, open [http://localhost:12345/](http://localhost:12345/) in a browser. 63 | 64 | Using 16 workers, the agent should be able to solve `PongDeterministic-v3` (not VNC) within 30 minutes (often less) on an `m4.10xlarge` instance. 65 | Using 32 workers, the agent is able to solve the same environment in 10 minutes on an `m4.16xlarge` instance. 66 | If you run this experiment on a high-end MacBook Pro, the above job will take just under 2 hours to solve Pong. 67 | 68 | Add '--visualise' toggle if you want to visualise the worker using env.render() as follows: 69 | 70 | `python train.py --num-workers 2 --env-id PongDeterministic-v3 --log-dir /tmp/pong --visualise` 71 | 72 | ![pong](https://github.com/openai/universe-starter-agent/raw/master/imgs/tb_pong.png "Pong") 73 | 74 | For best performance, it is recommended for the number of workers to not exceed available number of CPU cores. 75 | 76 | You can stop the experiment with `tmux kill-session` command. 77 | 78 | ## Playing games over remote desktop 79 | 80 | The main difference with the previous experiment is that now we are going to play the game through VNC protocol. 81 | The VNC environments are hosted on the EC2 cloud and have an interface that's different from a conventional Atari Gym 82 | environment; luckily, with the help of several wrappers (which are used within `envs.py` file) 83 | the experience should be similar to the agent as if it was played locally. The problem itself is more difficult 84 | because the observations and actions are delayed due to the latency induced by the network. 85 | 86 | More interestingly, you can also peek at what the agent is doing with a VNCViewer. 87 | 88 | Note that the default behavior of `train.py` is to start the remotes on a local machine. Take a look at https://github.com/openai/universe/blob/master/doc/remotes.rst for documentation on managing your remotes. Pass additional `-r` flag to point to pre-existing instances. 89 | 90 | ### VNC Pong 91 | 92 | `python train.py --num-workers 2 --env-id gym-core.PongDeterministic-v3 --log-dir /tmp/vncpong` 93 | 94 | _Peeking into the agent's environment with TurboVNC_ 95 | 96 | You can use your system viewer as `open vnc://localhost:5900` (or `open vnc://${docker_ip}:5900`) or connect TurboVNC to that ip/port. 97 | VNC password is `"openai"`. 98 | 99 | ![pong](https://github.com/openai/universe-starter-agent/raw/master/imgs/vnc_pong.png "Pong over VNC") 100 | 101 | #### Important caveats 102 | 103 | One of the novel challenges in using Universe environments is that 104 | they operate in *real time*, and in addition, it takes time for the 105 | environment to transmit the observation to the agent. This time 106 | creates a lag: where the greater the lag, the harder it is to solve 107 | environment with today's RL algorithms. Thus, to get the best 108 | possible results it is necessary to reduce the lag, which can be 109 | achieved by having both the environments and the agent live 110 | on the same high-speed computer network. So for example, if you have 111 | a fast local network, you could host the environments on one set of 112 | machines, and the agent on another machine that can speak to the 113 | environments with low latency. Alternatively, you can run the 114 | environments and the agent on the same EC2/Azure region. Other 115 | configurations tend to have greater lag. 116 | 117 | To keep track of your lag, look for the phrase `reaction_time` in 118 | stderr. If you run both the agent and the environment on nearby 119 | machines on the cloud, your `reaction_time` should be as low as 40ms. 120 | The `reaction_time` statistic is printed to stderr because we wrap our 121 | environment with the `Logger` wrapper, as done in 122 | [here](). 123 | 124 | Generally speaking, environments that are most affected by lag are 125 | games that place a lot of emphasis on reaction time. For example, 126 | this agent is able to solve VNC Pong 127 | (`gym-core.PongDeterministic-v3`) in under 2 hours when both the agent 128 | and the environment are co-located on the cloud, but this agent had 129 | difficulty solving VNC Pong when the environment was on the cloud 130 | while the agent was not. This issue affects environments that place 131 | great emphasis on reaction time. 132 | 133 | ### A note on tuning 134 | 135 | This implementation has been tuned to do well on VNC Pong, and we do not guarantee 136 | its performance on other tasks. It is meant as a starting point. 137 | 138 | ### Playing flash games 139 | 140 | You may run the following command to launch the agent on the game Neon Race: 141 | 142 | `python train.py --num-workers 2 --env-id flashgames.NeonRace-v0 --log-dir /tmp/neonrace` 143 | 144 | _What agent sees when playing Neon Race_ 145 | (you can connect to this view via [note](#vnc-pong) above) 146 | ![neon](https://github.com/openai/universe-starter-agent/raw/master/imgs/neon_race.png "Neon Race") 147 | 148 | Getting 80% of the maximal score takes between 1 and 2 hours with 16 workers, and getting to 100% of the score 149 | takes about 12 hours. Also, flash games are run at 5fps by default, so it should be possible to productively 150 | use 16 workers on a machine with 8 (and possibly even 4) cores. 151 | 152 | ### Next steps 153 | 154 | Now that you have seen an example agent, develop agents of your own. We hope that you will find 155 | doing so to be an exciting and an enjoyable task. 156 | -------------------------------------------------------------------------------- /a3c.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import namedtuple 3 | import numpy as np 4 | import tensorflow as tf 5 | from model import LSTMPolicy 6 | import six.moves.queue as queue 7 | import scipy.signal 8 | import threading 9 | import distutils.version 10 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0') 11 | 12 | def discount(x, gamma): 13 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] 14 | 15 | def process_rollout(rollout, gamma, lambda_=1.0): 16 | """ 17 | given a rollout, compute its returns and the advantage 18 | """ 19 | batch_si = np.asarray(rollout.states) 20 | batch_a = np.asarray(rollout.actions) 21 | rewards = np.asarray(rollout.rewards) 22 | vpred_t = np.asarray(rollout.values + [rollout.r]) 23 | 24 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) 25 | batch_r = discount(rewards_plus_v, gamma)[:-1] 26 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] 27 | # this formula for the advantage comes "Generalized Advantage Estimation": 28 | # https://arxiv.org/abs/1506.02438 29 | batch_adv = discount(delta_t, gamma * lambda_) 30 | 31 | features = rollout.features[0] 32 | return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, features) 33 | 34 | Batch = namedtuple("Batch", ["si", "a", "adv", "r", "terminal", "features"]) 35 | 36 | class PartialRollout(object): 37 | """ 38 | a piece of a complete rollout. We run our agent, and process its experience 39 | once it has processed enough steps. 40 | """ 41 | def __init__(self): 42 | self.states = [] 43 | self.actions = [] 44 | self.rewards = [] 45 | self.values = [] 46 | self.r = 0.0 47 | self.terminal = False 48 | self.features = [] 49 | 50 | def add(self, state, action, reward, value, terminal, features): 51 | self.states += [state] 52 | self.actions += [action] 53 | self.rewards += [reward] 54 | self.values += [value] 55 | self.terminal = terminal 56 | self.features += [features] 57 | 58 | def extend(self, other): 59 | assert not self.terminal 60 | self.states.extend(other.states) 61 | self.actions.extend(other.actions) 62 | self.rewards.extend(other.rewards) 63 | self.values.extend(other.values) 64 | self.r = other.r 65 | self.terminal = other.terminal 66 | self.features.extend(other.features) 67 | 68 | class RunnerThread(threading.Thread): 69 | """ 70 | One of the key distinctions between a normal environment and a universe environment 71 | is that a universe environment is _real time_. This means that there should be a thread 72 | that would constantly interact with the environment and tell it what to do. This thread is here. 73 | """ 74 | def __init__(self, env, policy, num_local_steps, visualise): 75 | threading.Thread.__init__(self) 76 | self.queue = queue.Queue(5) 77 | self.num_local_steps = num_local_steps 78 | self.env = env 79 | self.last_features = None 80 | self.policy = policy 81 | self.daemon = True 82 | self.sess = None 83 | self.summary_writer = None 84 | self.visualise = visualise 85 | 86 | def start_runner(self, sess, summary_writer): 87 | self.sess = sess 88 | self.summary_writer = summary_writer 89 | self.start() 90 | 91 | def run(self): 92 | with self.sess.as_default(): 93 | self._run() 94 | 95 | def _run(self): 96 | rollout_provider = env_runner(self.env, self.policy, self.num_local_steps, self.summary_writer, self.visualise) 97 | while True: 98 | # the timeout variable exists because apparently, if one worker dies, the other workers 99 | # won't die with it, unless the timeout is set to some large number. This is an empirical 100 | # observation. 101 | 102 | self.queue.put(next(rollout_provider), timeout=600.0) 103 | 104 | 105 | 106 | def env_runner(env, policy, num_local_steps, summary_writer, render): 107 | """ 108 | The logic of the thread runner. In brief, it constantly keeps on running 109 | the policy, and as long as the rollout exceeds a certain length, the thread 110 | runner appends the policy to the queue. 111 | """ 112 | last_state = env.reset() 113 | last_features = policy.get_initial_features() 114 | length = 0 115 | rewards = 0 116 | 117 | while True: 118 | terminal_end = False 119 | rollout = PartialRollout() 120 | 121 | for _ in range(num_local_steps): 122 | fetched = policy.act(last_state, *last_features) 123 | action, value_, features = fetched[0], fetched[1], fetched[2:] 124 | # argmax to convert from one-hot 125 | state, reward, terminal, info = env.step(action.argmax()) 126 | if render: 127 | env.render() 128 | 129 | # collect the experience 130 | rollout.add(last_state, action, reward, value_, terminal, last_features) 131 | length += 1 132 | rewards += reward 133 | 134 | last_state = state 135 | last_features = features 136 | 137 | if info: 138 | summary = tf.Summary() 139 | for k, v in info.items(): 140 | summary.value.add(tag=k, simple_value=float(v)) 141 | summary_writer.add_summary(summary, policy.global_step.eval()) 142 | summary_writer.flush() 143 | 144 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 145 | if terminal or length >= timestep_limit: 146 | terminal_end = True 147 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'): 148 | last_state = env.reset() 149 | last_features = policy.get_initial_features() 150 | print("Episode finished. Sum of rewards: %d. Length: %d" % (rewards, length)) 151 | length = 0 152 | rewards = 0 153 | break 154 | 155 | if not terminal_end: 156 | rollout.r = policy.value(last_state, *last_features) 157 | 158 | # once we have enough experience, yield it, and have the ThreadRunner place it on a queue 159 | yield rollout 160 | 161 | class A3C(object): 162 | def __init__(self, env, task, visualise): 163 | """ 164 | An implementation of the A3C algorithm that is reasonably well-tuned for the VNC environments. 165 | Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism. 166 | But overall, we'll define the model, specify its inputs, and describe how the policy gradients step 167 | should be computed. 168 | """ 169 | 170 | self.env = env 171 | self.task = task 172 | worker_device = "/job:worker/task:{}/cpu:0".format(task) 173 | with tf.device(tf.train.replica_device_setter(1, worker_device=worker_device)): 174 | with tf.variable_scope("global"): 175 | self.network = LSTMPolicy(env.observation_space.shape, env.action_space.n) 176 | self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), 177 | trainable=False) 178 | 179 | with tf.device(worker_device): 180 | with tf.variable_scope("local"): 181 | self.local_network = pi = LSTMPolicy(env.observation_space.shape, env.action_space.n) 182 | pi.global_step = self.global_step 183 | 184 | self.ac = tf.placeholder(tf.float32, [None, env.action_space.n], name="ac") 185 | self.adv = tf.placeholder(tf.float32, [None], name="adv") 186 | self.r = tf.placeholder(tf.float32, [None], name="r") 187 | 188 | log_prob_tf = tf.nn.log_softmax(pi.logits) 189 | prob_tf = tf.nn.softmax(pi.logits) 190 | 191 | # the "policy gradients" loss: its derivative is precisely the policy gradient 192 | # notice that self.ac is a placeholder that is provided externally. 193 | # adv will contain the advantages, as calculated in process_rollout 194 | pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac, [1]) * self.adv) 195 | 196 | # loss of value function 197 | vf_loss = 0.5 * tf.reduce_sum(tf.square(pi.vf - self.r)) 198 | entropy = - tf.reduce_sum(prob_tf * log_prob_tf) 199 | 200 | bs = tf.to_float(tf.shape(pi.x)[0]) 201 | self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01 202 | 203 | # 20 represents the number of "local steps": the number of timesteps 204 | # we run the policy before we update the parameters. 205 | # The larger local steps is, the lower is the variance in our policy gradients estimate 206 | # on the one hand; but on the other hand, we get less frequent parameter updates, which 207 | # slows down learning. In this code, we found that making local steps be much 208 | # smaller than 20 makes the algorithm more difficult to tune and to get to work. 209 | self.runner = RunnerThread(env, pi, 20, visualise) 210 | 211 | 212 | grads = tf.gradients(self.loss, pi.var_list) 213 | 214 | if use_tf12_api: 215 | tf.summary.scalar("model/policy_loss", pi_loss / bs) 216 | tf.summary.scalar("model/value_loss", vf_loss / bs) 217 | tf.summary.scalar("model/entropy", entropy / bs) 218 | tf.summary.image("model/state", pi.x) 219 | tf.summary.scalar("model/grad_global_norm", tf.global_norm(grads)) 220 | tf.summary.scalar("model/var_global_norm", tf.global_norm(pi.var_list)) 221 | self.summary_op = tf.summary.merge_all() 222 | 223 | else: 224 | tf.scalar_summary("model/policy_loss", pi_loss / bs) 225 | tf.scalar_summary("model/value_loss", vf_loss / bs) 226 | tf.scalar_summary("model/entropy", entropy / bs) 227 | tf.image_summary("model/state", pi.x) 228 | tf.scalar_summary("model/grad_global_norm", tf.global_norm(grads)) 229 | tf.scalar_summary("model/var_global_norm", tf.global_norm(pi.var_list)) 230 | self.summary_op = tf.merge_all_summaries() 231 | 232 | grads, _ = tf.clip_by_global_norm(grads, 40.0) 233 | 234 | # copy weights from the parameter server to the local model 235 | self.sync = tf.group(*[v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list)]) 236 | 237 | grads_and_vars = list(zip(grads, self.network.var_list)) 238 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0]) 239 | 240 | # each worker has a different set of adam optimizer parameters 241 | opt = tf.train.AdamOptimizer(1e-4) 242 | self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step) 243 | self.summary_writer = None 244 | self.local_steps = 0 245 | 246 | def start(self, sess, summary_writer): 247 | self.runner.start_runner(sess, summary_writer) 248 | self.summary_writer = summary_writer 249 | 250 | def pull_batch_from_queue(self): 251 | """ 252 | self explanatory: take a rollout from the queue of the thread runner. 253 | """ 254 | rollout = self.runner.queue.get(timeout=600.0) 255 | while not rollout.terminal: 256 | try: 257 | rollout.extend(self.runner.queue.get_nowait()) 258 | except queue.Empty: 259 | break 260 | return rollout 261 | 262 | def process(self, sess): 263 | """ 264 | process grabs a rollout that's been produced by the thread runner, 265 | and updates the parameters. The update is then sent to the parameter 266 | server. 267 | """ 268 | 269 | sess.run(self.sync) # copy weights from shared to local 270 | rollout = self.pull_batch_from_queue() 271 | batch = process_rollout(rollout, gamma=0.99, lambda_=1.0) 272 | 273 | should_compute_summary = self.task == 0 and self.local_steps % 11 == 0 274 | 275 | if should_compute_summary: 276 | fetches = [self.summary_op, self.train_op, self.global_step] 277 | else: 278 | fetches = [self.train_op, self.global_step] 279 | 280 | feed_dict = { 281 | self.local_network.x: batch.si, 282 | self.ac: batch.a, 283 | self.adv: batch.adv, 284 | self.r: batch.r, 285 | self.local_network.state_in[0]: batch.features[0], 286 | self.local_network.state_in[1]: batch.features[1], 287 | } 288 | 289 | fetched = sess.run(fetches, feed_dict=feed_dict) 290 | 291 | if should_compute_summary: 292 | self.summary_writer.add_summary(tf.Summary.FromString(fetched[0]), fetched[-1]) 293 | self.summary_writer.flush() 294 | self.local_steps += 1 295 | -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from gym.spaces.box import Box 3 | import numpy as np 4 | import gym 5 | from gym import spaces 6 | import logging 7 | import universe 8 | from universe import vectorized 9 | from universe.wrappers import BlockingReset, GymCoreAction, EpisodeID, Unvectorize, Vectorize, Vision, Logger 10 | from universe import spaces as vnc_spaces 11 | from universe.spaces.vnc_event import keycode 12 | import time 13 | logger = logging.getLogger(__name__) 14 | logger.setLevel(logging.INFO) 15 | universe.configure_logging() 16 | 17 | def create_env(env_id, client_id, remotes, **kwargs): 18 | spec = gym.spec(env_id) 19 | 20 | if spec.tags.get('flashgames', False): 21 | return create_flash_env(env_id, client_id, remotes, **kwargs) 22 | elif spec.tags.get('atari', False) and spec.tags.get('vnc', False): 23 | return create_vncatari_env(env_id, client_id, remotes, **kwargs) 24 | else: 25 | # Assume atari. 26 | assert "." not in env_id # universe environments have dots in names. 27 | return create_atari_env(env_id) 28 | 29 | def create_flash_env(env_id, client_id, remotes, **_): 30 | env = gym.make(env_id) 31 | env = Vision(env) 32 | env = Logger(env) 33 | env = BlockingReset(env) 34 | 35 | reg = universe.runtime_spec('flashgames').server_registry 36 | height = reg[env_id]["height"] 37 | width = reg[env_id]["width"] 38 | env = CropScreen(env, height, width, 84, 18) 39 | env = FlashRescale(env) 40 | 41 | keys = ['left', 'right', 'up', 'down', 'x'] 42 | if env_id == 'flashgames.NeonRace-v0': 43 | # Better key space for this game. 44 | keys = ['left', 'right', 'up', 'left up', 'right up', 'down', 'up x'] 45 | logger.info('create_flash_env(%s): keys=%s', env_id, keys) 46 | 47 | env = DiscreteToFixedKeysVNCActions(env, keys) 48 | env = EpisodeID(env) 49 | env = DiagnosticsInfo(env) 50 | env = Unvectorize(env) 51 | env.configure(fps=5.0, remotes=remotes, start_timeout=15 * 60, client_id=client_id, 52 | vnc_driver='go', vnc_kwargs={ 53 | 'encoding': 'tight', 'compress_level': 0, 54 | 'fine_quality_level': 50, 'subsample_level': 3}) 55 | return env 56 | 57 | def create_vncatari_env(env_id, client_id, remotes, **_): 58 | env = gym.make(env_id) 59 | env = Vision(env) 60 | env = Logger(env) 61 | env = BlockingReset(env) 62 | env = GymCoreAction(env) 63 | env = AtariRescale42x42(env) 64 | env = EpisodeID(env) 65 | env = DiagnosticsInfo(env) 66 | env = Unvectorize(env) 67 | 68 | logger.info('Connecting to remotes: %s', remotes) 69 | fps = env.metadata['video.frames_per_second'] 70 | env.configure(remotes=remotes, start_timeout=15 * 60, fps=fps, client_id=client_id) 71 | return env 72 | 73 | def create_atari_env(env_id): 74 | env = gym.make(env_id) 75 | env = Vectorize(env) 76 | env = AtariRescale42x42(env) 77 | env = DiagnosticsInfo(env) 78 | env = Unvectorize(env) 79 | return env 80 | 81 | def DiagnosticsInfo(env, *args, **kwargs): 82 | return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs) 83 | 84 | class DiagnosticsInfoI(vectorized.Filter): 85 | def __init__(self, log_interval=503): 86 | super(DiagnosticsInfoI, self).__init__() 87 | 88 | self._episode_time = time.time() 89 | self._last_time = time.time() 90 | self._local_t = 0 91 | self._log_interval = log_interval 92 | self._episode_reward = 0 93 | self._episode_length = 0 94 | self._all_rewards = [] 95 | self._num_vnc_updates = 0 96 | self._last_episode_id = -1 97 | 98 | def _after_reset(self, observation): 99 | logger.info('Resetting environment') 100 | self._episode_reward = 0 101 | self._episode_length = 0 102 | self._all_rewards = [] 103 | return observation 104 | 105 | def _after_step(self, observation, reward, done, info): 106 | to_log = {} 107 | if self._episode_length == 0: 108 | self._episode_time = time.time() 109 | 110 | self._local_t += 1 111 | if info.get("stats.vnc.updates.n") is not None: 112 | self._num_vnc_updates += info.get("stats.vnc.updates.n") 113 | 114 | if self._local_t % self._log_interval == 0: 115 | cur_time = time.time() 116 | elapsed = cur_time - self._last_time 117 | fps = self._log_interval / elapsed 118 | self._last_time = cur_time 119 | cur_episode_id = info.get('vectorized.episode_id', 0) 120 | to_log["diagnostics/fps"] = fps 121 | if self._last_episode_id == cur_episode_id: 122 | to_log["diagnostics/fps_within_episode"] = fps 123 | self._last_episode_id = cur_episode_id 124 | if info.get("stats.gauges.diagnostics.lag.action") is not None: 125 | to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0] 126 | to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1] 127 | if info.get("reward.count") is not None: 128 | to_log["diagnostics/reward_count"] = info["reward.count"] 129 | if info.get("stats.gauges.diagnostics.clock_skew") is not None: 130 | to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0] 131 | to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1] 132 | if info.get("stats.gauges.diagnostics.lag.observation") is not None: 133 | to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0] 134 | to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1] 135 | 136 | if info.get("stats.vnc.updates.n") is not None: 137 | to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"] 138 | to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed 139 | self._num_vnc_updates = 0 140 | if info.get("stats.vnc.updates.bytes") is not None: 141 | to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"] 142 | if info.get("stats.vnc.updates.pixels") is not None: 143 | to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"] 144 | if info.get("stats.vnc.updates.rectangles") is not None: 145 | to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"] 146 | if info.get("env_status.state_id") is not None: 147 | to_log["diagnostics/env_state_id"] = info["env_status.state_id"] 148 | 149 | if reward is not None: 150 | self._episode_reward += reward 151 | if observation is not None: 152 | self._episode_length += 1 153 | self._all_rewards.append(reward) 154 | 155 | if done: 156 | logger.info('Episode terminating: episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length) 157 | total_time = time.time() - self._episode_time 158 | to_log["global/episode_reward"] = self._episode_reward 159 | to_log["global/episode_length"] = self._episode_length 160 | to_log["global/episode_time"] = total_time 161 | to_log["global/reward_per_time"] = self._episode_reward / total_time 162 | self._episode_reward = 0 163 | self._episode_length = 0 164 | self._all_rewards = [] 165 | 166 | return observation, reward, done, to_log 167 | 168 | def _process_frame42(frame): 169 | frame = frame[34:34+160, :160] 170 | # Resize by half, then down to 42x42 (essentially mipmapping). If 171 | # we resize directly we lose pixels that, when mapped to 42x42, 172 | # aren't close enough to the pixel boundary. 173 | frame = cv2.resize(frame, (80, 80)) 174 | frame = cv2.resize(frame, (42, 42)) 175 | frame = frame.mean(2) 176 | frame = frame.astype(np.float32) 177 | frame *= (1.0 / 255.0) 178 | frame = np.reshape(frame, [42, 42, 1]) 179 | return frame 180 | 181 | class AtariRescale42x42(vectorized.ObservationWrapper): 182 | def __init__(self, env=None): 183 | super(AtariRescale42x42, self).__init__(env) 184 | self.observation_space = Box(0.0, 1.0, [42, 42, 1]) 185 | 186 | def _observation(self, observation_n): 187 | return [_process_frame42(observation) for observation in observation_n] 188 | 189 | class FixedKeyState(object): 190 | def __init__(self, keys): 191 | self._keys = [keycode(key) for key in keys] 192 | self._down_keysyms = set() 193 | 194 | def apply_vnc_actions(self, vnc_actions): 195 | for event in vnc_actions: 196 | if isinstance(event, vnc_spaces.KeyEvent): 197 | if event.down: 198 | self._down_keysyms.add(event.key) 199 | else: 200 | self._down_keysyms.discard(event.key) 201 | 202 | def to_index(self): 203 | action_n = 0 204 | for key in self._down_keysyms: 205 | if key in self._keys: 206 | # If multiple keys are pressed, just use the first one 207 | action_n = self._keys.index(key) + 1 208 | break 209 | return action_n 210 | 211 | class DiscreteToFixedKeysVNCActions(vectorized.ActionWrapper): 212 | """ 213 | Define a fixed action space. Action 0 is all keys up. Each element of keys can be a single key or a space-separated list of keys 214 | 215 | For example, 216 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right']) 217 | will have 3 actions: [none, left, right] 218 | 219 | You can define a state with more than one key down by separating with spaces. For example, 220 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right', 'space', 'left space', 'right space']) 221 | will have 6 actions: [none, left, right, space, left space, right space] 222 | """ 223 | def __init__(self, env, keys): 224 | super(DiscreteToFixedKeysVNCActions, self).__init__(env) 225 | 226 | self._keys = keys 227 | self._generate_actions() 228 | self.action_space = spaces.Discrete(len(self._actions)) 229 | 230 | def _generate_actions(self): 231 | self._actions = [] 232 | uniq_keys = set() 233 | for key in self._keys: 234 | for cur_key in key.split(' '): 235 | uniq_keys.add(cur_key) 236 | 237 | for key in [''] + self._keys: 238 | split_keys = key.split(' ') 239 | cur_action = [] 240 | for cur_key in uniq_keys: 241 | cur_action.append(vnc_spaces.KeyEvent.by_name(cur_key, down=(cur_key in split_keys))) 242 | self._actions.append(cur_action) 243 | self.key_state = FixedKeyState(uniq_keys) 244 | 245 | def _action(self, action_n): 246 | # Each action might be a length-1 np.array. Cast to int to 247 | # avoid warnings. 248 | return [self._actions[int(action)] for action in action_n] 249 | 250 | class CropScreen(vectorized.ObservationWrapper): 251 | """Crops out a [height]x[width] area starting from (top,left) """ 252 | def __init__(self, env, height, width, top=0, left=0): 253 | super(CropScreen, self).__init__(env) 254 | self.height = height 255 | self.width = width 256 | self.top = top 257 | self.left = left 258 | self.observation_space = Box(0, 255, shape=(height, width, 3)) 259 | 260 | def _observation(self, observation_n): 261 | return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None 262 | for ob in observation_n] 263 | 264 | def _process_frame_flash(frame): 265 | frame = cv2.resize(frame, (200, 128)) 266 | frame = frame.mean(2).astype(np.float32) 267 | frame *= (1.0 / 255.0) 268 | frame = np.reshape(frame, [128, 200, 1]) 269 | return frame 270 | 271 | class FlashRescale(vectorized.ObservationWrapper): 272 | def __init__(self, env=None): 273 | super(FlashRescale, self).__init__(env) 274 | self.observation_space = Box(0.0, 1.0, [128, 200, 1]) 275 | 276 | def _observation(self, observation_n): 277 | return [_process_frame_flash(observation) for observation in observation_n] 278 | -------------------------------------------------------------------------------- /imgs/dusk_drive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/universe-starter-agent/293904f01b4180ecf92dd9536284548108074a44/imgs/dusk_drive.png -------------------------------------------------------------------------------- /imgs/neon_race.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/universe-starter-agent/293904f01b4180ecf92dd9536284548108074a44/imgs/neon_race.png -------------------------------------------------------------------------------- /imgs/tb_pong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/universe-starter-agent/293904f01b4180ecf92dd9536284548108074a44/imgs/tb_pong.png -------------------------------------------------------------------------------- /imgs/vnc_pong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/universe-starter-agent/293904f01b4180ecf92dd9536284548108074a44/imgs/vnc_pong.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.rnn as rnn 4 | import distutils.version 5 | use_tf100_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('1.0.0') 6 | 7 | def normalized_columns_initializer(std=1.0): 8 | def _initializer(shape, dtype=None, partition_info=None): 9 | out = np.random.randn(*shape).astype(np.float32) 10 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) 11 | return tf.constant(out) 12 | return _initializer 13 | 14 | def flatten(x): 15 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) 16 | 17 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None): 18 | with tf.variable_scope(name): 19 | stride_shape = [1, stride[0], stride[1], 1] 20 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] 21 | 22 | # there are "num input feature maps * filter height * filter width" 23 | # inputs to each hidden unit 24 | fan_in = np.prod(filter_shape[:3]) 25 | # each unit in the lower layer receives a gradient from: 26 | # "num output feature maps * filter height * filter width" / 27 | # pooling size 28 | fan_out = np.prod(filter_shape[:2]) * num_filters 29 | # initialize weights with random weights 30 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 31 | 32 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound), 33 | collections=collections) 34 | b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.constant_initializer(0.0), 35 | collections=collections) 36 | return tf.nn.conv2d(x, w, stride_shape, pad) + b 37 | 38 | def linear(x, size, name, initializer=None, bias_init=0): 39 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer) 40 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init)) 41 | return tf.matmul(x, w) + b 42 | 43 | def categorical_sample(logits, d): 44 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1]) 45 | return tf.one_hot(value, d) 46 | 47 | class LSTMPolicy(object): 48 | def __init__(self, ob_space, ac_space): 49 | self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space)) 50 | 51 | for i in range(4): 52 | x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2])) 53 | # introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim 54 | x = tf.expand_dims(flatten(x), [0]) 55 | 56 | size = 256 57 | if use_tf100_api: 58 | lstm = rnn.BasicLSTMCell(size, state_is_tuple=True) 59 | else: 60 | lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True) 61 | self.state_size = lstm.state_size 62 | step_size = tf.shape(self.x)[:1] 63 | 64 | c_init = np.zeros((1, lstm.state_size.c), np.float32) 65 | h_init = np.zeros((1, lstm.state_size.h), np.float32) 66 | self.state_init = [c_init, h_init] 67 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c]) 68 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h]) 69 | self.state_in = [c_in, h_in] 70 | 71 | if use_tf100_api: 72 | state_in = rnn.LSTMStateTuple(c_in, h_in) 73 | else: 74 | state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) 75 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn( 76 | lstm, x, initial_state=state_in, sequence_length=step_size, 77 | time_major=False) 78 | lstm_c, lstm_h = lstm_state 79 | x = tf.reshape(lstm_outputs, [-1, size]) 80 | self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01)) 81 | self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1]) 82 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] 83 | self.sample = categorical_sample(self.logits, ac_space)[0, :] 84 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 85 | 86 | def get_initial_features(self): 87 | return self.state_init 88 | 89 | def act(self, ob, c, h): 90 | sess = tf.get_default_session() 91 | return sess.run([self.sample, self.vf] + self.state_out, 92 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) 93 | 94 | def value(self, ob, c, h): 95 | sess = tf.get_default_session() 96 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})[0] 97 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from six.moves import shlex_quote 5 | 6 | parser = argparse.ArgumentParser(description="Run commands") 7 | parser.add_argument('-w', '--num-workers', default=1, type=int, 8 | help="Number of workers") 9 | parser.add_argument('-r', '--remotes', default=None, 10 | help='The address of pre-existing VNC servers and ' 11 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).') 12 | parser.add_argument('-e', '--env-id', type=str, default="PongDeterministic-v3", 13 | help="Environment id") 14 | parser.add_argument('-l', '--log-dir', type=str, default="/tmp/pong", 15 | help="Log directory path") 16 | parser.add_argument('-n', '--dry-run', action='store_true', 17 | help="Print out commands rather than executing them") 18 | parser.add_argument('-m', '--mode', type=str, default='tmux', 19 | help="tmux: run workers in a tmux session. nohup: run workers with nohup. child: run workers as child processes") 20 | 21 | # Add visualise tag 22 | parser.add_argument('--visualise', action='store_true', 23 | help="Visualise the gym environment by running env.render() between each timestep") 24 | 25 | 26 | def new_cmd(session, name, cmd, mode, logdir, shell): 27 | if isinstance(cmd, (list, tuple)): 28 | cmd = " ".join(shlex_quote(str(v)) for v in cmd) 29 | if mode == 'tmux': 30 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd)) 31 | elif mode == 'child': 32 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir) 33 | elif mode == 'nohup': 34 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir) 35 | 36 | 37 | def create_commands(session, num_workers, remotes, env_id, logdir, shell='bash', mode='tmux', visualise=False): 38 | # for launching the TF workers and for launching tensorboard 39 | base_cmd = [ 40 | 'CUDA_VISIBLE_DEVICES=', 41 | sys.executable, 'worker.py', 42 | '--log-dir', logdir, 43 | '--env-id', env_id, 44 | '--num-workers', str(num_workers)] 45 | 46 | if visualise: 47 | base_cmd += ['--visualise'] 48 | 49 | if remotes is None: 50 | remotes = ["1"] * num_workers 51 | else: 52 | remotes = remotes.split(',') 53 | assert len(remotes) == num_workers 54 | 55 | cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)] 56 | for i in range(num_workers): 57 | cmds_map += [new_cmd(session, 58 | "w-%d" % i, base_cmd + ["--job-name", "worker", "--task", str(i), "--remotes", remotes[i]], mode, logdir, shell)] 59 | 60 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir, "--port", "12345"], mode, logdir, shell)] 61 | if mode == 'tmux': 62 | cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)] 63 | 64 | windows = [v[0] for v in cmds_map] 65 | 66 | notes = [] 67 | cmds = [ 68 | "mkdir -p {}".format(logdir), 69 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir), 70 | ] 71 | if mode == 'nohup' or mode == 'child': 72 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)] 73 | notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)] 74 | if mode == 'tmux': 75 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)] 76 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)] 77 | else: 78 | notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)] 79 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"] 80 | 81 | if mode == 'tmux': 82 | cmds += [ 83 | "kill $( lsof -i:12345 -t ) > /dev/null 2>&1", # kill any process using tensorboard's port 84 | "kill $( lsof -i:12222-{} -t ) > /dev/null 2>&1".format(num_workers+12222), # kill any processes using ps / worker ports 85 | "tmux kill-session -t {}".format(session), 86 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell) 87 | ] 88 | for w in windows[1:]: 89 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)] 90 | cmds += ["sleep 1"] 91 | for window, cmd in cmds_map: 92 | cmds += [cmd] 93 | 94 | return cmds, notes 95 | 96 | 97 | def run(): 98 | args = parser.parse_args() 99 | cmds, notes = create_commands("a3c", args.num_workers, args.remotes, args.env_id, args.log_dir, mode=args.mode, visualise=args.visualise) 100 | if args.dry_run: 101 | print("Dry-run mode due to -n flag, otherwise the following commands would be executed:") 102 | else: 103 | print("Executing the following commands:") 104 | print("\n".join(cmds)) 105 | print("") 106 | if not args.dry_run: 107 | if args.mode == "tmux": 108 | os.environ["TMUX"] = "" 109 | os.system("\n".join(cmds)) 110 | print('\n'.join(notes)) 111 | 112 | 113 | if __name__ == "__main__": 114 | run() 115 | -------------------------------------------------------------------------------- /worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import cv2 3 | import go_vncdriver 4 | import tensorflow as tf 5 | import argparse 6 | import logging 7 | import sys, signal 8 | import time 9 | import os 10 | from a3c import A3C 11 | from envs import create_env 12 | import distutils.version 13 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0') 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | 18 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless. 19 | class FastSaver(tf.train.Saver): 20 | def save(self, sess, save_path, global_step=None, latest_filename=None, 21 | meta_graph_suffix="meta", write_meta_graph=True): 22 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 23 | meta_graph_suffix, False) 24 | 25 | def run(args, server): 26 | env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes) 27 | trainer = A3C(env, args.task, args.visualise) 28 | 29 | # Variable names that start with "local" are not saved in checkpoints. 30 | if use_tf12_api: 31 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")] 32 | init_op = tf.variables_initializer(variables_to_save) 33 | init_all_op = tf.global_variables_initializer() 34 | else: 35 | variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")] 36 | init_op = tf.initialize_variables(variables_to_save) 37 | init_all_op = tf.initialize_all_variables() 38 | saver = FastSaver(variables_to_save) 39 | 40 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 41 | logger.info('Trainable vars:') 42 | for v in var_list: 43 | logger.info(' %s %s', v.name, v.get_shape()) 44 | 45 | def init_fn(ses): 46 | logger.info("Initializing all parameters.") 47 | ses.run(init_all_op) 48 | 49 | config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)]) 50 | logdir = os.path.join(args.log_dir, 'train') 51 | 52 | if use_tf12_api: 53 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task) 54 | else: 55 | summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task) 56 | 57 | logger.info("Events directory: %s_%s", logdir, args.task) 58 | sv = tf.train.Supervisor(is_chief=(args.task == 0), 59 | logdir=logdir, 60 | saver=saver, 61 | summary_op=None, 62 | init_op=init_op, 63 | init_fn=init_fn, 64 | summary_writer=summary_writer, 65 | ready_op=tf.report_uninitialized_variables(variables_to_save), 66 | global_step=trainer.global_step, 67 | save_model_secs=30, 68 | save_summaries_secs=30) 69 | 70 | num_global_steps = 100000000 71 | 72 | logger.info( 73 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " + 74 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.") 75 | with sv.managed_session(server.target, config=config) as sess, sess.as_default(): 76 | sess.run(trainer.sync) 77 | trainer.start(sess, summary_writer) 78 | global_step = sess.run(trainer.global_step) 79 | logger.info("Starting training at step=%d", global_step) 80 | while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps): 81 | trainer.process(sess) 82 | global_step = sess.run(trainer.global_step) 83 | 84 | # Ask for all the services to stop. 85 | sv.stop() 86 | logger.info('reached %s steps. worker stopped.', global_step) 87 | 88 | def cluster_spec(num_workers, num_ps): 89 | """ 90 | More tensorflow setup for data parallelism 91 | """ 92 | cluster = {} 93 | port = 12222 94 | 95 | all_ps = [] 96 | host = '127.0.0.1' 97 | for _ in range(num_ps): 98 | all_ps.append('{}:{}'.format(host, port)) 99 | port += 1 100 | cluster['ps'] = all_ps 101 | 102 | all_workers = [] 103 | for _ in range(num_workers): 104 | all_workers.append('{}:{}'.format(host, port)) 105 | port += 1 106 | cluster['worker'] = all_workers 107 | return cluster 108 | 109 | def main(_): 110 | """ 111 | Setting up Tensorflow for data parallel work 112 | """ 113 | 114 | parser = argparse.ArgumentParser(description=None) 115 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.') 116 | parser.add_argument('--task', default=0, type=int, help='Task index') 117 | parser.add_argument('--job-name', default="worker", help='worker or ps') 118 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers') 119 | parser.add_argument('--log-dir', default="/tmp/pong", help='Log directory path') 120 | parser.add_argument('--env-id', default="PongDeterministic-v3", help='Environment id') 121 | parser.add_argument('-r', '--remotes', default=None, 122 | help='References to environments to create (e.g. -r 20), ' 123 | 'or the address of pre-existing VNC servers and ' 124 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901)') 125 | 126 | # Add visualisation argument 127 | parser.add_argument('--visualise', action='store_true', 128 | help="Visualise the gym environment by running env.render() between each timestep") 129 | 130 | args = parser.parse_args() 131 | spec = cluster_spec(args.num_workers, 1) 132 | cluster = tf.train.ClusterSpec(spec).as_cluster_def() 133 | 134 | def shutdown(signal, frame): 135 | logger.warn('Received signal %s: exiting', signal) 136 | sys.exit(128+signal) 137 | signal.signal(signal.SIGHUP, shutdown) 138 | signal.signal(signal.SIGINT, shutdown) 139 | signal.signal(signal.SIGTERM, shutdown) 140 | 141 | if args.job_name == "worker": 142 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task, 143 | config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=2)) 144 | run(args, server) 145 | else: 146 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task, 147 | config=tf.ConfigProto(device_filters=["/job:ps"])) 148 | while True: 149 | time.sleep(1000) 150 | 151 | if __name__ == "__main__": 152 | tf.app.run() 153 | --------------------------------------------------------------------------------