├── .gitignore ├── config.json ├── README.md ├── create_server.py ├── compute_simple.py └── diagram /.gitignore: -------------------------------------------------------------------------------- 1 | # Project 2 | .idea 3 | 4 | # Python 5 | *.pyc 6 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "worker": [ 3 | "10.10.100.87:2222", 4 | "10.10.102.74:2222" 5 | ], 6 | "host": [ 7 | "10.10.100.87:2223" 8 | ] 9 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # distributed-tensorflow 2 | Distributed Tensorflow Example 3 | 4 | 5 | # Example 6 | 7 | The full explanation is [here](http://andersonjo.github.io/tensorflow/2017/10/17/Distributed-TensorFlow/) 8 | 9 | ## Configuration 10 | configure "config.json" file before running codes.
11 | The example requires at least two remote servers and each remote server has its own GPU.
12 | 13 | ## Create TensorFlow Server 14 | 15 | In remote server 1... 16 | ``` 17 | python3 create_server.py --task=0 18 | ``` 19 | 20 | In remote server 2... 21 | ``` 22 | python3 create_server.py --task=1 23 | ``` 24 | 25 | ## Run computation code 26 | 27 | ``` 28 | python3 compute_simple.py 29 | ``` 30 | -------------------------------------------------------------------------------- /create_server.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | import json 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--task', type=int, help='The task number') 7 | parser.add_argument('--job', type=str, default='worker', help='job name ("worker" or "host")') 8 | args = parser.parse_args() 9 | 10 | cluster_spec = json.load(open('config.json', 'rt')) 11 | cluster = tf.train.ClusterSpec(cluster_spec) 12 | 13 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2, allow_growth=True) 14 | server = tf.train.Server(cluster, job_name=args.job, task_index=args.task, 15 | config=tf.ConfigProto(gpu_options=gpu_options)) 16 | server.start() 17 | server.join() 18 | -------------------------------------------------------------------------------- /compute_simple.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import tensorflow as tf 4 | 5 | cluster_spec = json.load(open('config.json', 'rt')) 6 | cluster = tf.train.ClusterSpec(cluster_spec) 7 | server = tf.train.Server(cluster_spec, job_name='host', task_index=0) 8 | 9 | with tf.device('/job:worker/task:0'): 10 | a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a') 11 | b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b') 12 | # c = tf.matmul(a, b) 13 | 14 | with tf.device('/job:worker/task:1'): 15 | d = tf.matmul(a, b) + 100 16 | 17 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1, allow_growth=True) 18 | with tf.Session(server.target, config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 19 | result = sess.run(d) 20 | print(result) 21 | -------------------------------------------------------------------------------- /diagram: -------------------------------------------------------------------------------- 1 | 1Vddb5swFP01PLYCG0jy2GXd+jKpUip1fXThBlAIZsZJyH79DLYBA1GzbPloIiX43Gv7cs7BNhaer8vvjOTxDxpCaiE7LC381ULIn/nitwL2CrCnEohYEkrIaYFF8hsUaCt0k4RQGImc0pQnuQkGNMsg4AZGGKM7M21JU3PWnEQwABYBSYfoaxLyWKJTz27xJ0iiWM/s2CryToJVxOgmU/NZCC/rjwyviR5L5RcxCemuA+FHC88ZpVxercs5pBW1mjbZ79uBaFM3g4wf0wHJDluSbkBX7Kei65clFSOIAvlekeL/2lAduCtqyR5EAnLzsg3WmqSUGT0ECXhafbuQH6n/erZ3DbxStgKmYVH4ez9VYLI2DSOjTFSzD9X92SK8ixMOi5wEVXQnzCqwmK9T0XKa3ltgHMqDFDqNMMLvQNfA2V6k6A6+0lJ53dPa7lrnIO2PuOsarECi3Bo1Y7eKiQsl2riA+CMBT5KjyEl2ivj/Iqqc87a1dpF3hNbembR2B1q/QFZQtkzFEoLsBbBtxfKVOUJ6rddrvzPkyBnj6H9Q5H0KityJdz2KHOfW1vwnWvDPtOKjmbniY227i6z4znDPfiHF6uqsYGw+9+5FTT22D/YIgSx8qM6GohWkpCiSoNrpOGF8CHeogTLhPysW7z3VelOcCmLYvhOqmm8GoRAOzpk9OkWFdMMCMI5joqYImrPlOOsdVr0Rq2mMQUp4sjWrGGNazfBMk3oN0C6YmFZ3pz2xZPWqV/eg2RsI954Zd9YbSN7zYCAhDdl30vIqoThcMPLH5/nbulrfyQpaFzYaHGfM4Y50BmPiScead/a9LfZgCTwDS0S5wM7mWXxrnvW83oHcPtGzrt3zhn8ez3rYPA5c2LOi2b5wyvT2pR4//gE= --------------------------------------------------------------------------------