├── README.md └── multi_code ├── ps_local_sync.py ├── ps_distribute.py └── allreduce_local_sync.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This repo contains several muti-gpu/distributed training implementation in Tensorflow. 3 | 4 | It serve as a good starter training script: 5 | 6 | * allreduce_local_sync.py: synchronized update with allreduce. Recommended. 7 | * ps_local_sync.py: synchronized update with parameter server. This is recommended by Tensorflow documentation, though it is the least efficient. 8 | * ps_distribute.py: synchronized/asynchronized update with parameter server. Distributed implementation. 9 | 10 | There is a complementary article written in Chinese: https://zhuanlan.zhihu.com/p/50116885 11 | 12 | # Usage 13 | The code does not consumes any data: it generates fake data and fits them. 14 | The code shall be ran with Tensorflow >= 1.12. For Tensorflow 2.0, distributed strategy is recommended. 15 | Make sure you have multiple gpus :) 16 | 17 | Simply run 18 | ``` 19 | python allreduce_local_sync.py --gpus 0,1 --max_step 10000 20 | python ps_local_sync.py --gpus 0,1 --max_step 10000 21 | ``` 22 | For `ps_distribute.py`, make sure you have 2 GPU, and run the following commands in order: 23 | ``` 24 | # Cluster configuration is specified in the code, in this case, 1 ps and 2 worker 25 | # First start ps: 26 | $ python ps_distribute.py --job ps --index 0 27 | 28 | # Then start the workers: 29 | $ python ps_distribute.py --job worker --index 0 --gpu 0 --max_step 10000 30 | $ python ps_distribute.py --job worker --index 1 --gpu 1 --max_step 10000 31 | ``` 32 | 33 | For benchmarking, typically for different communication/computation ratio, simply modify the complexity of the model. 34 | 35 | 36 | -------------------------------------------------------------------------------- /multi_code/ps_local_sync.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | import argparse 5 | import os 6 | import tensorflow as tf 7 | 8 | 9 | def split_batches(num_splits, batches): 10 | batch_size = tf.shape(batches[0])[0] 11 | # evenly distributed sizes 12 | divisible_sizes = tf.fill([num_splits], tf.floor_div(batch_size, num_splits)) 13 | remainder_sizes = tf.sequence_mask(tf.mod(batch_size, num_splits), 14 | maxlen=num_splits, 15 | dtype=tf.int32) 16 | frag_sizes = divisible_sizes + remainder_sizes 17 | 18 | batch_frags_list = [] 19 | for batch in batches: 20 | batch_frags = tf.split(batch, frag_sizes, axis=0) 21 | batch_frags_list.append(batch_frags) 22 | 23 | frag_batches_list = list(zip(*batch_frags_list)) 24 | # fix corner case 25 | for i, frag_batches in enumerate(frag_batches_list): 26 | if len(frag_batches) == 1: 27 | frag_batches_list[i] = frag_batches[0] 28 | 29 | return frag_batches_list 30 | 31 | 32 | def build_dataset(num_gpus): 33 | rand = tf.random_normal([10000, 300]) 34 | rand_labels = tf.random_uniform([10000], minval=0, maxval=9, dtype=tf.int32) 35 | dataset = tf.data.Dataset.from_tensor_slices((rand, rand_labels)) 36 | dataset = dataset.repeat() 37 | dataset = dataset.batch(100) 38 | dataset = dataset.map(lambda rand, rand_labels: split_batches(num_splits=num_gpus, batches=[rand, rand_labels])) 39 | return dataset 40 | 41 | 42 | def build_tower(batch): 43 | feature, label = batch 44 | matrix = tf.get_variable('matrix', shape=[300, 500]) 45 | middle_matrix = tf.get_variable('middle_matrix', shape=[500, 500]) 46 | out_matrix = tf.get_variable('out_matrix', shape=[500, 10]) 47 | feature = tf.matmul(feature, matrix) 48 | for i in range(10): 49 | feature = tf.matmul(feature, middle_matrix) 50 | logits = tf.matmul(feature, out_matrix) 51 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 52 | logits=logits, labels=label)) 53 | return loss 54 | 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--gpus', default='0,1', type=str) 59 | parser.add_argument('--max_step', default=10000, type=int) 60 | args = parser.parse_args() 61 | args.num_gpus = len(args.gpus.split(",")) 62 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 63 | 64 | dataset = build_dataset(args.num_gpus) 65 | iterator = dataset.make_initializable_iterator() 66 | tower_batches = iterator.get_next() 67 | 68 | # build train graph 69 | tower_grads_list = [] 70 | tower_loss_list = [] 71 | # global variable scope 72 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 73 | for index, tower_batch in enumerate(tower_batches): 74 | with tf.device('/gpu:%d' % index): 75 | tower_loss = build_tower(tower_batch) 76 | if index == 0: 77 | # variables are first created on gpu 0 78 | # so gpu 0 is the parameter server 79 | tvars = tf.trainable_variables() 80 | tower_grads = tf.gradients(tower_loss, tvars) 81 | tower_grads_list.append(tower_grads) 82 | tower_loss_list.append(tower_loss) 83 | 84 | with tf.device('/gpu:0'): 85 | # update on parameter server 86 | loss = tf.add_n(tower_loss_list) / args.num_gpus 87 | avg_grads = [] 88 | for grad_list in zip(*tower_grads_list): 89 | # avoid making sparse gradients dense with simple tf.add_n 90 | grad_avg = tf.reduce_mean(tf.stack(grad_list, axis=0), axis=0) 91 | avg_grads.append(grad_avg) 92 | 93 | step = tf.train.get_or_create_global_step() 94 | optimizer = tf.train.AdamOptimizer() 95 | train_op = optimizer.apply_gradients(zip(avg_grads, tvars), 96 | global_step=step) 97 | saver = tf.train.Saver() 98 | 99 | # start running 100 | with tf.Session() as sess: 101 | sess.run(tf.global_variables_initializer()) 102 | sess.run(iterator.initializer) 103 | while True: 104 | try: 105 | fetch_loss, fetch_step, _ = sess.run([loss, step, train_op]) 106 | if fetch_step % 20 == 0: 107 | print("step: %d, loss: %.4f" % (fetch_step, fetch_loss)) 108 | if fetch_step > args.max_step: 109 | break 110 | except tf.errors.OutOfRangeError: 111 | break 112 | saver.save(sess, "./model") 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /multi_code/ps_distribute.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | import argparse 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def build_dataset(): 11 | rand = tf.random_normal([10000, 300]) 12 | rand_labels = tf.random_uniform([10000], minval=0, maxval=9, dtype=tf.int32) 13 | dataset = tf.data.Dataset.from_tensor_slices((rand, rand_labels)) 14 | dataset = dataset.repeat() 15 | dataset = dataset.batch(100) 16 | return dataset 17 | 18 | 19 | def build_tower(batch): 20 | feature, label = batch 21 | matrix = tf.get_variable('matrix', shape=[300, 500]) 22 | middle_matrix = tf.get_variable('middle_matrix', shape=[500, 500]) 23 | out_matrix = tf.get_variable('out_matrix', shape=[500, 10]) 24 | feature = tf.matmul(feature, matrix) 25 | for i in range(10): 26 | feature = tf.matmul(feature, middle_matrix) 27 | logits = tf.matmul(feature, out_matrix) 28 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 29 | logits=logits, labels=label)) 30 | return loss 31 | 32 | 33 | def main(): 34 | usage = """ 35 | First start ps: 36 | $ python ps_distribute.py --job ps --index 0 37 | 38 | Then start the workers: 39 | $ python ps_distribute.py --job worker --index 0 --gpu 0 --max_step 10000 40 | $ python ps_distribute.py --job worker --index 1 --gpu 1 --max_step 10000 41 | """ 42 | 43 | parser = argparse.ArgumentParser(usage=usage) 44 | parser.add_argument('--job', choices=['worker', 'ps']) 45 | parser.add_argument('--index', type=int) 46 | parser.add_argument('--gpu', type=int) 47 | parser.add_argument('--sync', action="store_true", help="Turn on synchronized gradient mode") 48 | parser.add_argument('--max_step', default=10000, type=int) 49 | 50 | # one ps, two worker 51 | ps = ["127.0.0.1:60000"] 52 | worker = ["127.0.0.1:60001", 53 | "127.0.0.1:60002"] 54 | args = parser.parse_args() 55 | 56 | cluster = tf.train.ClusterSpec( 57 | {"ps": ps, 58 | "worker": worker} 59 | ) 60 | 61 | if args.job == "ps": 62 | # ps on cpu 63 | os.environ["CUDA_VISIBLE_DEVICES"] = '' 64 | server = tf.train.Server(cluster, 65 | job_name=args.job, 66 | task_index=args.index) 67 | server.join() 68 | 69 | elif args.job == "worker": 70 | # worker on gpu 71 | # gpu index by worker index 72 | os.environ["CUDA_VISIBLE_DEVICES"] = '%d' % args.gpu 73 | server = tf.train.Server(cluster, 74 | job_name=args.job, 75 | task_index=args.index) 76 | 77 | dataset = build_dataset() 78 | iterator = dataset.make_initializable_iterator() 79 | batch = iterator.get_next() 80 | 81 | with tf.device( 82 | tf.train.replica_device_setter( 83 | worker_device="/job:worker/task:%d" % args.index, 84 | cluster=cluster)): 85 | loss = build_tower(batch) 86 | tvars = tf.trainable_variables() 87 | grads = tf.gradients(loss, tvars) 88 | step = tf.train.get_or_create_global_step() 89 | # avoid concurrent update 90 | optimizer = tf.train.GradientDescentOptimizer(0.1, use_locking=True) 91 | 92 | if args.sync: 93 | optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=len(worker)) 94 | hooks = [optimizer.make_session_run_hook(args.index == 0)] 95 | else: 96 | hooks = [] 97 | 98 | train_op = optimizer.apply_gradients(zip(grads, tvars), 99 | global_step=step) 100 | 101 | # if throws uninitialized errors, put your initializers here 102 | local_init_op = tf.group(tf.local_variables_initializer(), 103 | tf.tables_initializer(), 104 | iterator.initializer) 105 | scaffold = tf.train.Scaffold(local_init_op=local_init_op) 106 | hooks.append(tf.train.StopAtStepHook(args.max_step)) 107 | 108 | with tf.train.MonitoredTrainingSession( 109 | master=server.target, 110 | is_chief=(args.index == 0), 111 | scaffold=scaffold, 112 | checkpoint_dir='./', 113 | save_summaries_steps=15, 114 | save_checkpoint_steps=10000, 115 | hooks=hooks) as sess: 116 | 117 | while not sess.should_stop(): 118 | fetch_loss, fetch_step, _ = sess.run([loss, 119 | step, 120 | train_op]) 121 | 122 | if fetch_step % 20 == 0: 123 | print("job: %s, task_index: %d, step: %d, loss: %.4f" % 124 | (args.job, args.index, fetch_step, fetch_loss)) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /multi_code/allreduce_local_sync.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | import argparse 5 | import os 6 | import re 7 | import tensorflow as tf 8 | from packaging import version 9 | 10 | if version.parse(tf.__version__) >= version.parse("1.13.0"): 11 | # tf 1.13.0 move nccl from contrib into core 12 | from tensorflow.python.ops.nccl_ops import all_sum 13 | else: 14 | from tensorflow.contrib.nccl import all_sum 15 | 16 | 17 | def split_batches(num_splits, batches): 18 | batch_size = tf.shape(batches[0])[0] 19 | # evenly distributed sizes 20 | divisible_sizes = tf.fill([num_splits], tf.floor_div(batch_size, num_splits)) 21 | remainder_sizes = tf.sequence_mask(tf.mod(batch_size, num_splits), 22 | maxlen=num_splits, 23 | dtype=tf.int32) 24 | frag_sizes = divisible_sizes + remainder_sizes 25 | 26 | batch_frags_list = [] 27 | for batch in batches: 28 | batch_frags = tf.split(batch, frag_sizes, axis=0) 29 | batch_frags_list.append(batch_frags) 30 | 31 | frag_batches_list = list(zip(*batch_frags_list)) 32 | # fix corner case 33 | for i, frag_batches in enumerate(frag_batches_list): 34 | if len(frag_batches) == 1: 35 | frag_batches_list[i] = frag_batches[0] 36 | 37 | return frag_batches_list 38 | 39 | 40 | def build_dataset(num_gpus): 41 | rand = tf.random_normal([10000, 300]) 42 | rand_labels = tf.random_uniform([10000], minval=0, maxval=9, dtype=tf.int32) 43 | dataset = tf.data.Dataset.from_tensor_slices((rand, rand_labels)) 44 | dataset = dataset.repeat() 45 | dataset = dataset.batch(100) 46 | dataset = dataset.map(lambda rand, rand_labels: split_batches(num_splits=num_gpus, batches=[rand, rand_labels])) 47 | return dataset 48 | 49 | 50 | def build_tower(batch): 51 | feature, label = batch 52 | matrix = tf.get_variable('matrix', shape=[300, 500]) 53 | middle_matrix = tf.get_variable('middle_matrix', shape=[500, 500]) 54 | out_matrix = tf.get_variable('out_matrix', shape=[500, 10]) 55 | feature = tf.matmul(feature, matrix) 56 | for i in range(10): 57 | feature = tf.matmul(feature, middle_matrix) 58 | logits = tf.matmul(feature, out_matrix) 59 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 60 | logits=logits, labels=label)) 61 | return loss 62 | 63 | 64 | def main(): 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--gpus', default='0,1', type=str) 67 | parser.add_argument('--max_step', default=10000, type=int) 68 | args = parser.parse_args() 69 | args.num_gpus = len(args.gpus.split(",")) 70 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 71 | 72 | # avoid unimplemented gpu kernel error 73 | config = tf.ConfigProto(allow_soft_placement=True) 74 | with tf.Session(config=config) as sess: 75 | 76 | dataset = build_dataset(args.num_gpus) 77 | iterator = dataset.make_initializable_iterator() 78 | tower_batches = iterator.get_next() 79 | 80 | tower_grads_list = [] 81 | tower_tvars_list = [] 82 | tower_gvars_list = [] 83 | tower_loss_list = [] 84 | for index, tower_batch in enumerate(tower_batches): 85 | # by-device variable scope 86 | with tf.variable_scope("tower_%d" % index) as scope, \ 87 | tf.device('/gpu:%d' % index): 88 | 89 | tower_loss = build_tower(tower_batch) 90 | tower_gvars = tf.global_variables(scope._name) 91 | tower_tvars = tf.trainable_variables(scope._name) 92 | tower_grads = tf.gradients(tower_loss, tower_tvars) 93 | 94 | tower_loss_list.append(tower_loss) 95 | tower_tvars_list.append(tower_tvars) 96 | tower_gvars_list.append(tower_gvars) 97 | tower_grads_list.append(tower_grads) 98 | 99 | if index == 0: 100 | # only one variable global saver 101 | def clean(name): 102 | name = re.sub('^tower_\d+/', '', name) 103 | name = re.sub(':\d+$', '', name) 104 | return name 105 | 106 | save_dict = {clean(var.name): var 107 | for var in tower_gvars} 108 | saver = tf.train.Saver(save_dict) 109 | 110 | with tf.name_scope("tower_gvar_sync"): 111 | # different device is init with different random seed 112 | # need explicit synchronization before training!!! 113 | if len(tower_gvars_list) == 1: 114 | tower_gvar_sync = tf.no_op() 115 | else: 116 | sync_ops = [] 117 | for vars in zip(*tower_gvars_list): 118 | for var in vars[1:]: 119 | sync_ops.append(tf.assign(var, vars[0])) 120 | tower_gvar_sync = tf.group(*sync_ops) 121 | 122 | with tf.name_scope('all_reduce'): 123 | avg_tower_grads_list = [] 124 | for grads_to_avg in zip(*tower_grads_list): 125 | # nccl.all_sum will automatically 126 | # convert sparse gradients into dense one 127 | avg_tower_grads_list.append(all_sum(grads_to_avg)) 128 | avg_tower_grads_list = zip(*avg_tower_grads_list) 129 | 130 | with tf.name_scope('metrics'): 131 | loss = tf.add_n(tower_loss_list) / len(tower_loss_list) 132 | 133 | train_ops = [] 134 | for index, (tower_vars, tower_grads) in \ 135 | enumerate(zip(tower_tvars_list, avg_tower_grads_list)): 136 | with tf.variable_scope("tower_%d" % index), \ 137 | tf.device('/gpu:%d' % index): 138 | tower_grads = [grad / len(tower_batches) for grad in tower_grads] 139 | if index == 0: 140 | # only increment global step with the first worker 141 | step = tf.train.get_or_create_global_step() 142 | 143 | tower_optimizer = tf.train.AdamOptimizer() 144 | tower_train_op = tower_optimizer.apply_gradients(zip(tower_grads, tower_vars), 145 | global_step=step if index == 0 else None) 146 | train_ops.append(tower_train_op) 147 | train_op = tf.group(train_ops) 148 | 149 | # start running 150 | sess.run(tf.global_variables_initializer()) 151 | sess.run(iterator.initializer) 152 | # important to sync variables before training! 153 | sess.run(tower_gvar_sync) 154 | while True: 155 | try: 156 | fetch_loss, fetch_step, _ = sess.run([loss, step, train_op]) 157 | if fetch_step % 20 == 0: 158 | print("step: %d, loss: %.4f" % (fetch_step, fetch_loss)) 159 | if fetch_step > args.max_step: 160 | break 161 | except tf.errors.OutOfRangeError: 162 | break 163 | saver.save(sess, "./model") 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | --------------------------------------------------------------------------------