├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── Basic Classifier.ipynb ├── Basic Distributed Classifier.ipynb ├── Basic Federated Classifier.ipynb ├── LICENSE ├── README.md ├── advanced_classifier.py ├── advanced_distributed_classifier.py ├── advanced_federated_classifier.py ├── basic_classifier.py ├── basic_distributed_classifier.py ├── basic_federated_classifier.py ├── federated-MPI ├── README.md ├── mpi_advanced_classifier.py └── mpi_basic_classifier.py ├── federated-keras ├── README.md ├── keras_distributed_classifier.py └── keras_federated_classifier.py ├── federated-sockets ├── FederatedHook.py ├── README.md ├── advanced_socket_fed_classifier.py ├── basic_socket_fed_classifier.py └── config.py ├── federated_averaging_optimizer.py └── images ├── Logo_Acuratio.png ├── colab_logo.png ├── comindorg_logo.png ├── graph_tensorboard.png ├── slack_logo.jpg └── telegram_logo.jpg /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | **Describe the bug** 8 | A clear and concise description of what the bug is. 9 | 10 | **To Reproduce** 11 | Steps to reproduce the behavior: 12 | 1. Go to '...' 13 | 2. Click on '....' 14 | 3. Scroll down to '....' 15 | 4. See error 16 | 17 | **Expected behavior** 18 | A clear and concise description of what you expected to happen. 19 | 20 | **Screenshots** 21 | If applicable, add screenshots to help explain your problem. 22 | 23 | **Desktop (please complete the following information):** 24 | - OS: [e.g. iOS] 25 | - Browser [e.g. chrome, safari] 26 | - Version [e.g. 22] 27 | 28 | **Smartphone (please complete the following information):** 29 | - Device: [e.g. iPhone6] 30 | - OS: [e.g. iOS8.1] 31 | - Browser [e.g. stock browser, safari] 32 | - Version [e.g. 22] 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe.** 8 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 9 | 10 | **Describe the solution you'd like** 11 | A clear and concise description of what you want to happen. 12 | 13 | **Describe alternatives you've considered** 14 | A clear and concise description of any alternative solutions or features you've considered. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | IMPORTANT: Please create an issue for your Pull Request. 2 | 3 | Please provide enough information so that others can review your pull request: 4 | 5 | Explain the details for making this change. What existing problem does the pull request solve? 6 | 7 | Closing issues 8 | 9 | Put closes #XXXX in your comment to auto-close the issue that your PR fixes (if such). 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Basic Federated Classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Basic federated classifier with TensorFlow" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "The code in this notebook is copyright 2018 coMind. Licensed under the Apache License, Version 2.0; you may not use this code except in compliance with the License. You may obtain a copy of the License.\n", 15 | "\n", 16 | "Join the conversation at Slack.\n", 17 | "\n", 18 | "This a series of three tutorials you are in the last one: \n", 19 | "* [Basic Classifier](https://github.com/coMindOrg/federated-averaging-tutorials/blob/master/Basic%20Classifier.ipynb)\n", 20 | "* [Basic Distributed Classifier](https://github.com/coMindOrg/federated-averaging-tutorials/blob/master/Basic%20Distributed%20Classifier.ipynb)\n", 21 | "* [Basic Federated Classifier](https://github.com/coMindOrg/federated-averaging-tutorials/blob/master/Basic%20Federated%20Classifier.ipynb)\n", 22 | "\n", 23 | "In this tutorial we will see how to train a model using federated averaging.\n", 24 | "\n", 25 | "To begin a brief explanation of what it means to train using federated averaging with respect to training using a SyncReplicasOptimizer.\n", 26 | "\n", 27 | "In the previous tutorial, we explained that with SyncReplicasOptimizer each worker generated a gradient for its weights and wrote it to the parameter server. The chief read those gradients (including its own), it averaged them and updated the shared model.\n", 28 | "\n", 29 | "This time each worker will be updating its weights locally, as if it were the only one training. Every certain number of steps it will send its weights (not the gradients, but the weights themselves) to the parameter server. The chief will read the weights from there, it will average and write them again to the parameter server so that all the workers can overwrite theirs." 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "The entire first part of the code is the same as the distributed classifier tutorial.\n", 37 | "\n", 38 | "Two differences only:\n", 39 | "\n", 40 | "- This time we also import __federated_average_optimizer__, the library with which we can federalize learning.\n", 41 | "- On the other hand we define the variable __INTERVAL_STEPS__. Every how many steps we will perform the average of the weights. Put another way, how many steps will each worker make in local before writing their weights in the parameter server and overwriting them with the average that the chief has made." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# TensorFlow and tf.keras\n", 51 | "import tensorflow as tf\n", 52 | "from tensorflow import keras\n", 53 | "\n", 54 | "# Helper libraries\n", 55 | "import os\n", 56 | "import numpy as np\n", 57 | "from time import time\n", 58 | "import matplotlib.pyplot as plt\n", 59 | "import federated_averaging_optimizer\n", 60 | "\n", 61 | "flags = tf.app.flags\n", 62 | "flags.DEFINE_integer(\"task_index\", None,\n", 63 | " \"Worker task index, should be >= 0. task_index=0 is \"\n", 64 | " \"the master worker task that performs the variable \"\n", 65 | " \"initialization \")\n", 66 | "flags.DEFINE_string(\"ps_hosts\", \"localhost:2222\",\n", 67 | " \"Comma-separated list of hostname:port pairs\")\n", 68 | "flags.DEFINE_string(\"worker_hosts\", \"localhost:2223,localhost:2224\",\n", 69 | " \"Comma-separated list of hostname:port pairs\")\n", 70 | "flags.DEFINE_string(\"job_name\", None, \"job name: worker or ps\")\n", 71 | "\n", 72 | "BATCH_SIZE = 32\n", 73 | "EPOCHS = 5\n", 74 | "INTERVAL_STEPS = 10\n", 75 | "\n", 76 | "FLAGS = flags.FLAGS\n", 77 | "\n", 78 | "if FLAGS.job_name is None or FLAGS.job_name == \"\":\n", 79 | " raise ValueError(\"Must specify an explicit `job_name`\")\n", 80 | "if FLAGS.task_index is None or FLAGS.task_index == \"\":\n", 81 | " raise ValueError(\"Must specify an explicit `task_index`\")\n", 82 | "\n", 83 | "if FLAGS.task_index == 0:\n", 84 | " print('--- GPU Disabled ---')\n", 85 | " os.environ['CUDA_VISIBLE_DEVICES'] = ''\n", 86 | "\n", 87 | "#Construct the cluster and start the server\n", 88 | "ps_spec = FLAGS.ps_hosts.split(\",\")\n", 89 | "worker_spec = FLAGS.worker_hosts.split(\",\")\n", 90 | "\n", 91 | "# Get the number of workers.\n", 92 | "num_workers = len(worker_spec)\n", 93 | "print('{} workers defined'.format(num_workers))\n", 94 | "\n", 95 | "cluster = tf.train.ClusterSpec({\"ps\": ps_spec, \"worker\": worker_spec})\n", 96 | "\n", 97 | "server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)\n", 98 | "if FLAGS.job_name == \"ps\":\n", 99 | " print('--- Parameter Server Ready ---')\n", 100 | " server.join()\n", 101 | "\n", 102 | "fashion_mnist = keras.datasets.fashion_mnist\n", 103 | "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n", 104 | "print('Data loaded')\n", 105 | "\n", 106 | "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", 107 | " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n", 108 | "\n", 109 | "train_images = np.split(train_images, num_workers)[FLAGS.task_index]\n", 110 | "train_labels = np.split(train_labels, num_workers)[FLAGS.task_index]\n", 111 | "print('Local dataset size: {}'.format(train_images.shape[0]))\n", 112 | "\n", 113 | "train_images = train_images / 255.0\n", 114 | "test_images = test_images / 255.0\n", 115 | "\n", 116 | "is_chief = (FLAGS.task_index == 0)\n", 117 | "\n", 118 | "checkpoint_dir='logs_dir/federated_worker_{}/{}'.format(FLAGS.task_index, time())\n", 119 | "print('Checkpoint directory: ' + checkpoint_dir)\n", 120 | "\n", 121 | "worker_device = \"/job:worker/task:%d\" % FLAGS.task_index\n", 122 | "print('Worker device: ' + worker_device + ' - is_chief: {}'.format(is_chief))" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "Here we begin the definition of the graph in the same way as it was done in the basic classifier, we explicitly place every operation in the local worker. The rest is fairly standard until we reach the definition of the optimizer." 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "with tf.device(worker_device)\n", 139 | " global_step = tf.train.get_or_create_global_step()\n", 140 | "\n", 141 | " with tf.name_scope('dataset'), tf.device('/cpu:0'):\n", 142 | " images_placeholder = tf.placeholder(train_images.dtype, [None, train_images.shape[1], train_images.shape[2]], \n", 143 | " name='images_placeholder')\n", 144 | " labels_placeholder = tf.placeholder(train_labels.dtype, [None], name='labels_placeholder')\n", 145 | " batch_size = tf.placeholder(tf.int64, name='batch_size')\n", 146 | " shuffle_size = tf.placeholder(tf.int64, name='shuffle_size')\n", 147 | "\n", 148 | " dataset = tf.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder))\n", 149 | " dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True)\n", 150 | " dataset = dataset.repeat(EPOCHS)\n", 151 | " dataset = dataset.batch(batch_size)\n", 152 | " iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)\n", 153 | " dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')\n", 154 | " X, y = iterator.get_next()\n", 155 | "\n", 156 | " flatten_layer = tf.layers.flatten(X, name='flatten')\n", 157 | "\n", 158 | " dense_layer = tf.layers.dense(flatten_layer, 128, activation=tf.nn.relu, name='relu')\n", 159 | "\n", 160 | " predictions = tf.layers.dense(dense_layer, 10, activation=tf.nn.softmax, name='softmax')\n", 161 | "\n", 162 | " summary_averages = tf.train.ExponentialMovingAverage(0.9)\n", 163 | "\n", 164 | " with tf.name_scope('loss'):\n", 165 | " loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions))\n", 166 | " loss_averages_op = summary_averages.apply([loss])\n", 167 | " tf.summary.scalar('cross_entropy', summary_averages.average(loss))\n", 168 | "\n", 169 | " with tf.name_scope('accuracy'):\n", 170 | " with tf.name_scope('correct_prediction'):\n", 171 | " correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.cast(y, tf.int64))\n", 172 | " accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric')\n", 173 | " accuracy_averages_op = summary_averages.apply([accuracy])\n", 174 | " tf.summary.scalar('accuracy', summary_averages.average(accuracy))\n" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "We used the __replica_device_setter__ in the distributed learning to automatically choose in which device to place each defined op. Here we create it just to pass it as an argument to the custom optimizer that we have created to contain the logic of the federated averaging.\n", 182 | "\n", 183 | "This custom optimizer will use the __replica_device_setter__ to place a copy of each trainable variable in the ps, this new variables will store the averaged values of all the local models.\n", 184 | "\n", 185 | "Once this optimizer has been defined, we create the training operation and a, in the same way as we did with SyncReplicasOptimizer, a hook that will run inside the MonitoredTrainingSession, which handles the initialization." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | " with tf.name_scope('train'):\n", 195 | " device_setter = tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster)\n", 196 | " optimizer = federated_averaging_optimizer.FederatedAveragingOptimizer(\n", 197 | " tf.train.AdamOptimizer(np.sqrt(num_workers) * 0.001), \n", 198 | " replicas_to_aggregate=num_workers, interval_steps=INTERVAL_STEPS, is_chief=is_chief, \n", 199 | " device_setter=device_setter)\n", 200 | " with tf.control_dependencies([loss_averages_op, accuracy_averages_op]):\n", 201 | " train_op = optimizer.minimize(loss, global_step=global_step)\n", 202 | " model_average_hook = optimizer.make_session_run_hook()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "We keep defining our hooks as usual." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "n_batches = int(train_images.shape[0] / BATCH_SIZE)\n", 219 | "last_step = int(n_batches * EPOCHS)\n", 220 | "\n", 221 | "print('Graph definition finished')\n", 222 | "\n", 223 | "sess_config = tf.ConfigProto(\n", 224 | " allow_soft_placement=True,\n", 225 | " log_device_placement=False,\n", 226 | " operation_timeout_in_ms=20000,\n", 227 | " device_filters=[\"/job:ps\",\n", 228 | " \"/job:worker/task:%d\" % FLAGS.task_index])\n", 229 | "\n", 230 | "print('Training {} batches...'.format(last_step))\n", 231 | "\n", 232 | "class _LoggerHook(tf.train.SessionRunHook):\n", 233 | " def begin(self):\n", 234 | " self._total_loss = 0\n", 235 | " self._total_acc = 0\n", 236 | "\n", 237 | " def before_run(self, run_context):\n", 238 | " return tf.train.SessionRunArgs([loss, accuracy, global_step])\n", 239 | "\n", 240 | " def after_run(self, run_context, run_values):\n", 241 | " loss_value, acc_value, step_value = run_values.results\n", 242 | " self._total_loss += loss_value\n", 243 | " self._total_acc += acc_value\n", 244 | " if (step_value + 1) % n_batches == 0 and not step_value == 0:\n", 245 | " print(\"Epoch {}/{} - loss: {:.4f} - acc: {:.4f}\".format(\n", 246 | " int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches))\n", 247 | " self._total_loss = 0\n", 248 | " self._total_acc = 0\n", 249 | "\n", 250 | "class _InitHook(tf.train.SessionRunHook):\n", 251 | " def after_create_session(self, session, coord):\n", 252 | " session.run(dataset_init_op, feed_dict={\n", 253 | " images_placeholder: train_images, labels_placeholder: train_labels, \n", 254 | " batch_size: BATCH_SIZE, shuffle_size: train_images.shape[0]})" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "The shared variables generated within the custom optimizer get their initialized value from their corresponding trainable variables in the local worker. Therefore their initialization ops will be unavailable out of this session even if we try to restore a saved checkpoint.\n", 262 | "\n", 263 | "We need to define a custom saver which ignores this shared variables. In this case, we only save the trainable_variables ." 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "class _SaverHook(tf.train.SessionRunHook):\n", 273 | " def begin(self):\n", 274 | " self._saver = tf.train.Saver(tf.trainable_variables())\n", 275 | "\n", 276 | " def before_run(self, run_context):\n", 277 | " return tf.train.SessionRunArgs(global_step)\n", 278 | "\n", 279 | " def after_run(self, run_context, run_values):\n", 280 | " step_value = run_values.results\n", 281 | " if step_value % n_batches == 0 and not step_value == 0:\n", 282 | " self._saver.save(run_context.session, checkpoint_dir+'/model.ckpt', step_value)\n", 283 | "\n", 284 | " def end(self, session):\n", 285 | " self._saver.save(session, checkpoint_dir+'/model.ckpt', session.run(global_step))" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "The execution of the training session is standard. Notice the new hooks that we have added to the hook lists.\n", 293 | "\n", 294 | "WARNING! Do not define a chief worker. We need each worker to initialize their local session and train on its own!" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "with tf.name_scope('monitored_session'):\n", 304 | " with tf.train.MonitoredTrainingSession(\n", 305 | " master=server.target,\n", 306 | " checkpoint_dir=checkpoint_dir,\n", 307 | " hooks=[_LoggerHook(), _InitHook(), _SaverHook(), model_average_hook],\n", 308 | " config=sess_config,\n", 309 | " stop_grace_period_secs=10,\n", 310 | " save_checkpoint_secs=None) as mon_sess:\n", 311 | " while not mon_sess.should_stop():\n", 312 | " mon_sess.run(train_op)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "Finally, we evaluate the model." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "if is_chief:\n", 329 | " print('--- Begin Evaluation ---')\n", 330 | " tf.reset_default_graph()\n", 331 | " with tf.Session() as sess:\n", 332 | " ckpt = tf.train.get_checkpoint_state(checkpoint_dir)\n", 333 | " saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)\n", 334 | " saver.restore(sess, ckpt.model_checkpoint_path)\n", 335 | " print('Model restored')\n", 336 | " graph = tf.get_default_graph()\n", 337 | " images_placeholder = graph.get_tensor_by_name('dataset/images_placeholder:0')\n", 338 | " labels_placeholder = graph.get_tensor_by_name('dataset/labels_placeholder:0')\n", 339 | " batch_size = graph.get_tensor_by_name('dataset/batch_size:0')\n", 340 | " accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0')\n", 341 | " predictions = graph.get_tensor_by_name('softmax/BiasAdd:0')\n", 342 | " dataset_init_op = graph.get_operation_by_name('dataset/dataset_init')\n", 343 | " sess.run(dataset_init_op, feed_dict={\n", 344 | " images_placeholder: test_images, labels_placeholder: test_labels, \n", 345 | " batch_size: test_images.shape[0], shuffle_size: 1})\n", 346 | " print('Test accuracy: {:4f}'.format(sess.run(accuracy)))\n", 347 | " predicted = sess.run(predictions)\n", 348 | "\n", 349 | " # Plot the first 25 test images, their predicted label, and the true label\n", 350 | " # Color correct predictions in green, incorrect predictions in red\n", 351 | " plt.figure(figsize=(10, 10))\n", 352 | " for i in range(25):\n", 353 | " plt.subplot(5, 5, i + 1)\n", 354 | " plt.xticks([])\n", 355 | " plt.yticks([])\n", 356 | " plt.grid(False)\n", 357 | " plt.imshow(test_images[i], cmap=plt.cm.binary)\n", 358 | " predicted_label = np.argmax(predicted[i])\n", 359 | " true_label = test_labels[i]\n", 360 | " if predicted_label == true_label:\n", 361 | " color = 'green'\n", 362 | " else:\n", 363 | " color = 'red'\n", 364 | " plt.xlabel(\"{} ({})\".format(class_names[predicted_label],\n", 365 | " class_names[true_label]),\n", 366 | " color=color)\n", 367 | "\n", 368 | " plt.show(True)" 369 | ] 370 | } 371 | ], 372 | "metadata": { 373 | "kernelspec": { 374 | "display_name": "Python 3", 375 | "language": "python", 376 | "name": "python3" 377 | }, 378 | "language_info": { 379 | "codemirror_mode": { 380 | "name": "ipython", 381 | "version": 3 382 | }, 383 | "file_extension": ".py", 384 | "mimetype": "text/x-python", 385 | "name": "python", 386 | "nbconvert_exporter": "python", 387 | "pygments_lexer": "ipython3", 388 | "version": "3.6.7" 389 | } 390 | }, 391 | "nbformat": 4, 392 | "nbformat_minor": 2 393 | } 394 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # coMind: is Acuratio's open source project. 2 | 3 | [Acuratio_Logo](https://acuratio.com) 4 | 5 | Check out our Multicloud Platform at [acuratio.com](https://acuratio.com) 6 | 7 | This library is depecrated, contact us at hello@acuratio.com. Let us know what's your problem or use case and we'll get in touch with a privacy preserving solution. 8 | 9 | Federated averaging has a set of features that makes it perfect to train models in a collaborative way while preserving the privacy of sensitive data. In this repository you can learn how to start training ML models in a federated setup. 10 | 11 | drawing 12 | 13 | ## What can you expect to find here. 14 | 15 | We have developed a custom optimizer for TensorFlow to easily train neural networks in a federated way (NOTE: everytime we refer to federated here, we mean federated averaging). 16 | 17 | What is federated machine learning? In short, it is a step forward from distributed learning that can improve performance and training times. In our tutorials we explain in depth how it works, so we definitely encourage you to have a look! 18 | 19 | In addition to this custom optimizer, you can find some tutorials and examples to help you get started with TensorFlow and federated learning. From a basic training example, where all the steps of a local classification model are shown, to more elaborated distributed and federated learning setups. 20 | 21 | In this repository you will find 3 different types of files. 22 | 23 | - `federated_averaging_optimizer.py` which is the custom optimizer we have created to implement federated averaging in TensorFlow. 24 | 25 | - `basic_classifier.py`, `basic_distributed_classifier.py`, `basic_federated_classifier.py`, `advanced_classifier.py`, `advanced_distributed_classifier.py`, `advanced_federated_classifier.py` which are three basic and three advanced examples on how to train and evaluate TensorFlow models in a local, distributed and federated way. 26 | 27 | - `Basic Classifier.ipynb`, `Basic Distributed Classifier.ipynb`, `Basic Federated Classifier.ipynb` which are three IPython Notebooks where you can find the three basic examples named above and in depth documentation to walk you through. 28 | 29 | ## Installation dependencies 30 | 31 | - Python 3 32 | - TensorFlow 33 | - matplotlib (for the examples and tutorials) 34 | 35 | ## Usage 36 | 37 | Download and open the notebooks with Jupyter or Google Colab. The notebook with the local training example `Basic Classifier.ipynb` and the python scripts `basic_classifier.py` and `advanced_classifier.py` can be run right away. For the others you will need to open three different shells. One of them will be executing the parameter server and the other two the workers. 38 | 39 | For example, to run the `basic_distributed_classifier.py`: 40 | 41 | * 1st shell command should look like this: `python3 basic_distributed_classifier.py --job_name=ps --task_index=0` 42 | 43 | * 2nd shell: `python3 basic_distributed_classifier.py --job_name=worker --task_index=0` 44 | 45 | * 3rd shell: `python3 basic_distributed_classifier.py --job_name=worker --task_index=1` 46 | 47 | Follow the same steps for the `basic_federated_classifier.py`, `advanced_distributed_classifier.py` and `advanced_federated_classifier.py`. 48 | 49 | ### Colab Notebooks 50 | 51 | * [Basic Classifier](https://colab.research.google.com/drive/1hJ6UhELZ9sK3eX2_c-MamjxNt4gzgCis) 52 | * [Basic Distributed Classifier](https://colab.research.google.com/drive/1ZsSOD_J9aFRL4xACVUw0lau0Bc9IPD-C) 53 | * [Basic Federated Classifier](https://colab.research.google.com/drive/1zMNAJlqnNSziKYECTWhPyj4HSzg1g8sx) 54 | 55 | ## Additional resources 56 | 57 | Check [MPI](https://github.com/coMindOrg/federated-averaging-tutorials/tree/master/federated-MPI) to find an implementation of Federated Averaging with [Message Passing Interface](https://www.mpich.org/). This takes the communication out of TensorFlow and averages the weights with a custom hook. 58 | 59 | Check [sockets](https://github.com/coMindOrg/federated-averaging-tutorials/tree/master/federated-sockets) to find an implementation with python sockets. The same idea as with MPI but in this case we only need to know the public IP of the chief worker, and a custom hook will take care of the synchronization for us! 60 | 61 | Check [this](https://github.com/coMindOrg/federated-averaging-tutorials/tree/master/federated-keras) to see an easier implementation with keras! 62 | 63 | Check [this script](https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py) to see how to generate CIFAR-10 TFRecords. 64 | 65 | ## Troubleshooting and Help 66 | 67 | coMind has public Slack and Telegram channels which are a great place to ask questions and all things related to federated machine learning. 68 | 69 | ## Bugs and Issues 70 | 71 | Have a bug or an issue? [Open a new issue](https://github.com/coMindOrg/federated-averaging-tutorials/issues) here on GitHub or join our community in Slack or Telegram. 72 | 73 | *[Click here to join the Slack channel!](https://comindorg.slack.com/join/shared_invite/enQtNDMxMzc0NDA5OTEwLWIyZTg5MTg1MTM4NjhiNDM4YTU1OTI1NTgwY2NkNzZjYWY1NmI0ZjIyNWJiMTNkZmRhZDg2Nzc3YTYyNGQzM2I)* 74 | 75 | *[Click here to join the Telegram channel!](https://t.me/comind)* 76 | 77 | ## References 78 | 79 | The Federated Averaging algorithm is explained in more detail in the following paper: 80 | 81 | H. B. McMahan, E. Moore, D. Ramage, S. Hampson, and B. A. y Arcas. [Communication-efficient learning of deep networks from decentralized data](https://arxiv.org/pdf/1602.05629.pdf). In Conference on Artificial Intelligence and Statistics, 2017. 82 | 83 | The datsets used in these examples were: 84 | 85 | Alex Krizhevsky. [Learning Multiple Layers of Features from Tiny Images](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf). 86 | 87 | Han Xiao, Kashif Rasul, Roland Vollgraf. [Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms](https://arxiv.org/abs/1708.07747). 88 | 89 | ## About 90 | 91 | coMind is an open source project for training privacy-preserving federated deep learning models. 92 | 93 | * https://comind.org/ 94 | * [Twitter](https://twitter.com/coMindOrg) 95 | -------------------------------------------------------------------------------- /advanced_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow 19 | import tensorflow as tf 20 | 21 | # Helper libraries 22 | import os 23 | import numpy as np 24 | from time import time 25 | import multiprocessing 26 | 27 | # You can safely tune these variables 28 | BATCH_SIZE = 128 29 | SHUFFLE_SIZE = BATCH_SIZE * 100 30 | EPOCHS = 250 31 | EPOCHS_PER_DECAY = 50 32 | BATCHES_TO_PREFETCH = 1 33 | # ---------------- 34 | 35 | # Dataset dependent constants 36 | num_train_images = 50000 37 | num_test_images = 10000 38 | height = 32 39 | width = 32 40 | channels = 3 41 | num_batch_files = 5 42 | 43 | # Path to TFRecord files (check readme for instructions on how to get these files) 44 | cifar10_train_files = ['cifar-10-tf-records/train{}.tfrecords'.format(i) for i in range(num_batch_files)] 45 | cifar10_test_file = 'cifar-10-tf-records/test.tfrecords' 46 | 47 | # Shuffle filenames before loading them 48 | np.random.shuffle(cifar10_train_files) 49 | 50 | checkpoint_dir='logs_dir/{}'.format(time()) 51 | print('Checkpoint directory: ' + checkpoint_dir) 52 | 53 | global_step = tf.train.get_or_create_global_step() 54 | 55 | # Check number of available CPUs 56 | cpu_count = multiprocessing.cpu_count() 57 | 58 | # Define input pipeline, place these ops in the cpu 59 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 60 | # Map function to decode data and preprocess it 61 | def preprocess(serialized_examples): 62 | # Parse a batch 63 | features = tf.parse_example(serialized_examples, {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) 64 | # Decode and reshape image 65 | image = tf.map_fn(lambda img: tf.reshape(tf.decode_raw(img, tf.uint8), tf.stack([height, width, channels])), features['image'], dtype=tf.uint8, name='decode') 66 | # Cast image 67 | casted_image = tf.cast(image, tf.float32, name='input_cast') 68 | # Resize image for testing 69 | resized_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24) 70 | # Augment images for training 71 | distorted_image = tf.map_fn(lambda img: tf.random_crop(img, [24, 24, 3]), casted_image, name='random_crop') 72 | distorted_image = tf.image.random_flip_left_right(distorted_image) 73 | distorted_image = tf.image.random_brightness(distorted_image, 63) 74 | distorted_image = tf.image.random_contrast(distorted_image, 0.2, 1.8) 75 | # Check if test or train mode 76 | result = tf.cond(train_mode, lambda: distorted_image, lambda: resized_image) 77 | # Standardize images 78 | processed_image = tf.map_fn(lambda img: tf.image.per_image_standardization(img), result, name='standardization') 79 | return processed_image, features['label'] 80 | # Placeholders for the iterator 81 | filename_placeholder = tf.placeholder(tf.string, name='input_filename') 82 | batch_size = tf.placeholder(tf.int64, name='batch_size') 83 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 84 | train_mode = tf.placeholder(tf.bool, name='train_mode') 85 | 86 | # Create dataset, shuffle, repeat, batch, map and prefetch 87 | dataset = tf.data.TFRecordDataset(filename_placeholder) 88 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 89 | dataset = dataset.repeat(EPOCHS) 90 | dataset = dataset.batch(batch_size) 91 | dataset = dataset.map(preprocess, cpu_count) 92 | dataset = dataset.prefetch(BATCHES_TO_PREFETCH) 93 | # Define a feedable iterator and the initialization op 94 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 95 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 96 | X, y = iterator.get_next() 97 | 98 | # Define our model 99 | first_conv = tf.layers.conv2d(X, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='first_conv') 100 | 101 | first_pool = tf.nn.max_pool(first_conv, [1, 3, 3 ,1], [1, 2, 2, 1], padding='SAME', name='first_pool') 102 | 103 | first_norm = tf.nn.lrn(first_pool, 4, alpha=0.001 / 9.0, beta=0.75, name='first_norm') 104 | 105 | second_conv = tf.layers.conv2d(first_norm, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='second_conv') 106 | 107 | second_norm = tf.nn.lrn(second_conv, 4, alpha=0.001 / 9.0, beta=0.75, name='second_norm') 108 | 109 | second_pool = tf.nn.max_pool(second_norm, [1, 3, 3, 1], [1, 2, 2, 1], padding='SAME', name='second_pool') 110 | 111 | flatten_layer = tf.layers.flatten(second_pool, name='flatten') 112 | 113 | first_relu = tf.layers.dense(flatten_layer, 384, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='first_relu') 114 | 115 | second_relu = tf.layers.dense(first_relu, 192, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='second_relu') 116 | 117 | logits = tf.layers.dense(second_relu, 10, kernel_initializer=tf.truncated_normal_initializer(stddev=1/192.0), name='logits') 118 | 119 | # Object to keep moving averages of our metrics (for tensorboard) 120 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 121 | 122 | # Define cross_entropy loss 123 | with tf.name_scope('loss'): 124 | base_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits), name='base_loss') 125 | # Add regularization loss to both relu layers 126 | regularizer_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'relu/kernel' in v.name], name='regularizer_loss') * 0.004 127 | loss = tf.add(base_loss, regularizer_loss) 128 | loss_averages_op = summary_averages.apply([loss]) 129 | # Store moving average of the loss 130 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 131 | 132 | with tf.name_scope('accuracy'): 133 | with tf.name_scope('correct_prediction'): 134 | # Compare prediction with actual label 135 | correct_prediction = tf.equal(tf.argmax(logits, 1), y) 136 | # Average correct predictions in the current batch 137 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 138 | accuracy_averages_op = summary_averages.apply([accuracy]) 139 | # Store moving average of the accuracy 140 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 141 | 142 | n_batches = int(num_train_images / BATCH_SIZE) 143 | last_step = int(n_batches * EPOCHS) 144 | 145 | # Define moving averages of the trainable variables. This sometimes improve 146 | # the performance of the trained model 147 | with tf.name_scope('variable_averages'): 148 | variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) 149 | variable_averages_op = variable_averages.apply(tf.trainable_variables()) 150 | 151 | # Define optimizer and training op 152 | with tf.name_scope('train'): 153 | # Make decaying learning rate 154 | lr = tf.train.exponential_decay(0.1, global_step, n_batches * EPOCHS_PER_DECAY, 0.1, staircase=True) 155 | tf.summary.scalar('learning_rate', lr) 156 | # Make train_op dependent on moving averages ops. Otherwise they will be 157 | # disconnected from the graph 158 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op, variable_averages_op]): 159 | train_op = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step) 160 | 161 | print('Graph definition finished') 162 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 163 | 164 | print('Training {} batches...'.format(last_step)) 165 | 166 | # Logger hook to keep track of the training 167 | class _LoggerHook(tf.train.SessionRunHook): 168 | def begin(self): 169 | self._total_loss = 0 170 | self._total_acc = 0 171 | 172 | def before_run(self, run_context): 173 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 174 | 175 | def after_run(self, run_context, run_values): 176 | loss_value, acc_value, step_value = run_values.results 177 | self._total_loss += loss_value 178 | self._total_acc += acc_value 179 | if (step_value + 1) % n_batches == 0: 180 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 181 | self._total_loss = 0 182 | self._total_acc = 0 183 | 184 | # Hook to initialize the dataset 185 | class _InitHook(tf.train.SessionRunHook): 186 | def after_create_session(self, session, coord): 187 | session.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_train_files, batch_size: BATCH_SIZE, shuffle_size: SHUFFLE_SIZE, train_mode: True}) 188 | 189 | with tf.name_scope('monitored_session'): 190 | with tf.train.MonitoredTrainingSession( 191 | checkpoint_dir=checkpoint_dir, 192 | hooks=[_LoggerHook(), _InitHook(), tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir, save_steps=n_batches, saver=tf.train.Saver(variable_averages.variables_to_restore()))], 193 | config=sess_config, 194 | save_checkpoint_secs=None) as mon_sess: 195 | while not mon_sess.should_stop(): 196 | mon_sess.run(train_op) 197 | 198 | print('--- Begin Evaluation ---') 199 | # Reset graph and place ops in cpu to avoid OOM 200 | tf.reset_default_graph() 201 | with tf.device('/cpu:0'), tf.Session() as sess: 202 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 203 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 204 | saver.restore(sess, ckpt.model_checkpoint_path) 205 | print('Model restored') 206 | graph = tf.get_default_graph() 207 | filename_placeholder = graph.get_tensor_by_name('dataset/input_filename:0') 208 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 209 | shuffle_size = graph.get_tensor_by_name('dataset/shuffle_size:0') 210 | train_mode = graph.get_tensor_by_name('dataset/train_mode:0') 211 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 212 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 213 | sess.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_test_file, batch_size: num_test_images, shuffle_size: 1, train_mode: False}) 214 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 215 | -------------------------------------------------------------------------------- /advanced_distributed_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow 19 | import tensorflow as tf 20 | 21 | # Helper libraries 22 | import os 23 | import numpy as np 24 | from time import time 25 | import multiprocessing 26 | 27 | flags = tf.app.flags 28 | flags.DEFINE_integer("task_index", None, 29 | "Worker task index, should be >= 0. task_index=0 is " 30 | "the master worker task that performs the variable " 31 | "initialization ") 32 | flags.DEFINE_string("ps_hosts", "localhost:2222", 33 | "Comma-separated list of hostname:port pairs") 34 | flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", 35 | "Comma-separated list of hostname:port pairs") 36 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 37 | 38 | # You can safely tune these variables 39 | BATCH_SIZE = 128 40 | SHUFFLE_SIZE = BATCH_SIZE * 100 41 | EPOCHS = 250 42 | EPOCHS_PER_DECAY = 50 43 | BATCHES_TO_PREFETCH = 1 44 | # ---------------- 45 | 46 | FLAGS = flags.FLAGS 47 | 48 | if FLAGS.job_name is None or FLAGS.job_name == "": 49 | raise ValueError("Must specify an explicit `job_name`") 50 | if FLAGS.task_index is None or FLAGS.task_index == "": 51 | raise ValueError("Must specify an explicit `task_index`") 52 | 53 | # Only enable GPU for worker 1 (not needed if training with separate machines) 54 | if FLAGS.task_index == 0: 55 | print('--- GPU Disabled ---') 56 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 57 | 58 | #Construct the cluster and start the server 59 | ps_spec = FLAGS.ps_hosts.split(",") 60 | worker_spec = FLAGS.worker_hosts.split(",") 61 | 62 | # Get the number of workers. 63 | num_workers = len(worker_spec) 64 | print('{} workers defined'.format(num_workers)) 65 | 66 | # Dataset dependent constants 67 | num_train_images = int(50000 / num_workers) 68 | num_test_images = 10000 69 | height = 32 70 | width = 32 71 | channels = 3 72 | num_batch_files = 5 73 | 74 | cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) 75 | 76 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 77 | 78 | # ps will block here 79 | if FLAGS.job_name == "ps": 80 | print('--- Parameter Server Ready ---') 81 | server.join() 82 | 83 | # Path to TFRecord files (check readme for instructions on how to get these files) 84 | cifar10_train_files = ['cifar-10-tf-records/train{}.tfrecords'.format(i) for i in range(num_batch_files)] 85 | cifar10_test_file = 'cifar-10-tf-records/test.tfrecords' 86 | 87 | # Shuffle filenames before loading them 88 | np.random.shuffle(cifar10_train_files) 89 | 90 | is_chief = (FLAGS.task_index == 0) 91 | 92 | checkpoint_dir='logs_dir/{}'.format(time()) 93 | print('Checkpoint directory: ' + checkpoint_dir) 94 | 95 | worker_device = "/job:worker/task:%d" % FLAGS.task_index 96 | print('Worker device: ' + worker_device + ' - is_chief: {}'.format(is_chief)) 97 | 98 | # Check number of available CPUs 99 | cpu_count = int(multiprocessing.cpu_count() / num_workers) 100 | 101 | # replica device setter will place ops in the appropriate devices 102 | with tf.device( 103 | tf.train.replica_device_setter( 104 | worker_device=worker_device, 105 | cluster=cluster)): 106 | global_step = tf.train.get_or_create_global_step() 107 | 108 | # Define input pipeline, place these ops in the cpu 109 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 110 | # Map function to decode data and preprocess it 111 | def preprocess(serialized_examples): 112 | # Parse a batch 113 | features = tf.parse_example(serialized_examples, {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) 114 | # Decode and reshape image 115 | image = tf.map_fn(lambda img: tf.reshape(tf.decode_raw(img, tf.uint8), tf.stack([height, width, channels])), features['image'], dtype=tf.uint8, name='decode') 116 | # Cast image 117 | casted_image = tf.cast(image, tf.float32, name='input_cast') 118 | # Resize image for testing 119 | resized_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24) 120 | # Augment images for training 121 | distorted_image = tf.map_fn(lambda img: tf.random_crop(img, [24, 24, 3]), casted_image, name='random_crop') 122 | distorted_image = tf.image.random_flip_left_right(distorted_image) 123 | distorted_image = tf.image.random_brightness(distorted_image, 63) 124 | distorted_image = tf.image.random_contrast(distorted_image, 0.2, 1.8) 125 | # Check if test or train mode 126 | result = tf.cond(train_mode, lambda: distorted_image, lambda: resized_image) 127 | # Standardize images 128 | processed_image = tf.map_fn(lambda img: tf.image.per_image_standardization(img), result, name='standardization') 129 | return processed_image, features['label'] 130 | # Placeholders for the iterator 131 | filename_placeholder = tf.placeholder(tf.string, name='input_filename') 132 | batch_size = tf.placeholder(tf.int64, name='batch_size') 133 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 134 | train_mode = tf.placeholder(tf.bool, name='train_mode') 135 | 136 | # Create dataset, shuffle, repeat, batch, map and prefetch 137 | dataset = tf.data.TFRecordDataset(filename_placeholder) 138 | dataset = dataset.shard(num_workers, FLAGS.task_index) 139 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 140 | dataset = dataset.repeat(EPOCHS) 141 | dataset = dataset.batch(batch_size) 142 | dataset = dataset.map(preprocess, cpu_count) 143 | dataset = dataset.prefetch(BATCHES_TO_PREFETCH) 144 | # Define a feedable iterator and the initialization op 145 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 146 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 147 | X, y = iterator.get_next() 148 | 149 | # Define our model 150 | first_conv = tf.layers.conv2d(X, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='first_conv') 151 | 152 | first_pool = tf.nn.max_pool(first_conv, [1, 3, 3 ,1], [1, 2, 2, 1], padding='SAME', name='first_pool') 153 | 154 | first_norm = tf.nn.lrn(first_pool, 4, alpha=0.001 / 9.0, beta=0.75, name='first_norm') 155 | 156 | second_conv = tf.layers.conv2d(first_norm, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='second_conv') 157 | 158 | second_norm = tf.nn.lrn(second_conv, 4, alpha=0.001 / 9.0, beta=0.75, name='second_norm') 159 | 160 | second_pool = tf.nn.max_pool(second_norm, [1, 3, 3, 1], [1, 2, 2, 1], padding='SAME', name='second_pool') 161 | 162 | flatten_layer = tf.layers.flatten(second_pool, name='flatten') 163 | 164 | first_relu = tf.layers.dense(flatten_layer, 384, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='first_relu') 165 | 166 | second_relu = tf.layers.dense(first_relu, 192, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='second_relu') 167 | 168 | logits = tf.layers.dense(second_relu, 10, kernel_initializer=tf.truncated_normal_initializer(stddev=1/192.0), name='logits') 169 | 170 | # Object to keep moving averages of our metrics (for tensorboard) 171 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 172 | n_batches = int(num_train_images / (BATCH_SIZE * num_workers)) 173 | 174 | # Define cross_entropy loss 175 | with tf.name_scope('loss'): 176 | base_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits), name='base_loss') 177 | # Add regularization loss to both relu layers 178 | regularizer_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'relu/kernel' in v.name], name='regularizer_loss') * 0.004 179 | loss = tf.add(base_loss, regularizer_loss) 180 | loss_averages_op = summary_averages.apply([loss]) 181 | # Store moving average of the loss 182 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 183 | 184 | with tf.name_scope('accuracy'): 185 | with tf.name_scope('correct_prediction'): 186 | # Compare prediction with actual label 187 | correct_prediction = tf.equal(tf.argmax(logits, 1), y) 188 | # Average correct predictions in the current batch 189 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 190 | accuracy_averages_op = summary_averages.apply([accuracy]) 191 | # Store moving average of the accuracy 192 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 193 | 194 | # Define moving averages of the trainable variables. This sometimes improve 195 | # the performance of the trained model 196 | with tf.name_scope('variable_averages'): 197 | variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) 198 | variable_averages_op = variable_averages.apply(tf.trainable_variables()) 199 | 200 | # Define optimizer and training op 201 | with tf.name_scope('train'): 202 | # Make decaying learning rate 203 | lr = tf.train.exponential_decay(0.1, global_step, n_batches * EPOCHS_PER_DECAY, 0.1, staircase=True) 204 | tf.summary.scalar('learning_rate', lr) 205 | # Wrap the optimizer in a SyncReplicasOptimizer for distributed training 206 | optimizer = tf.train.SyncReplicasOptimizer(tf.train.GradientDescentOptimizer(np.sqrt(num_workers) * lr), replicas_to_aggregate=num_workers) 207 | # Make train_op dependent on moving averages ops. Otherwise they will be 208 | # disconnected from the graph 209 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op, variable_averages_op]): 210 | train_op = optimizer.minimize(loss, global_step=global_step) 211 | sync_replicas_hook = optimizer.make_session_run_hook(is_chief) 212 | 213 | print('Graph definition finished') 214 | 215 | last_step = int(n_batches * EPOCHS) 216 | 217 | sess_config = tf.ConfigProto( 218 | allow_soft_placement=True, 219 | log_device_placement=False, 220 | device_filters=["/job:ps", 221 | "/job:worker/task:%d" % FLAGS.task_index]) 222 | 223 | print('Training {} batches...'.format(last_step)) 224 | 225 | # Logger hook to keep track of the training 226 | class _LoggerHook(tf.train.SessionRunHook): 227 | def begin(self): 228 | self._total_loss = 0 229 | self._total_acc = 0 230 | 231 | def before_run(self, run_context): 232 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 233 | 234 | def after_run(self, run_context, run_values): 235 | loss_value, acc_value, step_value = run_values.results 236 | self._total_loss += loss_value 237 | self._total_acc += acc_value 238 | if (step_value + 1) % n_batches == 0: 239 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 240 | self._total_loss = 0 241 | self._total_acc = 0 242 | 243 | def end(self, session): 244 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(session.run(global_step) / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 245 | 246 | # Hook to initialize the dataset 247 | class _InitHook(tf.train.SessionRunHook): 248 | def after_create_session(self, session, coord): 249 | session.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_train_files, batch_size: BATCH_SIZE, shuffle_size: SHUFFLE_SIZE, train_mode: True}) 250 | 251 | with tf.name_scope('monitored_session'): 252 | with tf.train.MonitoredTrainingSession( 253 | master=server.target, 254 | is_chief=is_chief, 255 | checkpoint_dir=checkpoint_dir, 256 | hooks=[_LoggerHook(), _InitHook(), sync_replicas_hook], 257 | chief_only_hooks=[tf.train.CheckpointSaverHook(checkpoint_dir=checkpoint_dir, save_steps=n_batches, saver=tf.train.Saver(variable_averages.variables_to_restore()))], 258 | config=sess_config, 259 | stop_grace_period_secs=10, 260 | save_checkpoint_secs=None) as mon_sess: 261 | while not mon_sess.should_stop(): 262 | mon_sess.run(train_op) 263 | 264 | if is_chief: 265 | print('--- Begin Evaluation ---') 266 | # Reset graph to clear any ops stored in other devices 267 | tf.reset_default_graph() 268 | with tf.Session() as sess: 269 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 270 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 271 | saver.restore(sess, ckpt.model_checkpoint_path) 272 | print('Model restored') 273 | graph = tf.get_default_graph() 274 | filename_placeholder = graph.get_tensor_by_name('dataset/input_filename:0') 275 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 276 | shuffle_size = graph.get_tensor_by_name('dataset/shuffle_size:0') 277 | train_mode = graph.get_tensor_by_name('dataset/train_mode:0') 278 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 279 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 280 | sess.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_test_file, batch_size: num_test_images, shuffle_size: 1, train_mode: False}) 281 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 282 | -------------------------------------------------------------------------------- /advanced_federated_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow 19 | import tensorflow as tf 20 | 21 | # Helper libraries 22 | import os 23 | import numpy as np 24 | from time import time 25 | import multiprocessing 26 | 27 | # Import custom optimizer 28 | import federated_averaging_optimizer 29 | 30 | flags = tf.app.flags 31 | flags.DEFINE_integer("task_index", None, 32 | "Worker task index, should be >= 0. task_index=0 is " 33 | "the master worker task that performs the variable " 34 | "initialization ") 35 | flags.DEFINE_string("ps_hosts", "localhost:2222", 36 | "Comma-separated list of hostname:port pairs") 37 | flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", 38 | "Comma-separated list of hostname:port pairs") 39 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 40 | 41 | # You can safely tune these variables 42 | BATCH_SIZE = 128 43 | SHUFFLE_SIZE = BATCH_SIZE * 100 44 | EPOCHS = 250 45 | EPOCHS_PER_DECAY = 50 46 | INTERVAL_STEPS = 100 47 | BATCHES_TO_PREFETCH = 1 48 | # ---------------- 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | if FLAGS.job_name is None or FLAGS.job_name == "": 53 | raise ValueError("Must specify an explicit `job_name`") 54 | if FLAGS.task_index is None or FLAGS.task_index == "": 55 | raise ValueError("Must specify an explicit `task_index`") 56 | 57 | # Only enable GPU for worker 1 (not needed if training with separate machines) 58 | if FLAGS.task_index == 0: 59 | print('--- GPU Disabled ---') 60 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 61 | 62 | #Construct the cluster and start the server 63 | ps_spec = FLAGS.ps_hosts.split(",") 64 | worker_spec = FLAGS.worker_hosts.split(",") 65 | 66 | # Get the number of workers. 67 | num_workers = len(worker_spec) 68 | print('{} workers defined'.format(num_workers)) 69 | 70 | # Dataset dependent constants 71 | num_train_images = int(50000 / num_workers) 72 | num_test_images = 10000 73 | height = 32 74 | width = 32 75 | channels = 3 76 | num_batch_files = 5 77 | 78 | cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) 79 | 80 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 81 | 82 | # ps will block here 83 | if FLAGS.job_name == "ps": 84 | print('--- Parameter Server Ready ---') 85 | server.join() 86 | 87 | # Path to TFRecord files (check readme for instructions on how to get these files) 88 | cifar10_train_files = ['cifar-10-tf-records/train{}.tfrecords'.format(i) for i in range(num_batch_files)] 89 | cifar10_test_file = 'cifar-10-tf-records/test.tfrecords' 90 | 91 | # Shuffle filenames before loading them 92 | np.random.shuffle(cifar10_train_files) 93 | 94 | is_chief = (FLAGS.task_index == 0) 95 | 96 | checkpoint_dir='logs_dir/federated_worker_{}/{}'.format(FLAGS.task_index, time()) 97 | print('Checkpoint directory: ' + checkpoint_dir) 98 | 99 | worker_device = "/job:worker/task:%d" % FLAGS.task_index 100 | print('Worker device: ' + worker_device + ' - is_chief: {}'.format(is_chief)) 101 | 102 | # Check number of available CPUs 103 | cpu_count = int(multiprocessing.cpu_count() / num_workers) 104 | 105 | # Place all ops in local worker by default 106 | with tf.device(worker_device): 107 | global_step = tf.train.get_or_create_global_step() 108 | 109 | # Define input pipeline, place these ops in the cpu 110 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 111 | # Map function to decode data and preprocess it 112 | def preprocess(serialized_examples): 113 | # Parse a batch 114 | features = tf.parse_example(serialized_examples, {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) 115 | # Decode and reshape imag 116 | image = tf.map_fn(lambda img: tf.reshape(tf.decode_raw(img, tf.uint8), tf.stack([height, width, channels])), features['image'], dtype=tf.uint8, name='decode') 117 | # Cast image 118 | casted_image = tf.cast(image, tf.float32, name='input_cast') 119 | # Resize image for testing 120 | resized_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24) 121 | # Augment images for training 122 | distorted_image = tf.map_fn(lambda img: tf.random_crop(img, [24, 24, 3]), casted_image, name='random_crop') 123 | distorted_image = tf.image.random_flip_left_right(distorted_image) 124 | distorted_image = tf.image.random_brightness(distorted_image, 63) 125 | distorted_image = tf.image.random_contrast(distorted_image, 0.2, 1.8) 126 | # Check if test or train mode 127 | result = tf.cond(train_mode, lambda: distorted_image, lambda: resized_image) 128 | # Standardize images 129 | processed_image = tf.map_fn(lambda img: tf.image.per_image_standardization(img), result, name='standardization') 130 | return processed_image, features['label'] 131 | # Placeholders for the iterator 132 | filename_placeholder = tf.placeholder(tf.string, name='input_filename') 133 | batch_size = tf.placeholder(tf.int64, name='batch_size') 134 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 135 | train_mode = tf.placeholder(tf.bool, name='train_mode') 136 | 137 | # Create dataset, shuffle, repeat, batch, map and prefetch 138 | dataset = tf.data.TFRecordDataset(filename_placeholder) 139 | dataset = dataset.shard(num_workers, FLAGS.task_index) 140 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 141 | dataset = dataset.repeat(EPOCHS) 142 | dataset = dataset.batch(batch_size) 143 | dataset = dataset.map(preprocess, cpu_count) 144 | dataset = dataset.prefetch(BATCHES_TO_PREFETCH) 145 | # Define a feedable iterator and the initialization op 146 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 147 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 148 | X, y = iterator.get_next() 149 | 150 | # Define our model 151 | first_conv = tf.layers.conv2d(X, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='first_conv') 152 | 153 | first_pool = tf.nn.max_pool(first_conv, [1, 3, 3 ,1], [1, 2, 2, 1], padding='SAME', name='first_pool') 154 | 155 | first_norm = tf.nn.lrn(first_pool, 4, alpha=0.001 / 9.0, beta=0.75, name='first_norm') 156 | 157 | second_conv = tf.layers.conv2d(first_norm, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='second_conv') 158 | 159 | second_norm = tf.nn.lrn(second_conv, 4, alpha=0.001 / 9.0, beta=0.75, name='second_norm') 160 | 161 | second_pool = tf.nn.max_pool(second_norm, [1, 3, 3, 1], [1, 2, 2, 1], padding='SAME', name='second_pool') 162 | 163 | flatten_layer = tf.layers.flatten(second_pool, name='flatten') 164 | 165 | first_relu = tf.layers.dense(flatten_layer, 384, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='first_relu') 166 | 167 | second_relu = tf.layers.dense(first_relu, 192, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='second_relu') 168 | 169 | logits = tf.layers.dense(second_relu, 10, kernel_initializer=tf.truncated_normal_initializer(stddev=1/192.0), name='logits') 170 | 171 | # Object to keep moving averages of our metrics (for tensorboard) 172 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 173 | n_batches = int(num_train_images / BATCH_SIZE) 174 | 175 | # Define cross_entropy loss 176 | with tf.name_scope('loss'): 177 | base_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits), name='base_loss') 178 | # Add regularization loss to both relu layers 179 | regularizer_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'relu/kernel' in v.name], name='regularizer_loss') * 0.004 180 | loss = tf.add(base_loss, regularizer_loss) 181 | loss_averages_op = summary_averages.apply([loss]) 182 | # Store moving average of the loss 183 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 184 | 185 | with tf.name_scope('accuracy'): 186 | with tf.name_scope('correct_prediction'): 187 | # Compare prediction with actual label 188 | correct_prediction = tf.equal(tf.argmax(logits, 1), y) 189 | # Average correct predictions in the current batch 190 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 191 | accuracy_averages_op = summary_averages.apply([accuracy]) 192 | # Store moving average of the accuracy 193 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 194 | 195 | # Define moving averages of the trainable variables. This sometimes improve 196 | # the performance of the trained model 197 | with tf.name_scope('variable_averages'): 198 | variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) 199 | variable_averages_op = variable_averages.apply(tf.trainable_variables()) 200 | 201 | # Define optimizer and training op 202 | with tf.name_scope('train'): 203 | # Define device setter to place copies of local variables 204 | device_setter = tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster) 205 | # Make decaying learning rate 206 | lr = tf.train.exponential_decay(0.1, global_step, n_batches * EPOCHS_PER_DECAY, 0.1, staircase=True) 207 | tf.summary.scalar('learning_rate', lr) 208 | # Wrap the optimizer in a FederatedAveragingOptimizer for federated training 209 | optimizer = federated_averaging_optimizer.FederatedAveragingOptimizer(tf.train.GradientDescentOptimizer(lr), replicas_to_aggregate=num_workers, interval_steps=INTERVAL_STEPS, is_chief=is_chief, device_setter=device_setter) 210 | # Make train_op dependent on moving averages ops. Otherwise they will be 211 | # disconnected from the graph 212 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op, variable_averages_op]): 213 | train_op = optimizer.minimize(loss, global_step=global_step) 214 | model_average_hook = optimizer.make_session_run_hook() 215 | 216 | print('Graph definition finished') 217 | 218 | last_step = int(n_batches * EPOCHS) 219 | 220 | sess_config = tf.ConfigProto( 221 | allow_soft_placement=True, 222 | log_device_placement=False, 223 | device_filters=["/job:ps", 224 | "/job:worker/task:%d" % FLAGS.task_index]) 225 | 226 | print('Training {} batches...'.format(last_step)) 227 | 228 | # Logger hook to keep track of the training 229 | class _LoggerHook(tf.train.SessionRunHook): 230 | def begin(self): 231 | self._total_loss = 0 232 | self._total_acc = 0 233 | 234 | def before_run(self, run_context): 235 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 236 | 237 | def after_run(self, run_context, run_values): 238 | loss_value, acc_value, step_value = run_values.results 239 | self._total_loss += loss_value 240 | self._total_acc += acc_value 241 | if (step_value + 1) % n_batches == 0: 242 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 243 | self._total_loss = 0 244 | self._total_acc = 0 245 | 246 | # Hook to initialize the dataset 247 | class _InitHook(tf.train.SessionRunHook): 248 | def after_create_session(self, session, coord): 249 | session.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_train_files, batch_size: BATCH_SIZE, shuffle_size: SHUFFLE_SIZE, train_mode: True}) 250 | 251 | # Hook to save just trainable_variables 252 | class _SaverHook(tf.train.SessionRunHook): 253 | def begin(self): 254 | self._saver = tf.train.Saver(variable_averages.variables_to_restore()) 255 | 256 | def before_run(self, run_context): 257 | return tf.train.SessionRunArgs(global_step) 258 | 259 | def after_run(self, run_context, run_values): 260 | step_value = run_values.results 261 | if step_value % n_batches == 0 and not step_value == 0: 262 | self._saver.save(run_context.session, checkpoint_dir+'/model.ckpt', step_value) 263 | 264 | def end(self, session): 265 | self._saver.save(session, checkpoint_dir+'/model.ckpt', session.run(global_step)) 266 | 267 | # Make sure we do not define a chief worker 268 | with tf.name_scope('monitored_session'): 269 | with tf.train.MonitoredTrainingSession( 270 | master=server.target, 271 | checkpoint_dir=checkpoint_dir, 272 | hooks=[_LoggerHook(), _InitHook(), _SaverHook(), model_average_hook], 273 | config=sess_config, 274 | stop_grace_period_secs=10, 275 | save_checkpoint_secs=None) as mon_sess: 276 | while not mon_sess.should_stop(): 277 | mon_sess.run(train_op) 278 | 279 | if is_chief: 280 | print('--- Begin Evaluation ---') 281 | # Reset graph to clear any ops stored in other devices 282 | tf.reset_default_graph() 283 | with tf.Session() as sess: 284 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 285 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 286 | saver.restore(sess, ckpt.model_checkpoint_path) 287 | print('Model restored') 288 | graph = tf.get_default_graph() 289 | filename_placeholder = graph.get_tensor_by_name('dataset/input_filename:0') 290 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 291 | shuffle_size = graph.get_tensor_by_name('dataset/shuffle_size:0') 292 | train_mode = graph.get_tensor_by_name('dataset/train_mode:0') 293 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 294 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 295 | sess.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_test_file, batch_size: num_test_images, shuffle_size: 1, train_mode: False}) 296 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 297 | -------------------------------------------------------------------------------- /basic_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow and tf.keras 19 | import tensorflow as tf 20 | from tensorflow import keras 21 | 22 | # Helper libraries 23 | import numpy as np 24 | import matplotlib.pyplot as plt 25 | from time import time 26 | 27 | # You can safely tune these variables 28 | BATCH_SIZE = 32 29 | EPOCHS = 5 30 | # ---------------- 31 | 32 | # Load dataset as numpy arrays 33 | fashion_mnist = keras.datasets.fashion_mnist 34 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 35 | print('Data loaded') 36 | print('Local dataset size: {}'.format(train_images.shape[0])) 37 | 38 | # List with class names to see the labels of the images with matplotlib 39 | class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 40 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 41 | 42 | # Normalize dataset 43 | train_images = train_images / 255.0 44 | test_images = test_images / 255.0 45 | 46 | checkpoint_dir='logs_dir/{}'.format(time()) 47 | print('Checkpoint directory: ' + checkpoint_dir) 48 | 49 | global_step = tf.train.get_or_create_global_step() 50 | 51 | # Define input pipeline, place these ops in the cpu 52 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 53 | # Placeholders for the iterator 54 | images_placeholder = tf.placeholder(train_images.dtype, [None, train_images.shape[1], train_images.shape[2]], name='images_placeholder') 55 | labels_placeholder = tf.placeholder(train_labels.dtype, [None], name='labels_placeholder') 56 | batch_size = tf.placeholder(tf.int64, name='batch_size') 57 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 58 | 59 | # Create dataset from numpy arrays, shuffle, repeat and batch 60 | dataset = tf.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder)) 61 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 62 | dataset = dataset.repeat(EPOCHS) 63 | dataset = dataset.batch(batch_size) 64 | # Define a feedable iterator and the initialization op 65 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 66 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 67 | X, y = iterator.get_next() 68 | 69 | # Define our model 70 | flatten_layer = tf.layers.flatten(X, name='flatten') 71 | 72 | dense_layer = tf.layers.dense(flatten_layer, 128, activation=tf.nn.relu, name='relu') 73 | 74 | predictions = tf.layers.dense(dense_layer, 10, activation=tf.nn.softmax, name='softmax') 75 | 76 | # Object to keep moving averages of our metrics (for tensorboard) 77 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 78 | 79 | # Define cross_entropy loss 80 | with tf.name_scope('loss'): 81 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 82 | loss_averages_op = summary_averages.apply([loss]) 83 | # Store moving average of the loss 84 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 85 | 86 | # Define accuracy metric 87 | with tf.name_scope('accuracy'): 88 | with tf.name_scope('correct_prediction'): 89 | # Compare prediction with actual label 90 | correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.cast(y, tf.int64)) 91 | # Average correct predictions in the current batch 92 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 93 | accuracy_averages_op = summary_averages.apply([accuracy]) 94 | # Store moving average of the accuracy 95 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 96 | 97 | # Define optimizer and training op 98 | with tf.name_scope('train'): 99 | # Make train_op dependent on moving averages ops. Otherwise they will be 100 | # disconnected from the graph 101 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op]): 102 | train_op = tf.train.AdamOptimizer(0.001).minimize(loss, global_step=global_step) 103 | 104 | print('Graph definition finished') 105 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 106 | 107 | n_batches = int(train_images.shape[0] / BATCH_SIZE) 108 | last_step = int(n_batches * EPOCHS) 109 | print('Training {} batches...'.format(last_step)) 110 | 111 | # Logger hook to keep track of the training 112 | class _LoggerHook(tf.train.SessionRunHook): 113 | def begin(self): 114 | self._total_loss = 0 115 | self._total_acc = 0 116 | 117 | def before_run(self, run_context): 118 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 119 | 120 | def after_run(self, run_context, run_values): 121 | loss_value, acc_value, step_value = run_values.results 122 | self._total_loss += loss_value 123 | self._total_acc += acc_value 124 | if (step_value + 1) % n_batches == 0: 125 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 126 | self._total_loss = 0 127 | self._total_acc = 0 128 | 129 | # Hook to initialize the dataset 130 | class _InitHook(tf.train.SessionRunHook): 131 | def after_create_session(self, session, coord): 132 | session.run(dataset_init_op, feed_dict={images_placeholder: train_images, labels_placeholder: train_labels, batch_size: BATCH_SIZE, shuffle_size: train_images.shape[0]}) 133 | 134 | with tf.name_scope('monitored_session'): 135 | with tf.train.MonitoredTrainingSession( 136 | checkpoint_dir=checkpoint_dir, 137 | hooks=[_LoggerHook(), _InitHook()], 138 | config=sess_config, 139 | save_checkpoint_steps=n_batches) as mon_sess: 140 | while not mon_sess.should_stop(): 141 | mon_sess.run(train_op) 142 | 143 | print('--- Begin Evaluation ---') 144 | with tf.device('/cpu:0'), tf.Session() as sess: 145 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 146 | tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 147 | print('Model restored') 148 | sess.run(dataset_init_op, feed_dict={images_placeholder: test_images, labels_placeholder: test_labels, batch_size: test_images.shape[0], shuffle_size: 1}) 149 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 150 | predicted = sess.run(predictions) 151 | 152 | # Plot the first 25 test images, their predicted label, and the true label 153 | # Color correct predictions in green, incorrect predictions in red 154 | plt.figure(figsize=(10, 10)) 155 | for i in range(25): 156 | plt.subplot(5, 5, i + 1) 157 | plt.xticks([]) 158 | plt.yticks([]) 159 | plt.grid(False) 160 | plt.imshow(test_images[i], cmap=plt.cm.binary) 161 | predicted_label = np.argmax(predicted[i]) 162 | true_label = test_labels[i] 163 | if predicted_label == true_label: 164 | color = 'green' 165 | else: 166 | color = 'red' 167 | plt.xlabel("{} ({})".format(class_names[predicted_label], 168 | class_names[true_label]), 169 | color=color) 170 | 171 | plt.show(True) 172 | -------------------------------------------------------------------------------- /basic_distributed_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow and tf.keras 19 | import tensorflow as tf 20 | from tensorflow import keras 21 | 22 | # Helper libraries 23 | import os 24 | import numpy as np 25 | from time import time 26 | import matplotlib.pyplot as plt 27 | 28 | flags = tf.app.flags 29 | flags.DEFINE_integer("task_index", None, 30 | "Worker task index, should be >= 0. task_index=0 is " 31 | "the master worker task that performs the variable " 32 | "initialization ") 33 | flags.DEFINE_string("ps_hosts", "localhost:2222", 34 | "Comma-separated list of hostname:port pairs") 35 | flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", 36 | "Comma-separated list of hostname:port pairs") 37 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 38 | 39 | # You can safely tune these variables 40 | BATCH_SIZE = 32 41 | EPOCHS = 5 42 | # ---------------- 43 | 44 | FLAGS = flags.FLAGS 45 | 46 | if FLAGS.job_name is None or FLAGS.job_name == "": 47 | raise ValueError("Must specify an explicit `job_name`") 48 | if FLAGS.task_index is None or FLAGS.task_index == "": 49 | raise ValueError("Must specify an explicit `task_index`") 50 | 51 | # Only enable GPU for worker 1 (not needed if training with separate machines) 52 | if FLAGS.task_index == 0: 53 | print('--- GPU Disabled ---') 54 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 55 | 56 | #Construct the cluster and start the server 57 | ps_spec = FLAGS.ps_hosts.split(",") 58 | worker_spec = FLAGS.worker_hosts.split(",") 59 | 60 | # Get the number of workers. 61 | num_workers = len(worker_spec) 62 | print('{} workers defined'.format(num_workers)) 63 | 64 | cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) 65 | 66 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 67 | 68 | # Parameter server will block here 69 | if FLAGS.job_name == "ps": 70 | print('--- Parameter Server Ready ---') 71 | server.join() 72 | 73 | # Load dataset as numpy arrays 74 | fashion_mnist = keras.datasets.fashion_mnist 75 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 76 | print('Data loaded') 77 | 78 | # List with class names to see the labels of the images with matplotlib 79 | class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 80 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 81 | 82 | # Split dataset between workers 83 | train_images = np.array_split(train_images, num_workers)[FLAGS.task_index] 84 | train_labels = np.array_split(train_labels, num_workers)[FLAGS.task_index] 85 | print('Local dataset size: {}'.format(train_images.shape[0])) 86 | 87 | # Normalize dataset 88 | train_images = train_images / 255.0 89 | test_images = test_images / 255.0 90 | 91 | is_chief = (FLAGS.task_index == 0) 92 | 93 | checkpoint_dir='logs_dir/{}'.format(time()) 94 | print('Checkpoint directory: ' + checkpoint_dir) 95 | 96 | worker_device = "/job:worker/task:%d" % FLAGS.task_index 97 | print('Worker device: ' + worker_device + ' - is_chief: {}'.format(is_chief)) 98 | 99 | # replica_device_setter will place vars in the corresponding device 100 | with tf.device( 101 | tf.train.replica_device_setter( 102 | worker_device=worker_device, 103 | cluster=cluster)): 104 | global_step = tf.train.get_or_create_global_step() 105 | 106 | # Define input pipeline, place these ops in the cpu 107 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 108 | # Placeholders for the iterator 109 | images_placeholder = tf.placeholder(train_images.dtype, [None, train_images.shape[1], train_images.shape[2]], name='images_placeholder') 110 | labels_placeholder = tf.placeholder(train_labels.dtype, [None], name='labels_placeholder') 111 | batch_size = tf.placeholder(tf.int64, name='batch_size') 112 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 113 | 114 | # Create dataset from numpy arrays, shuffle, repeat and batch 115 | dataset = tf.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder)) 116 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 117 | dataset = dataset.repeat(EPOCHS) 118 | dataset = dataset.batch(batch_size) 119 | # Define a feedable iterator and the initialization op 120 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 121 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 122 | X, y = iterator.get_next() 123 | 124 | # Define our model 125 | flatten_layer = tf.layers.flatten(X, name='flatten') 126 | 127 | dense_layer = tf.layers.dense(flatten_layer, 128, activation=tf.nn.relu, name='relu') 128 | 129 | predictions = tf.layers.dense(dense_layer, 10, activation=tf.nn.softmax, name='softmax') 130 | 131 | # Object to keep moving averages of our metrics (for tensorboard) 132 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 133 | 134 | # Define cross_entropy loss 135 | with tf.name_scope('loss'): 136 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 137 | loss_averages_op = summary_averages.apply([loss]) 138 | # Store moving average of the loss 139 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 140 | 141 | # Define accuracy metric 142 | with tf.name_scope('accuracy'): 143 | with tf.name_scope('correct_prediction'): 144 | # Compare prediction with actual label 145 | correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.cast(y, tf.int64)) 146 | # Average correct predictions in the current batch 147 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 148 | accuracy_averages_op = summary_averages.apply([accuracy]) 149 | # Store moving average of the accuracy 150 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 151 | 152 | # Define optimizer and training op 153 | with tf.name_scope('train'): 154 | # Wrap optimizer in a SyncReplicasOptimizer for distributed training 155 | optimizer = tf.train.SyncReplicasOptimizer(tf.train.AdamOptimizer(np.sqrt(num_workers) * 0.001), replicas_to_aggregate=num_workers) 156 | # Make train_op dependent on moving averages ops. Otherwise they will be 157 | # disconnected from the graph 158 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op]): 159 | train_op = optimizer.minimize(loss, global_step=global_step) 160 | # Define a hook for optimizer initialization 161 | sync_replicas_hook = optimizer.make_session_run_hook(is_chief) 162 | 163 | print('Graph definition finished') 164 | 165 | sess_config = tf.ConfigProto( 166 | allow_soft_placement=True, 167 | log_device_placement=False, 168 | device_filters=["/job:ps", 169 | "/job:worker/task:%d" % FLAGS.task_index]) 170 | 171 | n_batches = int(train_images.shape[0] / (BATCH_SIZE * num_workers)) 172 | last_step = int(n_batches * EPOCHS) 173 | 174 | print('Training {} batches...'.format(last_step)) 175 | 176 | # Logger hook to keep track of the training 177 | class _LoggerHook(tf.train.SessionRunHook): 178 | def begin(self): 179 | self._total_loss = 0 180 | self._total_acc = 0 181 | 182 | def before_run(self, run_context): 183 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 184 | 185 | def after_run(self, run_context, run_values): 186 | loss_value, acc_value, step_value = run_values.results 187 | self._total_loss += loss_value 188 | self._total_acc += acc_value 189 | if (step_value + 1) % n_batches == 0: 190 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 191 | self._total_loss = 0 192 | self._total_acc = 0 193 | 194 | def end(self, session): 195 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(session.run(global_step) / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 196 | 197 | # Hook to initialize the dataset 198 | class _InitHook(tf.train.SessionRunHook): 199 | def after_create_session(self, session, coord): 200 | session.run(dataset_init_op, feed_dict={images_placeholder: train_images, labels_placeholder: train_labels, batch_size: BATCH_SIZE, shuffle_size: train_images.shape[0]}) 201 | 202 | with tf.name_scope('monitored_session'): 203 | with tf.train.MonitoredTrainingSession( 204 | master=server.target, 205 | is_chief=is_chief, 206 | checkpoint_dir=checkpoint_dir, 207 | hooks=[_LoggerHook(), _InitHook(), sync_replicas_hook], 208 | config=sess_config, 209 | stop_grace_period_secs=10, 210 | save_checkpoint_steps=n_batches) as mon_sess: 211 | while not mon_sess.should_stop(): 212 | mon_sess.run(train_op) 213 | 214 | if is_chief: 215 | print('--- Begin Evaluation ---') 216 | # Reset graph and load it again to clean tensors placed in other devices 217 | tf.reset_default_graph() 218 | with tf.Session() as sess: 219 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 220 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 221 | saver.restore(sess, ckpt.model_checkpoint_path) 222 | print('Model restored') 223 | graph = tf.get_default_graph() 224 | images_placeholder = graph.get_tensor_by_name('dataset/images_placeholder:0') 225 | labels_placeholder = graph.get_tensor_by_name('dataset/labels_placeholder:0') 226 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 227 | shuffle_size = graph.get_tensor_by_name('dataset/shuffle_size:0') 228 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 229 | predictions = graph.get_tensor_by_name('softmax/BiasAdd:0') 230 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 231 | sess.run(dataset_init_op, feed_dict={images_placeholder: test_images, labels_placeholder: test_labels, batch_size: test_images.shape[0], shuffle_size: 1}) 232 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 233 | predicted = sess.run(predictions) 234 | 235 | # Plot the first 25 test images, their predicted label, and the true label 236 | # Color correct predictions in green, incorrect predictions in red 237 | plt.figure(figsize=(10, 10)) 238 | for i in range(25): 239 | plt.subplot(5, 5, i + 1) 240 | plt.xticks([]) 241 | plt.yticks([]) 242 | plt.grid(False) 243 | plt.imshow(test_images[i], cmap=plt.cm.binary) 244 | predicted_label = np.argmax(predicted[i]) 245 | true_label = test_labels[i] 246 | if predicted_label == true_label: 247 | color = 'green' 248 | else: 249 | color = 'red' 250 | plt.xlabel("{} ({})".format(class_names[predicted_label], 251 | class_names[true_label]), 252 | color=color) 253 | 254 | plt.show(True) 255 | -------------------------------------------------------------------------------- /basic_federated_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow and tf.keras 19 | import tensorflow as tf 20 | from tensorflow import keras 21 | 22 | # Helper libraries 23 | import os 24 | import numpy as np 25 | from time import time 26 | import matplotlib.pyplot as plt 27 | 28 | # Import custom optimizer 29 | import federated_averaging_optimizer 30 | 31 | flags = tf.app.flags 32 | flags.DEFINE_integer("task_index", None, 33 | "Worker task index, should be >= 0. task_index=0 is " 34 | "the master worker task that performs the variable " 35 | "initialization ") 36 | flags.DEFINE_string("ps_hosts", "localhost:2222", 37 | "Comma-separated list of hostname:port pairs") 38 | flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", 39 | "Comma-separated list of hostname:port pairs") 40 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 41 | 42 | # You can safely tune these variables 43 | BATCH_SIZE = 32 44 | EPOCHS = 5 45 | INTERVAL_STEPS = 100 46 | # ---------------- 47 | 48 | FLAGS = flags.FLAGS 49 | 50 | if FLAGS.job_name is None or FLAGS.job_name == "": 51 | raise ValueError("Must specify an explicit `job_name`") 52 | if FLAGS.task_index is None or FLAGS.task_index == "": 53 | raise ValueError("Must specify an explicit `task_index`") 54 | 55 | # Only enable GPU for worker 1 (not needed if training with separate machines) 56 | if FLAGS.task_index == 0: 57 | print('--- GPU Disabled ---') 58 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 59 | 60 | #Construct the cluster and start the server 61 | ps_spec = FLAGS.ps_hosts.split(",") 62 | worker_spec = FLAGS.worker_hosts.split(",") 63 | 64 | # Get the number of workers. 65 | num_workers = len(worker_spec) 66 | print('{} workers defined'.format(num_workers)) 67 | 68 | cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) 69 | 70 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 71 | # Parameter server will block here 72 | if FLAGS.job_name == "ps": 73 | print('--- Parameter Server Ready ---') 74 | server.join() 75 | 76 | # Load dataset as numpy arrays 77 | fashion_mnist = keras.datasets.fashion_mnist 78 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 79 | print('Data loaded') 80 | 81 | # List with class names to see the labels of the images with matplotlib 82 | class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 83 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 84 | 85 | # Split dataset between workers 86 | train_images = np.array_split(train_images, num_workers)[FLAGS.task_index] 87 | train_labels = np.array_split(train_labels, num_workers)[FLAGS.task_index] 88 | print('Local dataset size: {}'.format(train_images.shape[0])) 89 | 90 | # Normalize dataset 91 | train_images = train_images / 255.0 92 | test_images = test_images / 255.0 93 | 94 | is_chief = (FLAGS.task_index == 0) 95 | 96 | checkpoint_dir='logs_dir/federated_worker_{}/{}'.format(FLAGS.task_index, time()) 97 | print('Checkpoint directory: ' + checkpoint_dir) 98 | 99 | worker_device = "/job:worker/task:%d" % FLAGS.task_index 100 | print('Worker device: ' + worker_device + ' - is_chief: {}'.format(is_chief)) 101 | 102 | # Place all ops in the local worker by default 103 | with tf.device(worker_device): 104 | global_step = tf.train.get_or_create_global_step() 105 | 106 | # Define input pipeline, place these ops in the cpu 107 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 108 | # Placeholders for the iterator 109 | images_placeholder = tf.placeholder(train_images.dtype, [None, train_images.shape[1], train_images.shape[2]], name='images_placeholder') 110 | labels_placeholder = tf.placeholder(train_labels.dtype, [None], name='labels_placeholder') 111 | batch_size = tf.placeholder(tf.int64, name='batch_size') 112 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 113 | 114 | # Create dataset from numpy arrays, shuffle, repeat and batch 115 | dataset = tf.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder)) 116 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 117 | dataset = dataset.repeat(EPOCHS) 118 | dataset = dataset.batch(batch_size) 119 | # Define a feedable iterator and the initialization op 120 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 121 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 122 | X, y = iterator.get_next() 123 | 124 | # Define our model 125 | flatten_layer = tf.layers.flatten(X, name='flatten') 126 | 127 | dense_layer = tf.layers.dense(flatten_layer, 128, activation=tf.nn.relu, name='relu') 128 | 129 | predictions = tf.layers.dense(dense_layer, 10, activation=tf.nn.softmax, name='softmax') 130 | 131 | # Object to keep moving averages of our metrics (for tensorboard) 132 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 133 | 134 | # Define cross_entropy loss 135 | with tf.name_scope('loss'): 136 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 137 | loss_averages_op = summary_averages.apply([loss]) 138 | # Store moving average of the loss 139 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 140 | 141 | # Define accuracy metric 142 | with tf.name_scope('accuracy'): 143 | with tf.name_scope('correct_prediction'): 144 | # Compare prediction with actual label 145 | correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.cast(y, tf.int64)) 146 | # Average correct predictions in the current batch 147 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 148 | accuracy_averages_op = summary_averages.apply([accuracy]) 149 | # Store moving average of the accuracy 150 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 151 | 152 | # Define optimizer and training op 153 | with tf.name_scope('train'): 154 | # Define device setter to place copies of local variables 155 | device_setter = tf.train.replica_device_setter(worker_device=worker_device, cluster=cluster) 156 | # Wrap optimizer in a FederatedAveragingOptimizer for federated training 157 | optimizer = federated_averaging_optimizer.FederatedAveragingOptimizer(tf.train.AdamOptimizer(0.001), replicas_to_aggregate=num_workers, interval_steps=INTERVAL_STEPS, is_chief=is_chief, device_setter=device_setter) 158 | # Make train_op dependent on moving averages ops. Otherwise they will be 159 | # disconnected from the graph 160 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op]): 161 | train_op = optimizer.minimize(loss, global_step=global_step) 162 | # Define a hook for optimizer initialization 163 | federated_hook = optimizer.make_session_run_hook() 164 | 165 | n_batches = int(train_images.shape[0] / BATCH_SIZE) 166 | last_step = int(n_batches * EPOCHS) 167 | 168 | print('Graph definition finished') 169 | 170 | sess_config = tf.ConfigProto( 171 | allow_soft_placement=True, 172 | log_device_placement=False, 173 | operation_timeout_in_ms=20000, 174 | device_filters=["/job:ps", 175 | "/job:worker/task:%d" % FLAGS.task_index]) 176 | 177 | print('Training {} batches...'.format(last_step)) 178 | 179 | # Logger hook to keep track of the training 180 | class _LoggerHook(tf.train.SessionRunHook): 181 | def begin(self): 182 | self._total_loss = 0 183 | self._total_acc = 0 184 | 185 | def before_run(self, run_context): 186 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 187 | 188 | def after_run(self, run_context, run_values): 189 | loss_value, acc_value, step_value = run_values.results 190 | self._total_loss += loss_value 191 | self._total_acc += acc_value 192 | if (step_value + 1) % n_batches == 0: 193 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format(int(step_value / n_batches) + 1, EPOCHS, self._total_loss / n_batches, self._total_acc / n_batches)) 194 | self._total_loss = 0 195 | self._total_acc = 0 196 | 197 | # Hook to initialize the dataset 198 | class _InitHook(tf.train.SessionRunHook): 199 | def after_create_session(self, session, coord): 200 | session.run(dataset_init_op, feed_dict={images_placeholder: train_images, labels_placeholder: train_labels, batch_size: BATCH_SIZE, shuffle_size: train_images.shape[0]}) 201 | 202 | # Hook to save just trainable_variables 203 | class _SaverHook(tf.train.SessionRunHook): 204 | def begin(self): 205 | self._saver = tf.train.Saver(tf.trainable_variables()) 206 | 207 | def before_run(self, run_context): 208 | return tf.train.SessionRunArgs(global_step) 209 | 210 | def after_run(self, run_context, run_values): 211 | step_value = run_values.results 212 | if step_value % n_batches == 0 and not step_value == 0: 213 | self._saver.save(run_context.session, checkpoint_dir+'/model.ckpt', step_value) 214 | 215 | def end(self, session): 216 | self._saver.save(session, checkpoint_dir+'/model.ckpt', session.run(global_step)) 217 | 218 | # Make sure we do not define a chief worker 219 | with tf.name_scope('monitored_session'): 220 | with tf.train.MonitoredTrainingSession( 221 | master=server.target, 222 | checkpoint_dir=checkpoint_dir, 223 | hooks=[_LoggerHook(), _InitHook(), _SaverHook(), federated_hook], 224 | config=sess_config, 225 | stop_grace_period_secs=10, 226 | save_checkpoint_secs=None) as mon_sess: 227 | while not mon_sess.should_stop(): 228 | mon_sess.run(train_op) 229 | 230 | if is_chief: 231 | print('--- Begin Evaluation ---') 232 | # Reset graph and load it again to clean tensors placed in other devices 233 | tf.reset_default_graph() 234 | with tf.Session() as sess: 235 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 236 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 237 | saver.restore(sess, ckpt.model_checkpoint_path) 238 | print('Model restored') 239 | graph = tf.get_default_graph() 240 | images_placeholder = graph.get_tensor_by_name('dataset/images_placeholder:0') 241 | labels_placeholder = graph.get_tensor_by_name('dataset/labels_placeholder:0') 242 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 243 | shuffle_size = graph.get_tensor_by_name('dataset/shuffle_size:0') 244 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 245 | predictions = graph.get_tensor_by_name('softmax/BiasAdd:0') 246 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 247 | sess.run(dataset_init_op, feed_dict={images_placeholder: test_images, labels_placeholder: test_labels, batch_size: test_images.shape[0], shuffle_size: 1}) 248 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 249 | predicted = sess.run(predictions) 250 | 251 | # Plot the first 25 test images, their predicted label, and the true label 252 | # Color correct predictions in green, incorrect predictions in red 253 | plt.figure(figsize=(10, 10)) 254 | for i in range(25): 255 | plt.subplot(5, 5, i + 1) 256 | plt.xticks([]) 257 | plt.yticks([]) 258 | plt.grid(False) 259 | plt.imshow(test_images[i], cmap=plt.cm.binary) 260 | predicted_label = np.argmax(predicted[i]) 261 | true_label = test_labels[i] 262 | if predicted_label == true_label: 263 | color = 'green' 264 | else: 265 | color = 'red' 266 | plt.xlabel("{} ({})".format(class_names[predicted_label], 267 | class_names[true_label]), 268 | color=color) 269 | 270 | plt.show(True) 271 | -------------------------------------------------------------------------------- /federated-MPI/README.md: -------------------------------------------------------------------------------- 1 | # Implementation with MPI 2 | 3 | This is the implementation of Federated Averaging using Message Passing Interface. It is harder to set-up but much easier to run, you can launch the whole cluster with just one command! 4 | 5 | ## Installation dependencies 6 | 7 | Same as previous, and: 8 | - [Mpich3](https://www.mpich.org/) 9 | - [mpi4py](https://mpi4py.readthedocs.io/en/stable/) 10 | 11 | ## Usage 12 | 13 | To run two processes in the same computer with the basic classifier type in the shell: `mpiexec -n 2 python3 mpi_basic_classifier.py` 14 | 15 | To run a cluster of nodes list their IP's in a file and run: `mpiexec -f your_file python3 mpi_basic_classifier.py` 16 | 17 | ## Useful resources 18 | 19 | Check [this tutorial](https://lleksah.wordpress.com/2016/04/11/configuring-a-raspberry-cluster-with-mpi/) to set-up a cluster of Raspberry Pi's. 20 | 21 | Check [this thread](https://raspberrypi.stackexchange.com/questions/54103/how-to-install-mpi4py-on-for-python3-on-raspberry-pi-after-installing-mpich) to solve common problems with mpi4py after installing mpich. 22 | 23 | Check [this script](https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py) to see how to generate CIFAR-10 TFRecords. 24 | 25 | ## Troubleshooting and Help 26 | 27 | coMind has public Slack and Telegram channels which are a great place to ask questions and all things related to federated machine learning. 28 | 29 | ## About 30 | 31 | coMind is an open source project for training privacy-preserving federated deep learning models. 32 | 33 | * https://comind.org/ 34 | * [Twitter](https://twitter.com/coMindOrg) 35 | -------------------------------------------------------------------------------- /federated-MPI/mpi_advanced_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow 19 | import tensorflow as tf 20 | 21 | # Helper libraries 22 | import numpy as np 23 | from time import time 24 | from mpi4py import MPI 25 | import sys 26 | import multiprocessing 27 | 28 | # You can safely tune these variables 29 | BATCH_SIZE = 128 30 | SHUFFLE_SIZE = BATCH_SIZE * 100 31 | EPOCHS = 250 32 | EPOCHS_PER_DECAY = 50 33 | INTERVAL_STEPS = 100 # Steps between averages 34 | BATCHES_TO_PREFETCH = 1 35 | # ----------------- 36 | 37 | # Let the code know about the MPI config 38 | COMM = MPI.COMM_WORLD 39 | 40 | num_workers = COMM.size 41 | 42 | # Dataset dependent constants 43 | NUM_TRAIN_IMAGES = int(50000 / num_workers) 44 | NUM_TEST_IMAGES = 10000 45 | HEIGHT = 32 46 | WIDTH = 32 47 | CHANNELS = 3 48 | NUM_BATCH_FILES = 5 49 | 50 | # Path to TFRecord files (check readme for instructions on how to get these files) 51 | cifar10_train_files = ['cifar-10-tf-records/train{}.tfrecords'.format(i) 52 | for i in range(NUM_BATCH_FILES)] 53 | cifar10_test_file = 'cifar-10-tf-records/test.tfrecords' 54 | 55 | # Shuffle filenames before loading them 56 | np.random.shuffle(cifar10_train_files) 57 | 58 | CHECKPOINT_DIR = 'logs_dir/{}'.format(time()) 59 | print('Checkpoint directory: ' + CHECKPOINT_DIR) 60 | sys.stdout.flush() 61 | 62 | global_step = tf.train.get_or_create_global_step() 63 | 64 | CPU_COUNT = int(multiprocessing.cpu_count() / num_workers) 65 | 66 | # Define input pipeline, place these ops in the cpu 67 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 68 | # Map function to decode data and preprocess it 69 | def preprocess(serialized_examples): 70 | """ Preprocess data """ 71 | # Parse a batch 72 | features = tf.parse_example( 73 | serialized_examples, 74 | {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) 75 | # Decode and reshape imag 76 | image = tf.map_fn(lambda img: tf.reshape(tf.decode_raw(img, tf.uint8), 77 | tf.stack([HEIGHT, WIDTH, CHANNELS])), 78 | features['image'], dtype=tf.uint8, name='decode') 79 | # Cast image 80 | casted_image = tf.cast(image, tf.float32, name='input_cast') 81 | # Resize image for testing 82 | resized_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24) 83 | # Augment images for training 84 | distorted_image = tf.map_fn(lambda img: tf.random_crop(img, [24, 24, 3]), 85 | casted_image, name='random_crop') 86 | distorted_image = tf.image.random_flip_left_right(distorted_image) 87 | distorted_image = tf.image.random_brightness(distorted_image, 63) 88 | distorted_image = tf.image.random_contrast(distorted_image, 0.2, 1.8) 89 | # Check if test or train mode 90 | result = tf.cond(train_mode, lambda: distorted_image, lambda: resized_image) 91 | # Standardize images 92 | processed_image = tf.map_fn(lambda img: tf.image.per_image_standardization(img), 93 | result, name='standardization') 94 | return processed_image, features['label'] 95 | # Placeholders for the iterator 96 | filename_placeholder = tf.placeholder(tf.string, name='input_filename') 97 | batch_size = tf.placeholder(tf.int64, name='batch_size') 98 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 99 | train_mode = tf.placeholder(tf.bool, name='train_mode') 100 | 101 | # Create dataset, shuffle, repeat, batch, map and prefetch 102 | dataset = tf.data.TFRecordDataset(filename_placeholder) 103 | dataset = dataset.shard(num_workers, COMM.rank) 104 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 105 | dataset = dataset.repeat(EPOCHS) 106 | dataset = dataset.batch(batch_size) 107 | dataset = dataset.map(preprocess, CPU_COUNT) 108 | dataset = dataset.prefetch(BATCHES_TO_PREFETCH) 109 | # Define a feedable iterator and the initialization op 110 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 111 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 112 | X, y = iterator.get_next() 113 | 114 | # Define our model 115 | first_conv = tf.layers.conv2d(X, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='first_conv') 116 | 117 | first_pool = tf.nn.max_pool(first_conv, [1, 3, 3 ,1], [1, 2, 2, 1], padding='SAME', name='first_pool') 118 | 119 | first_norm = tf.nn.lrn(first_pool, 4, alpha=0.001 / 9.0, beta=0.75, name='first_norm') 120 | 121 | second_conv = tf.layers.conv2d(first_norm, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='second_conv') 122 | 123 | second_norm = tf.nn.lrn(second_conv, 4, alpha=0.001 / 9.0, beta=0.75, name='second_norm') 124 | 125 | second_pool = tf.nn.max_pool(second_norm, [1, 3, 3, 1], [1, 2, 2, 1], padding='SAME', name='second_pool') 126 | 127 | flatten_layer = tf.layers.flatten(second_pool, name='flatten') 128 | 129 | first_relu = tf.layers.dense(flatten_layer, 384, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='first_relu') 130 | 131 | second_relu = tf.layers.dense(first_relu, 192, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='second_relu') 132 | 133 | logits = tf.layers.dense(second_relu, 10, kernel_initializer=tf.truncated_normal_initializer(stddev=1/192.0), name='logits') 134 | 135 | # Object to keep moving averages of our metrics (for tensorboard) 136 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 137 | 138 | # Define cross_entropy loss 139 | with tf.name_scope('loss'): 140 | base_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, 141 | logits=logits), 142 | name='base_loss') 143 | # Add regularization loss to both relu layers 144 | regularizer_loss = tf.add_n( 145 | [tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'relu/kernel' in v.name], 146 | name='regularizer_loss') * 0.004 147 | loss = tf.add(base_loss, regularizer_loss) 148 | loss_averages_op = summary_averages.apply([loss]) 149 | # Store moving average of the loss 150 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 151 | 152 | with tf.name_scope('accuracy'): 153 | with tf.name_scope('correct_prediction'): 154 | # Compare prediction with actual label 155 | correct_prediction = tf.equal(tf.argmax(logits, 1), y) 156 | # Average correct predictions in the current batch 157 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 158 | accuracy_averages_op = summary_averages.apply([accuracy]) 159 | # Store moving average of the accuracy 160 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 161 | 162 | N_BATCHES = int(NUM_TRAIN_IMAGES / BATCH_SIZE) 163 | LAST_STEP = int(N_BATCHES * EPOCHS) 164 | 165 | # Define moving averages of the trainable variables. This sometimes improve 166 | # the performance of the trained model 167 | with tf.name_scope('variable_averages'): 168 | variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) 169 | variable_averages_op = variable_averages.apply(tf.trainable_variables()) 170 | 171 | # Define optimizer and training op 172 | with tf.name_scope('train'): 173 | # Make decaying learning rate 174 | lr = tf.train.exponential_decay(0.1, global_step, N_BATCHES * EPOCHS_PER_DECAY, 175 | 0.1, staircase=True) 176 | tf.summary.scalar('learning_rate', lr) 177 | # Make train_op dependent on moving averages ops. Otherwise they will be 178 | # disconnected from the graph 179 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op, variable_averages_op]): 180 | train_op = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step) 181 | 182 | print('Graph definition finished') 183 | sys.stdout.flush() 184 | SESS_CONFIG = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 185 | 186 | print('Training {} batches...'.format(LAST_STEP)) 187 | sys.stdout.flush() 188 | 189 | # Logger hook to keep track of the training 190 | class _LoggerHook(tf.train.SessionRunHook): 191 | def begin(self): 192 | """ Run this in session begin """ 193 | self._total_loss = 0 194 | self._total_acc = 0 195 | 196 | def before_run(self, run_context): 197 | """ Run this in session before_run """ 198 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 199 | 200 | def after_run(self, run_context, run_values): 201 | """ Run this in session after_run """ 202 | loss_value, acc_value, step_value = run_values.results 203 | self._total_loss += loss_value 204 | self._total_acc += acc_value 205 | if (step_value + 1) % N_BATCHES == 0 and COMM.rank == 0: 206 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format( 207 | int(step_value / N_BATCHES) + 1, EPOCHS, self._total_loss / N_BATCHES, 208 | self._total_acc / N_BATCHES)) 209 | sys.stdout.flush() 210 | self._total_loss = 0 211 | self._total_acc = 0 212 | 213 | # Custom hook 214 | class _FederatedHook(tf.train.SessionRunHook): 215 | def __init__(self, comm): 216 | """ Initialize Hook """ 217 | # Store the MPI config 218 | self._comm = comm 219 | 220 | def _create_placeholders(self): 221 | """ Create placeholders for all the trainable variables """ 222 | # Create placeholders for all the trainable variables 223 | for var in tf.trainable_variables(): 224 | self._placeholders.append(tf.placeholder_with_default( 225 | var, var.shape, name="%s/%s" % ("FedAvg", var.op.name))) 226 | 227 | def _assign_vars(self, local_vars): 228 | """ Assign value feeded to placeholders to local vars """ 229 | reassign_ops = [] 230 | for var, fvar in zip(local_vars, self._placeholders): 231 | reassign_ops.append(tf.assign(var, fvar)) 232 | return tf.group(*(reassign_ops)) 233 | 234 | def _gather_weights(self, session): 235 | """Gather all weights in the chief worker""" 236 | gathered_weights = [] 237 | for var in tf.trainable_variables(): 238 | value = session.run(var) 239 | value = self._comm.gather(value, root=0) 240 | gathered_weights.append(np.array(value)) 241 | return gathered_weights 242 | 243 | def _broadcast_weights(self, session): 244 | """Broadcast averaged weights to all workers""" 245 | broadcasted_weights = [] 246 | for var in tf.trainable_variables(): 247 | value = session.run(var) 248 | value = self._comm.bcast(value, root=0) 249 | broadcasted_weights.append(np.array(value)) 250 | return broadcasted_weights 251 | 252 | def begin(self): 253 | """ Run this in session begin """ 254 | self._placeholders = [] 255 | self._create_placeholders() 256 | # Op to initialize update the weight 257 | self._update_local_vars_op = self._assign_vars(tf.trainable_variables()) 258 | 259 | def after_create_session(self, session, coord): 260 | """ Run this after creating session """ 261 | # Broadcast weights 262 | broadcasted_weights = self._broadcast_weights(session) 263 | # Initialize the workers at the same point 264 | if self._comm.rank != 0: 265 | feed_dict = {} 266 | for placeh, bweight in zip(self._placeholders, broadcasted_weights): 267 | feed_dict[placeh] = bweight 268 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 269 | 270 | def before_run(self, run_context): 271 | """ Run this in session before_run """ 272 | return tf.train.SessionRunArgs(global_step) 273 | 274 | def after_run(self, run_context, run_values): 275 | """ Run this in session after_run """ 276 | step_value = run_values.results 277 | session = run_context.session 278 | # Check if we should average 279 | if step_value % INTERVAL_STEPS == 0 and not step_value == 0: 280 | gathered_weights = self._gather_weights(session) 281 | # Chief gather weights and averages 282 | if self._comm.rank == 0: 283 | print('Average applied, iter: {}/{}'.format(step_value, LAST_STEP)) 284 | sys.stdout.flush() 285 | for i, elem in enumerate(gathered_weights): 286 | gathered_weights[i] = np.mean(elem, axis=0) 287 | feed_dict = {} 288 | for placeh, gweight in zip(self._placeholders, gathered_weights): 289 | feed_dict[placeh] = gweight 290 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 291 | # The rest get the averages and update their local model 292 | broadcasted_weights = self._broadcast_weights(session) 293 | if self._comm.rank != 0: 294 | feed_dict = {} 295 | for placeh, bweight in zip(self._placeholders, broadcasted_weights): 296 | feed_dict[placeh] = bweight 297 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 298 | 299 | # Hook to initialize the dataset 300 | class _InitHook(tf.train.SessionRunHook): 301 | def after_create_session(self, session, coord): 302 | session.run(dataset_init_op, feed_dict={ 303 | filename_placeholder: cifar10_train_files, batch_size: BATCH_SIZE, 304 | shuffle_size: SHUFFLE_SIZE, train_mode: True}) 305 | 306 | print("Worker {} ready".format(COMM.rank)) 307 | sys.stdout.flush() 308 | 309 | with tf.name_scope('monitored_session'): 310 | with tf.train.MonitoredTrainingSession( 311 | checkpoint_dir=CHECKPOINT_DIR, 312 | hooks=[_LoggerHook(), _InitHook(), _FederatedHook(COMM), 313 | tf.train.CheckpointSaverHook(checkpoint_dir=CHECKPOINT_DIR, 314 | save_steps=N_BATCHES, 315 | saver=tf.train.Saver( 316 | variable_averages.variables_to_restore()))], 317 | config=SESS_CONFIG, 318 | save_checkpoint_secs=None) as mon_sess: 319 | while not mon_sess.should_stop(): 320 | mon_sess.run(train_op) 321 | 322 | if COMM.rank == 0: 323 | print('--- Begin Evaluation ---') 324 | sys.stdout.flush() 325 | tf.reset_default_graph() 326 | with tf.Session() as sess: 327 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) 328 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 329 | saver.restore(sess, ckpt.model_checkpoint_path) 330 | print('Model restored') 331 | sys.stdout.flush() 332 | graph = tf.get_default_graph() 333 | images_placeholder = graph.get_tensor_by_name('dataset/images_placeholder:0') 334 | labels_placeholder = graph.get_tensor_by_name('dataset/labels_placeholder:0') 335 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 336 | train_mode = graph.get_tensor_by_name('dataset/train_mode:0') 337 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 338 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 339 | sess.run(dataset_init_op, feed_dict={ 340 | filename_placeholder: cifar10_test_file, batch_size: NUM_TEST_IMAGES, 341 | shuffle_size: 1, train_mode: False}) 342 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 343 | sys.stdout.flush() 344 | -------------------------------------------------------------------------------- /federated-MPI/mpi_basic_classifier.py: -------------------------------------------------------------------------------- 1 | """# Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ==============================================================================""" 17 | 18 | # TensorFlow and tf.keras 19 | import tensorflow as tf 20 | from tensorflow import keras 21 | 22 | # Helper libraries 23 | import numpy as np 24 | from time import time 25 | from mpi4py import MPI 26 | import sys 27 | 28 | # Let the code know about the MPI config 29 | COMM = MPI.COMM_WORLD 30 | 31 | # Load dataset as numpy arrays 32 | fashion_mnist = keras.datasets.fashion_mnist 33 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 34 | 35 | # Split dataset 36 | train_images = np.array_split(train_images, COMM.size)[COMM.rank] 37 | train_labels = np.array_split(train_labels, COMM.size)[COMM.rank] 38 | 39 | # You can safely tune these variables 40 | BATCH_SIZE = 32 41 | SHUFFLE_SIZE = train_images.shape[0] 42 | EPOCHS = 5 43 | INTERVAL_STEPS = 100 44 | # ----------------- 45 | 46 | # Normalize dataset 47 | train_images = train_images / 255.0 48 | test_images = test_images / 255.0 49 | 50 | CHECKPOINT_DIR = 'logs_dir/{}'.format(time()) 51 | 52 | global_step = tf.train.get_or_create_global_step() 53 | 54 | # Define input pipeline, place these ops in the cpu 55 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 56 | # Placeholders for the iterator 57 | images_placeholder = tf.placeholder(train_images.dtype, [None, train_images.shape[1], train_images.shape[2]]) 58 | labels_placeholder = tf.placeholder(train_labels.dtype, [None]) 59 | batch_size = tf.placeholder(tf.int64) 60 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 61 | 62 | # Create dataset, shuffle, repeat and batch 63 | dataset = tf.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder)) 64 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 65 | dataset = dataset.repeat(EPOCHS) 66 | dataset = dataset.batch(batch_size) 67 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 68 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 69 | X, y = iterator.get_next() 70 | 71 | # Define our model 72 | flatten_layer = tf.layers.flatten(X, name='flatten') 73 | 74 | dense_layer = tf.layers.dense(flatten_layer, 128, activation=tf.nn.relu, name='relu') 75 | 76 | predictions = tf.layers.dense(dense_layer, 10, activation=tf.nn.softmax, name='softmax') 77 | 78 | # Object to keep moving averages of our metrics (for tensorboard) 79 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 80 | 81 | # Define cross_entropy loss 82 | with tf.name_scope('loss'): 83 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 84 | loss_averages_op = summary_averages.apply([loss]) 85 | # Store moving average of the loss 86 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 87 | 88 | with tf.name_scope('accuracy'): 89 | with tf.name_scope('correct_prediction'): 90 | # Compare prediction with actual label 91 | correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.cast(y, tf.int64)) 92 | # Average correct predictions in the current batch 93 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 94 | accuracy_averages_op = summary_averages.apply([accuracy]) 95 | # Store moving average of the accuracy 96 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 97 | 98 | # Define optimizer and training op 99 | with tf.name_scope('train'): 100 | # Make train_op dependent on moving averages ops. Otherwise they will be 101 | # disconnected from the graph 102 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op]): 103 | train_op = tf.train.AdamOptimizer(0.001).minimize(loss, global_step=global_step) 104 | 105 | SESS_CONFIG = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 106 | 107 | N_BATCHES = int(train_images.shape[0] / BATCH_SIZE) 108 | LAST_STEP = int(N_BATCHES * EPOCHS) 109 | 110 | # Logger hook to keep track of the training 111 | class _LoggerHook(tf.train.SessionRunHook): 112 | def begin(self): 113 | """ Run this in session begin """ 114 | self._total_loss = 0 115 | self._total_acc = 0 116 | 117 | def before_run(self, run_context): 118 | """ Run this in session before_run """ 119 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 120 | 121 | def after_run(self, run_context, run_values): 122 | """ Run this in session after_run """ 123 | loss_value, acc_value, step_value = run_values.results 124 | self._total_loss += loss_value 125 | self._total_acc += acc_value 126 | if (step_value + 1) % N_BATCHES == 0 and COMM.rank == 0: 127 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format( 128 | int(step_value / N_BATCHES) + 1, 129 | EPOCHS, self._total_loss / N_BATCHES, self._total_acc / N_BATCHES)) 130 | sys.stdout.flush() 131 | self._total_loss = 0 132 | self._total_acc = 0 133 | 134 | # Custom hook 135 | class _FederatedHook(tf.train.SessionRunHook): 136 | def __init__(self, comm): 137 | """ Initialize Hook """ 138 | # Store the MPI config 139 | self._comm = comm 140 | 141 | def _create_placeholders(self): 142 | """ Create placeholders for all the trainable variables """ 143 | for var in tf.trainable_variables(): 144 | self._placeholders.append( 145 | tf.placeholder_with_default( 146 | var, var.shape, name="%s/%s" % ("FedAvg", var.op.name))) 147 | 148 | def _assign_vars(self, local_vars): 149 | """ Assign value feeded to placeholders to local vars """ 150 | reassign_ops = [] 151 | for var, fvar in zip(local_vars, self._placeholders): 152 | reassign_ops.append(tf.assign(var, fvar)) 153 | return tf.group(*(reassign_ops)) 154 | 155 | def _gather_weights(self, session): 156 | """Gather all weights in the chief worker""" 157 | gathered_weights = [] 158 | for var in tf.trainable_variables(): 159 | value = session.run(var) 160 | value = self._comm.gather(value, root=0) 161 | gathered_weights.append(np.array(value)) 162 | return gathered_weights 163 | 164 | def _broadcast_weights(self, session): 165 | """Broadcast averaged weights to all workers""" 166 | broadcasted_weights = [] 167 | for var in tf.trainable_variables(): 168 | value = session.run(var) 169 | value = self._comm.bcast(value, root=0) 170 | broadcasted_weights.append(np.array(value)) 171 | return broadcasted_weights 172 | 173 | def begin(self): 174 | """ Run this in session begin """ 175 | self._placeholders = [] 176 | self._create_placeholders() 177 | # Op to initialize update the weights 178 | self._update_local_vars_op = self._assign_vars(tf.trainable_variables()) 179 | 180 | def after_create_session(self, session, coord): 181 | """ Run this after creating session """ 182 | # Broadcast weights 183 | broadcasted_weights = self._broadcast_weights(session) 184 | # Initialize the workers at the same point 185 | if self._comm.rank != 0: 186 | feed_dict = {} 187 | for placeh, bweight in zip(self._placeholders, broadcasted_weights): 188 | feed_dict[placeh] = bweight 189 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 190 | 191 | def before_run(self, run_context): 192 | """ Run this in session before_run """ 193 | return tf.train.SessionRunArgs(global_step) 194 | 195 | def after_run(self, run_context, run_values): 196 | """ Run this in session after_run """ 197 | step_value = run_values.results 198 | session = run_context.session 199 | # Check if we should average 200 | if step_value % INTERVAL_STEPS == 0 and not step_value == 0: 201 | gathered_weights = self._gather_weights(session) 202 | # Chief gather weights and averages 203 | if self._comm.rank == 0: 204 | print('Average applied, iter: {}/{}'.format(step_value, LAST_STEP)) 205 | sys.stdout.flush() 206 | for i, elem in enumerate(gathered_weights): 207 | gathered_weights[i] = np.mean(elem, axis=0) 208 | feed_dict = {} 209 | for placeh, gweight in zip(self._placeholders, gathered_weights): 210 | feed_dict[placeh] = gweight 211 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 212 | # The rest get the averages and update their local model 213 | broadcasted_weights = self._broadcast_weights(session) 214 | if self._comm.rank != 0: 215 | feed_dict = {} 216 | for placeh, bweight in zip(self._placeholders, broadcasted_weights): 217 | feed_dict[placeh] = bweight 218 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 219 | 220 | # Hook to initialize the dataset 221 | class _InitHook(tf.train.SessionRunHook): 222 | def after_create_session(self, session, coord): 223 | """ Run this after creating session """ 224 | session.run(dataset_init_op, feed_dict={ 225 | images_placeholder: train_images, 226 | labels_placeholder: train_labels, 227 | batch_size: BATCH_SIZE, shuffle_size: SHUFFLE_SIZE}) 228 | 229 | print("Worker {} ready".format(COMM.rank)) 230 | sys.stdout.flush() 231 | 232 | with tf.name_scope('monitored_session'): 233 | with tf.train.MonitoredTrainingSession( 234 | checkpoint_dir=CHECKPOINT_DIR, 235 | hooks=[_LoggerHook(), _InitHook(), _FederatedHook(COMM)], 236 | config=SESS_CONFIG, 237 | save_checkpoint_steps=N_BATCHES) as mon_sess: 238 | while not mon_sess.should_stop(): 239 | mon_sess.run(train_op) 240 | 241 | if COMM.rank == 0: 242 | print('--- Begin Evaluation ---') 243 | sys.stdout.flush() 244 | with tf.Session() as sess: 245 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) 246 | tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 247 | print('Model restored') 248 | sys.stdout.flush() 249 | sess.run(dataset_init_op, feed_dict={ 250 | images_placeholder: test_images, labels_placeholder: test_labels, 251 | batch_size: test_images.shape[0], shuffle_size: 1}) 252 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 253 | sys.stdout.flush() 254 | -------------------------------------------------------------------------------- /federated-keras/README.md: -------------------------------------------------------------------------------- 1 | # Federated with Keras 2 | 3 | This shows the usage of the distributed and federated set-ups with keras. 4 | 5 | ## Dependencies 6 | 7 | You will need the custom `federated_averaging_optimizer.py` to be able to run the keras example. You can [find it](https://github.com/coMindOrg/federated-averaging-tutorials/blob/master/federated_averaging_optimizer.py) in this same repository. 8 | 9 | ## Usage 10 | 11 | For example, to run the `keras_distributed_classifier.py`: 12 | 13 | * 1st shell command should look like this: `python3 keras_distributed_classifier.py --job_name=ps --task_index=0` 14 | 15 | * 2nd shell: `python3 keras_distributed_classifier.py --job_name=worker --task_index=0` 16 | 17 | * 3rd shell: `python3 keras_distributed_classifier.py --job_name=worker --task_index=1` 18 | 19 | Follow the same steps for the `keras_federated_classifier.py`. 20 | 21 | ## Useful resources 22 | 23 | Check [Keras](https://keras.io/) to learn more about this great API. 24 | 25 | ## Troubleshooting and Help 26 | 27 | coMind has public Slack and Telegram channels which are a great place to ask questions and all things related to federated machine learning. 28 | 29 | ## About 30 | 31 | coMind is an open source project for training privacy-preserving federated deep learning models. 32 | 33 | * https://comind.org/ 34 | * [Twitter](https://twitter.com/coMindOrg) 35 | -------------------------------------------------------------------------------- /federated-keras/keras_distributed_classifier.py: -------------------------------------------------------------------------------- 1 | # Helper libraries 2 | import os 3 | import numpy as np 4 | from time import time 5 | 6 | # TensorFlow and tf.keras 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_integer("task_index", None, 12 | "Worker task index, should be >= 0. task_index=0 is " 13 | "the master worker task the performs the variable " 14 | "initialization ") 15 | flags.DEFINE_integer("train_steps", 1000, 16 | "Number of (global) training steps to perform") 17 | flags.DEFINE_string("ps_hosts", "localhost:2222", 18 | "Comma-separated list of hostname:port pairs") 19 | flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", 20 | "Comma-separated list of hostname:port pairs") 21 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | # Steps between averages 26 | INTERVAL_STEPS = 100 27 | 28 | # Disable GPU to avoid OOM issues (could enable it for just one of the workers) 29 | # Not necessary if workers are hosted in different machines 30 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 31 | 32 | if FLAGS.job_name is None or FLAGS.job_name == "": 33 | raise ValueError("Must specify an explicit `job_name`") 34 | if FLAGS.task_index is None or FLAGS.task_index == "": 35 | raise ValueError("Must specify an explicit `task_index`") 36 | print("job name = %s" % FLAGS.job_name) 37 | print("task index = %d" % FLAGS.task_index) 38 | 39 | #Construct the cluster and start the server 40 | ps_spec = FLAGS.ps_hosts.split(",") 41 | worker_spec = FLAGS.worker_hosts.split(",") 42 | 43 | # Get the number of workers. 44 | NUM_WORKERS = len(worker_spec) 45 | 46 | cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) 47 | 48 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 49 | 50 | # The server will block here 51 | if FLAGS.job_name == "ps": 52 | server.join() 53 | 54 | fashion_mnist = keras.datasets.fashion_mnist 55 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 56 | 57 | CLASS_NAMES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 58 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 59 | 60 | # Split dataset between workers 61 | train_images = np.array_split(train_images, NUM_WORKERS)[FLAGS.task_index] 62 | train_labels = np.array_split(train_labels, NUM_WORKERS)[FLAGS.task_index] 63 | print('Local dataset size: {}'.format(train_images.shape[0])) 64 | 65 | # Normalize dataset 66 | train_images = train_images / 255.0 67 | test_images = test_images / 255.0 68 | 69 | IS_CHIEF = (FLAGS.task_index == 0) 70 | 71 | WORKER_DEVICE = "/job:worker/task:%d" % FLAGS.task_index 72 | 73 | # Device setter will place vars in the appropriate device 74 | with tf.device( 75 | tf.train.replica_device_setter( 76 | worker_device=WORKER_DEVICE, 77 | cluster=cluster)): 78 | global_step = tf.train.get_or_create_global_step() 79 | 80 | # Define the model 81 | model = keras.Sequential([ 82 | keras.layers.Flatten(input_shape=(28, 28)), 83 | keras.layers.Dense(128, activation=tf.nn.relu, name='relu'), 84 | keras.layers.Dense(10, activation=tf.nn.softmax, name='softmax') 85 | ]) 86 | 87 | # Get placeholder for the labels 88 | y = tf.placeholder(tf.float32, shape=[None], name='labels') 89 | 90 | # Store reference to the output of the model 91 | predictions = model.output 92 | 93 | with tf.name_scope('loss'): 94 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 95 | 96 | tf.summary.scalar('cross_entropy', loss) 97 | 98 | with tf.name_scope('train'): 99 | # Define the distributed optimizer 100 | optimizer = tf.train.SyncReplicasOptimizer(tf.train.AdamOptimizer(0.001), 101 | replicas_to_aggregate=NUM_WORKERS) 102 | train_op = optimizer.minimize(loss, global_step=global_step) 103 | # Define the hook which initializes the optimizer 104 | sync_replicas_hook = optimizer.make_session_run_hook(is_chief=IS_CHIEF) 105 | 106 | # ConfiProto for our session 107 | SESS_CONFIG = tf.ConfigProto( 108 | allow_soft_placement=True, 109 | log_device_placement=False, 110 | device_filters=["/job:ps", 111 | "/job:worker/task:%d" % FLAGS.task_index]) 112 | 113 | # We need to let the MonitoredSession initialize the variables 114 | keras.backend.manual_variable_initialization(True) 115 | # Define the training feed 116 | train_feed = {model.inputs[0]: train_images, y: train_labels} 117 | 118 | # Hook to log training progress 119 | class _LoggerHook(tf.train.SessionRunHook): 120 | def before_run(self, run_context): 121 | """ Run this in session before_run """ 122 | return tf.train.SessionRunArgs(global_step) 123 | 124 | def after_run(self, run_context, run_values): 125 | """ Run this in session after_run """ 126 | step = run_values.results 127 | if step % 100 == 0: 128 | print('Iter {}/{}'.format(step, FLAGS.train_steps)) 129 | 130 | with tf.train.MonitoredTrainingSession( 131 | master=server.target, 132 | is_chief=IS_CHIEF, 133 | checkpoint_dir='logs_dir/{}'.format(time()), 134 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.train_steps), 135 | _LoggerHook(), sync_replicas_hook], 136 | save_checkpoint_steps=100, 137 | config=SESS_CONFIG) as mon_sess: 138 | keras.backend.set_session(mon_sess) 139 | while not mon_sess.should_stop(): 140 | mon_sess.run(train_op, feed_dict=train_feed) 141 | -------------------------------------------------------------------------------- /federated-keras/keras_federated_classifier.py: -------------------------------------------------------------------------------- 1 | # Helper libraries 2 | import os 3 | import numpy as np 4 | from time import time 5 | import sys 6 | import federated_averaging_optimizer 7 | 8 | # TensorFlow and tf.keras 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | # Trick to import from parent directory 13 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 14 | 15 | flags = tf.app.flags 16 | flags.DEFINE_integer("task_index", None, 17 | "Worker task index, should be >= 0. task_index=0 is " 18 | "the master worker task the performs the variable " 19 | "initialization ") 20 | flags.DEFINE_integer("train_steps", 1000, 21 | "Number of (global) training steps to perform") 22 | flags.DEFINE_string("ps_hosts", "localhost:2222", 23 | "Comma-separated list of hostname:port pairs") 24 | flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", 25 | "Comma-separated list of hostname:port pairs") 26 | flags.DEFINE_string("job_name", None, "job name: worker or ps") 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | # Steps between averages 31 | INTERVAL_STEPS = 100 32 | 33 | # Disable GPU to avoid OOM issues (could enable it for just one of the workers) 34 | # Not necessary if workers are hosted in different machines 35 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 36 | 37 | if FLAGS.job_name is None or FLAGS.job_name == "": 38 | raise ValueError("Must specify an explicit `job_name`") 39 | if FLAGS.task_index is None or FLAGS.task_index == "": 40 | raise ValueError("Must specify an explicit `task_index`") 41 | print("job name = %s" % FLAGS.job_name) 42 | print("task index = %d" % FLAGS.task_index) 43 | 44 | #Construct the cluster and start the server 45 | ps_spec = FLAGS.ps_hosts.split(",") 46 | worker_spec = FLAGS.worker_hosts.split(",") 47 | 48 | # Get the number of workers. 49 | NUM_WORKERS = len(worker_spec) 50 | 51 | cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) 52 | 53 | server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) 54 | 55 | # The server will block here 56 | if FLAGS.job_name == "ps": 57 | server.join() 58 | 59 | fashion_mnist = keras.datasets.fashion_mnist 60 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 61 | 62 | CLASS_NAMES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 63 | 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 64 | 65 | # Split dataset between workers 66 | train_images = np.array_split(train_images, NUM_WORKERS)[FLAGS.task_index] 67 | train_labels = np.array_split(train_labels, NUM_WORKERS)[FLAGS.task_index] 68 | print('Local dataset size: {}'.format(train_images.shape[0])) 69 | 70 | # Normalize dataset 71 | train_images = train_images / 255.0 72 | test_images = test_images / 255.0 73 | 74 | IS_CHIEF = (FLAGS.task_index == 0) 75 | 76 | # We are not telling the MonitoredSession who is the chief so we need to 77 | # prevent non-chief workers from saving checkpoints or summaries. 78 | if IS_CHIEF: 79 | CHECKPOINT_DIR = 'logs_dir/{}'.format(time()) 80 | else: 81 | CHECKPOINT_DIR = None 82 | 83 | WORKER_DEVICE = "/job:worker/task:%d" % FLAGS.task_index 84 | 85 | # Place all ops in the local worker by default 86 | with tf.device(WORKER_DEVICE): 87 | global_step = tf.train.get_or_create_global_step() 88 | 89 | # Define the model 90 | model = keras.Sequential([ 91 | keras.layers.Flatten(input_shape=(28, 28)), 92 | keras.layers.Dense(128, activation=tf.nn.relu, name='relu'), 93 | keras.layers.Dense(10, activation=tf.nn.softmax, name='softmax') 94 | ]) 95 | 96 | # Get placeholder for the labels 97 | y = tf.placeholder(tf.float32, shape=[None], name='labels') 98 | 99 | # Store reference to the output of the model 100 | predictions = model.output 101 | 102 | with tf.name_scope('loss'): 103 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 104 | 105 | tf.summary.scalar('cross_entropy', loss) 106 | 107 | with tf.name_scope('train'): 108 | # Define a device setter which will place a global copy of trainable variables 109 | # in the parameter server. 110 | device_setter = tf.train.replica_device_setter(worker_device=WORKER_DEVICE, cluster=cluster) 111 | # Define our custom optimizer 112 | optimizer = federated_averaging_optimizer.FederatedAveragingOptimizer( 113 | tf.train.AdamOptimizer(0.001), 114 | replicas_to_aggregate=NUM_WORKERS, interval_steps=INTERVAL_STEPS, 115 | is_chief=IS_CHIEF, device_setter=device_setter) 116 | train_op = optimizer.minimize(loss, global_step=global_step) 117 | # Define the hook which initializes the optimizer 118 | federated_average_hook = optimizer.make_session_run_hook() 119 | 120 | # ConfiProto for our session 121 | SESS_CONFIG = tf.ConfigProto( 122 | allow_soft_placement=True, 123 | log_device_placement=False, 124 | device_filters=["/job:ps", 125 | "/job:worker/task:%d" % FLAGS.task_index]) 126 | 127 | # We need to let the MonitoredSession initialize the variables 128 | keras.backend.manual_variable_initialization(True) 129 | # Define the training feed 130 | train_feed = {model.inputs[0]: train_images, y: train_labels} 131 | 132 | # Hook to log training progress 133 | class _LoggerHook(tf.train.SessionRunHook): 134 | def before_run(self, run_context): 135 | """ Run this in session before_run """ 136 | return tf.train.SessionRunArgs(global_step) 137 | 138 | def after_run(self, run_context, run_values): 139 | """ Run this in session after_run """ 140 | step = run_values.results 141 | if step % 100 == 0: 142 | print('Iter {}/{}'.format(step, FLAGS.train_steps)) 143 | 144 | with tf.train.MonitoredTrainingSession( 145 | master=server.target, 146 | checkpoint_dir=CHECKPOINT_DIR, 147 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.train_steps), 148 | _LoggerHook(), federated_average_hook], 149 | save_checkpoint_steps=100, 150 | config=SESS_CONFIG) as mon_sess: 151 | keras.backend.set_session(mon_sess) 152 | while not mon_sess.should_stop(): 153 | mon_sess.run(train_op, feed_dict=train_feed) 154 | -------------------------------------------------------------------------------- /federated-sockets/FederatedHook.py: -------------------------------------------------------------------------------- 1 | """# Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ==============================================================================""" 17 | 18 | import socket 19 | import time 20 | import ssl 21 | import hmac 22 | import tensorflow as tf 23 | import numpy as np 24 | from config import SSL_CONF as SC 25 | from config import SEND_RECEIVE_CONF as SRC 26 | 27 | try: 28 | import cPickle as pickle 29 | except ImportError: 30 | import pickle 31 | 32 | 33 | class _FederatedHook(tf.train.SessionRunHook): 34 | """Provides a hook to implemente federated averaging with tensorflow. 35 | 36 | In a typical synchronous training environment, gradients will be averaged each 37 | step and then applied to the variables in one shot, after which replicas can 38 | fetch the new variables and continue. In a federated average training environment, 39 | model variables will be averaged every 'interval_steps' steps, and then the 40 | replicas will fetch the new variables and continue training locally. In the 41 | interval between two average operations, there is no data transfer, which can 42 | accelerate training. 43 | 44 | The hook has two different ways of working depending if it is the chief worker or not. 45 | 46 | The chief starts creating a socket that will act as server. Then it stays 47 | waiting _wait_time seconds and accepting connections of all those workers that 48 | want to join the training, and distributes a task index to each of them. 49 | This task index is not always neccesary. In our demos we use it to tell 50 | each worker which part of the dataset it has to use for the training and it 51 | could have other applications. 52 | 53 | Remember if you training is not going to be performed in a LAN you will 54 | need to do some port forwarding, we recommend you to have a look to this 55 | article we wrote about it: 56 | https://medium.com/comind/raspberry-pis-federated-learning-751b10fc92c9 57 | 58 | Once the training is going to start sends it's weights to the other workers, 59 | so that they all start with the same initial ones. 60 | After each batch is trained, it checks if _interval_steps has been completed, 61 | and if so, it gathers the weights of all the workers and its own, averages them 62 | and sends the average to all those workers that sended weights to it. 63 | 64 | Workers open a socket connection with the chief and wait to get their worker number. 65 | Once the training is going to start they wait for the chief to send them its weights. 66 | After each training round they check if _interval_steps has been completed, 67 | and if so, they send their weights to the chief and wait for it's response, 68 | the averaged weights with which they will continue training. 69 | """ 70 | 71 | def __init__(self, is_chief, private_ip, public_ip, wait_time=30, interval_steps=100): 72 | 73 | """Construcs a FederatedHook object 74 | Args: 75 | is_chief (bool): whether it is going to act as chief or not. 76 | private_ip (str): complete local ip in which the chief is going to 77 | serve its socket. Example: 172.134.65.123:7777 78 | public_ip (str): ip to which the workers are going to connect. 79 | interval_steps (int, optional): number of steps between two 80 | "average op", which specifies how frequent a model 81 | synchronization is performed. 82 | wait_time: how mucht time the chief should wait at the beginning 83 | for the workers to connect. 84 | """ 85 | self._is_chief = is_chief 86 | self._private_ip = private_ip.split(':')[0] 87 | self._private_port = int(private_ip.split(':')[1]) 88 | self._public_ip = public_ip.split(':')[0] 89 | self._public_port = int(public_ip.split(':')[1]) 90 | self._interval_steps = interval_steps 91 | self._wait_time = wait_time 92 | self._nex_task_index = 0 93 | # We get the number of connections that have been made, and which task_index 94 | # corresponds to this worker. 95 | self.task_index, self.num_workers = self._get_task_index() 96 | 97 | def _get_task_index(self): 98 | 99 | """Chief distributes task index number to workers that connect to it and 100 | lets them know how many workers are there in total. 101 | Returns: 102 | task_index (int): task index corresponding to this worker. 103 | num_workers (int): number of total workers. 104 | """ 105 | 106 | if self._is_chief: 107 | self._server_socket = self._start_socket_server() 108 | self._server_socket.settimeout(5) 109 | users = [] 110 | t_end = time.time() + self._wait_time 111 | 112 | while time.time() < t_end: 113 | try: 114 | sock, _ = self._server_socket.accept() 115 | connection_socket = ssl.wrap_socket( 116 | sock, 117 | server_side=True, 118 | certfile=SC.cert_path, 119 | keyfile=SC.key_path, 120 | ssl_version=ssl.PROTOCOL_TLSv1) 121 | if connection_socket not in users: 122 | users.append(connection_socket) 123 | except socket.timeout: 124 | pass 125 | 126 | num_workers = len(users) + 1 127 | _ = [us.send((str(i+1) + ':' + str(num_workers)).encode('utf-8')) \ 128 | for i, us in enumerate(users)] 129 | self._nex_task_index = len(users) + 1 130 | _ = [us.close() for us in users] 131 | 132 | self._server_socket.settimeout(120) 133 | return 0, num_workers 134 | 135 | client_socket = self._start_socket_worker() 136 | message = client_socket.recv(1024).decode('utf-8').split(':') 137 | client_socket.close() 138 | return int(message[0]), int(message[1]) 139 | 140 | def _create_placeholders(self): 141 | """Creates the placeholders that we will use to inject the weights into the graph""" 142 | for var in tf.trainable_variables(): 143 | self._placeholders.append(tf.placeholder_with_default(var, var.shape, 144 | name="%s/%s" % ("FedAvg", 145 | var.op.name))) 146 | 147 | def _assign_vars(self, local_vars): 148 | """Utility to refresh local variables. 149 | 150 | Args: 151 | local_vars: List of local variables. 152 | global_vars: List of global variables. 153 | 154 | Returns: 155 | refresh_ops: The ops to assign value of global vars to local vars. 156 | """ 157 | reassign_ops = [] 158 | for var, fvar in zip(local_vars, self._placeholders): 159 | reassign_ops.append(tf.assign(var, fvar)) 160 | return tf.group(*(reassign_ops)) 161 | 162 | @staticmethod 163 | def _receiving_subroutine(connection_socket): 164 | """Subroutine inside _get_np_array to recieve a list of numpy arrays. 165 | If the sending was not correctly recieved it sends back an error message 166 | to the sender in order to try it again. 167 | Args: 168 | connection_socket (socket): a socket with a connection already 169 | established. 170 | """ 171 | timeout = 0.5 172 | while True: 173 | ultimate_buffer = b'' 174 | connection_socket.settimeout(240) 175 | first_round = True 176 | while True: 177 | try: 178 | receiving_buffer = connection_socket.recv(SRC.buffer) 179 | except socket.timeout: 180 | break 181 | if first_round: 182 | connection_socket.settimeout(timeout) 183 | first_round = False 184 | if not receiving_buffer: 185 | break 186 | ultimate_buffer += receiving_buffer 187 | 188 | pos_signature = SRC.hashsize 189 | signature = ultimate_buffer[:pos_signature] 190 | message = ultimate_buffer[pos_signature:] 191 | good_signature = hmac.new(SRC.key, message, SRC.hashfunction).digest() 192 | 193 | if signature != good_signature: 194 | connection_socket.send(SRC.error) 195 | timeout += 0.5 196 | continue 197 | else: 198 | connection_socket.send(SRC.recv) 199 | connection_socket.settimeout(120) 200 | return message 201 | 202 | def _get_np_array(self, connection_socket): 203 | """Routine to recieve a list of numpy arrays. 204 | Args: 205 | connection_socket (socket): a socket with a connection already 206 | established. 207 | """ 208 | 209 | message = self._receiving_subroutine(connection_socket) 210 | final_image = pickle.loads(message) 211 | return final_image 212 | 213 | @staticmethod 214 | def _send_np_array(arrays_to_send, connection_socket): 215 | """Routine to send a list of numpy arrays. It sends it as many time as necessary 216 | Args: 217 | connection_socket (socket): a socket with a connection already 218 | established. 219 | """ 220 | serialized = pickle.dumps(arrays_to_send) 221 | signature = hmac.new(SRC.key, serialized, SRC.hashfunction).digest() 222 | assert len(signature) == SRC.hashsize 223 | message = signature + serialized 224 | connection_socket.settimeout(240) 225 | connection_socket.sendall(message) 226 | while True: 227 | check = connection_socket.recv(len(SRC.error)) 228 | if check == SRC.error: 229 | connection_socket.sendall(message) 230 | elif check == SRC.recv: 231 | connection_socket.settimeout(120) 232 | break 233 | 234 | def _start_socket_server(self): 235 | """Creates a socket with ssl protection that will act as server. 236 | Returns: 237 | sever_socket (socket): ssl secured socket that will act as server. 238 | """ 239 | server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 240 | server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 241 | context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 242 | context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 # optional 243 | context.set_ciphers('EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH') 244 | server_socket.bind((self._private_ip, self._private_port)) 245 | server_socket.listen() 246 | return server_socket 247 | 248 | def _start_socket_worker(self): 249 | """Creates a socket with ssl protection that will act as client. 250 | Returns: 251 | sever_socket (socket): ssl secured socket that will work as client. 252 | """ 253 | to_wrap_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 254 | context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 255 | context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 # optional 256 | 257 | client_socket = ssl.wrap_socket(to_wrap_socket) 258 | client_socket.connect((self._public_ip, self._public_port)) 259 | return client_socket 260 | 261 | def begin(self): 262 | """Session begin""" 263 | self._placeholders = [] 264 | self._create_placeholders() 265 | self._update_local_vars_op = self._assign_vars(tf.trainable_variables()) 266 | self._global_step = tf.get_collection(tf.GraphKeys.GLOBAL_STEP)[0] 267 | 268 | def after_create_session(self, session, coord): 269 | """ 270 | If chief: 271 | Once the training is going to start sends it's weights to the other 272 | workers, so that they all start with the same initial ones. 273 | Once it has send the weights to all the workers it sends them a 274 | signal to start training. 275 | Workers: 276 | Wait for the chief to send them its weights and inject them into 277 | the graph. 278 | """ 279 | if self._is_chief: 280 | users = [] 281 | addresses = [] 282 | while len(users) < (self.num_workers - 1): 283 | try: 284 | self._server_socket.settimeout(30) 285 | sock, address = self._server_socket.accept() 286 | connection_socket = ssl.wrap_socket( 287 | sock, 288 | server_side=True, 289 | certfile=SC.cert_path, 290 | keyfile=SC.key_path, 291 | ssl_version=ssl.PROTOCOL_TLSv1) 292 | 293 | print('Connected: ' + address[0] + ':' + str(address[1])) 294 | except socket.timeout: 295 | print('Some workers could not connect') 296 | break 297 | try: 298 | print('SENDING Worker: ' + address[0] + ':' + str(address[1])) 299 | self._send_np_array(session.run(tf.trainable_variables()), connection_socket) 300 | print('SENT Worker {}'.format(len(users))) 301 | users.append(connection_socket) 302 | addresses.append(address) 303 | except (ConnectionResetError, BrokenPipeError): 304 | print('Could not send to : ' 305 | + address[0] + ':' + str(address[1]) 306 | + ', fallen worker') 307 | connection_socket.close() 308 | for i, user in enumerate(users): 309 | try: 310 | user.send(SRC.signal) 311 | user.close() 312 | except (ConnectionResetError, BrokenPipeError): 313 | print('Fallen Worker: ' + addresses[i][0] + ':' + str(address[i][1])) 314 | self.num_workers -= 1 315 | try: 316 | user.close() 317 | except (ConnectionResetError, BrokenPipeError): 318 | pass 319 | else: 320 | print('Starting Initialization') 321 | client_socket = self._start_socket_worker() 322 | broadcasted_weights = self._get_np_array(client_socket) 323 | feed_dict = {} 324 | for placeh, brweigh in zip(self._placeholders, broadcasted_weights): 325 | feed_dict[placeh] = brweigh 326 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 327 | print('Initialization finished') 328 | client_socket.settimeout(120) 329 | client_socket.recv(len(SRC.signal)) 330 | client_socket.close() 331 | 332 | def before_run(self, run_context): 333 | """ Session before_run""" 334 | return tf.train.SessionRunArgs(self._global_step) 335 | 336 | def after_run(self, run_context, run_values): 337 | """ 338 | Both chief and workers, check if they should average their weights in 339 | this roud. Is this is the case: 340 | 341 | If chief: 342 | Tries to gather the weights of all the workers, but ignores those 343 | that lost connection at some point. 344 | It averages them and then send them back to the workers. 345 | Finally in injects the averaged weights to its own graph. 346 | Workers: 347 | Send their weights to the chief. 348 | Wait for the chief to send them the averaged weights and inject them into 349 | their graph. 350 | """ 351 | step_value = run_values.results 352 | session = run_context.session 353 | if step_value % self._interval_steps == 0 and not step_value == 0: 354 | if self._is_chief: 355 | self._server_socket.listen(self.num_workers - 1) 356 | gathered_weights = [session.run(tf.trainable_variables())] 357 | users = [] 358 | addresses = [] 359 | for i in range(self.num_workers - 1): 360 | try: 361 | self._server_socket.settimeout(30) 362 | sock, address = self._server_socket.accept() 363 | connection_socket = ssl.wrap_socket( 364 | sock, 365 | server_side=True, 366 | certfile=SC.cert_path, 367 | keyfile=SC.key_path, 368 | ssl_version=ssl.PROTOCOL_TLSv1) 369 | 370 | print('Connected: ' + address[0] + ':' + str(address[1])) 371 | except socket.timeout: 372 | print('Some workers could not connect') 373 | break 374 | try: 375 | recieved = self._get_np_array(connection_socket) 376 | gathered_weights.append(recieved) 377 | users.append(connection_socket) 378 | addresses.append(address) 379 | print('Received from ' + address[0] + ':' + str(address[1])) 380 | except (ConnectionResetError, BrokenPipeError): 381 | print('Could not recieve from : ' 382 | + address[0] + ':' + str(address[1]) 383 | + ', fallen worker') 384 | connection_socket.close() 385 | 386 | self.num_workers = len(users) + 1 387 | 388 | print('Average applied ' 389 | + 'with {} workers, iter: {}'.format(self.num_workers, step_value)) 390 | rearranged_weights = [] 391 | 392 | #In gathered_weights, each list represents the weights of each worker. 393 | #We want to gahter in each list the weights of a single layer so 394 | #to average them afterwards 395 | for i in range(len(gathered_weights[0])): 396 | rearranged_weights.append([elem[i] for elem in gathered_weights]) 397 | for i, elem in enumerate(rearranged_weights): 398 | rearranged_weights[i] = np.mean(elem, axis=0) 399 | 400 | for i, user in enumerate(users): 401 | try: 402 | self._send_np_array(rearranged_weights, user) 403 | user.close() 404 | except (ConnectionResetError, BrokenPipeError): 405 | print('Fallen Worker: ' + addresses[i][0] + ':' + str(address[i][1])) 406 | self.num_workers -= 1 407 | try: 408 | user.close() 409 | except socket.timeout: 410 | pass 411 | 412 | feed_dict = {} 413 | for placeh, reweigh in zip(self._placeholders, rearranged_weights): 414 | feed_dict[placeh] = reweigh 415 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 416 | 417 | else: 418 | worker_socket = self._start_socket_worker() 419 | print('Sending weights') 420 | value = session.run(tf.trainable_variables()) 421 | self._send_np_array(value, worker_socket) 422 | 423 | broadcasted_weights = self._get_np_array(worker_socket) 424 | feed_dict = {} 425 | for placeh, brweigh in zip(self._placeholders, broadcasted_weights): 426 | feed_dict[placeh] = brweigh 427 | session.run(self._update_local_vars_op, feed_dict=feed_dict) 428 | print('Weights succesfully updated, iter: {}'.format(step_value)) 429 | worker_socket.close() 430 | 431 | def end(self, session): 432 | """ Session end """ 433 | if self._is_chief: 434 | self._server_socket.close() 435 | -------------------------------------------------------------------------------- /federated-sockets/README.md: -------------------------------------------------------------------------------- 1 | # Implementation with custom hook 2 | 3 | This is the implementation of Federated Averaging using our custom hook. If you wish to use this same implementation with your own code just import the FederatedHook, set the config file and launch! 4 | 5 | ## Usage 6 | 7 | First of all set the config file: 8 | 9 | > `SEND_RECEIVE_CONF.key = Shared key to sign messages and guarantee integrity` (a Bytearray, you can leave it as is) 10 | 11 | Generate a private key and a certificate with: `openssl req -new -x509 -days 365 -nodes -out server.pem -keyout server.key` 12 | 13 | > `SSL_CONF.key_path = Path to your private key` 14 | 15 | > `SSL_CONF.cert_path = Path to your certificate` 16 | 17 | Next set the IP's in the main code to your own. No need to change this if you are using localhost. 18 | 19 | Finally, set the `WAIT_TIME`, the chief worker will wait for new workers during this amount of seconds before the training. 20 | 21 | And launch the shells, as many as you want: 22 | 23 | * 1st shell: `python3 basic_socket_fed_classifier.py --is_chief=True` 24 | 25 | * Next shells: `python3 basic_socket_fed_classifier.py` 26 | 27 | ## Troubleshooting and Help 28 | 29 | coMind has public Slack and Telegram channels which are a great place to ask questions and all things related to federated machine learning. 30 | 31 | ## Useful resources 32 | 33 | Check the [medium post](https://medium.com/comind/raspberry-pis-federated-learning-751b10fc92c9) to learn about port forwarding and how to set-up your chief host. 34 | 35 | Check [this script](https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py) to see how to generate CIFAR-10 TFRecords. 36 | 37 | ## About 38 | 39 | coMind is an open source project for training privacy-preserving federated deep learning models. 40 | 41 | * https://comind.org/ 42 | * [Twitter](https://twitter.com/coMindOrg) 43 | -------------------------------------------------------------------------------- /federated-sockets/advanced_socket_fed_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow 19 | import tensorflow as tf 20 | 21 | # Helper libraries 22 | import numpy as np 23 | from time import time 24 | import multiprocessing 25 | 26 | # Custom federated hook 27 | from FederatedHook import _FederatedHook 28 | 29 | flags = tf.app.flags 30 | 31 | flags.DEFINE_boolean("is_chief", False, "True if this worker is chief") 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | # You can safely tune these variables 36 | BATCH_SIZE = 128 37 | SHUFFLE_SIZE = BATCH_SIZE * 100 38 | EPOCHS = 250 39 | EPOCHS_PER_DECAY = 50 40 | INTERVAL_STEPS = 100 # Steps between averages 41 | WAIT_TIME = 30 # How many seconds to wait for new workers to connect 42 | BATCHES_TO_PREFETCH = 1 43 | # ----------------- 44 | 45 | # Set these IPs to your own, can leave as localhost for local testing 46 | CHIEF_PUBLIC_IP = 'localhost:7777' # Public IP of the chief worker 47 | CHIEF_PRIVATE_IP = 'localhost:7777' # Private IP of the chief worker 48 | 49 | # Create the custom hook 50 | federated_hook = _FederatedHook(FLAGS.is_chief, CHIEF_PRIVATE_IP, CHIEF_PUBLIC_IP, WAIT_TIME, INTERVAL_STEPS) 51 | 52 | # Dataset dependent constants 53 | NUM_TRAIN_IMAGES = int(50000 / federated_hook.num_workers) 54 | NUM_TEST_IMAGES = 10000 55 | HEIGHT = 32 56 | WIDTH = 32 57 | CHANNELS = 3 58 | NUM_BATCH_FILES = 5 59 | 60 | # Path to TFRecord files (check readme for instructions on how to get these files) 61 | cifar10_train_files = ['cifar-10-tf-records/train{}.tfrecords'.format(i) for i in range(NUM_BATCH_FILES)] 62 | cifar10_test_file = 'cifar-10-tf-records/test.tfrecords' 63 | 64 | # Shuffle filenames before loading them 65 | np.random.shuffle(cifar10_train_files) 66 | 67 | CHECKPOINT_DIR = 'logs_dir/{}'.format(time()) 68 | print('Checkpoint directory: ' + CHECKPOINT_DIR) 69 | 70 | global_step = tf.train.get_or_create_global_step() 71 | 72 | # Check number of available CPUs 73 | CPU_COUNT = int(multiprocessing.cpu_count() / federated_hook.num_workers) 74 | 75 | # Define input pipeline, place these ops in the cpu 76 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 77 | # Map function to decode data and preprocess it 78 | def preprocess(serialized_examples): 79 | # Parse a batch 80 | features = tf.parse_example(serialized_examples, {'image': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) 81 | # Decode and reshape imag 82 | image = tf.map_fn(lambda img: tf.reshape(tf.decode_raw(img, tf.uint8), tf.stack([HEIGHT, WIDTH, CHANNELS])), features['image'], dtype=tf.uint8, name='decode') 83 | # Cast image 84 | casted_image = tf.cast(image, tf.float32, name='input_cast') 85 | # Resize image for testing 86 | resized_image = tf.image.resize_image_with_crop_or_pad(casted_image, 24, 24) 87 | # Augment images for training 88 | distorted_image = tf.map_fn(lambda img: tf.random_crop(img, [24, 24, 3]), 89 | casted_image, name='random_crop') 90 | distorted_image = tf.image.random_flip_left_right(distorted_image) 91 | distorted_image = tf.image.random_brightness(distorted_image, 63) 92 | distorted_image = tf.image.random_contrast(distorted_image, 0.2, 1.8) 93 | # Check if test or train mode 94 | result = tf.cond(train_mode, lambda: distorted_image, lambda: resized_image) 95 | # Standardize images 96 | processed_image = tf.map_fn(lambda img: tf.image.per_image_standardization(img), 97 | result, name='standardization') 98 | return processed_image, features['label'] 99 | # Placeholders for the iterator 100 | filename_placeholder = tf.placeholder(tf.string, name='input_filename') 101 | batch_size = tf.placeholder(tf.int64, name='batch_size') 102 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 103 | train_mode = tf.placeholder(tf.bool, name='train_mode') 104 | 105 | # Create dataset, shuffle, repeat, batch, map and prefetch 106 | dataset = tf.data.TFRecordDataset(filename_placeholder) 107 | dataset = dataset.shard(federated_hook.num_workers, federated_hook.task_index) 108 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 109 | dataset = dataset.repeat(EPOCHS) 110 | dataset = dataset.batch(batch_size) 111 | dataset = dataset.map(preprocess, CPU_COUNT) 112 | dataset = dataset.prefetch(BATCHES_TO_PREFETCH) 113 | # Define a feedable iterator and the initialization op 114 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 115 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 116 | X, y = iterator.get_next() 117 | 118 | # Define our model 119 | first_conv = tf.layers.conv2d(X, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='first_conv') 120 | 121 | first_pool = tf.nn.max_pool(first_conv, [1, 3, 3 ,1], [1, 2, 2, 1], padding='SAME', name='first_pool') 122 | 123 | first_norm = tf.nn.lrn(first_pool, 4, alpha=0.001 / 9.0, beta=0.75, name='first_norm') 124 | 125 | second_conv = tf.layers.conv2d(first_norm, 64, 5, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=5e-2), name='second_conv') 126 | 127 | second_norm = tf.nn.lrn(second_conv, 4, alpha=0.001 / 9.0, beta=0.75, name='second_norm') 128 | 129 | second_pool = tf.nn.max_pool(second_norm, [1, 3, 3, 1], [1, 2, 2, 1], padding='SAME', name='second_pool') 130 | 131 | flatten_layer = tf.layers.flatten(second_pool, name='flatten') 132 | 133 | first_relu = tf.layers.dense(flatten_layer, 384, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='first_relu') 134 | 135 | second_relu = tf.layers.dense(first_relu, 192, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.04), name='second_relu') 136 | 137 | logits = tf.layers.dense(second_relu, 10, kernel_initializer=tf.truncated_normal_initializer(stddev=1/192.0), name='logits') 138 | 139 | # Object to keep moving averages of our metrics (for tensorboard) 140 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 141 | 142 | # Define cross_entropy loss 143 | with tf.name_scope('loss'): 144 | base_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits), name='base_loss') 145 | # Add regularization loss to both relu layers 146 | regularizer_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'relu/kernel' in v.name], name='regularizer_loss') * 0.004 147 | loss = tf.add(base_loss, regularizer_loss) 148 | loss_averages_op = summary_averages.apply([loss]) 149 | # Store moving average of the loss 150 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 151 | 152 | with tf.name_scope('accuracy'): 153 | with tf.name_scope('correct_prediction'): 154 | # Compare prediction with actual label 155 | correct_prediction = tf.equal(tf.argmax(logits, 1), y) 156 | # Average correct predictions in the current batch 157 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy_metric') 158 | accuracy_averages_op = summary_averages.apply([accuracy]) 159 | # Store moving average of the accuracy 160 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 161 | 162 | N_BATCHES = int(NUM_TRAIN_IMAGES / BATCH_SIZE) 163 | LAST_STEP = int(N_BATCHES * EPOCHS) 164 | 165 | # Define moving averages of the trainable variables. This sometimes improve 166 | # the performance of the trained mode 167 | with tf.name_scope('variable_averages'): 168 | variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) 169 | variable_averages_op = variable_averages.apply(tf.trainable_variables()) 170 | 171 | # Define optimizer and training op 172 | with tf.name_scope('train'): 173 | # Make decaying learning rate 174 | lr = tf.train.exponential_decay(0.1, global_step, N_BATCHES * EPOCHS_PER_DECAY, 0.1, staircase=True) 175 | tf.summary.scalar('learning_rate', lr) 176 | # Make train_op dependent on moving averages ops. Otherwise they will be 177 | # disconnected from the graph 178 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op, variable_averages_op]): 179 | train_op = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step) 180 | 181 | print('Graph definition finished') 182 | sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 183 | 184 | print('Training {} batches...'.format(LAST_STEP)) 185 | 186 | # Logger hook to keep track of the training 187 | class _LoggerHook(tf.train.SessionRunHook): 188 | def begin(self): 189 | """ Run this in session begin """ 190 | self._total_loss = 0 191 | self._total_acc = 0 192 | 193 | def before_run(self, run_context): 194 | """ Run this in session before_run """ 195 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 196 | 197 | def after_run(self, run_context, run_values): 198 | """ Run this in session after_run """ 199 | loss_value, acc_value, step_value = run_values.results 200 | self._total_loss += loss_value 201 | self._total_acc += acc_value 202 | if (step_value + 1) % N_BATCHES == 0: 203 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format( 204 | int(step_value / N_BATCHES) + 1, 205 | EPOCHS, self._total_loss / N_BATCHES, self._total_acc / N_BATCHES)) 206 | self._total_loss = 0 207 | self._total_acc = 0 208 | 209 | # Hook to initialize the dataset 210 | class _InitHook(tf.train.SessionRunHook): 211 | def after_create_session(self, session, coord): 212 | """ Run this after creating session """ 213 | session.run(dataset_init_op, feed_dict={ 214 | filename_placeholder: cifar10_train_files, 215 | batch_size: BATCH_SIZE, shuffle_size: SHUFFLE_SIZE, train_mode: True}) 216 | 217 | with tf.name_scope('monitored_session'): 218 | with tf.train.MonitoredTrainingSession( 219 | checkpoint_dir=CHECKPOINT_DIR, 220 | hooks=[_LoggerHook(), _InitHook(), federated_hook, 221 | tf.train.CheckpointSaverHook(checkpoint_dir=CHECKPOINT_DIR, 222 | save_steps=N_BATCHES, 223 | saver=tf.train.Saver( 224 | variable_averages.variables_to_restore()))], 225 | config=sess_config, 226 | save_checkpoint_secs=None) as mon_sess: 227 | while not mon_sess.should_stop(): 228 | mon_sess.run(train_op) 229 | 230 | print('--- Begin Evaluation ---') 231 | # Reset graph to clear any ops stored in other devices 232 | tf.reset_default_graph() 233 | with tf.Session() as sess: 234 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) 235 | saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True) 236 | saver.restore(sess, ckpt.model_checkpoint_path) 237 | print('Model restored') 238 | graph = tf.get_default_graph() 239 | filename_placeholder = graph.get_tensor_by_name('dataset/input_filename:0') 240 | batch_size = graph.get_tensor_by_name('dataset/batch_size:0') 241 | shuffle_size = graph.get_tensor_by_name('dataset/shuffle_size:0') 242 | train_mode = graph.get_tensor_by_name('dataset/train_mode:0') 243 | accuracy = graph.get_tensor_by_name('accuracy/accuracy_metric:0') 244 | dataset_init_op = graph.get_operation_by_name('dataset/dataset_init') 245 | sess.run(dataset_init_op, feed_dict={filename_placeholder: cifar10_test_file, batch_size: NUM_TEST_IMAGES, shuffle_size: 1, train_mode: False}) 246 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 247 | -------------------------------------------------------------------------------- /federated-sockets/basic_socket_fed_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 coMind. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # https://comind.org/ 16 | # ============================================================================== 17 | 18 | # TensorFlow and tf.keras 19 | import tensorflow as tf 20 | from tensorflow import keras 21 | 22 | # Custom federated hook 23 | from FederatedHook import _FederatedHook 24 | 25 | # Helper libraries 26 | import os 27 | import numpy as np 28 | from time import time 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_boolean("is_chief", False, "True if this worker is chief") 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 37 | 38 | # You can safely tune these variables 39 | BATCH_SIZE = 32 40 | EPOCHS = 5 41 | INTERVAL_STEPS = 100 # Steps between averages 42 | WAIT_TIME = 30 # How many seconds to wait for new workers to connect 43 | # ----------------- 44 | 45 | # Set these IPs to your own, can leave as localhost for local testing 46 | CHIEF_PUBLIC_IP = 'localhost:7777' # Public IP of the chief worker 47 | CHIEF_PRIVATE_IP = 'localhost:7777' # Private IP of the chief worker 48 | 49 | # Create the custom hook 50 | federated_hook = _FederatedHook(FLAGS.is_chief, CHIEF_PRIVATE_IP, CHIEF_PUBLIC_IP, WAIT_TIME, INTERVAL_STEPS) 51 | 52 | # Load dataset as numpy arrays 53 | fashion_mnist = keras.datasets.fashion_mnist 54 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 55 | 56 | # Split dataset 57 | train_images = np.array_split(train_images, federated_hook.num_workers)[federated_hook.task_index] 58 | train_labels = np.array_split(train_labels, federated_hook.num_workers)[federated_hook.task_index] 59 | 60 | # You can safely tune this variable 61 | SHUFFLE_SIZE = train_images.shape[0] 62 | # ----------------- 63 | 64 | print('Local dataset size: {}'.format(train_images.shape[0])) 65 | 66 | # Normalize dataset 67 | train_images = train_images / 255.0 68 | test_images = test_images / 255.0 69 | 70 | CHECKPOINT_DIR = 'logs_dir/{}'.format(time()) 71 | 72 | global_step = tf.train.get_or_create_global_step() 73 | 74 | # Define input pipeline, place these ops in the cpu 75 | with tf.name_scope('dataset'), tf.device('/cpu:0'): 76 | # Placeholders for the iterator 77 | images_placeholder = tf.placeholder(train_images.dtype, [None, train_images.shape[1], train_images.shape[2]]) 78 | labels_placeholder = tf.placeholder(train_labels.dtype, [None]) 79 | batch_size = tf.placeholder(tf.int64) 80 | shuffle_size = tf.placeholder(tf.int64, name='shuffle_size') 81 | 82 | # Create dataset, shuffle, repeat and batch 83 | dataset = tf.data.Dataset.from_tensor_slices((images_placeholder, labels_placeholder)) 84 | dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True) 85 | dataset = dataset.repeat(EPOCHS) 86 | dataset = dataset.batch(batch_size) 87 | iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 88 | dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') 89 | X, y = iterator.get_next() 90 | 91 | # Define our model 92 | flatten_layer = tf.layers.flatten(X, name='flatten') 93 | 94 | dense_layer = tf.layers.dense(flatten_layer, 128, activation=tf.nn.relu, name='relu') 95 | 96 | predictions = tf.layers.dense(dense_layer, 10, activation=tf.nn.softmax, name='softmax') 97 | 98 | # Object to keep moving averages of our metrics (for tensorboard) 99 | summary_averages = tf.train.ExponentialMovingAverage(0.9) 100 | 101 | # Define cross_entropy loss 102 | with tf.name_scope('loss'): 103 | loss = tf.reduce_mean(keras.losses.sparse_categorical_crossentropy(y, predictions)) 104 | loss_averages_op = summary_averages.apply([loss]) 105 | # Store moving average of the loss 106 | tf.summary.scalar('cross_entropy', summary_averages.average(loss)) 107 | 108 | with tf.name_scope('accuracy'): 109 | with tf.name_scope('correct_prediction'): 110 | # Compare prediction with actual label 111 | correct_prediction = tf.equal(tf.argmax(predictions, 1), tf.cast(y, tf.int64)) 112 | # Average correct predictions in the current batch 113 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 114 | accuracy_averages_op = summary_averages.apply([accuracy]) 115 | # Store moving average of the accuracy 116 | tf.summary.scalar('accuracy', summary_averages.average(accuracy)) 117 | 118 | # Define optimizer and training op 119 | with tf.name_scope('train'): 120 | # Make train_op dependent on moving averages ops. Otherwise they will be 121 | # disconnected from the graph 122 | with tf.control_dependencies([loss_averages_op, accuracy_averages_op]): 123 | train_op = tf.train.AdamOptimizer(0.001).minimize(loss, global_step=global_step) 124 | 125 | SESS_CONFIG = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 126 | 127 | N_BATCHES = int(train_images.shape[0] / BATCH_SIZE) 128 | LAST_STEP = int(N_BATCHES * EPOCHS) 129 | 130 | # Logger hook to keep track of the training 131 | class _LoggerHook(tf.train.SessionRunHook): 132 | def begin(self): 133 | """ Run this in session begin """ 134 | self._total_loss = 0 135 | self._total_acc = 0 136 | 137 | def before_run(self, run_context): 138 | """ Run this in session before_run """ 139 | return tf.train.SessionRunArgs([loss, accuracy, global_step]) 140 | 141 | def after_run(self, run_context, run_values): 142 | """ Run this in session after_run """ 143 | loss_value, acc_value, step_value = run_values.results 144 | self._total_loss += loss_value 145 | self._total_acc += acc_value 146 | if (step_value + 1) % N_BATCHES == 0: 147 | print("Epoch {}/{} - loss: {:.4f} - acc: {:.4f}".format( 148 | int(step_value / N_BATCHES) + 1, 149 | EPOCHS, self._total_loss / N_BATCHES, 150 | self._total_acc / N_BATCHES)) 151 | self._total_loss = 0 152 | self._total_acc = 0 153 | 154 | class _InitHook(tf.train.SessionRunHook): 155 | """ Hook to initialize the dataset """ 156 | def after_create_session(self, session, coord): 157 | """ Run this after creating session """ 158 | session.run(dataset_init_op, feed_dict={ 159 | images_placeholder: train_images, 160 | labels_placeholder: train_labels, 161 | shuffle_size: SHUFFLE_SIZE, batch_size: BATCH_SIZE}) 162 | 163 | print("Worker {} ready".format(federated_hook.task_index)) 164 | 165 | with tf.name_scope('monitored_session'): 166 | with tf.train.MonitoredTrainingSession( 167 | checkpoint_dir=CHECKPOINT_DIR, 168 | hooks=[_LoggerHook(), _InitHook(), federated_hook], 169 | config=SESS_CONFIG, 170 | save_checkpoint_steps=N_BATCHES) as mon_sess: 171 | while not mon_sess.should_stop(): 172 | mon_sess.run(train_op) 173 | 174 | print('--- Begin Evaluation ---') 175 | with tf.Session() as sess: 176 | ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) 177 | tf.train.Saver().restore(sess, ckpt.model_checkpoint_path) 178 | print('Model restored') 179 | sess.run(dataset_init_op, feed_dict={ 180 | images_placeholder: test_images, 181 | labels_placeholder: test_labels, 182 | shuffle_size: 1, batch_size: test_images.shape[0]}) 183 | print('Test accuracy: {:4f}'.format(sess.run(accuracy))) 184 | -------------------------------------------------------------------------------- /federated-sockets/config.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | SEND_RECEIVE_CONF = lambda x: x 4 | SEND_RECEIVE_CONF.key = b'4C5jwen4wpNEjBeq1YmdBayIQ1oD' 5 | SEND_RECEIVE_CONF.hashfunction = hashlib.sha1 6 | SEND_RECEIVE_CONF.hashsize = int(160 / 8) 7 | SEND_RECEIVE_CONF.error = b'error' 8 | SEND_RECEIVE_CONF.recv = b'reciv' 9 | SEND_RECEIVE_CONF.signal = b'go!go!go!' 10 | SEND_RECEIVE_CONF.buffer = 8192*2 11 | 12 | SSL_CONF = lambda x: x 13 | SSL_CONF.key_path = 'server.key' 14 | SSL_CONF.cert_path = 'server.pem' 15 | -------------------------------------------------------------------------------- /federated_averaging_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Modifications copyright (C) 2018 coMind. 16 | # ============================================================================== 17 | 18 | """Synchronize replicas for FedAvg training.""" 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from tensorflow.python.framework import constant_op 24 | from tensorflow.python.framework import dtypes 25 | from tensorflow.python.framework import ops 26 | from tensorflow.python.ops import array_ops 27 | from tensorflow.python.ops import control_flow_ops 28 | from tensorflow.python.ops import data_flow_ops 29 | from tensorflow.python.ops import math_ops 30 | from tensorflow.python.ops import state_ops 31 | from tensorflow.python.ops import variables 32 | from tensorflow.python.ops import variable_scope 33 | from tensorflow.python.platform import tf_logging as logging 34 | from tensorflow.python.training import optimizer 35 | from tensorflow.python.training import session_run_hook 36 | 37 | # Please note that the parameters from replicas are averaged so you need to 38 | # increase the learning rate according to the number of replicas. This change is 39 | # introduced to be consistent with how parameters are aggregated within a batch 40 | class FederatedAveragingOptimizer(optimizer.Optimizer): 41 | """Class to synchronize and aggregate model params. 42 | 43 | In a typical synchronous training environment, gradients will be averaged each 44 | step and then applied to the variables in one shot, after which replicas can 45 | fetch the new variables and continue. In a federated average training environment, 46 | model variables will be averaged every 'interval_steps' steps, and then the 47 | replicas will fetch the new variables and continue training locally. In the 48 | interval between two average operations, there is no data transfer, which can 49 | accelerate training. 50 | 51 | The following accumulators/queue are created: 52 | 53 | * N `parameter accumulators`, one per variable to train. Local variables are 54 | pushed to them and the chief worker will wait until enough variables are 55 | collected and then average them. The accumulator will drop all stale variables 56 | (more details in the accumulator op). 57 | * 1 `token` queue where the optimizer pushes the new global_step value after 58 | all variables are updated. 59 | 60 | The following local variable is created: 61 | * `global_step`, one per replica. Updated after every average operation. 62 | 63 | The optimizer adds nodes to the graph to collect local variables and pause 64 | the trainers until variables are updated. 65 | For the Parameter Server job: 66 | 67 | 1. An accumulator is created for each variable, and each replica pushes the 68 | local variables into the accumulators. 69 | 2. Each accumulator averages once enough variables (replicas_to_aggregate) 70 | have been accumulated. 71 | 3. Apply the averaged variables to global variables. 72 | 4. Only after all variables have been updated, increment the global step. 73 | 5. Only after step 4, pushes a token in the `token_queue`, once for 74 | each worker replica. The workers can now fetch the token and start 75 | the next round. 76 | 77 | For the replicas: 78 | 79 | 1. Start a training block: fetch variables and train for "interval_steps" steps. 80 | 2. Once the training block has been computed, push local variables into variable 81 | accumulators. Each accumulator will check the staleness and drop the stale. 82 | 3. After pushing all the variables, dequeue a token from the token queue and 83 | continue training. Note that this is effectively a barrier. 84 | 4. Fetch new variables and start the next block. 85 | 86 | ### Usage 87 | 88 | ```python 89 | # Create any optimizer to update the variables, say a simple SGD: 90 | opt = GradientDescentOptimizer(learning_rate=0.1) 91 | 92 | # Wrap the optimizer with fed_avg_optimizer with 50 replicas: at each 93 | # step the FederatedAveragingOptimizer collects "replicas_to_aggregate" variables 94 | # before applying the average. Note that if you want to have 2 backup replicas, 95 | # you can change total_num_replicas=52 and make sure this number matches how 96 | # many physical replicas you started in your job. 97 | opt = fed_avg_optimizer.FederatedAveragingOptimizer(opt, 98 | replicas_to_aggregate=50, 99 | is_chief=True, 100 | interval_steps=100, 101 | device_setter) 102 | 103 | # Some models have startup_delays to help stabilize the model but when using 104 | # federated_average training, set it to 0. 105 | 106 | # Now you can call 'minimize() normally' 107 | # train_op = opt.minimize(loss, global_step=global_step) 108 | 109 | # And also, create the hook which handles initialization. 110 | fed_avg_hook = opt.make_session_run_hook() 111 | ``` 112 | 113 | In the training program, every worker will run the train_op as if not 114 | averaged or synchronized. Note that if you want to run other ops like 115 | test op, you should use common session instead of MonitoredSession: 116 | 117 | ```python 118 | with training.MonitoredTrainingSession( 119 | master=workers[worker_id].target, 120 | hooks=[fed_avg_hook]) as mon_sess: 121 | while not mon_sess.should_stop(): 122 | mon_sess.run(training_op) 123 | sess = mon_sess._tf_sess() 124 | sess.run(testing_op) 125 | ``` 126 | """ 127 | 128 | def __init__(self, 129 | opt, 130 | replicas_to_aggregate, 131 | interval_steps, 132 | is_chief=False, 133 | total_num_replicas=None, 134 | device_setter=None, 135 | use_locking=False, 136 | name="fedAverage"): 137 | """Construct a fedAverage optimizer. 138 | 139 | Args: 140 | opt: The actual optimizer that will be used to compute and apply the 141 | gradients. Must be one of the Optimizer classes. 142 | replicas_to_aggregate: number of replicas to aggregate for each variable 143 | update. 144 | interval_steps: number of steps between two "average op", which specifies 145 | how frequent a model synchronization is performed. 146 | is_chief: whether this worker is chief or not. 147 | total_num_replicas: Total number of tasks/workers/replicas, could be 148 | If total_num_replicas > replicas_to_aggregate: it is backup_replicas + 149 | replicas_to_aggregate. 150 | If total_num_replicas < replicas_to_aggregate: Replicas compute 151 | multiple blocks per update to variables. 152 | device_setter: A replica_device_setter that will be used to place copies 153 | of the trainable variables in the parameter server. 154 | use_locking: If True use locks for update operations. 155 | name: string. Name of the global variables and related operation on ps. 156 | """ 157 | if total_num_replicas is None: 158 | total_num_replicas = replicas_to_aggregate 159 | 160 | super(FederatedAveragingOptimizer, self).__init__(use_locking, name) 161 | logging.info( 162 | "FedAvgV4: replicas_to_aggregate=%s; total_num_replicas=%s", 163 | replicas_to_aggregate, total_num_replicas) 164 | self._opt = opt 165 | self._replicas_to_aggregate = replicas_to_aggregate 166 | self._interval_steps = interval_steps 167 | self._is_chief = is_chief 168 | self._total_num_replicas = total_num_replicas 169 | self._tokens_per_step = max(total_num_replicas, replicas_to_aggregate) - 1 170 | self._device_setter = device_setter 171 | self._name = name 172 | 173 | # Remember which accumulator is on which device to set the initial step in 174 | # the accumulator to be global step. This list contains list of the 175 | # following format: (accumulator, device). 176 | self._accumulator_list = [] 177 | 178 | def _generate_shared_variables(self): 179 | """Generate a global variable placed on ps for each trainable variable. 180 | 181 | This creates a new copy of each user-defined trainable variable and places 182 | them on ps_device. These variables store the averaged parameters. 183 | """ 184 | # Only the chief should initialize the variables 185 | if self._is_chief: 186 | collections = [ops.GraphKeys.GLOBAL_VARIABLES, "global_model"] 187 | else: 188 | collections = ["global_model"] 189 | 190 | # Generate new global variables dependent on trainable variables. 191 | with ops.device(self._device_setter): 192 | for v in variables.trainable_variables(): 193 | _ = variable_scope.variable( 194 | name="%s/%s" % (self._name, v.op.name), 195 | initial_value=v.initialized_value(), trainable=False, 196 | collections=collections) 197 | 198 | # Place the global step in the ps so that all the workers can see it 199 | self._global_step = variables.Variable(0, name="%s_global_step" % 200 | self._name, trainable=False) 201 | 202 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 203 | """Apply gradients to variables. 204 | This contains most of the synchronization implementation. 205 | 206 | Args: 207 | grads_and_vars: List of (local_vars, gradients) pairs. 208 | global_step: Variable to increment by one after the variables have been 209 | updated. We need it to check staleness. 210 | name: Optional name for the returned operation. Default to the 211 | name passed to the Optimizer constructor. 212 | 213 | Returns: 214 | train_op: The op to dequeue a token so the replicas can exit this batch 215 | and apply averages to local vars or an op to update vars locally. 216 | 217 | Raises: 218 | ValueError: If the grads_and_vars is empty. 219 | ValueError: If global step is not provided, the staleness cannot be 220 | checked. 221 | """ 222 | if not grads_and_vars: 223 | raise ValueError("Must supply at least one variable") 224 | if global_step is None: 225 | raise ValueError("Global step is required") 226 | 227 | # Generate copy of all trainable variables 228 | self._generate_shared_variables() 229 | 230 | # Wraps the apply_gradients op of the parent optimizer 231 | apply_updates = self._opt.apply_gradients(grads_and_vars, global_step) 232 | 233 | # This function will be called whenever the global_step divides interval steps 234 | def _apply_averages(): # pylint: disable=missing-docstring 235 | # Collect local and global vars 236 | local_vars = [v for g, v in grads_and_vars if g is not None] 237 | global_vars = ops.get_collection_ref("global_model") 238 | # sync queue, place it in the ps 239 | with ops.colocate_with(self._global_step): 240 | sync_queue = data_flow_ops.FIFOQueue( 241 | -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue") 242 | train_ops = [] 243 | aggregated_vars = [] 244 | with ops.name_scope(None, self._name + "/global"): 245 | for var, gvar in zip(local_vars, global_vars): 246 | # pylint: disable=protected-access 247 | # Get reference to the tensor, this works with Variable and ResourceVariable 248 | var = ops.convert_to_tensor(var) 249 | # Place the accumulator in the same ps as the corresponding global_var 250 | with ops.device(gvar.device): 251 | var_accum = data_flow_ops.ConditionalAccumulator( 252 | var.dtype, 253 | shape=var.get_shape(), 254 | shared_name=gvar.name + "/var_accum") 255 | # Add op to push local_var to accumulator 256 | train_ops.append( 257 | var_accum.apply_grad(var, local_step=global_step)) 258 | # Op to average the vars in the accumulator 259 | aggregated_vars.append(var_accum.take_grad(self._replicas_to_aggregate)) 260 | # Remember accumulator and corresponding device 261 | self._accumulator_list.append((var_accum, gvar.device)) 262 | # chief worker updates global vars and enqueues tokens to the sync queue 263 | if self._is_chief: 264 | update_ops = [] 265 | # Make sure train_ops are run 266 | with ops.control_dependencies(train_ops): 267 | # Update global_vars with average values 268 | for avg_var, gvar in zip(aggregated_vars, global_vars): 269 | with ops.device(gvar.device): 270 | update_ops.append(state_ops.assign(gvar, avg_var)) 271 | # Update shared global_step 272 | with ops.device(global_step.device): 273 | update_ops.append(state_ops.assign_add(self._global_step, 1)) 274 | # After averaging, push tokens to the queue 275 | with ops.control_dependencies(update_ops), ops.device( 276 | global_step.device): 277 | tokens = array_ops.fill([self._tokens_per_step], 278 | constant_op.constant(False)) 279 | sync_op = sync_queue.enqueue_many(tokens) 280 | # non chief workers deque a token, they will block here until chief is done 281 | else: 282 | # Make sure train_ops are run 283 | with ops.control_dependencies(train_ops), ops.device( 284 | global_step.device): 285 | sync_op = sync_queue.dequeue() 286 | 287 | # All workers pull averaged values 288 | with ops.control_dependencies([sync_op]): 289 | local_update_op = self._assign_vars(local_vars, global_vars) 290 | return local_update_op 291 | 292 | # Check if we should push and average or not 293 | with ops.control_dependencies([apply_updates]): 294 | condition = math_ops.equal( 295 | math_ops.mod(global_step, self._interval_steps), 0) 296 | conditional_update = control_flow_ops.cond( 297 | condition, _apply_averages, control_flow_ops.no_op) 298 | 299 | chief_init_ops = [] 300 | # Initialize accumulators, ops placed in ps 301 | for accum, dev in self._accumulator_list: 302 | with ops.device(dev): 303 | chief_init_ops.append( 304 | accum.set_global_step(global_step, name="SetGlobalStep")) 305 | self._chief_init_op = control_flow_ops.group(*(chief_init_ops)) 306 | 307 | return conditional_update 308 | 309 | def _assign_vars(self, local_vars, global_vars): 310 | """Utility to refresh local variables. 311 | 312 | Args: 313 | local_vars: List of local variables. 314 | global_vars: List of global variables. 315 | 316 | Returns: 317 | refresh_ops: The ops to assign value of global vars to local vars. 318 | """ 319 | reassign_ops = [] 320 | for local_var, global_var in zip(local_vars, global_vars): 321 | reassign_ops.append(state_ops.assign(local_var, global_var)) 322 | refresh_ops = control_flow_ops.group(*(reassign_ops)) 323 | return refresh_ops 324 | 325 | def make_session_run_hook(self): 326 | """Creates a hook to handle federated average init operations.""" 327 | return _FederatedAverageHook(self) 328 | 329 | class _FederatedAverageHook(session_run_hook.SessionRunHook): 330 | """A SessionRunHook that handles ops related to FederatedAveragingOptimizer.""" 331 | 332 | def __init__(self, fed_avg_optimizer): 333 | """Creates hook to handle FederatedAveragingOptimizer 334 | 335 | Args: 336 | fed_avg_optimizer: 'FederatedAveragingOptimizer' which this hook will 337 | initialize. 338 | """ 339 | self._fed_avg_optimizer = fed_avg_optimizer 340 | 341 | def begin(self): 342 | local_vars = variables.trainable_variables() 343 | global_vars = ops.get_collection_ref("global_model") 344 | self._variable_init_op = self._fed_avg_optimizer._assign_vars( 345 | local_vars, 346 | global_vars) 347 | 348 | def after_create_session(self, session, coord): 349 | # Make sure all models start at the same point 350 | session.run(self._variable_init_op) 351 | -------------------------------------------------------------------------------- /images/Logo_Acuratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coMindOrg/federated-averaging-tutorials/123974939ffb3f14d50795e854cb698685fff46a/images/Logo_Acuratio.png -------------------------------------------------------------------------------- /images/colab_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coMindOrg/federated-averaging-tutorials/123974939ffb3f14d50795e854cb698685fff46a/images/colab_logo.png -------------------------------------------------------------------------------- /images/comindorg_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coMindOrg/federated-averaging-tutorials/123974939ffb3f14d50795e854cb698685fff46a/images/comindorg_logo.png -------------------------------------------------------------------------------- /images/graph_tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coMindOrg/federated-averaging-tutorials/123974939ffb3f14d50795e854cb698685fff46a/images/graph_tensorboard.png -------------------------------------------------------------------------------- /images/slack_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coMindOrg/federated-averaging-tutorials/123974939ffb3f14d50795e854cb698685fff46a/images/slack_logo.jpg -------------------------------------------------------------------------------- /images/telegram_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coMindOrg/federated-averaging-tutorials/123974939ffb3f14d50795e854cb698685fff46a/images/telegram_logo.jpg --------------------------------------------------------------------------------