├── hogwild ├── __init__.py ├── data.py ├── utils.py └── mnist.py ├── README.md ├── preproc └── preproc.py ├── parallel_simulator ├── thread_test.py └── process_test.py ├── .gitignore └── the_thing ├── utils.py └── a3c.py /hogwild/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A3C implementation in Python 2 | 3 | THIS IS STILL WORK IN PROGRESS - IT DOES NOT WORK YET! 4 | 5 | ### Lessons learned 6 | 1. Hogwild - turns out that to run hogwild in Tensorflow one needs to create separate tensorflow graphs with shared variables. Example in `hogwild/mnist.py` 7 | 2. Parallel simulators with OpenAI gym - turns out that running more than 4 threads with Atari simulator in parallel is not supported. Check out `parallel_simulator/thread_test.py`. On second try, I forced gym to be run on a separate process, this way the bug does not occur (`parallel_simulator/process_test.py`) 8 | 3. To make sure that the preprocessing makes sense, I visualize it in `preproc/preproc.py`. It behaves a bit weird for Pong, but should be possible to recover policy from it. 9 | -------------------------------------------------------------------------------- /preproc/preproc.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import matplotlib.pyplot as plt 3 | import tensorflow as tf 4 | import time 5 | 6 | def preproc_graph(session, input_shape): 7 | prev_frame = tf.placeholder(tf.uint8, [1] + list(input_shape)) 8 | cur_frame = tf.placeholder(tf.uint8, [1] + list(input_shape)) 9 | out = tf.maximum(tf.cast(prev_frame, tf.float32), 10 | tf.cast(cur_frame, tf.float32)) 11 | out = tf.image.rgb_to_grayscale(out) 12 | out = tf.image.resize_bilinear(out, (84, 84)) / 255.0 13 | def f(pf, cf): 14 | return session.run(out, {prev_frame:pf[None], cur_frame: cf[None]})[0] 15 | return f 16 | 17 | def main(): 18 | session = tf.Session() 19 | 20 | env = gym.make('Pong-v0') 21 | last_observation = env.reset() 22 | preproc_f = preproc_graph(session, env.observation_space.shape) 23 | 24 | fig, ax = plt.subplots(figsize=(6,6)) 25 | plt.ion() 26 | 27 | for _ in range(1000): 28 | observation, _, _, _ = env.step(env.action_space.sample()) 29 | print("wtf?") 30 | pp = preproc_f(last_observation, observation) 31 | print("wtf!") 32 | 33 | ax.imshow(pp[:,:,0]) 34 | plt.pause(0.05) 35 | 36 | print("Let the bodies hit the floor") 37 | 38 | last_observation = observation 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /parallel_simulator/thread_test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | 4 | from threading import Thread, Lock 5 | 6 | class AtomicInt(object): 7 | def __init__(self, value=0): 8 | self._value = 0 9 | self._lock = Lock() 10 | 11 | def inc(self): 12 | with self._lock: 13 | self._value += 1 14 | 15 | def get(self): 16 | return self._value 17 | 18 | def main(): 19 | TARGET_ITERATIONS = 10000 20 | NUM_THREADS = 4 21 | num_iterations = AtomicInt() 22 | 23 | start_time = time.time() 24 | def run_simulator(): 25 | while num_iterations.get() < TARGET_ITERATIONS: 26 | env = gym.make('Pong-v0') 27 | env.reset() 28 | done = False 29 | while not done and num_iterations.get() < TARGET_ITERATIONS: 30 | # env.render() 31 | _, _, done, _ = env.step(env.action_space.sample()) # take a random action 32 | num_iterations.inc() 33 | 34 | threads = [Thread(target=run_simulator) for _ in range(NUM_THREADS)] 35 | for thread in threads: 36 | thread.start() 37 | for thread in threads: 38 | thread.join() 39 | total_time = time.time() - start_time 40 | 41 | print('Total time for %d iterations: %.3f seconds' % (num_iterations.get(), total_time)) 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /parallel_simulator/process_test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import time 3 | 4 | from multiprocessing.managers import BaseManager, NamespaceProxy 5 | from threading import Thread, Lock 6 | 7 | 8 | def make_wraper(*args, **kwargs): 9 | env = gym.make(*args, **kwargs) 10 | env.get_action_space = lambda: env.action_space 11 | env.get_observation_space = lambda: env.observation_space 12 | return env 13 | 14 | class GymOnProcess(BaseManager): 15 | pass 16 | 17 | GymOnProcess.register('Env', make_wraper) 18 | 19 | class AtomicInt(object): 20 | def __init__(self, value=0): 21 | self._value = 0 22 | self._lock = Lock() 23 | 24 | def inc(self): 25 | with self._lock: 26 | self._value += 1 27 | 28 | def get(self): 29 | return self._value 30 | 31 | def main(): 32 | TARGET_ITERATIONS = 10000 33 | NUM_THREADS = 12 34 | num_iterations = AtomicInt() 35 | 36 | start_time = time.time() 37 | def run_simulator(): 38 | with GymOnProcess() as gym_on_process: 39 | env = gym_on_process.Env('Pong-v0') 40 | action_space = env.get_action_space() 41 | while num_iterations.get() < TARGET_ITERATIONS: 42 | env.reset() 43 | done = False 44 | while not done and num_iterations.get() < TARGET_ITERATIONS: 45 | # env.render() 46 | _, _, done, _ = env.step(action_space.sample()) # take a random action 47 | num_iterations.inc() 48 | 49 | threads = [Thread(target=run_simulator) for _ in range(NUM_THREADS)] 50 | for thread in threads: 51 | thread.start() 52 | for thread in threads: 53 | thread.join() 54 | total_time = time.time() - start_time 55 | 56 | print('Total time for %d iterations: %.3f seconds' % (num_iterations.get(), total_time)) 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /hogwild/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.datasets import fetch_mldata 4 | 5 | class Data(object): 6 | def __init__(self, batch_size, validation_size): 7 | self.batch_size = batch_size 8 | 9 | # Load MNIST 10 | mnist = fetch_mldata('MNIST original') 11 | X, Y_labels = mnist['data'], mnist['target'] 12 | 13 | # normalize X to (0.0, 1.0) range 14 | X = X.astype(np.float32) / 255.0 15 | 16 | # one hot encode the labels 17 | Y = np.zeros((len(Y_labels), 10)) 18 | Y[range(len(Y_labels)), Y_labels.astype(np.int32)] = 1. 19 | 20 | # ensure type is float32 21 | X = X.astype(np.float32) 22 | Y = Y.astype(np.float32) 23 | 24 | # shuffle examples 25 | permutation = np.random.permutation(len(X)) 26 | X = X[permutation] 27 | Y = Y[permutation] 28 | 29 | # split into train, validate, test 30 | train_end = 60000 - validation_size 31 | validation_end = 60000 32 | test_end = 70000 33 | 34 | self.X_train = X[0:train_end] 35 | self.X_valid = X[train_end:validation_end] 36 | self.X_test = X[validation_end:test_end] 37 | 38 | self.Y_train = Y[0:train_end] 39 | self.Y_valid = Y[train_end:validation_end] 40 | self.Y_test = Y[validation_end:test_end] 41 | 42 | def iterate_batches(self, data_x, data_y): 43 | assert len(data_x) == len(data_y) 44 | 45 | for batch_start in range(0, len(data_x), self.batch_size): 46 | batch_x = data_x[batch_start:(batch_start + self.batch_size)] 47 | batch_y = data_y[batch_start:(batch_start + self.batch_size)] 48 | 49 | yield batch_x, batch_y 50 | 51 | def iterate_train(self): 52 | return self.iterate_batches(self.X_train, self.Y_train) 53 | 54 | def iterate_validate(self): 55 | return self.iterate_batches(self.X_valid, self.Y_valid) 56 | 57 | def iterate_test(self): 58 | return self.iterate_batches(self.X_test, self.Y_test) 59 | -------------------------------------------------------------------------------- /hogwild/utils.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from concurrent.futures import ThreadPoolExecutor 4 | from contextlib import contextmanager 5 | from queue import Queue 6 | from threading import Semaphore 7 | 8 | 9 | class BlockOnFullThreadPool(ThreadPoolExecutor): 10 | def __init__(self, *args, **kwargs): 11 | max_workers = None 12 | if len(args) > 0: 13 | max_workers = args[0] 14 | elif 'max_workers' in kwargs: 15 | max_workers = kwargs['max_workers'] 16 | queue_size = kwargs.pop('queue_size', 0) 17 | 18 | super(BlockOnFullThreadPool, self).__init__(*args, **kwargs) 19 | assert type(max_workers) is int and type(queue_size) is int 20 | self._block_on_full_sem = Semaphore(max_workers + queue_size) 21 | 22 | def submit(self, *args, **kwargs): 23 | self._block_on_full_sem.acquire() 24 | future = super(BlockOnFullThreadPool, self).submit(*args, **kwargs) 25 | future.add_done_callback(self._on_task_done) 26 | return future 27 | 28 | def _on_task_done(self, result): 29 | exp = result.exception() 30 | if exp is not None: 31 | msg = '-' * 80 + '\n' 32 | msg += 'Exception occured in ThreadPoolExecutor:\n' 33 | msg += ''.join(traceback.format_tb(exp.__traceback__)) 34 | msg += exp.__class__.__name__ + ': ' + str(exp) + '\n' 35 | msg += '-' * 80 36 | print(msg) 37 | self._block_on_full_sem.release() 38 | 39 | 40 | 41 | class SharedResource(object): 42 | def __init__(self, objs): 43 | self.q = Queue() 44 | while len(objs) > 0: 45 | self.q.put(objs.pop()) 46 | 47 | @contextmanager 48 | def lease(self): 49 | r = None 50 | try: 51 | r = self.q.get() 52 | yield r 53 | finally: 54 | if r is not None: 55 | self.q.put(r) 56 | 57 | 58 | 59 | if __name__ == '__main__': 60 | import time 61 | 62 | def work(i): 63 | print('start work', i) 64 | time.sleep(1) 65 | if i == 5: 66 | print(1/0) 67 | print('end work', i) 68 | 69 | with BlockOnFullThreadPool(max_workers=2) as pool: 70 | for i in range(10): 71 | print('adding to queue') 72 | pool.submit(work, i) 73 | pool.shutdown(wait=True) 74 | print('done') 75 | -------------------------------------------------------------------------------- /the_thing/utils.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | from contextlib import contextmanager 5 | from multiprocessing.managers import BaseManager 6 | from queue import Queue 7 | from threading import Thread, Lock 8 | 9 | def make_wraper(*args, **kwargs): 10 | env = gym.make(*args, **kwargs) 11 | env.get_action_space = lambda: env.action_space 12 | env.get_observation_space = lambda: env.observation_space 13 | return env 14 | 15 | class GymOnProcess(BaseManager): 16 | pass 17 | 18 | GymOnProcess.register('Env', make_wraper) 19 | 20 | class AtomicInt(object): 21 | def __init__(self, value=0): 22 | self._value = 0 23 | self._lock = Lock() 24 | 25 | def inc(self): 26 | with self._lock: 27 | self._value += 1 28 | 29 | def get(self): 30 | return self._value 31 | 32 | class SharedResource(object): 33 | def __init__(self, objs): 34 | self.q = Queue() 35 | while len(objs) > 0: 36 | self.q.put(objs.pop()) 37 | 38 | @contextmanager 39 | def lease(self): 40 | r = None 41 | try: 42 | r = self.q.get() 43 | yield r 44 | finally: 45 | if r is not None: 46 | self.q.put(r) 47 | 48 | class ExplorationSchedule(object): 49 | def __init__(self, desc): 50 | self.schedules = [] 51 | for schedule_desc in desc.split('-'): 52 | selection_p, initial, final, steps = schedule_desc.split(':') 53 | selection_p, initial, final = map(float, [selection_p, initial, final]) 54 | steps = int(float(steps)) 55 | 56 | assert 0 <= selection_p <= 1 57 | assert 0 <= initial <= 1 58 | assert 0 <= final <= 1 59 | self.schedules.append((selection_p, initial, final, steps)) 60 | self.current_schedule = None 61 | 62 | def reset(self): 63 | choice_idx = np.random.choice(len(self.schedules), p=[s[0] for s in self.schedules]) 64 | self.current_schedule = self.schedules[choice_idx] 65 | 66 | def random_action_p(self, steps): 67 | assert self.current_schedule is not None, "must call reset first" 68 | _, initial, final, final_steps = self.current_schedule 69 | fraction_complete = min(steps / final_steps, 1.0) 70 | current_p = initial + (final - initial) * fraction_complete 71 | return current_p 72 | 73 | def should_be_random(self, steps): 74 | return np.random.random() < self.random_action_p(steps) 75 | 76 | def __str__(self): 77 | ret = 'Exploration schedule:\n' 78 | for schedule in self.schedules: 79 | ret += ' - With p=%.2f, annealing from %.2f to %.2f over %d steps\n' % schedule 80 | return ret 81 | -------------------------------------------------------------------------------- /hogwild/mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tensorflow implementation of a Convolutional Network 3 | for the MNIST dataset. Adapted based on the tutorial: 4 | 5 | https://www.tensorflow.org/versions/r0.9/tutorials/mnist/pros/index.html#deep-mnist-for-experts 6 | 7 | Uses Hogwild! 8 | """ 9 | 10 | import argparse 11 | import tensorflow as tf 12 | import time 13 | 14 | from data import Data 15 | from utils import BlockOnFullThreadPool, SharedResource 16 | 17 | 18 | def weight_variable(shape, name): 19 | initializer = tf.truncated_normal_initializer(stddev=0.1) 20 | return tf.get_variable(name, shape, initializer=initializer) 21 | 22 | def bias_variable(shape, name): 23 | initializer = tf.constant_initializer(0.1) 24 | return tf.get_variable(name, shape, initializer=initializer) 25 | 26 | def conv2d(x, W): 27 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 28 | 29 | def max_pool_2x2(x): 30 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 31 | strides=[1, 2, 2, 1], padding='SAME') 32 | 33 | def forward(image, keep_prob): 34 | # Reshape to 4D 35 | image_4d = tf.reshape(image, [-1,28,28,1]) 36 | 37 | # Conv1 38 | W_conv1 = weight_variable([5, 5, 1, 32], "W1") 39 | b_conv1 = bias_variable([32], "b1") 40 | 41 | h_conv1 = tf.nn.relu(conv2d(image_4d, W_conv1) + b_conv1) 42 | h_pool1 = max_pool_2x2(h_conv1) 43 | 44 | # Conv2 45 | W_conv2 = weight_variable([5, 5, 32, 64], "w2") 46 | b_conv2 = bias_variable([64], "b2") 47 | 48 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 49 | h_pool2 = max_pool_2x2(h_conv2) 50 | 51 | # FC1 (with dropout) 52 | W_fc1 = weight_variable([7 * 7 * 64, 1024], "w3") 53 | b_fc1 = bias_variable([1024], "b3") 54 | 55 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 56 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 57 | 58 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 59 | 60 | # FC2 61 | W_fc2 = weight_variable([1024, 10], "w4") 62 | b_fc2 = bias_variable([10], "b4") 63 | 64 | y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 65 | 66 | return y_conv 67 | 68 | def build_graph(reuse): 69 | with tf.variable_scope('model', reuse=reuse): 70 | x = tf.placeholder(tf.float32, shape=[None, 784]) 71 | y_ = tf.placeholder(tf.float32, shape=[None, 10]) 72 | keep_prob = tf.placeholder(tf.float32) 73 | 74 | y_conv = forward(x, keep_prob) 75 | 76 | cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1])) 77 | 78 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 79 | 80 | correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 81 | 82 | num_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32)) 83 | 84 | no_op = tf.no_op() 85 | 86 | return x, y_, keep_prob, train_step, num_correct, no_op 87 | 88 | 89 | def accuracy(session, graphs, data_iter, num_threads, train=False): 90 | num_total = 0 91 | num_correct = 0 92 | 93 | def process_batch(batch_x, batch_y): 94 | nonlocal num_correct 95 | nonlocal num_total 96 | with graphs.lease() as g: 97 | input_placeholder, output_placeholder, keep_prob_placeholder, train_step_f, num_correct_f, no_op = g 98 | batch_num_correct, _ = session.run( 99 | [num_correct_f, train_step_f if train else no_op], 100 | { 101 | input_placeholder: batch_x, 102 | output_placeholder: batch_y, 103 | keep_prob_placeholder: 0.5 if train else 1.0, 104 | }) 105 | num_correct += batch_num_correct 106 | num_total += len(batch_x) 107 | 108 | with BlockOnFullThreadPool(max_workers=num_threads, queue_size=num_threads // 2) as pool: 109 | for i, (batch_x, batch_y) in enumerate(data_iter): 110 | pool.submit(process_batch, batch_x, batch_y) 111 | pool.shutdown(wait=True) 112 | 113 | return float(num_correct) / float(num_total) 114 | 115 | def main(args): 116 | with tf.device("cpu"): 117 | data = Data(batch_size=args.batch_size, validation_size=6000) 118 | 119 | session = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=args.num_threads)) 120 | graphs = SharedResource([build_graph(reuse=i > 0) for i in range(args.num_threads)]) 121 | 122 | session.run(tf.initialize_all_variables()) 123 | 124 | train_total_time_sum = 0 125 | for epoch in range(args.num_epochs): 126 | train_start_time = time.time() 127 | train_accuracy = accuracy(session, graphs, data.iterate_train(), num_threads=args.num_threads, train=True) 128 | train_total_time = time.time() - train_start_time 129 | train_total_time_sum += train_total_time 130 | 131 | validate_accuracy = accuracy(session, graphs, data.iterate_validate(), num_threads=args.num_threads, train=False) 132 | 133 | print ("Training epoch number %d:" % (epoch,)) 134 | print (" Time to train = %.3f s" % (train_total_time)) 135 | print (" Training set accuracy = %.1f %%" % (100.0 * train_accuracy,)) 136 | print (" Validation set accuracy = %.1f %%" % (100.0 * validate_accuracy,)) 137 | print ("") 138 | print ("Training done.") 139 | 140 | test_accuracy = accuracy(session, graphs, data.iterate_test(), num_threads=args.num_threads, train=False) 141 | print (" Average time per training epoch = %.3f s" % (train_total_time_sum / NUM_EPOCHS,)) 142 | print (" Test set accuracy = %.1f %%" % (100.0 * test_accuracy,)) 143 | 144 | def parse_args(): 145 | parser = argparse.ArgumentParser(description='Hogwild training on MNIST.') 146 | parser.add_argument('--num_threads', type=int, default=9, help='number of threads to use') 147 | parser.add_argument('--num_epochs', type=int, default=32, help='number of epochs') 148 | parser.add_argument('--batch_size', type=int, default=32, help='number of examples to use in each iteration of SGD') 149 | return parser.parse_args() 150 | 151 | if __name__ == '__main__': 152 | main(parse_args()) 153 | -------------------------------------------------------------------------------- /the_thing/a3c.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from collections import deque 7 | from scipy.misc import imresize 8 | from tensorflow.contrib import layers 9 | from threading import Thread 10 | 11 | from utils import ( 12 | GymOnProcess, 13 | ExplorationSchedule, 14 | AtomicInt, 15 | SharedResource 16 | ) 17 | 18 | def preproc_graph(session, input_shape): 19 | prev_frame = tf.placeholder(tf.uint8, [1] + list(input_shape)) 20 | cur_frame = tf.placeholder(tf.uint8, [1] + list(input_shape)) 21 | out = tf.maximum(tf.cast(prev_frame, tf.float32), 22 | tf.cast(cur_frame, tf.float32)) 23 | out = tf.image.rgb_to_grayscale(out) 24 | out = tf.image.resize_bilinear(out, (84, 84)) / 255.0 25 | def f(pf, cf): 26 | return session.run(out, {prev_frame:pf[None], cur_frame: cf[None]})[0] 27 | return f 28 | 29 | def forward(image, num_actions): 30 | # Conv1 31 | out = layers.convolution2d(image, num_outputs=16, kernel_size=8, stride=4, activation_fn=tf.nn.relu, scope='conv1') 32 | out = layers.convolution2d(out, num_outputs=32, kernel_size=4, stride=2, activation_fn=tf.nn.relu, scope='conv2') 33 | out = layers.flatten(out, scope='flatten') 34 | out = layers.fully_connected(out, num_outputs=256, activation_fn=tf.nn.relu, scope='fc1') 35 | 36 | action_logprobs = tf.nn.log_softmax(layers.fully_connected(out, num_outputs=num_actions, activation_fn=None, scope='fc_actor')) 37 | value = layers.fully_connected(out, num_outputs=1, activation_fn=None, scope='fc_critic') 38 | value = tf.reshape(value, [-1]) 39 | return action_logprobs, value 40 | 41 | def a3c_graph(args, session, num_actions, reuse): 42 | with tf.variable_scope('a3c', reuse=reuse): 43 | states_ph = tf.placeholder(tf.float32, [None, 84, 84, args.context]) 44 | actions_ph = tf.placeholder(tf.int32, [None]) 45 | discounted_rewards_ph = tf.placeholder(tf.float32, [None]) 46 | 47 | action_logprobs, value = forward(states_ph, num_actions) 48 | 49 | def run_actor(state): 50 | ret = session.run(action_logprobs, {states_ph: state[None]})[0] 51 | return np.argmax(ret) 52 | 53 | def run_critic(state): 54 | return session.run(value, {states_ph: state[None]})[0] 55 | 56 | action_mask = layers.one_hot_encoding(actions_ph, num_classes=num_actions) 57 | chosen_actions_logprobs = tf.reduce_sum(action_mask * action_logprobs, 1) 58 | 59 | actor_policy_advantage = chosen_actions_logprobs * (discounted_rewards_ph - value) 60 | actor_policy_advantage = tf.reduce_mean(actor_policy_advantage) 61 | 62 | actor_policy_entropy = tf.reduce_sum(-tf.exp(action_logprobs) * action_logprobs, 1) 63 | actor_policy_entropy = tf.reduce_mean(actor_policy_entropy) 64 | 65 | critic_error = (value - discounted_rewards_ph)**2 66 | critic_error = tf.reduce_mean(critic_error) 67 | 68 | optimizer = tf.train.AdamOptimizer(learning_rate=args.lr, beta1=0.5) 69 | 70 | update_op = tf.group( 71 | optimizer.minimize(-(actor_policy_advantage + args.beta * actor_policy_entropy)), 72 | optimizer.minimize(critic_error) 73 | ) 74 | 75 | def perform_update(states, actions, discounted_rewards): 76 | session.run(update_op, { 77 | states_ph: states, 78 | actions_ph: actions, 79 | discounted_rewards_ph: discounted_rewards, 80 | }) 81 | return run_actor, run_critic, perform_update 82 | 83 | def run_simulator(args, global_num_frames, preproc_graphs, a3c_graphs): 84 | with preproc_graphs.lease() as preproc_single_frame: 85 | with a3c_graphs.lease() as a3c_graph: 86 | with GymOnProcess() as gym_on_process: 87 | env = gym_on_process.Env(args.env) 88 | action_space = env.get_action_space() 89 | observation_space = env.get_observation_space() 90 | exploration = ExplorationSchedule(args.exploration) 91 | 92 | unprocessed_images = deque([observation_space.low for _ in range(args.context)], maxlen=args.context) 93 | processed_images = deque([np.zeros((84, 84, 1)) for _ in range(args.context)], maxlen=args.context) 94 | 95 | observations, actions, rewards = [], [], [] 96 | 97 | run_actor, run_critic, perform_update = a3c_graph 98 | 99 | def preproc(observation): 100 | unprocessed_images.append(observation) 101 | processed_images.append(preproc_single_frame(unprocessed_images[-2], unprocessed_images[-1])) 102 | return np.concatenate(processed_images, 2) 103 | 104 | while global_num_frames.get() < args.max_frames: 105 | observation = env.reset() 106 | exploration.reset() 107 | done = False 108 | 109 | observations.append(preproc(observation)) 110 | 111 | while not done and global_num_frames.get() < args.max_frames: 112 | if exploration.should_be_random(global_num_frames.get()): 113 | action = action_space.sample() 114 | else: 115 | action = run_actor(observations[-1]) 116 | next_observation, reward, done, info = env.step(action) 117 | actions.append(action) 118 | rewards.append(reward) 119 | 120 | next_observation_preproc = preproc(next_observation) 121 | 122 | global_num_frames.inc() 123 | if done or len(observations) >= args.n_step: 124 | assert len(observations) == len(actions) and len(actions) == len(rewards) 125 | n = len(observations) 126 | if done: 127 | value_f = 0 128 | else: 129 | value_f = run_critic(next_observation_preproc) 130 | discounted_rewards = [0 for i in range(n)] 131 | for i in reversed(range(n)): 132 | value_f = rewards[i] + args.gamma * value_f 133 | discounted_rewards[i] = value_f 134 | 135 | perform_update(np.array(observations), np.array(actions), np.array(discounted_rewards)) 136 | observations, actions, rewards = [], [], [] 137 | 138 | if not done: 139 | observations.append(next_observation_preproc) 140 | 141 | def main(args): 142 | with tf.device("cpu"): 143 | session_config = tf.ConfigProto(intra_op_parallelism_threads=1, 144 | inter_op_parallelism_threads=args.num_threads) 145 | session = tf.Session(config=session_config) 146 | 147 | env = gym.make(args.env) 148 | assert type(env.observation_space) is gym.spaces.Box 149 | input_shape = env.observation_space.shape 150 | assert type(env.action_space) is gym.spaces.Discrete 151 | num_actions = env.action_space.n 152 | env.close() 153 | 154 | global_num_frames = AtomicInt() 155 | preproc_graphs = SharedResource([preproc_graph(session, input_shape) 156 | for _ in range(args.num_threads)]) 157 | a3c_graphs = SharedResource([a3c_graph(args, session, num_actions, reuse=(i > 0)) 158 | for i in range(args.num_threads)]) 159 | 160 | session.run(tf.initialize_all_variables()) 161 | 162 | threads = [] 163 | 164 | for thread_idx in range(args.num_threads): 165 | thread = Thread(target=run_simulator, args=(args, global_num_frames, preproc_graphs, a3c_graphs)) 166 | threads.append(thread) 167 | thread.start() 168 | for thread in threads: 169 | thread.join() 170 | 171 | def parse_args(): 172 | parser = argparse.ArgumentParser(description='Hogwild training on MNIST.') 173 | parser.add_argument('--num_threads', type=int, default=9, help='number of threads to use') 174 | parser.add_argument('--max_frames', type=int, default=2e8, help='max number of game frames to learn for.') 175 | parser.add_argument('--n_step', type=int, default=5, help='steps of simulation between every update.') 176 | parser.add_argument('--context', type=int, default=4, help='how many past frames to use as input to the network.') 177 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate.') 178 | parser.add_argument('--beta', type=float, default=0.01, help='coefficient in front of entropy regularization.') 179 | parser.add_argument('--gamma', type=float, default=0.99, help='discount_factor.') 180 | parser.add_argument('--env', type=str, default='Pong-v0', help='Which OpenAI gym env to use.') 181 | parser.add_argument('--exploration', type=str, default='0.4:1:0.1:4e6-0.3:1:0.01:4e6-0.3:1:0.5:4e6', 182 | help='Exploration schedule. List of hyphen separated schedules. Each schedule consist of four colon spearated numbers. Probability of selecting that schedule, initial probability, final probability and number of frames over which the probability is linearly annealed between initial and final value.') 183 | args = parser.parse_args() 184 | 185 | print(str(ExplorationSchedule(args.exploration))) 186 | 187 | return args 188 | 189 | if __name__ == '__main__': 190 | main(parse_args()) 191 | --------------------------------------------------------------------------------