├── .gitignore
├── config.json
├── README.md
├── create_server.py
├── compute_simple.py
└── diagram
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project
2 | .idea
3 |
4 | # Python
5 | *.pyc
6 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "worker": [
3 | "10.10.100.87:2222",
4 | "10.10.102.74:2222"
5 | ],
6 | "host": [
7 | "10.10.100.87:2223"
8 | ]
9 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # distributed-tensorflow
2 | Distributed Tensorflow Example
3 |
4 |
5 | # Example
6 |
7 | The full explanation is [here](http://andersonjo.github.io/tensorflow/2017/10/17/Distributed-TensorFlow/)
8 |
9 | ## Configuration
10 | configure "config.json" file before running codes.
11 | The example requires at least two remote servers and each remote server has its own GPU.
12 |
13 | ## Create TensorFlow Server
14 |
15 | In remote server 1...
16 | ```
17 | python3 create_server.py --task=0
18 | ```
19 |
20 | In remote server 2...
21 | ```
22 | python3 create_server.py --task=1
23 | ```
24 |
25 | ## Run computation code
26 |
27 | ```
28 | python3 compute_simple.py
29 | ```
30 |
--------------------------------------------------------------------------------
/create_server.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import argparse
3 | import json
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--task', type=int, help='The task number')
7 | parser.add_argument('--job', type=str, default='worker', help='job name ("worker" or "host")')
8 | args = parser.parse_args()
9 |
10 | cluster_spec = json.load(open('config.json', 'rt'))
11 | cluster = tf.train.ClusterSpec(cluster_spec)
12 |
13 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2, allow_growth=True)
14 | server = tf.train.Server(cluster, job_name=args.job, task_index=args.task,
15 | config=tf.ConfigProto(gpu_options=gpu_options))
16 | server.start()
17 | server.join()
18 |
--------------------------------------------------------------------------------
/compute_simple.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import tensorflow as tf
4 |
5 | cluster_spec = json.load(open('config.json', 'rt'))
6 | cluster = tf.train.ClusterSpec(cluster_spec)
7 | server = tf.train.Server(cluster_spec, job_name='host', task_index=0)
8 |
9 | with tf.device('/job:worker/task:0'):
10 | a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
11 | b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
12 | # c = tf.matmul(a, b)
13 |
14 | with tf.device('/job:worker/task:1'):
15 | d = tf.matmul(a, b) + 100
16 |
17 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1, allow_growth=True)
18 | with tf.Session(server.target, config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
19 | result = sess.run(d)
20 | print(result)
21 |
--------------------------------------------------------------------------------
/diagram:
--------------------------------------------------------------------------------
1 | 1Vddb5swFP01PLYCG0jy2GXd+jKpUip1fXThBlAIZsZJyH79DLYBA1GzbPloIiX43Gv7cs7BNhaer8vvjOTxDxpCaiE7LC381ULIn/nitwL2CrCnEohYEkrIaYFF8hsUaCt0k4RQGImc0pQnuQkGNMsg4AZGGKM7M21JU3PWnEQwABYBSYfoaxLyWKJTz27xJ0iiWM/s2CryToJVxOgmU/NZCC/rjwyviR5L5RcxCemuA+FHC88ZpVxercs5pBW1mjbZ79uBaFM3g4wf0wHJDluSbkBX7Kei65clFSOIAvlekeL/2lAduCtqyR5EAnLzsg3WmqSUGT0ECXhafbuQH6n/erZ3DbxStgKmYVH4ez9VYLI2DSOjTFSzD9X92SK8ixMOi5wEVXQnzCqwmK9T0XKa3ltgHMqDFDqNMMLvQNfA2V6k6A6+0lJ53dPa7lrnIO2PuOsarECi3Bo1Y7eKiQsl2riA+CMBT5KjyEl2ivj/Iqqc87a1dpF3hNbembR2B1q/QFZQtkzFEoLsBbBtxfKVOUJ6rddrvzPkyBnj6H9Q5H0KityJdz2KHOfW1vwnWvDPtOKjmbniY227i6z4znDPfiHF6uqsYGw+9+5FTT22D/YIgSx8qM6GohWkpCiSoNrpOGF8CHeogTLhPysW7z3VelOcCmLYvhOqmm8GoRAOzpk9OkWFdMMCMI5joqYImrPlOOsdVr0Rq2mMQUp4sjWrGGNazfBMk3oN0C6YmFZ3pz2xZPWqV/eg2RsI954Zd9YbSN7zYCAhDdl30vIqoThcMPLH5/nbulrfyQpaFzYaHGfM4Y50BmPiScead/a9LfZgCTwDS0S5wM7mWXxrnvW83oHcPtGzrt3zhn8ez3rYPA5c2LOi2b5wyvT2pR4//gE=
--------------------------------------------------------------------------------