├── README.md ├── __init__.py ├── master.py ├── pytt ├── __init__.py ├── bencode.py ├── tracker.py └── utils.py ├── requirements.txt ├── server.py ├── sync.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Tensorpeers 2 | P2P peer-to-peer training of deep learning [tensorflow](https://github.com/tensorflow/tensorflow) models 3 | 4 | Not to be confused with [locally distributed training](https://www.tensorflow.org/how_tos/distributed/), 5 | However presumably a lot can be learned from tf.train.Supervisor etc 6 | 7 | ## Community Power 8 | In the Golden age of deep learning, baidu and others have shown that training time can scale almost linearly with the number of GPUs. 9 | This gives large corporations an advantage over startups in the run for the best A.I. systems ... until now. 10 | 11 | Tensorpeers will empower the community to combine their efforts and GPU time into quickly converging wonderful models. 12 | 13 | ## Architecture 14 | The architecture has to be slightly different from existing 'parameter server' schemes, because of relatively slow Internet connections. However our optimistic guess is that this won't hinder the success of this project: as long as we find any merging scheme, which successfully combines the *gained knowledge* of two separate runs, we should be fine. 15 | 16 | To speed things up, we base this project on python-libtorrent. 17 | 18 | ## Install dependency: 19 | MAC: 20 | `brew install libtorrent-rasterbar --with-python` 21 | LINUX: 22 | `apt-get install python-libtorrent` or 23 | `apt-get install python3-libtorrent` 24 | 25 | ## Open questions 26 | This is a wildly wide open research area, so if you want to make the world a better place (and or need a PhD thesis): 27 | Herewith you have full leverage. 28 | 29 | 30 | 35 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pannous/tensorpeers/f571827060f201e030f06a94ee276aa315746dda/__init__.py -------------------------------------------------------------------------------- /master.py: -------------------------------------------------------------------------------- 1 | # purpose of 'master tracker servers': 2 | 3 | # keeps the *best* checkpoint up to date 4 | # serves the current model checkpoint state to peers 5 | # download the progress of other peers 6 | # validate their claimed accuracy 7 | # reject checkpoints which are too far behind 8 | # merge the progress of good peer checkpoints into the *current best master* model 9 | # repeat ... 10 | # coordinate with possible other master tracker servers (later) 11 | import tensorpeers.sync 12 | import pytt.tracker 13 | import tensorflow as tf 14 | 15 | current_score=0.0 16 | tolerance=0.9 17 | 18 | def merge(model_name, path): 19 | print("accepting checkpoint with current best net") 20 | #todo 21 | 22 | def tracker(): 23 | print ("Keep track of all models, peers and checkpoints") 24 | pytt.tracker.start_tracker() 25 | pytt.tracker.get_peer_list() 26 | pytt.tracker.listen(new_peer_checkpoint_available) 27 | 28 | 29 | def evaluate(graph, checkpoint): 30 | print(" confirm the announced test accuracy (todo)") 31 | return current_score 32 | 33 | def new_peer_checkpoint_available(model_name,torrent): 34 | checkpoint=tensorpeers.sync.download(model_name, torrent) 35 | graph=tf.import_graph_def(model_name) 36 | score=evaluate(graph, checkpoint) 37 | if score= (3, 0): 179 | encode_func[bytes] = encode_string 180 | 181 | def translate(value): 182 | return value.translate(str.maketrans('', '', valid_chars)) 183 | else: 184 | encode_func[unicode] = encode_string 185 | import string 186 | 187 | def translate(value): 188 | return value.translate(string.maketrans('', ''), valid_chars) 189 | -------------------------------------------------------------------------------- /pytt/tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # BitTorrent Tracker using Tornado 4 | # 5 | # @author: Sreejith K 6 | # Created on 12th May 2011 7 | # http://foobarnbaz.com 8 | import logging 9 | from optparse import OptionParser 10 | import sys 11 | 12 | import tornado.ioloop 13 | import tornado.web 14 | import tornado.httpserver 15 | 16 | from .bencode import bencode 17 | from .utils import * 18 | 19 | 20 | logger = logging.getLogger('tornado.access') 21 | 22 | 23 | class TrackerStats(BaseHandler): 24 | """Shows the Tracker statistics on this page. 25 | """ 26 | @tornado.web.asynchronous 27 | def get(self): 28 | self.send_error(404) 29 | 30 | 31 | class AnnounceHandler(BaseHandler): 32 | """Track the torrents. Respond with the peer-list. 33 | """ 34 | @tornado.web.asynchronous 35 | def get(self): 36 | failure_reason = '' 37 | warning_message = '' 38 | 39 | # get all the required parameters from the HTTP request. 40 | info_hash = self.get_argument('info_hash') 41 | peer_id = self.get_argument('peer_id') 42 | ip = self.request.remote_ip 43 | port = self.get_argument('port') 44 | 45 | # send appropirate error code. 46 | if not info_hash: 47 | return self.send_error(MISSING_INFO_HASH) 48 | if not peer_id: 49 | return self.send_error(MISSING_PEER_ID) 50 | if not port: 51 | return self.send_error(MISSING_PORT) 52 | if len(info_hash) != INFO_HASH_LEN: 53 | return self.send_error(INVALID_INFO_HASH) 54 | if len(peer_id) != PEER_ID_LEN: 55 | return self.send_error(INVALID_PEER_ID) 56 | 57 | # Shelve in Python2 doesn't support unicode 58 | info_hash = str(info_hash) 59 | 60 | # get the optional parameters. 61 | # FIXME: these parameters will be used in future versions 62 | # uploaded = int(self.get_argument('uploaded', 0)) 63 | # downloaded = int(self.get_argument('downloaded', 0)) 64 | # left = int(self.get_argument('left', 0)) 65 | compact = int(self.get_argument('compact', 0)) 66 | no_peer_id = int(self.get_argument('no_peer_id', 0)) 67 | event = self.get_argument('event', '') 68 | numwant = int(self.get_argument('numwant', DEFAULT_ALLOWED_PEERS)) 69 | if numwant > MAX_ALLOWED_PEERS: 70 | # XXX: cannot request more than MAX_ALLOWED_PEERS. 71 | return self.send_error(INVALID_NUMWANT) 72 | 73 | # key = self.get_argument('key', '') 74 | tracker_id = self.get_argument('trackerid', '') 75 | 76 | # store the peer info 77 | if event: 78 | store_peer_info(info_hash, peer_id, ip, port, event) 79 | 80 | # generate response 81 | response = {} 82 | # Interval in seconds that the client should wait between sending 83 | # regular requests to the tracker. 84 | response['interval'] = get_config().getint('tracker', 'interval') 85 | # Minimum announce interval. If present clients must not re-announce 86 | # more frequently than this. 87 | response['min interval'] = get_config().getint('tracker', 88 | 'min_interval') 89 | # FIXME 90 | response['tracker id'] = tracker_id 91 | response['complete'] = no_of_seeders(info_hash) 92 | response['incomplete'] = no_of_leechers(info_hash) 93 | 94 | # get the peer list for this announce 95 | response['peers'] = get_peer_list(info_hash, 96 | numwant, 97 | compact, 98 | no_peer_id) 99 | 100 | # set error and warning messages for the client if any. 101 | if failure_reason: 102 | response['failure reason'] = failure_reason 103 | if warning_message: 104 | response['warning message'] = warning_message 105 | 106 | # send the bencoded response as text/plain document. 107 | self.set_header('Content-Type', 'text/plain') 108 | self.write(bencode(response)) 109 | self.finish() 110 | 111 | 112 | class ScrapeHandler(BaseHandler): 113 | """Returns the state of all torrents this tracker is managing. 114 | """ 115 | @tornado.web.asynchronous 116 | def get(self): 117 | info_hashes = self.get_arguments('info_hash') 118 | response = {} 119 | for info_hash in info_hashes: 120 | info_hash = str(info_hash) 121 | response[info_hash] = {} 122 | response[info_hash]['complete'] = no_of_seeders(info_hash) 123 | # FIXME: number of times clients have registered completion. 124 | response[info_hash]['downloaded'] = no_of_seeders(info_hash) 125 | response[info_hash]['incomplete'] = no_of_leechers(info_hash) 126 | # this is possible typo: 127 | # response[info_hash]['name'] = bdecode(info_hash).get(name, '') 128 | 129 | # send the bencoded response as text/plain document. 130 | self.set_header('content-type', 'text/plain') 131 | self.write(bencode(response)) 132 | self.finish() 133 | 134 | 135 | def run_app(port): 136 | """Start Tornado IOLoop for this application. 137 | """ 138 | tracker = tornado.web.Application([ 139 | (r"/announce.*", AnnounceHandler), 140 | (r"/scrape.*", ScrapeHandler), 141 | (r"/", TrackerStats), 142 | ]) 143 | logging.info('Starting Pytt tracker on port %d' % port) 144 | http_server = tornado.httpserver.HTTPServer(tracker) 145 | http_server.listen(port) 146 | tornado.ioloop.IOLoop.instance().start() 147 | 148 | 149 | def start_tracker(): 150 | """Start the Torrent Tracker. 151 | """ 152 | # parse commandline options 153 | parser = OptionParser() 154 | parser.add_option('-p', '--port', help='Tracker Port', default=0) 155 | parser.add_option('-b', '--background', action='store_true', 156 | default=False, help='Start in background') 157 | parser.add_option('-d', '--debug', action='store_true', 158 | default=False, help='Debug mode') 159 | (options, args) = parser.parse_args() 160 | 161 | # setup directories 162 | create_pytt_dirs() 163 | # setup logging 164 | setup_logging(options.debug) 165 | 166 | try: 167 | # start the torrent tracker 168 | run_app(int(options.port) or get_config().getint('tracker', 'port')) 169 | except KeyboardInterrupt: 170 | logging.info('Tracker Stopped.') 171 | close_db() 172 | sys.exit(0) 173 | except Exception as ex: 174 | logging.fatal('%s' % str(ex)) 175 | close_db() 176 | sys.exit(-1) 177 | 178 | 179 | if __name__ == '__main__': 180 | start_tracker() 181 | -------------------------------------------------------------------------------- /pytt/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Common utilities for Pytt. 4 | # 5 | # @author: Sreejith K 6 | # Created on 12th May 2011 7 | # http://foobarnbaz.com 8 | 9 | 10 | import os 11 | import logging 12 | import logging.handlers 13 | import shelve 14 | from socket import inet_aton 15 | from struct import pack 16 | import tornado.web 17 | import binascii 18 | try: 19 | from ConfigParser import RawConfigParser 20 | from httplib import responses 21 | except ImportError: 22 | from configparser import RawConfigParser 23 | from http.client import responses 24 | 25 | 26 | # Paths used by Pytt. 27 | CONFIG_PATH = os.path.expanduser('~/.pytt/config/pytt.conf') 28 | DB_PATH = os.path.expanduser('~/.pytt/db/pytt.db') 29 | LOG_PATH = os.path.expanduser('~/.pytt/log/pytt.log') 30 | 31 | # Some global constants. 32 | PEER_INCREASE_LIMIT = 30 33 | DEFAULT_ALLOWED_PEERS = 50 34 | MAX_ALLOWED_PEERS = 55 35 | INFO_HASH_LEN = 20 * 2 # info_hash is hexified. 36 | PEER_ID_LEN = 20 37 | 38 | # HTTP Error Codes for BitTorrent Tracker 39 | INVALID_REQUEST_TYPE = 100 40 | MISSING_INFO_HASH = 101 41 | MISSING_PEER_ID = 102 42 | MISSING_PORT = 103 43 | INVALID_INFO_HASH = 150 44 | INVALID_PEER_ID = 151 45 | INVALID_NUMWANT = 152 46 | GENERIC_ERROR = 900 47 | 48 | # Pytt response messages 49 | PYTT_RESPONSE_MESSAGES = { 50 | INVALID_REQUEST_TYPE: 'Invalid Request type', 51 | MISSING_INFO_HASH: 'Missing info_hash field', 52 | MISSING_PEER_ID: 'Missing peer_id field', 53 | MISSING_PORT: 'Missing port field', 54 | INVALID_INFO_HASH: 'info_hash is not %d bytes' % INFO_HASH_LEN, 55 | INVALID_PEER_ID: 'peer_id is not %d bytes' % PEER_ID_LEN, 56 | INVALID_NUMWANT: 'Peers more than %d is not allowed.' % MAX_ALLOWED_PEERS, 57 | GENERIC_ERROR: 'Error in request', 58 | } 59 | # add our response codes to httplib.responses 60 | responses.update(PYTT_RESPONSE_MESSAGES) 61 | 62 | logger = logging.getLogger('tornado.access') 63 | 64 | 65 | def setup_logging(debug=False): 66 | """Setup application logging. 67 | """ 68 | if debug: 69 | level = logging.DEBUG 70 | else: 71 | level = logging.INFO 72 | log_handler = logging.handlers.RotatingFileHandler(LOG_PATH, 73 | maxBytes=1024*1024, 74 | backupCount=2) 75 | root_logger = logging.getLogger('') 76 | root_logger.setLevel(level) 77 | format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 78 | formatter = logging.Formatter(format) 79 | log_handler.setFormatter(formatter) 80 | root_logger.addHandler(log_handler) 81 | 82 | 83 | def create_config(path): 84 | """Create default config file. 85 | """ 86 | logging.info('creating default config at %s' % CONFIG_PATH) 87 | config = RawConfigParser() 88 | config.add_section('tracker') 89 | config.set('tracker', 'port', '8080') 90 | config.set('tracker', 'interval', '5') 91 | config.set('tracker', 'min_interval', '1') 92 | with open(path, 'w') as f: 93 | config.write(f) 94 | 95 | 96 | def create_pytt_dirs(): 97 | """Create directories to store config, log and db files. 98 | """ 99 | logging.info('setting up directories for Pytt') 100 | for path in [CONFIG_PATH, DB_PATH, LOG_PATH]: 101 | dirname = os.path.dirname(path) 102 | if not os.path.exists(dirname): 103 | os.makedirs(dirname) 104 | # create the default config if its not there. 105 | if not os.path.exists(CONFIG_PATH): 106 | create_config(CONFIG_PATH) 107 | 108 | 109 | class BaseHandler(tornado.web.RequestHandler): 110 | """Since I dont like some tornado craps :-) 111 | """ 112 | def decode_argument(self, value, name): 113 | # info_hash is raw_bytes, hexify it. 114 | if name == 'info_hash': 115 | value = binascii.hexlify(value) 116 | return super(BaseHandler, self).decode_argument(value, name) 117 | 118 | 119 | class ConfigError(Exception): 120 | """Raised when config error occurs. 121 | """ 122 | 123 | 124 | class Config: 125 | """Provide a single entry point to the Configuration. 126 | """ 127 | __shared_state = {} 128 | 129 | def __init__(self): 130 | """Borg pattern. All instances will have same state. 131 | """ 132 | self.__dict__ = self.__shared_state 133 | 134 | def get(self): 135 | """Get the config object. 136 | """ 137 | if not hasattr(self, '__config'): 138 | self.__config = RawConfigParser() 139 | if self.__config.read(CONFIG_PATH) == []: 140 | raise ConfigError('No config at %s' % CONFIG_PATH) 141 | return self.__config 142 | 143 | def close(self): 144 | """Close config connection 145 | """ 146 | if not hasattr(self, '__config'): 147 | return 0 148 | del self.__config 149 | 150 | 151 | class Database: 152 | """Provide a single entry point to the database. 153 | """ 154 | __shared_state = {} 155 | 156 | def __init__(self): 157 | """Borg pattern. All instances will have same state. 158 | """ 159 | self.__dict__ = self.__shared_state 160 | 161 | def get(self): 162 | """Get the shelve object. 163 | """ 164 | if not hasattr(self, '__db'): 165 | self.__db = shelve.open(DB_PATH, writeback=True) 166 | return self.__db 167 | 168 | def close(self): 169 | """Close db connection 170 | """ 171 | if not hasattr(self, '__db'): 172 | return 0 173 | self.__db.close() 174 | del self.__db 175 | 176 | 177 | def get_config(): 178 | """Get a connection to the configuration. 179 | """ 180 | return Config().get() 181 | 182 | 183 | def get_db(): 184 | """Get a persistent connection to the database. 185 | """ 186 | return Database().get() 187 | 188 | 189 | def close_db(): 190 | """Close db connection. 191 | """ 192 | Database().close() 193 | 194 | 195 | def no_of_seeders(info_hash): 196 | """Number of peers with the entire file, aka "seeders". 197 | """ 198 | db = get_db() 199 | count = 0 200 | if info_hash in db: 201 | for peer_info in db[info_hash]: 202 | if peer_info[3] == 'completed': 203 | count += 1 204 | return count 205 | 206 | 207 | def no_of_leechers(info_hash): 208 | """Number of non-seeder peers, aka "leechers". 209 | """ 210 | db = get_db() 211 | count = 0 212 | if info_hash in db: 213 | for peer_info in db[info_hash]: 214 | if peer_info[3] == 'started': 215 | count += 1 216 | return count 217 | 218 | 219 | def store_peer_info(info_hash, peer_id, ip, port, status): 220 | """Store the information about the peer. 221 | """ 222 | db = get_db() 223 | if info_hash in db: 224 | if (peer_id, ip, port, status) not in db[info_hash]: 225 | db[info_hash].append((peer_id, ip, port, status)) 226 | else: 227 | db[info_hash] = [(peer_id, ip, port, status)] 228 | 229 | 230 | # TODO: add ipv6 support 231 | def get_peer_list(info_hash, numwant, compact, no_peer_id): 232 | """Get all the peer's info with peer_id, ip and port. 233 | Eg: [{'peer_id':'#1223&&IJM', 'ip':'162.166.112.2', 'port': '7887'}, ...] 234 | """ 235 | db = get_db() 236 | if compact: 237 | byteswant = numwant * 6 238 | compact_peers = b'' 239 | # make a compact peer list 240 | if info_hash in db: 241 | for peer_info in db[info_hash]: 242 | ip = inet_aton(peer_info[1]) 243 | port = pack('>H', int(peer_info[2])) 244 | compact_peers += (ip+port) 245 | logging.debug('compact peer list: %r' % compact_peers[:byteswant]) 246 | return compact_peers[:byteswant] 247 | else: 248 | peers = [] 249 | if info_hash in db: 250 | for peer_info in db[info_hash]: 251 | p = {} 252 | p['peer_id'], p['ip'], p['port'], _ = peer_info 253 | if no_peer_id: 254 | del p['peer_id'] 255 | peers.append(p) 256 | logging.debug('peer list: %r' % peers[:numwant]) 257 | return peers[:numwant] 258 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | numpy 3 | libtorrent 4 | # via 5 | # sudo apt-get install python-libtorrent 6 | # or 7 | # sudo apt-get install python3-libtorrent 8 | # or 9 | # sudo brew install libtorrent -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import tensorflow as tf 3 | 4 | cluster_spec = {'ps': ['localhost:2222'], 'worker': ['localhost:2223', 'localhost:2224']} 5 | 6 | import socket 7 | from flask import Flask 8 | 9 | 10 | try: 11 | from urllib2 import urlopen 12 | from urllib import urlretrieve, quote 13 | except ImportError: 14 | from urllib.request import urlopen, urlretrieve, quote # py3 HELL 15 | 16 | app = Flask(__name__) 17 | 18 | @app.route('/') 19 | def index(): 20 | return 'TensorPeers distributed training server!\n' 21 | 22 | 23 | @app.route('/register') 24 | def register(): 25 | return 'TensorPeers training client registered!\n' 26 | 27 | 28 | @app.route('/list') 29 | def list_clients(): 30 | return 'TensorPeers training client list:\n'+ client_list() 31 | 32 | 33 | @app.route('/start') 34 | def start(): 35 | print("waiting for clients") 36 | cluster = tf.train.ClusterSpec(cluster_spec) 37 | server = tf.train.Server(cluster, job_name="ps") 38 | server.join() 39 | return "waiting for clients" 40 | 41 | def client_list(): 42 | return "\n".join(myip) 43 | 44 | 45 | def download(url): # to memory 46 | return urlopen(url).read() 47 | 48 | host = socket.gethostname() 49 | print("host", host) 50 | myip = download('http://pannous.net/ip.php').strip() 51 | print("myip", myip) 52 | # local_ip=socket.gethostbyname(host) 53 | 54 | if __name__ == '__main__': 55 | app.run(debug=True, host='0.0.0.0', port=2221) 56 | 57 | -------------------------------------------------------------------------------- /sync.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import print_function 3 | import sys 4 | 5 | import time 6 | try: 7 | import libtorrent as lt 8 | except: 9 | print (""" You need to install the libtorrent dependency: One of 10 | apt-get install python-libtorrent 11 | apt-get install python3-libtorrent 12 | brew install libtorrent-rasterbar --with-python" 13 | 14 | pip install python-libtorrent DOESN'T WORK! use one of the above! 15 | building boost dependency might take a while 16 | """) 17 | 18 | exit(0) 19 | # alternatives to libtorrent: 20 | # https://github.com/Blender3D/torrent not a functional replacement yet (but pure python) 21 | # https://github.com/damoxc/spritzle ? 22 | 23 | def upload(model_name, path, checkpoint_nr): 24 | # Create torrent 25 | fs = lt.file_storage() 26 | lt.add_files(fs, path +"/" + checkpoint_nr) 27 | t = lt.create_torrent(fs) 28 | t.add_tracker("udp://tracker.openbittorrent.com:80/announce", 0) 29 | # t.add_tracker("udp://tracker.pannous.com:80/announce", 0) # see pytt 30 | t.set_creator('libtorrent %s' % lt.version) 31 | t.set_comment("checkpoint: " + checkpoint_nr) 32 | lt.set_piece_hashes(t, ".") 33 | torrent = t.generate() 34 | torrent_file = model_name+"_" + checkpoint_nr + ".torrent" 35 | f = open(torrent_file, "wb") 36 | f.write(lt.bencode(torrent)) 37 | f.close() 38 | 39 | # Seed torrent 40 | ses = lt.session() 41 | ses.listen_on(6881, 6891) 42 | h = ses.add_torrent({'ti': lt.torrent_info(torrent_file), 'save_path': '.', 'seed_mode': True}) 43 | print("Total size: " + str(h.status().total_wanted)) 44 | print("Name: " + h.name()) 45 | while True: 46 | s = h.status() 47 | state_str = ['queued', 'checking', 'downloading metadata', \ 48 | 'downloading', 'finished', 'seeding', 'allocating', 'checking fastresume'] 49 | 50 | msg = '\r%.2f%% complete (down: %.1f kb/s up: %.1f kB/s peers: %d) %s' 51 | print(msg % (s.progress * 100, s.download_rate / 1000, s.upload_rate / 1000, s.num_peers, state_str[s.state])) 52 | sys.stdout.flush() 53 | 54 | time.sleep(1) 55 | 56 | 57 | def download(model_name, checkpoint="current"): 58 | ses = lt.session() 59 | ses.listen_on(6881, 6891) 60 | 61 | e = lt.bdecode(open(model_name+"_"+checkpoint+".torrent", 'rb').read()) 62 | info = lt.torrent_info(e) 63 | 64 | params = {'save_path': '.', 'storage_mode': lt.storage_mode_t.storage_mode_sparse, 'ti': info} 65 | h = ses.add_torrent(params) 66 | 67 | s = h.status() 68 | while (not s.is_seeding): 69 | s = h.status() 70 | state_str = ['queued', 'checking', 'downloading metadata', \ 71 | 'downloading', 'finished', 'seeding', 'allocating'] 72 | msg = '%.2f%% complete (down: %.1f kb/s up: %.1f kB/s peers: %d) %s' 73 | print(msg % (s.progress * 100, s.download_rate / 1000, s.upload_rate / 1000, s.num_peers, state_str[s.state])) 74 | 75 | time.sleep(1) 76 | 77 | 78 | def sync(): 79 | upload() 80 | 81 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | cluster = tf.train.ClusterSpec({'ps': ['localhost:2222'], 3 | 'worker': ['localhost:2223', 'localhost:2224'] 4 | }) 5 | 6 | TASK_INDEX = -1# -1=server 0=master-worker or n>=1:worker > 7 | if TASK_INDEX==-1: 8 | print("waiting for clients") 9 | server = tf.train.Server(cluster, job_name="ps") 10 | server.join() 11 | 12 | from tensorflow.examples.tutorials.mnist import input_data 13 | 14 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 15 | 16 | server = tf.train.Server(cluster, job_name="worker", task_index=TASK_INDEX)#,shared=True) 17 | 18 | with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % TASK_INDEX, cluster=cluster)): 19 | x = tf.placeholder(tf.float32, shape=[None, 784]) 20 | y_ = tf.placeholder(tf.float32, shape=[None, 10]) 21 | W = tf.Variable(tf.zeros([784, 10])) 22 | b = tf.Variable(tf.zeros([10])) 23 | y = tf.matmul(x, W) + b 24 | logits = tf.nn.softmax_cross_entropy_with_logits(logits=y,labels= y_) 25 | cross_entropy = tf.reduce_mean(logits) 26 | global_step = tf.Variable(0) 27 | 28 | train_op = tf.train.AdagradOptimizer(0.01).minimize( 29 | cross_entropy, global_step=global_step) 30 | summary_op = tf.summary.merge_all() 31 | init_op = tf.global_variables_initializer() 32 | 33 | sv = tf.train.Supervisor(is_chief=(TASK_INDEX == 0), 34 | init_op=init_op, 35 | summary_op=summary_op, 36 | global_step=global_step) 37 | 38 | with sv.managed_session(server.target) as sess: 39 | step = 0 40 | batch_sz = 50 41 | iters = 55000 / batch_sz 42 | while not sv.should_stop() and step < iters: 43 | bx = mnist.train.images[step * batch_sz:(step + 1) * batch_sz] 44 | by = mnist.train.labels[step * batch_sz:(step + 1) * batch_sz] 45 | feed_dict={x: bx, y_: by} 46 | _, step = sess.run([train_op, global_step], feed_dict) 47 | 48 | # Ask for all the services to stop. 49 | sv.stop() 50 | init_feed_dict={} 51 | sess.run(init_op, feed_dict=init_feed_dict) 52 | --------------------------------------------------------------------------------