├── LICENSE ├── README.md ├── cluster_dispatch.py ├── cluster_specs.py ├── distributed_tensorflow.sh └── example.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Peidong Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Distributed-TensorFlow-Using-MPI 2 | Template for Deploying Distributed TensorFlow on Clusters Using MPI 3 | 4 | ## Brief Description 5 | Scripts in this repository can be used on dynamically allocated clusters. It combines mpi4py with the typical distributed TensorFlow cluster settings (see the template at https://www.tensorflow.org/deploy/distributed). The basic idea is to use MPI to traverse among the nodes and set up the TensorFlow cluster. 6 | 7 | ## Contents 8 | 1. distributed_tensorflow.sh: the PBS script you can use while submitting jobs 9 | 2. cluster_specs.py: get the nodes you acquired for this job and convert this information into the corresponding ps_hosts and worker_hosts in TensorFlow 10 | 3. cluster_dispatch.py: specify job_name and task_index and deploy the job 11 | 4. example.py: a simple example script for testing. Note that the functions in this example may be depracated soon and you should design your model based on the template at https://www.tensorflow.org/deploy/distributed. Also note that this example script may have a different license from the one for this repository. 12 | 13 | ## Arguments 14 | Most of the arguments should be modified in distributed_tensorflow.sh. However, you can always change the stdout outputs (the contents in print()) in cluster_dispath.py. 15 | -------------------------------------------------------------------------------- /cluster_dispatch.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import socket 3 | import os 4 | import argparse 5 | 6 | FLAGS=None 7 | 8 | 9 | def main(): 10 | comm = MPI.COMM_WORLD 11 | rank = comm.Get_rank() 12 | 13 | ps_hosts_list = FLAGS.ps_hosts.split(',') 14 | worker_hosts_list = FLAGS.worker_hosts.split(',') 15 | num_ps_hosts = len(ps_hosts_list) 16 | num_worker_hosts = len(worker_hosts_list) 17 | num_hosts = num_ps_hosts + num_worker_hosts 18 | 19 | for rank_rotate in range(num_hosts): 20 | if rank == rank_rotate: 21 | print("I am rank " + str(rank_rotate) + "...") 22 | hostname = socket.gethostname() 23 | print("My hostname is: " + hostname) 24 | for ps_hosts_rotate in range(num_ps_hosts): 25 | if hostname == ps_hosts_list[ps_hosts_rotate].split(':')[0]: 26 | print("My job ID is: ps" + str(ps_hosts_rotate)) 27 | os.system("python -u " + FLAGS.script + " --ps_hosts=" + FLAGS.ps_hosts + " --worker_hosts=" + FLAGS.worker_hosts + " --job_name=ps --task_index=" + str(ps_hosts_rotate)) 28 | for worker_hosts_rotate in range(num_worker_hosts): 29 | if hostname == worker_hosts_list[worker_hosts_rotate].split(':')[0]: 30 | print("My job ID is: worker" + str(worker_hosts_rotate)) 31 | os.system("python -u " + FLAGS.script + " --ps_hosts=" + FLAGS.ps_hosts + " --worker_hosts=" + FLAGS.worker_hosts + " --job_name=worker --task_index=" + str(worker_hosts_rotate)) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.register("type", "bool", lambda v: v.lower() == "true") 37 | parser.add_argument( 38 | "--ps_hosts", 39 | type=str, 40 | default="", 41 | help="Comma-separated list of hostname:port pairs" 42 | ) 43 | parser.add_argument( 44 | "--worker_hosts", 45 | type=str, 46 | default="", 47 | help="Comma-separated list of hostname:port pairs" 48 | ) 49 | parser.add_argument( 50 | "--script", 51 | type=str, 52 | default="", 53 | help="The .py file you want to execute" 54 | ) 55 | 56 | FLAGS, unparsed = parser.parse_known_args() 57 | main() 58 | -------------------------------------------------------------------------------- /cluster_specs.py: -------------------------------------------------------------------------------- 1 | from more_itertools import unique_everseen 2 | import argparse 3 | 4 | FLAGS=None 5 | 6 | 7 | def main(): 8 | f = open(FLAGS.hosts_file,'r') 9 | hosts_list = [] 10 | for line in f: 11 | hosts_list.append(line.strip()) 12 | f.close() 13 | hosts_list = list(unique_everseen(hosts_list)) 14 | 15 | # all hosts other than ps are all treated as workers, .ten.osc.edu is for owens, for other clusters, you may change correspondingly 16 | ps_hosts = [hosts_list[i] + ".ten.osc.edu:2222" for i in range(FLAGS.num_ps_hosts)] 17 | worker_hosts = [hosts_list[i] + ".ten.osc.edu:2222" for i in range(len(ps_hosts), len(hosts_list))] 18 | 19 | print(','.join(ps_hosts), ','.join(worker_hosts)) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.register("type", "bool", lambda v: v.lower() == "true") 25 | parser.add_argument( 26 | "--hosts_file", 27 | type=str, 28 | default="" 29 | ) 30 | parser.add_argument( 31 | "--num_ps_hosts", 32 | type=int, 33 | default=1 34 | ) 35 | 36 | FLAGS, unparsed = parser.parse_known_args() 37 | main() 38 | -------------------------------------------------------------------------------- /distributed_tensorflow.sh: -------------------------------------------------------------------------------- 1 | #PBS -N distributed_tensorflow 2 | #PBS -l walltime=01:00:00 3 | #PBS -l nodes=3:ppn=28:gpus=1 4 | 5 | export PS_HOSTS=$(python //cluster_specs.py --hosts_file=$PBS_NODEFILE --num_ps_hosts=1 | cut -f1 -d ' ') 6 | export WORKER_HOSTS=$(python //cluster_specs.py --hosts_file=$PBS_NODEFILE --num_ps_hosts=1 | cut -f2 -d ' ') 7 | 8 | mpiexec -ppn 1 python -u //cluster_dispatch.py --ps_hosts=$PS_HOSTS --worker_hosts=$WORKER_HOSTS \ 9 | --script=//example.py 10 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | import sys 4 | import time 5 | 6 | FLAGS = None 7 | 8 | def main(_): 9 | 10 | ps_hosts = FLAGS.ps_hosts.split(',') 11 | worker_hosts = FLAGS.worker_hosts.split(',') 12 | 13 | cluster = tf.train.ClusterSpec({"ps":ps_hosts, "worker":worker_hosts}) 14 | 15 | # start a server for a specific task 16 | server = tf.train.Server( 17 | cluster, 18 | job_name=FLAGS.job_name, 19 | task_index=FLAGS.task_index) 20 | 21 | # config 22 | batch_size = 100 23 | learning_rate = 0.0005 24 | training_epochs = 20 25 | logs_path = "/tmp/train_logs" 26 | 27 | # load mnist data set 28 | from tensorflow.examples.tutorials.mnist import input_data 29 | # mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 30 | 31 | if FLAGS.job_name == "ps": 32 | server.join() 33 | elif FLAGS.job_name == "worker": 34 | 35 | # Between-graph replication 36 | with tf.device(tf.train.replica_device_setter( 37 | worker_device="/job:worker/task:%d" % FLAGS.task_index, 38 | cluster=cluster)): 39 | 40 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 41 | 42 | # count the number of updates 43 | global_step = tf.get_variable( 44 | 'global_step', 45 | [], 46 | initializer = tf.constant_initializer(0), 47 | trainable = False) 48 | 49 | # input images 50 | with tf.name_scope('input'): 51 | # None -> batch size can be any size, 784 -> flattened mnist image 52 | x = tf.placeholder(tf.float32, shape=[None, 784], name="x-input") 53 | # target 10 output classes 54 | y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y-input") 55 | 56 | # model parameters will change during training so we use tf.Variable 57 | tf.set_random_seed(1) 58 | with tf.name_scope("weights"): 59 | W1 = tf.Variable(tf.random_normal([784, 100])) 60 | W2 = tf.Variable(tf.random_normal([100, 10])) 61 | 62 | # bias 63 | with tf.name_scope("biases"): 64 | b1 = tf.Variable(tf.zeros([100])) 65 | b2 = tf.Variable(tf.zeros([10])) 66 | 67 | # implement model 68 | with tf.name_scope("softmax"): 69 | # y is our prediction 70 | z2 = tf.add(tf.matmul(x,W1),b1) 71 | a2 = tf.nn.sigmoid(z2) 72 | z3 = tf.add(tf.matmul(a2,W2),b2) 73 | y = tf.nn.softmax(z3) 74 | 75 | # specify cost function 76 | with tf.name_scope('cross_entropy'): 77 | # this is our cost 78 | cross_entropy = tf.reduce_mean( 79 | -tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 80 | 81 | # specify optimizer 82 | with tf.name_scope('train'): 83 | # optimizer is an "operation" which we can execute in a session 84 | grad_op = tf.train.GradientDescentOptimizer(learning_rate) 85 | ''' 86 | rep_op = tf.train.SyncReplicasOptimizer( 87 | grad_op, 88 | replicas_to_aggregate=len(workers), 89 | replica_id=FLAGS.task_index, 90 | total_num_replicas=len(workers), 91 | use_locking=True) 92 | train_op = rep_op.minimize(cross_entropy, global_step=global_step) 93 | ''' 94 | train_op = grad_op.minimize(cross_entropy, global_step=global_step) 95 | 96 | ''' 97 | init_token_op = rep_op.get_init_tokens_op() 98 | chief_queue_runner = rep_op.get_chief_queue_runner() 99 | ''' 100 | 101 | with tf.name_scope('Accuracy'): 102 | # accuracy 103 | correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 104 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 105 | 106 | # create a summary for our cost and accuracy 107 | tf.summary.scalar("cost", cross_entropy) 108 | tf.summary.scalar("accuracy", accuracy) 109 | 110 | # merge all summaries into a single "operation" which we can execute in a session 111 | summary_op = tf.summary.merge_all() 112 | init_op = tf.global_variables_initializer() 113 | print("Variables initialized ...") 114 | 115 | sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), 116 | global_step=global_step, 117 | init_op=init_op) 118 | 119 | begin_time = time.time() 120 | frequency = 100 121 | with sv.prepare_or_wait_for_session(server.target) as sess: 122 | ''' 123 | # is chief 124 | if FLAGS.task_index == 0: 125 | sv.start_queue_runners(sess, [chief_queue_runner]) 126 | sess.run(init_token_op) 127 | ''' 128 | # create log writer object (this will log on every machine) 129 | writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph()) 130 | 131 | # perform training cycles 132 | start_time = time.time() 133 | for epoch in range(training_epochs): 134 | 135 | # number of batches in one epoch 136 | batch_count = int(mnist.train.num_examples/batch_size) 137 | 138 | count = 0 139 | for i in range(batch_count): 140 | batch_x, batch_y = mnist.train.next_batch(batch_size) 141 | 142 | # perform the operations we defined earlier on batch 143 | _, cost, summary, step = sess.run( 144 | [train_op, cross_entropy, summary_op, global_step], 145 | feed_dict={x: batch_x, y_: batch_y}) 146 | writer.add_summary(summary, step) 147 | 148 | count += 1 149 | if count % frequency == 0 or i+1 == batch_count: 150 | elapsed_time = time.time() - start_time 151 | start_time = time.time() 152 | print("Step: %d," % (step+1), 153 | " Epoch: %2d," % (epoch+1), 154 | " Batch: %3d of %3d," % (i+1, batch_count), 155 | " Cost: %.4f," % cost, 156 | " AvgTime: %3.2fms" % float(elapsed_time*1000/frequency)) 157 | count = 0 158 | 159 | 160 | print("Test-Accuracy: %2.2f" % sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 161 | print("Total Time: %3.2fs" % float(time.time() - begin_time)) 162 | print("Final Cost: %.4f" % cost) 163 | 164 | sv.stop() 165 | print("done") 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.register("type", "bool", lambda v: v.lower() == "true") 171 | parser.add_argument( 172 | "--ps_hosts", 173 | type=str, 174 | default="", 175 | help="Comma-separated list of hostname:port pairs" 176 | ) 177 | parser.add_argument( 178 | "--worker_hosts", 179 | type=str, 180 | default="", 181 | help="Comma-separated list of hostname:port pairs" 182 | ) 183 | parser.add_argument( 184 | "--job_name", 185 | type=str, 186 | default="", 187 | help="One of 'ps', 'worker'" 188 | ) 189 | parser.add_argument( 190 | "--task_index", 191 | type=int, 192 | default=0, 193 | help="Index of task within the job" 194 | ) 195 | 196 | FLAGS, unparsed = parser.parse_known_args() 197 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 198 | --------------------------------------------------------------------------------