├── .gitignore ├── LICENSE ├── README.md └── Shakespeare ├── FedAvg.ipynb ├── FedMed.ipynb ├── FedProx.ipynb └── qFedAvg.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Santiago Gonzalez Toral 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fl-algorithms 2 | 3 | A repository containing notebooks that implement different Federated Learning algorithms using PyTorch 4 | 5 | # Experiments 6 | 7 | ### FL using the Shakespeare dataset 8 | - **Federated Averaging** (FedAvg) 9 | - [Notebook](Shakespeare/FedAvg.ipynb) 10 | - [Paper: Communication-Efficient Learning of Deep Networks from Decentralized Data](http://proceedings.mlr.press/v54/mcmahan17a.html) 11 | - **Federated Prox** (FedProx) 12 | - [Notebook](Shakespeare/FedProx.ipynb) 13 | - [Paper: Federated Optimization in Heterogenous Networks](https://arxiv.org/abs/1812.06127) 14 | - **Federated Median** (FedMed) 15 | - [Notebook](Shakespeare/FedMed.ipynb) 16 | - [Paper: Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](http://proceedings.mlr.press/v80/yin18a.html) 17 | - **q-Federated Averaging** (qFedAvg) a.k.a. **q-Fair Federated Learning** (qFFL) 18 | - [Notebook](Shakespeare/qFedAvg.ipynb) 19 | - [Paper: Fair Resource Allocation In Federated Learning](https://arxiv.org/abs/1905.10497) 20 | 21 | # License 22 | 23 | [MIT](LICENSE) 24 | -------------------------------------------------------------------------------- /Shakespeare/FedAvg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "FedAvg.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "Y-ZegBArf4xQ", 10 | "cce_-qnxhD4n", 11 | "XelbyPsDlfgb", 12 | "IOBblyFGlwlU", 13 | "sWVOxcAao2_t", 14 | "vFFAfTOwpk4j", 15 | "c640e4NnpksE", 16 | "3crFDN0xqGu6", 17 | "YXtGLkoAqLIW" 18 | ], 19 | "toc_visible": true, 20 | "authorship_tag": "ABX9TyODmfkWrZZ1tNBNYaII6os2", 21 | "include_colab_link": true 22 | }, 23 | "kernelspec": { 24 | "name": "python3", 25 | "display_name": "Python 3" 26 | }, 27 | "language_info": { 28 | "name": "python" 29 | }, 30 | "accelerator": "GPU" 31 | }, 32 | "cells": [ 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "view-in-github", 37 | "colab_type": "text" 38 | }, 39 | "source": [ 40 | "\"Open" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "ZXnVpPCFfQkm" 47 | }, 48 | "source": [ 49 | "# FedPerf - Shakespeare + FedAvg algorithm" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "TWDjzbCzfUFt" 56 | }, 57 | "source": [ 58 | "## Setup & Dependencies Installation" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "z_f-rWudW0Fo" 65 | }, 66 | "source": [ 67 | "%%capture\n", 68 | "!pip install torchsummaryX unidecode" 69 | ], 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "metadata": { 76 | "id": "3TmMInb_fnw_" 77 | }, 78 | "source": [ 79 | "%load_ext tensorboard\n", 80 | "\n", 81 | "import copy\n", 82 | "from functools import reduce\n", 83 | "import json\n", 84 | "import matplotlib.pyplot as plt\n", 85 | "import numpy as np\n", 86 | "import os\n", 87 | "import pandas as pd\n", 88 | "import pickle\n", 89 | "import random\n", 90 | "from sklearn.model_selection import train_test_split\n", 91 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", 92 | "import time\n", 93 | "import torch\n", 94 | "from torch.autograd import Variable\n", 95 | "import torch.nn as nn\n", 96 | "import torch.nn.functional as F\n", 97 | "from torch.utils.data import Dataset\n", 98 | "from torch.utils.data.dataloader import DataLoader\n", 99 | "from torch.utils.data.sampler import Sampler\n", 100 | "from torch.utils.tensorboard import SummaryWriter\n", 101 | "from torchsummary import summary\n", 102 | "from torchsummaryX import summary as summaryx\n", 103 | "from torchvision import transforms, utils, datasets\n", 104 | "from tqdm.notebook import tqdm\n", 105 | "from unidecode import unidecode\n", 106 | "\n", 107 | "%matplotlib inline\n", 108 | "\n", 109 | "# Check assigned GPU\n", 110 | "gpu_info = !nvidia-smi\n", 111 | "gpu_info = '\\n'.join(gpu_info)\n", 112 | "if gpu_info.find('failed') >= 0:\n", 113 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n", 114 | " print('and then re-execute this cell.')\n", 115 | "else:\n", 116 | " print(gpu_info)\n", 117 | "\n", 118 | "# set manual seed for reproducibility\n", 119 | "RANDOM_SEED = 42\n", 120 | "\n", 121 | "# general reproducibility\n", 122 | "random.seed(RANDOM_SEED)\n", 123 | "np.random.seed(RANDOM_SEED)\n", 124 | "torch.manual_seed(RANDOM_SEED)\n", 125 | "torch.cuda.manual_seed(RANDOM_SEED)\n", 126 | "\n", 127 | "# gpu training specific\n", 128 | "torch.backends.cudnn.deterministic = True\n", 129 | "torch.backends.cudnn.benchmark = False" 130 | ], 131 | "execution_count": null, 132 | "outputs": [] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "id": "8Mfqv6uHfwh-" 138 | }, 139 | "source": [ 140 | "## Mount GDrive" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "metadata": { 146 | "id": "75qbJwxsGj-k" 147 | }, 148 | "source": [ 149 | "BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/FedAvg'" 150 | ], 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "metadata": { 157 | "id": "Ost-6lXSfveC" 158 | }, 159 | "source": [ 160 | "try:\n", 161 | " from google.colab import drive\n", 162 | " drive.mount('/content/drive')\n", 163 | " os.makedirs(BASE_DIR, exist_ok=True)\n", 164 | "except:\n", 165 | " print(\"WARNING: Results won't be stored on GDrive\")\n", 166 | " BASE_DIR = './'\n", 167 | "\n" 168 | ], 169 | "execution_count": null, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "Y-ZegBArf4xQ" 176 | }, 177 | "source": [ 178 | "## Loading Dataset" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "hf03LRxof7Zj" 185 | }, 186 | "source": [ 187 | "!rm -Rf data\n", 188 | "!mkdir -p data scripts" 189 | ], 190 | "execution_count": null, 191 | "outputs": [] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "ngygA4-Fgobx" 197 | }, 198 | "source": [ 199 | "GENERATE_DATASET = False # If False, download the dataset provided by the q-FFL paper\n", 200 | "DATA_DIR = 'data/'\n", 201 | "# Dataset generation params\n", 202 | "SAMPLES_FRACTION = 1. # If using an already generated dataset\n", 203 | "# SAMPLES_FRACTION = 0.2 # Fraction of total samples in the dataset - FedProx default script\n", 204 | "# SAMPLES_FRACTION = 0.05 # Fraction of total samples in the dataset - qFFL\n", 205 | "TRAIN_FRACTION = 0.8 # Train set size\n", 206 | "MIN_SAMPLES = 0 # Min samples per client (for filtering purposes) - FedProx\n", 207 | "# MIN_SAMPLES = 64 # Min samples per client (for filtering purposes) - qFFL" 208 | ], 209 | "execution_count": null, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "metadata": { 215 | "id": "nUmwJgJygoYD" 216 | }, 217 | "source": [ 218 | "# Download raw dataset\n", 219 | "# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt\n", 220 | "!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt" 221 | ], 222 | "execution_count": null, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "id": "4dCvx80BgoVr" 229 | }, 230 | "source": [ 231 | "if not GENERATE_DATASET:\n", 232 | " !rm -Rf data/train data/test\n", 233 | " !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD # Download Shakespeare dataset used by the FedProx paper\n", 234 | " !unzip shakespeare.zip\n", 235 | " !mv -f shakespeare_paper/train data/\n", 236 | " !mv -f shakespeare_paper/test data/\n", 237 | " !rm -R shakespeare_paper/ shakespeare.zip\n" 238 | ], 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "metadata": { 245 | "id": "a4pzFvPvhQhq" 246 | }, 247 | "source": [ 248 | "corpus = []\n", 249 | "with open('data/shakespeare.txt', 'r') as f:\n", 250 | " data = list(unidecode(f.read()))\n", 251 | " corpus = list(set(list(data)))\n", 252 | "print('Corpus Length:', len(corpus))" 253 | ], 254 | "execution_count": null, 255 | "outputs": [] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": { 260 | "id": "cce_-qnxhD4n" 261 | }, 262 | "source": [ 263 | "#### Dataset Preprocessing script" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "metadata": { 269 | "id": "Rt13M4IcgoTV" 270 | }, 271 | "source": [ 272 | "%%capture\n", 273 | "if GENERATE_DATASET:\n", 274 | " # Download dataset generation scripts\n", 275 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py\n", 276 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py\n", 277 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py\n", 278 | "\n", 279 | " # Download data preprocessing scripts\n", 280 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py\n", 281 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py" 282 | ], 283 | "execution_count": null, 284 | "outputs": [] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "metadata": { 289 | "id": "EIEyRW27goPo" 290 | }, 291 | "source": [ 292 | "# Running scripts\n", 293 | "if GENERATE_DATASET:\n", 294 | " !mkdir -p data/raw_data data/all_data data/train data/test\n", 295 | " !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data\n", 296 | " !python scripts/gen_all_data.py" 297 | ], 298 | "execution_count": null, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": { 304 | "id": "mq8V6v_4hhhD" 305 | }, 306 | "source": [ 307 | "#### Dataset class" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "metadata": { 313 | "id": "H2SjEBKoWDxv" 314 | }, 315 | "source": [ 316 | "class ShakespeareDataset(Dataset):\n", 317 | " def __init__(self, x, y, corpus, seq_length):\n", 318 | " self.x = x\n", 319 | " self.y = y\n", 320 | " self.corpus = corpus\n", 321 | " self.corpus_size = len(self.corpus)\n", 322 | " super(ShakespeareDataset, self).__init__()\n", 323 | "\n", 324 | " def __len__(self):\n", 325 | " return len(self.x)\n", 326 | "\n", 327 | " def __repr__(self):\n", 328 | " return f'{self.__class__} - (length: {self.__len__()})'\n", 329 | "\n", 330 | " def __getitem__(self, i):\n", 331 | " input_seq = self.x[i]\n", 332 | " next_char = self.y[i]\n", 333 | " # print('\\tgetitem', i, input_seq, next_char)\n", 334 | " input_value = self.text2charindxs(input_seq)\n", 335 | " target_value = self.get_label_from_char(next_char)\n", 336 | " return input_value, target_value\n", 337 | "\n", 338 | " def text2charindxs(self, text):\n", 339 | " tensor = torch.zeros(len(text), dtype=torch.int32)\n", 340 | " for i, c in enumerate(text):\n", 341 | " tensor[i] = self.get_label_from_char(c)\n", 342 | " return tensor\n", 343 | "\n", 344 | " def get_label_from_char(self, c):\n", 345 | " return self.corpus.index(c)\n", 346 | "\n", 347 | " def get_char_from_label(self, l):\n", 348 | " return self.corpus[l]" 349 | ], 350 | "execution_count": null, 351 | "outputs": [] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "9fgJtS62lYAN" 357 | }, 358 | "source": [ 359 | "##### Federated Dataset" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "metadata": { 365 | "id": "5DqL5pTmgn5X" 366 | }, 367 | "source": [ 368 | "class ShakespeareFedDataset(ShakespeareDataset):\n", 369 | " def __init__(self, x, y, corpus, seq_length):\n", 370 | " super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)\n", 371 | "\n", 372 | " def dataloader(self, batch_size, shuffle=True):\n", 373 | " return DataLoader(self,\n", 374 | " batch_size=batch_size,\n", 375 | " shuffle=shuffle,\n", 376 | " num_workers=0)\n" 377 | ], 378 | "execution_count": null, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": { 384 | "id": "XelbyPsDlfgb" 385 | }, 386 | "source": [ 387 | "## Partitioning & Data Loaders" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": { 393 | "id": "IOBblyFGlwlU" 394 | }, 395 | "source": [ 396 | "### IID" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "metadata": { 402 | "id": "cSZFWKmsgn1p" 403 | }, 404 | "source": [ 405 | "def iid_partition_(dataset, clients):\n", 406 | " \"\"\"\n", 407 | " I.I.D paritioning of data over clients\n", 408 | " Shuffle the data\n", 409 | " Split it between clients\n", 410 | " \n", 411 | " params:\n", 412 | " - dataset (torch.utils.Dataset): Dataset\n", 413 | " - clients (int): Number of Clients to split the data between\n", 414 | "\n", 415 | " returns:\n", 416 | " - Dictionary of image indexes for each client\n", 417 | " \"\"\"\n", 418 | "\n", 419 | " num_items_per_client = int(len(dataset)/clients)\n", 420 | " client_dict = {}\n", 421 | " image_idxs = [i for i in range(len(dataset))]\n", 422 | "\n", 423 | " for i in range(clients):\n", 424 | " client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))\n", 425 | " image_idxs = list(set(image_idxs) - client_dict[i])\n", 426 | "\n", 427 | " return client_dict" 428 | ], 429 | "execution_count": null, 430 | "outputs": [] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "metadata": { 435 | "id": "-lGwDyhSll9h" 436 | }, 437 | "source": [ 438 | "def iid_partition(corpus, seq_length=80, val_split=False):\n", 439 | "\n", 440 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n", 441 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n", 442 | "\n", 443 | " with open(train_file, 'r') as file:\n", 444 | " data_train = json.loads(unidecode(file.read()))\n", 445 | "\n", 446 | " with open(test_file, 'r') as file:\n", 447 | " data_test = json.loads(unidecode(file.read()))\n", 448 | "\n", 449 | " \n", 450 | " total_samples_train = sum(data_train['num_samples'])\n", 451 | "\n", 452 | " data_dict = {}\n", 453 | "\n", 454 | " x_train, y_train = [], []\n", 455 | " x_test, y_test = [], []\n", 456 | " # x_val, y_val = [], []\n", 457 | "\n", 458 | " users = list(zip(data_train['users'], data_train['num_samples']))\n", 459 | " # random.shuffle(users)\n", 460 | "\n", 461 | "\n", 462 | "\n", 463 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n", 464 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n", 465 | " sample_count = 0\n", 466 | " \n", 467 | " for i, (author_id, samples) in enumerate(users):\n", 468 | "\n", 469 | " if sample_count >= total_samples:\n", 470 | " print('Max samples reached', sample_count, '/', total_samples)\n", 471 | " break\n", 472 | "\n", 473 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n", 474 | " print('SKIP', author_id, samples)\n", 475 | " continue\n", 476 | " else:\n", 477 | " udata_train = data_train['user_data'][author_id]\n", 478 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n", 479 | " \n", 480 | " sample_count += max_samples\n", 481 | " # print('sample_count', sample_count)\n", 482 | "\n", 483 | " x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])\n", 484 | " y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])\n", 485 | "\n", 486 | " author_data = data_test['user_data'][author_id]\n", 487 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n", 488 | "\n", 489 | " if val_split:\n", 490 | " x_test.extend(author_data['x'][:int(test_size / 2)])\n", 491 | " y_test.extend(author_data['y'][:int(test_size / 2)])\n", 492 | " # x_val.extend(author_data['x'][int(test_size / 2):])\n", 493 | " # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])\n", 494 | "\n", 495 | " else:\n", 496 | " x_test.extend(author_data['x'][:int(test_size)])\n", 497 | " y_test.extend(author_data['y'][:int(test_size)])\n", 498 | "\n", 499 | " train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)\n", 500 | " test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)\n", 501 | " # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)\n", 502 | "\n", 503 | " data_dict = iid_partition_(train_ds, clients=len(users))\n", 504 | "\n", 505 | " return train_ds, data_dict, test_ds" 506 | ], 507 | "execution_count": null, 508 | "outputs": [] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": { 513 | "id": "MFvc8mLoouKa" 514 | }, 515 | "source": [ 516 | "### Non-IID" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "metadata": { 522 | "id": "GZ76WsCZot9s" 523 | }, 524 | "source": [ 525 | "def noniid_partition(corpus, seq_length=80, val_split=False):\n", 526 | "\n", 527 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n", 528 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n", 529 | "\n", 530 | " with open(train_file, 'r') as file:\n", 531 | " data_train = json.loads(unidecode(file.read()))\n", 532 | "\n", 533 | " with open(test_file, 'r') as file:\n", 534 | " data_test = json.loads(unidecode(file.read()))\n", 535 | "\n", 536 | " \n", 537 | " total_samples_train = sum(data_train['num_samples'])\n", 538 | "\n", 539 | " data_dict = {}\n", 540 | "\n", 541 | " x_test, y_test = [], []\n", 542 | "\n", 543 | " users = list(zip(data_train['users'], data_train['num_samples']))\n", 544 | " # random.shuffle(users)\n", 545 | "\n", 546 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n", 547 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n", 548 | " sample_count = 0\n", 549 | " \n", 550 | " for i, (author_id, samples) in enumerate(users):\n", 551 | "\n", 552 | " if sample_count >= total_samples:\n", 553 | " print('Max samples reached', sample_count, '/', total_samples)\n", 554 | " break\n", 555 | "\n", 556 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n", 557 | " print('SKIP', author_id, samples)\n", 558 | " continue\n", 559 | " else:\n", 560 | " udata_train = data_train['user_data'][author_id]\n", 561 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n", 562 | " \n", 563 | " sample_count += max_samples\n", 564 | " # print('sample_count', sample_count)\n", 565 | "\n", 566 | " x_train = data_train['user_data'][author_id]['x'][:max_samples]\n", 567 | " y_train = data_train['user_data'][author_id]['y'][:max_samples]\n", 568 | "\n", 569 | " train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)\n", 570 | "\n", 571 | " x_val, y_val = None, None\n", 572 | " val_ds = None\n", 573 | " author_data = data_test['user_data'][author_id]\n", 574 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n", 575 | " if val_split:\n", 576 | " x_test += author_data['x'][:int(test_size / 2)]\n", 577 | " y_test += author_data['y'][:int(test_size / 2)]\n", 578 | " x_val = author_data['x'][int(test_size / 2):]\n", 579 | " y_val = author_data['y'][int(test_size / 2):int(test_size)]\n", 580 | "\n", 581 | " val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)\n", 582 | "\n", 583 | " else:\n", 584 | " x_test += author_data['x'][:int(test_size)]\n", 585 | " y_test += author_data['y'][:int(test_size)]\n", 586 | "\n", 587 | " data_dict[author_id] = {\n", 588 | " 'train_ds': train_ds,\n", 589 | " 'val_ds': val_ds\n", 590 | " }\n", 591 | "\n", 592 | " test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)\n", 593 | "\n", 594 | " return data_dict, test_ds" 595 | ], 596 | "execution_count": null, 597 | "outputs": [] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": { 602 | "id": "sWVOxcAao2_t" 603 | }, 604 | "source": [ 605 | "## Models" 606 | ] 607 | }, 608 | { 609 | "cell_type": "markdown", 610 | "metadata": { 611 | "id": "gQQQ2mLeo6EA" 612 | }, 613 | "source": [ 614 | "### Shakespeare LSTM" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "metadata": { 620 | "id": "2mGXTrXRot7R" 621 | }, 622 | "source": [ 623 | "class ShakespeareLSTM(nn.Module):\n", 624 | " \"\"\"\n", 625 | " \"\"\"\n", 626 | "\n", 627 | " def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):\n", 628 | " super(ShakespeareLSTM, self).__init__()\n", 629 | " self.input_dim = input_dim\n", 630 | " self.embedding_dim = embedding_dim\n", 631 | " self.hidden_dim = hidden_dim\n", 632 | " self.classes = classes\n", 633 | " self.no_layers = lstm_layers\n", 634 | " \n", 635 | " self.embedding = nn.Embedding(num_embeddings=self.classes,\n", 636 | " embedding_dim=self.embedding_dim)\n", 637 | " self.lstm = nn.LSTM(input_size=self.embedding_dim, \n", 638 | " hidden_size=self.hidden_dim,\n", 639 | " num_layers=self.no_layers,\n", 640 | " batch_first=batch_first, \n", 641 | " dropout=dropout if self.no_layers > 1 else 0.)\n", 642 | " self.fc = nn.Linear(hidden_dim, self.classes)\n", 643 | "\n", 644 | " def forward(self, x, hc=None):\n", 645 | " batch_size = x.size(0)\n", 646 | " x_emb = self.embedding(x)\n", 647 | " out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)\n", 648 | " dense = self.fc(ht[-1])\n", 649 | " return dense\n", 650 | " \n", 651 | " def init_hidden(self, batch_size):\n", 652 | " return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),\n", 653 | " Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))\n" 654 | ], 655 | "execution_count": null, 656 | "outputs": [] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "metadata": { 661 | "id": "5QsuJlVipMc8" 662 | }, 663 | "source": [ 664 | "#### Model Summary" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "metadata": { 670 | "id": "n_Vb0BYpot5I" 671 | }, 672 | "source": [ 673 | "batch_size = 10\n", 674 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n", 675 | "\n", 676 | "shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n", 677 | " embedding_dim=8, # mcmahan17a, fedprox, qFFL\n", 678 | " hidden_dim=256, # mcmahan17a, fedprox impl\n", 679 | " # hidden_dim=100, # fedprox paper\n", 680 | " classes=len(corpus),\n", 681 | " lstm_layers=2,\n", 682 | " dropout=0.1, # TODO:\n", 683 | " batch_first=True\n", 684 | " )\n", 685 | "\n", 686 | "if torch.cuda.is_available():\n", 687 | " shakespeare_lstm.cuda()\n", 688 | "\n", 689 | "\n", 690 | "\n", 691 | "hc = shakespeare_lstm.init_hidden(batch_size)\n", 692 | "\n", 693 | "x_sample = torch.zeros((batch_size, seq_length),\n", 694 | " dtype=torch.long,\n", 695 | " device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))\n", 696 | "\n", 697 | "x_sample[0][0] = 1\n", 698 | "x_sample\n", 699 | "\n", 700 | "print(\"\\nShakespeare LSTM SUMMARY\")\n", 701 | "print(summaryx(shakespeare_lstm, x_sample))" 702 | ], 703 | "execution_count": null, 704 | "outputs": [] 705 | }, 706 | { 707 | "cell_type": "markdown", 708 | "metadata": { 709 | "id": "qn7egnzTpeks" 710 | }, 711 | "source": [ 712 | "## FedAvg Algorithm" 713 | ] 714 | }, 715 | { 716 | "cell_type": "markdown", 717 | "metadata": { 718 | "id": "vFFAfTOwpk4j" 719 | }, 720 | "source": [ 721 | "### Plot Utils" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "metadata": { 727 | "id": "oyYjWa6IpnTY" 728 | }, 729 | "source": [ 730 | "from sklearn.metrics import f1_score" 731 | ], 732 | "execution_count": null, 733 | "outputs": [] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "metadata": { 738 | "id": "367THsiTpo-C" 739 | }, 740 | "source": [ 741 | "def plot_scores(history, exp_id, title, suffix):\n", 742 | " accuracies = [x['accuracy'] for x in history]\n", 743 | " f1_macro = [x['f1_macro'] for x in history]\n", 744 | " f1_weighted = [x['f1_weighted'] for x in history]\n", 745 | "\n", 746 | " fig, ax = plt.subplots()\n", 747 | " ax.plot(accuracies, 'tab:orange')\n", 748 | " ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)\n", 749 | " ax.grid()\n", 750 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)\n", 751 | " plt.show()\n", 752 | "\n", 753 | " fig, ax = plt.subplots()\n", 754 | " ax.plot(f1_macro, 'tab:orange')\n", 755 | " ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)\n", 756 | " ax.grid()\n", 757 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')\n", 758 | " plt.show()\n", 759 | "\n", 760 | " fig, ax = plt.subplots()\n", 761 | " ax.plot(f1_weighted, 'tab:orange')\n", 762 | " ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)\n", 763 | " ax.grid()\n", 764 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')\n", 765 | " plt.show()\n", 766 | "\n", 767 | "\n", 768 | "def plot_losses(history, exp_id, title, suffix):\n", 769 | " val_losses = [x['loss'] for x in history]\n", 770 | " train_losses = [x['train_loss'] for x in history]\n", 771 | "\n", 772 | " fig, ax = plt.subplots()\n", 773 | " ax.plot(train_losses, 'tab:orange')\n", 774 | " ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)\n", 775 | " ax.grid()\n", 776 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')\n", 777 | " plt.show()\n", 778 | "\n", 779 | " fig, ax = plt.subplots()\n", 780 | " ax.plot(val_losses, 'tab:orange')\n", 781 | " ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)\n", 782 | " ax.grid()\n", 783 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')\n", 784 | " plt.show()\n" 785 | ], 786 | "execution_count": null, 787 | "outputs": [] 788 | }, 789 | { 790 | "cell_type": "markdown", 791 | "metadata": { 792 | "id": "c640e4NnpksE" 793 | }, 794 | "source": [ 795 | "### Systems Heterogeneity Simulations\n", 796 | "\n", 797 | "Generate epochs for selected clients based on percentage of devices that corresponds to heterogeneity. \n", 798 | "\n", 799 | "Assign x number of epochs (chosen unifirmly at random between [1, E]) to 0%, 50% or 90% of the selected devices, respectively. Settings where 0% devices perform fewer than E epochs of work correspond to the environments without system heterogeneity, while 90% of the devices sending their partial solutions corresponds to highly heterogenous system." 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "zuEZYnl5ot2m" 806 | }, 807 | "source": [ 808 | "def GenerateLocalEpochs(percentage, size, max_epochs):\n", 809 | " ''' Method generates list of epochs for selected clients\n", 810 | " to replicate system heteroggeneity\n", 811 | "\n", 812 | " Params:\n", 813 | " percentage: percentage of clients to have fewer than E epochs\n", 814 | " size: total size of the list\n", 815 | " max_epochs: maximum value for local epochs\n", 816 | " \n", 817 | " Returns:\n", 818 | " List of size epochs for each Client Update\n", 819 | "\n", 820 | " '''\n", 821 | "\n", 822 | " # if percentage is 0 then each client runs for E epochs\n", 823 | " if percentage == 0:\n", 824 | " return np.array([max_epochs]*size)\n", 825 | " else:\n", 826 | " # get the number of clients to have fewer than E epochs\n", 827 | " heterogenous_size = int((percentage/100) * size)\n", 828 | "\n", 829 | " # generate random uniform epochs of heterogenous size between 1 and E\n", 830 | " epoch_list = np.random.randint(1, max_epochs, heterogenous_size)\n", 831 | "\n", 832 | " # the rest of the clients will have E epochs\n", 833 | " remaining_size = size - heterogenous_size\n", 834 | " rem_list = [max_epochs]*remaining_size\n", 835 | "\n", 836 | " epoch_list = np.append(epoch_list, rem_list, axis=0)\n", 837 | " \n", 838 | " # shuffle the list and return\n", 839 | " np.random.shuffle(epoch_list)\n", 840 | "\n", 841 | " return epoch_list" 842 | ], 843 | "execution_count": null, 844 | "outputs": [] 845 | }, 846 | { 847 | "cell_type": "markdown", 848 | "metadata": { 849 | "id": "VQ9PZM0Gp9ve" 850 | }, 851 | "source": [ 852 | "### Local Training (Client Update)" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "metadata": { 858 | "id": "EDJFltwdotzZ" 859 | }, 860 | "source": [ 861 | "class CustomDataset(Dataset):\n", 862 | " def __init__(self, dataset, idxs):\n", 863 | " self.dataset = dataset\n", 864 | " self.idxs = list(idxs)\n", 865 | "\n", 866 | " def __len__(self):\n", 867 | " return len(self.idxs)\n", 868 | "\n", 869 | " def __getitem__(self, item):\n", 870 | " data, label = self.dataset[self.idxs[item]]\n", 871 | " return data, label" 872 | ], 873 | "execution_count": null, 874 | "outputs": [] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "metadata": { 879 | "id": "HtRzU5Yepddq" 880 | }, 881 | "source": [ 882 | "class ClientUpdate(object):\n", 883 | " def __init__(self, dataset, batchSize, learning_rate, epochs, idxs, mu, algorithm):\n", 884 | " # self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True)\n", 885 | " if hasattr(dataset, 'dataloader'):\n", 886 | " self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)\n", 887 | " else:\n", 888 | " self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)\n", 889 | "\n", 890 | " self.algorithm = algorithm\n", 891 | " self.learning_rate = learning_rate\n", 892 | " self.epochs = epochs\n", 893 | " self.mu = mu\n", 894 | "\n", 895 | " def train(self, model):\n", 896 | " # print(\"Client training for {} epochs.\".format(self.epochs))\n", 897 | " criterion = nn.CrossEntropyLoss()\n", 898 | " proximal_criterion = nn.MSELoss(reduction='mean')\n", 899 | " optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)\n", 900 | "\n", 901 | " # use the weights of global model for proximal term calculation\n", 902 | " global_model = copy.deepcopy(model)\n", 903 | "\n", 904 | " # calculate local training time\n", 905 | " start_time = time.time()\n", 906 | "\n", 907 | "\n", 908 | " e_loss = []\n", 909 | " for epoch in range(1, self.epochs+1):\n", 910 | "\n", 911 | " train_loss = 0.0\n", 912 | "\n", 913 | " model.train()\n", 914 | " for data, labels in self.train_loader:\n", 915 | "\n", 916 | " if torch.cuda.is_available():\n", 917 | " data, labels = data.cuda(), labels.cuda()\n", 918 | "\n", 919 | " # clear the gradients\n", 920 | " optimizer.zero_grad()\n", 921 | " # make a forward pass\n", 922 | " output = model(data)\n", 923 | "\n", 924 | " # calculate the loss + the proximal term\n", 925 | " _, pred = torch.max(output, 1)\n", 926 | "\n", 927 | " if self.algorithm == 'fedprox':\n", 928 | " proximal_term = 0.0\n", 929 | "\n", 930 | " # iterate through the current and global model parameters\n", 931 | " for w, w_t in zip(model.parameters(), global_model.parameters()) :\n", 932 | " # update the proximal term \n", 933 | " #proximal_term += torch.sum(torch.abs((w-w_t)**2))\n", 934 | " proximal_term += (w-w_t).norm(2)\n", 935 | "\n", 936 | " loss = criterion(output, labels) + (self.mu/2)*proximal_term\n", 937 | " else:\n", 938 | " loss = criterion(output, labels)\n", 939 | " \n", 940 | " # do a backwards pass\n", 941 | " loss.backward()\n", 942 | " # perform a single optimization step\n", 943 | " optimizer.step()\n", 944 | " # update training loss\n", 945 | " train_loss += loss.item()*data.size(0)\n", 946 | "\n", 947 | " # average losses\n", 948 | " train_loss = train_loss/len(self.train_loader.dataset)\n", 949 | " e_loss.append(train_loss)\n", 950 | "\n", 951 | " total_loss = sum(e_loss)/len(e_loss)\n", 952 | "\n", 953 | " return model.state_dict(), total_loss, (time.time() - start_time)" 954 | ], 955 | "execution_count": null, 956 | "outputs": [] 957 | }, 958 | { 959 | "cell_type": "markdown", 960 | "metadata": { 961 | "id": "3crFDN0xqGu6" 962 | }, 963 | "source": [ 964 | "### Server Side Training" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "metadata": { 970 | "id": "c085xSOoqEHk" 971 | }, 972 | "source": [ 973 | "def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, mu, percentage, plt_title, plt_color, target_test_accuracy,\n", 974 | " classes, algorithm=\"fedprox\", history=[], eval_every=1, tb_logger=None):\n", 975 | " \"\"\"\n", 976 | " Function implements the Federated Averaging Algorithm from the FedAvg paper.\n", 977 | " Specifically, this function is used for the server side training and weight update\n", 978 | "\n", 979 | " Params:\n", 980 | " - model: PyTorch model to train\n", 981 | " - rounds: Number of communication rounds for the client update\n", 982 | " - batch_size: Batch size for client update training\n", 983 | " - lr: Learning rate used for client update training\n", 984 | " - ds: Dataset used for training\n", 985 | " - data_dict: Type of data partition used for training (IID or non-IID)\n", 986 | " - test_data_dict: Data used for testing the model\n", 987 | " - C: Fraction of clients randomly chosen to perform computation on each round\n", 988 | " - K: Total number of clients\n", 989 | " - E: Number of training passes each client makes over its local dataset per round\n", 990 | " - mu: proximal term constant\n", 991 | " - percentage: percentage of selected client to have fewer than E epochs\n", 992 | " Returns:\n", 993 | " - model: Trained model on the server\n", 994 | " \"\"\"\n", 995 | "\n", 996 | " start = time.time()\n", 997 | "\n", 998 | " # global model weights\n", 999 | " global_weights = model.state_dict()\n", 1000 | "\n", 1001 | " # training loss\n", 1002 | " train_loss = []\n", 1003 | "\n", 1004 | " # test accuracy\n", 1005 | " test_acc = []\n", 1006 | "\n", 1007 | " # store last loss for convergence\n", 1008 | " last_loss = 0.0\n", 1009 | "\n", 1010 | " # total time taken \n", 1011 | " total_time = 0\n", 1012 | "\n", 1013 | " print(f\"System heterogeneity set to {percentage}% stragglers.\\n\")\n", 1014 | " print(f\"Picking {max(int(C*K),1 )} random clients per round.\\n\")\n", 1015 | "\n", 1016 | " users_id = list(data_dict.keys())\n", 1017 | "\n", 1018 | " for curr_round in range(1, rounds+1):\n", 1019 | " w, local_loss, lst_local_train_time = [], [], []\n", 1020 | "\n", 1021 | " m = max(int(C*K), 1)\n", 1022 | "\n", 1023 | " heterogenous_epoch_list = GenerateLocalEpochs(percentage, size=m, max_epochs=E)\n", 1024 | " heterogenous_epoch_list = np.array(heterogenous_epoch_list)\n", 1025 | " # print('heterogenous_epoch_list', len(heterogenous_epoch_list))\n", 1026 | "\n", 1027 | " S_t = np.random.choice(range(K), m, replace=False)\n", 1028 | " S_t = np.array(S_t)\n", 1029 | " print('Clients: {}/{} -> {}'.format(len(S_t), K, S_t))\n", 1030 | " \n", 1031 | " # For Federated Averaging, drop all the clients that are stragglers\n", 1032 | " if algorithm == 'fedavg':\n", 1033 | " stragglers_indices = np.argwhere(heterogenous_epoch_list < E)\n", 1034 | " heterogenous_epoch_list = np.delete(heterogenous_epoch_list, stragglers_indices)\n", 1035 | " S_t = np.delete(S_t, stragglers_indices)\n", 1036 | "\n", 1037 | " # for _, (k, epoch) in tqdm(enumerate(zip(S_t, heterogenous_epoch_list))):\n", 1038 | " for i in tqdm(range(len(S_t))):\n", 1039 | " # print('k', k)\n", 1040 | " k = S_t[i]\n", 1041 | " epoch = heterogenous_epoch_list[i]\n", 1042 | " key = users_id[k]\n", 1043 | " ds_ = ds if ds else data_dict[key]['train_ds']\n", 1044 | " idxs = data_dict[key] if ds else None\n", 1045 | " # print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')\n", 1046 | " local_update = ClientUpdate(dataset=ds_, batchSize=batch_size, learning_rate=lr, epochs=epoch, idxs=idxs, mu=mu, algorithm=algorithm)\n", 1047 | " weights, loss, local_train_time = local_update.train(model=copy.deepcopy(model))\n", 1048 | " # print(f'Local train time for {k} on {len(idxs) if idxs else len(ds_)} samples: {local_train_time}')\n", 1049 | " # print(f'Local train time: {local_train_time}')\n", 1050 | "\n", 1051 | " w.append(copy.deepcopy(weights))\n", 1052 | " local_loss.append(copy.deepcopy(loss))\n", 1053 | " lst_local_train_time.append(local_train_time)\n", 1054 | "\n", 1055 | " # calculate time to update the global weights\n", 1056 | " global_start_time = time.time()\n", 1057 | "\n", 1058 | " # updating the global weights\n", 1059 | " weights_avg = copy.deepcopy(w[0])\n", 1060 | " for k in weights_avg.keys():\n", 1061 | " for i in range(1, len(w)):\n", 1062 | " weights_avg[k] += w[i][k]\n", 1063 | "\n", 1064 | " weights_avg[k] = torch.div(weights_avg[k], len(w))\n", 1065 | "\n", 1066 | " global_weights = weights_avg\n", 1067 | "\n", 1068 | " global_end_time = time.time()\n", 1069 | "\n", 1070 | " # calculate total time \n", 1071 | " total_time += (global_end_time - global_start_time) + sum(lst_local_train_time)/len(lst_local_train_time)\n", 1072 | "\n", 1073 | " # move the updated weights to our model state dict\n", 1074 | " model.load_state_dict(global_weights)\n", 1075 | "\n", 1076 | " # loss\n", 1077 | " loss_avg = sum(local_loss) / len(local_loss)\n", 1078 | " print('Round: {}... \\tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))\n", 1079 | " train_loss.append(loss_avg)\n", 1080 | " if tb_logger:\n", 1081 | " tb_logger.add_scalar(f'Train/Loss', loss_avg, curr_round)\n", 1082 | "\n", 1083 | " # testing\n", 1084 | " # if curr_round % eval_every == 0:\n", 1085 | " test_scores = testing(model, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(classes), classes)\n", 1086 | " test_scores['train_loss'] = loss_avg\n", 1087 | " test_loss, test_accuracy = test_scores['loss'], test_scores['accuracy']\n", 1088 | " history.append(test_scores)\n", 1089 | " \n", 1090 | " # print('Round: {}... \\tAverage Loss: {} \\tTest Loss: {} \\tTest Acc: {}'.format(curr_round, round(loss_avg, 3), round(test_loss, 3), round(test_accuracy, 3)))\n", 1091 | "\n", 1092 | " if tb_logger:\n", 1093 | " tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)\n", 1094 | " tb_logger.add_scalars(f'Test/Scores', {\n", 1095 | " 'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']\n", 1096 | " }, curr_round)\n", 1097 | "\n", 1098 | " test_acc.append(test_accuracy)\n", 1099 | " # break if we achieve the target test accuracy\n", 1100 | " if test_accuracy >= target_test_accuracy:\n", 1101 | " rounds = curr_round\n", 1102 | " break\n", 1103 | "\n", 1104 | " # break if we achieve convergence, i.e., loss between two consecutive rounds is <0.0001\n", 1105 | " if algorithm == 'fedprox' and abs(loss_avg - last_loss) < 1e-5:\n", 1106 | " rounds = curr_round\n", 1107 | " break\n", 1108 | " \n", 1109 | " # update the last loss\n", 1110 | " last_loss = loss_avg\n", 1111 | "\n", 1112 | " end = time.time()\n", 1113 | " \n", 1114 | " # plot train loss\n", 1115 | " fig, ax = plt.subplots()\n", 1116 | " x_axis = np.arange(1, rounds+1)\n", 1117 | " y_axis = np.array(train_loss)\n", 1118 | " ax.plot(x_axis, y_axis)\n", 1119 | "\n", 1120 | " ax.set(xlabel='Number of Rounds', ylabel='Train Loss', title=plt_title)\n", 1121 | " ax.grid()\n", 1122 | " # fig.savefig(plt_title+'.jpg', format='jpg')\n", 1123 | "\n", 1124 | " # plot test accuracy\n", 1125 | " fig1, ax1 = plt.subplots()\n", 1126 | " x_axis1 = np.arange(1, rounds+1)\n", 1127 | " y_axis1 = np.array(test_acc)\n", 1128 | " ax1.plot(x_axis1, y_axis1)\n", 1129 | "\n", 1130 | " ax1.set(xlabel='Number of Rounds', ylabel='Test Accuracy', title=plt_title)\n", 1131 | " ax1.grid()\n", 1132 | " # fig1.savefig(plt_title+'-test.jpg', format='jpg')\n", 1133 | " \n", 1134 | " print(\"Training Done! Total time taken to Train: {}\".format(end-start))\n", 1135 | "\n", 1136 | " return model, history" 1137 | ], 1138 | "execution_count": null, 1139 | "outputs": [] 1140 | }, 1141 | { 1142 | "cell_type": "markdown", 1143 | "metadata": { 1144 | "id": "YXtGLkoAqLIW" 1145 | }, 1146 | "source": [ 1147 | "### Testing Loop" 1148 | ] 1149 | }, 1150 | { 1151 | "cell_type": "code", 1152 | "metadata": { 1153 | "id": "dQJIJno4qKvc" 1154 | }, 1155 | "source": [ 1156 | "def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):\n", 1157 | " #test loss \n", 1158 | " test_loss = 0.0\n", 1159 | " correct_class = list(0. for i in range(num_classes))\n", 1160 | " total_class = list(0. for i in range(num_classes))\n", 1161 | "\n", 1162 | " test_loader = DataLoader(dataset, batch_size=bs)\n", 1163 | " l = len(test_loader)\n", 1164 | " model.eval()\n", 1165 | " print('running validation...')\n", 1166 | " for i, (data, labels) in enumerate(tqdm(test_loader)):\n", 1167 | "\n", 1168 | " if torch.cuda.is_available():\n", 1169 | " data, labels = data.cuda(), labels.cuda()\n", 1170 | "\n", 1171 | " output = model(data)\n", 1172 | " loss = criterion(output, labels)\n", 1173 | " test_loss += loss.item()*data.size(0)\n", 1174 | "\n", 1175 | " _, pred = torch.max(output, 1)\n", 1176 | "\n", 1177 | " # For F1Score\n", 1178 | " y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()\n", 1179 | " y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()\n", 1180 | "\n", 1181 | " correct_tensor = pred.eq(labels.data.view_as(pred))\n", 1182 | " correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())\n", 1183 | "\n", 1184 | " #test accuracy for each object class\n", 1185 | " # for i in range(num_classes):\n", 1186 | " # label = labels.data[i]\n", 1187 | " # correct_class[label] += correct[i].item()\n", 1188 | " # total_class[label] += 1\n", 1189 | "\n", 1190 | " for i, lbl in enumerate(labels.data):\n", 1191 | " # print('lbl', i, lbl)\n", 1192 | " correct_class[lbl] += correct.data[i]\n", 1193 | " total_class[lbl] += 1\n", 1194 | " \n", 1195 | " # avg test loss\n", 1196 | " test_loss = test_loss/len(test_loader.dataset)\n", 1197 | " print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n", 1198 | "\n", 1199 | " # Avg F1 Score\n", 1200 | " f1_macro = f1_score(y_true, y_hat, average='macro')\n", 1201 | " # F1-Score -> weigthed to consider class imbalance\n", 1202 | " f1_weighted = f1_score(y_true, y_hat, average='weighted')\n", 1203 | " print(\"F1 Score: {:.6f} (macro) {:.6f} (weighted) %\\n\".format(f1_macro, f1_weighted))\n", 1204 | "\n", 1205 | " # print test accuracy\n", 1206 | " if print_all:\n", 1207 | " for i in range(num_classes):\n", 1208 | " if total_class[i]>0:\n", 1209 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % \n", 1210 | " (classes[i], 100 * correct_class[i] / total_class[i],\n", 1211 | " np.sum(correct_class[i]), np.sum(total_class[i])))\n", 1212 | " else:\n", 1213 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", 1214 | "\n", 1215 | " overall_accuracy = np.sum(correct_class) / np.sum(total_class)\n", 1216 | "\n", 1217 | " print('\\nFinal Test Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))\n", 1218 | "\n", 1219 | " return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}" 1220 | ], 1221 | "execution_count": null, 1222 | "outputs": [] 1223 | }, 1224 | { 1225 | "cell_type": "markdown", 1226 | "metadata": { 1227 | "id": "uxqXLBd8qbC2" 1228 | }, 1229 | "source": [ 1230 | "## Experiments" 1231 | ] 1232 | }, 1233 | { 1234 | "cell_type": "code", 1235 | "metadata": { 1236 | "id": "VRKlrkVHO8Na" 1237 | }, 1238 | "source": [ 1239 | "FAIL-ON-PURPOSE" 1240 | ], 1241 | "execution_count": null, 1242 | "outputs": [] 1243 | }, 1244 | { 1245 | "cell_type": "code", 1246 | "metadata": { 1247 | "id": "E2CfSkNVqKtL" 1248 | }, 1249 | "source": [ 1250 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n", 1251 | "embedding_dim = 8 # mcmahan17a, fedprox, qFFL\n", 1252 | "# hidden_dim = 100 # fedprox paper\n", 1253 | "hidden_dim = 256 # mcmahan17a, fedprox impl\n", 1254 | "num_classes = len(corpus)\n", 1255 | "classes = list(range(num_classes))\n", 1256 | "lstm_layers = 2 # mcmahan17a, fedprox, qFFL\n", 1257 | "dropout = 0.1 # TODO\n" 1258 | ], 1259 | "execution_count": null, 1260 | "outputs": [] 1261 | }, 1262 | { 1263 | "cell_type": "code", 1264 | "metadata": { 1265 | "id": "OvCgT_qbmFAO" 1266 | }, 1267 | "source": [ 1268 | "class Hyperparameters():\n", 1269 | "\n", 1270 | " def __init__(self, total_clients):\n", 1271 | " # number of training rounds\n", 1272 | " self.rounds = 50\n", 1273 | " # client fraction\n", 1274 | " self.C = 0.5\n", 1275 | " # number of clients\n", 1276 | " self.K = total_clients\n", 1277 | " # number of training passes on local dataset for each roung\n", 1278 | " # self.E = 20\n", 1279 | " self.E = 1\n", 1280 | " # batch size\n", 1281 | " self.batch_size = 10\n", 1282 | " # learning Rate\n", 1283 | " self.lr = 0.8\n", 1284 | " # proximal term constant\n", 1285 | " # self.mu = 0.0\n", 1286 | " self.mu = 0.001\n", 1287 | " # percentage of clients to have fewer than E epochs\n", 1288 | " self.percentage = 0\n", 1289 | " # self.percentage = 50\n", 1290 | " # self.percentage = 90\n", 1291 | " # target test accuracy\n", 1292 | " self.target_test_accuracy= 99.0\n", 1293 | " # self.target_test_accuracy=96.0" 1294 | ], 1295 | "execution_count": null, 1296 | "outputs": [] 1297 | }, 1298 | { 1299 | "cell_type": "code", 1300 | "metadata": { 1301 | "id": "m_JVF83mfM3f" 1302 | }, 1303 | "source": [ 1304 | "exp_log = dict()" 1305 | ], 1306 | "execution_count": null, 1307 | "outputs": [] 1308 | }, 1309 | { 1310 | "cell_type": "markdown", 1311 | "metadata": { 1312 | "id": "rYOPtnYoqhWd" 1313 | }, 1314 | "source": [ 1315 | "### IID" 1316 | ] 1317 | }, 1318 | { 1319 | "cell_type": "code", 1320 | "metadata": { 1321 | "id": "FRKc7NrzqKpU" 1322 | }, 1323 | "source": [ 1324 | "train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True) # Not using val_ds but makes train eval periods faster\n", 1325 | "\n", 1326 | "total_clients = len(data_dict.keys())\n", 1327 | "'Total users:', total_clients" 1328 | ], 1329 | "execution_count": null, 1330 | "outputs": [] 1331 | }, 1332 | { 1333 | "cell_type": "code", 1334 | "metadata": { 1335 | "id": "eaKtpKT5q1q_" 1336 | }, 1337 | "source": [ 1338 | "hparams = Hyperparameters(total_clients)\n", 1339 | "hparams.__dict__" 1340 | ], 1341 | "execution_count": null, 1342 | "outputs": [] 1343 | }, 1344 | { 1345 | "cell_type": "code", 1346 | "metadata": { 1347 | "id": "5ilYMClTV_WR" 1348 | }, 1349 | "source": [ 1350 | "# Sweeping parameter\n", 1351 | "PARAM_NAME = 'clients_fraction'\n", 1352 | "PARAM_VALUE = hparams.C\n", 1353 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n", 1354 | "exp_id" 1355 | ], 1356 | "execution_count": null, 1357 | "outputs": [] 1358 | }, 1359 | { 1360 | "cell_type": "code", 1361 | "metadata": { 1362 | "id": "xAhy4CWVZy3F" 1363 | }, 1364 | "source": [ 1365 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n", 1366 | "os.makedirs(EXP_DIR, exist_ok=True)\n", 1367 | "\n", 1368 | "# tb_logger = SummaryWriter(log_dir)\n", 1369 | "# print(f'TBoard logger created at: {log_dir}')\n", 1370 | "\n", 1371 | "title = 'LSTM FedProx on IID'" 1372 | ], 1373 | "execution_count": null, 1374 | "outputs": [] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "metadata": { 1379 | "id": "LwTdeiv8q8_L" 1380 | }, 1381 | "source": [ 1382 | "def run_experiment(run_id):\n", 1383 | "\n", 1384 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n", 1385 | " embedding_dim=embedding_dim, \n", 1386 | " hidden_dim=hidden_dim,\n", 1387 | " classes=num_classes,\n", 1388 | " lstm_layers=lstm_layers,\n", 1389 | " dropout=dropout,\n", 1390 | " batch_first=True\n", 1391 | " )\n", 1392 | "\n", 1393 | " if torch.cuda.is_available():\n", 1394 | " shakespeare_lstm.cuda()\n", 1395 | " \n", 1396 | " test_history = []\n", 1397 | "\n", 1398 | " lstm_iid_trained, test_history = training(shakespeare_lstm,\n", 1399 | " hparams.rounds, hparams.batch_size, hparams.lr,\n", 1400 | " train_ds,\n", 1401 | " data_dict,\n", 1402 | " test_ds,\n", 1403 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n", 1404 | " title, \"green\",\n", 1405 | " hparams.target_test_accuracy,\n", 1406 | " corpus, # classes\n", 1407 | " history=test_history,\n", 1408 | " algorithm='fedavg',\n", 1409 | " # tb_logger=tb_writer\n", 1410 | " )\n", 1411 | " \n", 1412 | "\n", 1413 | " final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n", 1414 | " print(f'\\n\\n========================================================\\n\\n')\n", 1415 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n", 1416 | "\n", 1417 | " log = {\n", 1418 | " 'history': test_history,\n", 1419 | " 'hyperparams': hparams.__dict__\n", 1420 | " }\n", 1421 | "\n", 1422 | " with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:\n", 1423 | " pickle.dump(log, file)\n", 1424 | "\n", 1425 | " return test_history\n" 1426 | ], 1427 | "execution_count": null, 1428 | "outputs": [] 1429 | }, 1430 | { 1431 | "cell_type": "code", 1432 | "metadata": { 1433 | "id": "gSU61KsSq87G" 1434 | }, 1435 | "source": [ 1436 | "exp_history = list()\n", 1437 | "for run_id in range(2): # TOTAL RUNS\n", 1438 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n", 1439 | " exp_history.append(run_experiment(run_id))\n", 1440 | " print(f'\\n\\n========================================================\\n\\n')" 1441 | ], 1442 | "execution_count": null, 1443 | "outputs": [] 1444 | }, 1445 | { 1446 | "cell_type": "code", 1447 | "metadata": { 1448 | "id": "us-HifGq3Uhf" 1449 | }, 1450 | "source": [ 1451 | "exp_log[title] = {\n", 1452 | " 'history': exp_history,\n", 1453 | " 'hyperparams': hparams.__dict__\n", 1454 | "}" 1455 | ], 1456 | "execution_count": null, 1457 | "outputs": [] 1458 | }, 1459 | { 1460 | "cell_type": "code", 1461 | "metadata": { 1462 | "id": "qDGpo4ug33dN" 1463 | }, 1464 | "source": [ 1465 | "df = None\n", 1466 | "for i, e in enumerate(exp_history):\n", 1467 | " if i == 0:\n", 1468 | " df = pd.json_normalize(e)\n", 1469 | " continue\n", 1470 | " df = df + pd.json_normalize(e)\n", 1471 | " \n", 1472 | "df_avg = df / len(exp_history)\n", 1473 | "avg_history = df_avg.to_dict(orient='records')" 1474 | ], 1475 | "execution_count": null, 1476 | "outputs": [] 1477 | }, 1478 | { 1479 | "cell_type": "code", 1480 | "metadata": { 1481 | "id": "Hf77BQAD36Eq" 1482 | }, 1483 | "source": [ 1484 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')" 1485 | ], 1486 | "execution_count": null, 1487 | "outputs": [] 1488 | }, 1489 | { 1490 | "cell_type": "code", 1491 | "metadata": { 1492 | "id": "wJClynRJ38Dh" 1493 | }, 1494 | "source": [ 1495 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')" 1496 | ], 1497 | "execution_count": null, 1498 | "outputs": [] 1499 | }, 1500 | { 1501 | "cell_type": "code", 1502 | "metadata": { 1503 | "id": "N36h40Zs1h_f" 1504 | }, 1505 | "source": [ 1506 | "with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:\n", 1507 | " pickle.dump(exp_log, file)" 1508 | ], 1509 | "execution_count": null, 1510 | "outputs": [] 1511 | }, 1512 | { 1513 | "cell_type": "markdown", 1514 | "metadata": { 1515 | "id": "BaoYWkWgqvUQ" 1516 | }, 1517 | "source": [ 1518 | "### Non-IID" 1519 | ] 1520 | }, 1521 | { 1522 | "cell_type": "code", 1523 | "metadata": { 1524 | "id": "lVmhd_791lk7" 1525 | }, 1526 | "source": [ 1527 | "exp_log = dict()" 1528 | ], 1529 | "execution_count": null, 1530 | "outputs": [] 1531 | }, 1532 | { 1533 | "cell_type": "code", 1534 | "metadata": { 1535 | "id": "pILgaho8qKgF" 1536 | }, 1537 | "source": [ 1538 | "data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)\n", 1539 | "\n", 1540 | "total_clients = len(data_dict.keys())\n", 1541 | "'Total users:', total_clients" 1542 | ], 1543 | "execution_count": null, 1544 | "outputs": [] 1545 | }, 1546 | { 1547 | "cell_type": "code", 1548 | "metadata": { 1549 | "id": "Y3o7qgBcqKX_" 1550 | }, 1551 | "source": [ 1552 | "hparams = Hyperparameters(total_clients)\n", 1553 | "hparams.__dict__" 1554 | ], 1555 | "execution_count": null, 1556 | "outputs": [] 1557 | }, 1558 | { 1559 | "cell_type": "code", 1560 | "metadata": { 1561 | "id": "VANr1h0Pq51N" 1562 | }, 1563 | "source": [ 1564 | "# Sweeping parameter\n", 1565 | "PARAM_NAME = 'clients_fraction'\n", 1566 | "PARAM_VALUE = hparams.C\n", 1567 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n", 1568 | "exp_id" 1569 | ], 1570 | "execution_count": null, 1571 | "outputs": [] 1572 | }, 1573 | { 1574 | "cell_type": "code", 1575 | "metadata": { 1576 | "id": "yXgYFIyZ4ipm" 1577 | }, 1578 | "source": [ 1579 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n", 1580 | "os.makedirs(EXP_DIR, exist_ok=True)\n", 1581 | "\n", 1582 | "# tb_logger = SummaryWriter(log_dir)\n", 1583 | "# print(f'TBoard logger created at: {log_dir}')\n", 1584 | "\n", 1585 | "title = 'LSTM FedProx on Non-IID'" 1586 | ], 1587 | "execution_count": null, 1588 | "outputs": [] 1589 | }, 1590 | { 1591 | "cell_type": "code", 1592 | "metadata": { 1593 | "id": "Vnv7UaE0q6dG" 1594 | }, 1595 | "source": [ 1596 | "def run_experiment(run_id):\n", 1597 | "\n", 1598 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,\n", 1599 | " embedding_dim=embedding_dim,\n", 1600 | " hidden_dim=hidden_dim,\n", 1601 | " classes=num_classes,\n", 1602 | " lstm_layers=lstm_layers,\n", 1603 | " dropout=dropout,\n", 1604 | " batch_first=True\n", 1605 | " )\n", 1606 | "\n", 1607 | " if torch.cuda.is_available():\n", 1608 | " shakespeare_lstm.cuda()\n", 1609 | "\n", 1610 | " test_history = []\n", 1611 | "\n", 1612 | " lstm_non_iid_trained, test_history = training(shakespeare_lstm,\n", 1613 | " hparams.rounds, hparams.batch_size, hparams.lr,\n", 1614 | " None, # ds empty as it is included in data_dict\n", 1615 | " data_dict,\n", 1616 | " test_ds,\n", 1617 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n", 1618 | " title, \"green\",\n", 1619 | " hparams.target_test_accuracy,\n", 1620 | " corpus, # classes\n", 1621 | " history=test_history,\n", 1622 | " algorithm='fedavg',\n", 1623 | " # tb_logger=tb_writer\n", 1624 | " )\n", 1625 | "\n", 1626 | " \n", 1627 | "\n", 1628 | " final_scores = testing(lstm_non_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n", 1629 | " print(f'\\n\\n========================================================\\n\\n')\n", 1630 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n", 1631 | "\n", 1632 | " log = {\n", 1633 | " 'history': test_history,\n", 1634 | " 'hyperparams': hparams.__dict__\n", 1635 | " }\n", 1636 | "\n", 1637 | " with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:\n", 1638 | " pickle.dump(log, file)\n", 1639 | "\n", 1640 | " return test_history\n" 1641 | ], 1642 | "execution_count": null, 1643 | "outputs": [] 1644 | }, 1645 | { 1646 | "cell_type": "code", 1647 | "metadata": { 1648 | "id": "0pLbVBwVq6Uw" 1649 | }, 1650 | "source": [ 1651 | "exp_history = list()\n", 1652 | "for run_id in range(2): # TOTAL RUNS\n", 1653 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n", 1654 | " exp_history.append(run_experiment(run_id))\n", 1655 | " print(f'\\n\\n========================================================\\n\\n')" 1656 | ], 1657 | "execution_count": null, 1658 | "outputs": [] 1659 | }, 1660 | { 1661 | "cell_type": "code", 1662 | "metadata": { 1663 | "id": "n5F38z5C4qw9" 1664 | }, 1665 | "source": [ 1666 | "exp_log[title] = {\n", 1667 | " 'history': exp_history,\n", 1668 | " 'hyperparams': hparams.__dict__\n", 1669 | "}" 1670 | ], 1671 | "execution_count": null, 1672 | "outputs": [] 1673 | }, 1674 | { 1675 | "cell_type": "code", 1676 | "metadata": { 1677 | "id": "inIGn3Mh4qpO" 1678 | }, 1679 | "source": [ 1680 | "df = None\n", 1681 | "for i, e in enumerate(exp_history):\n", 1682 | " if i == 0:\n", 1683 | " df = pd.json_normalize(e)\n", 1684 | " continue\n", 1685 | " df = df + pd.json_normalize(e)\n", 1686 | " \n", 1687 | "df_avg = df / len(exp_history)\n", 1688 | "avg_history = df_avg.to_dict(orient='records')" 1689 | ], 1690 | "execution_count": null, 1691 | "outputs": [] 1692 | }, 1693 | { 1694 | "cell_type": "code", 1695 | "metadata": { 1696 | "id": "z8ngcls64qjc" 1697 | }, 1698 | "source": [ 1699 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')" 1700 | ], 1701 | "execution_count": null, 1702 | "outputs": [] 1703 | }, 1704 | { 1705 | "cell_type": "code", 1706 | "metadata": { 1707 | "id": "GR9vjtYs4qBX" 1708 | }, 1709 | "source": [ 1710 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')" 1711 | ], 1712 | "execution_count": null, 1713 | "outputs": [] 1714 | }, 1715 | { 1716 | "cell_type": "markdown", 1717 | "metadata": { 1718 | "id": "adK1OTS-40Z8" 1719 | }, 1720 | "source": [ 1721 | "### Pickle Experiment Results" 1722 | ] 1723 | }, 1724 | { 1725 | "cell_type": "code", 1726 | "metadata": { 1727 | "id": "i5nl-hsa4zqw" 1728 | }, 1729 | "source": [ 1730 | "with open(f'{EXP_DIR}/results_niid.pkl', 'wb') as file:\n", 1731 | " pickle.dump(exp_log, file)" 1732 | ], 1733 | "execution_count": null, 1734 | "outputs": [] 1735 | } 1736 | ] 1737 | } -------------------------------------------------------------------------------- /Shakespeare/FedProx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "FedProx.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "TWDjzbCzfUFt", 10 | "cce_-qnxhD4n", 11 | "MFvc8mLoouKa", 12 | "sWVOxcAao2_t", 13 | "vFFAfTOwpk4j", 14 | "c640e4NnpksE", 15 | "VQ9PZM0Gp9ve", 16 | "3crFDN0xqGu6", 17 | "YXtGLkoAqLIW" 18 | ], 19 | "toc_visible": true, 20 | "authorship_tag": "ABX9TyOeIyifUjZfud7miIF289x8", 21 | "include_colab_link": true 22 | }, 23 | "kernelspec": { 24 | "name": "python3", 25 | "display_name": "Python 3" 26 | }, 27 | "language_info": { 28 | "name": "python" 29 | }, 30 | "accelerator": "GPU" 31 | }, 32 | "cells": [ 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "id": "view-in-github", 37 | "colab_type": "text" 38 | }, 39 | "source": [ 40 | "\"Open" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "ZXnVpPCFfQkm" 47 | }, 48 | "source": [ 49 | "# FedPerf - Shakespeare + FedProx algorithm" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "TWDjzbCzfUFt" 56 | }, 57 | "source": [ 58 | "## Setup & Dependencies Installation" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "metadata": { 64 | "id": "z_f-rWudW0Fo" 65 | }, 66 | "source": [ 67 | "%%capture\n", 68 | "!pip install torchsummaryX unidecode" 69 | ], 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "metadata": { 76 | "id": "3TmMInb_fnw_" 77 | }, 78 | "source": [ 79 | "%load_ext tensorboard\n", 80 | "\n", 81 | "import copy\n", 82 | "from functools import reduce\n", 83 | "import json\n", 84 | "import matplotlib.pyplot as plt\n", 85 | "import numpy as np\n", 86 | "import os\n", 87 | "import pickle\n", 88 | "import pandas as pd\n", 89 | "import random\n", 90 | "from sklearn.model_selection import train_test_split\n", 91 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", 92 | "import time\n", 93 | "import torch\n", 94 | "from torch.autograd import Variable\n", 95 | "import torch.nn as nn\n", 96 | "import torch.nn.functional as F\n", 97 | "from torch.utils.data import Dataset\n", 98 | "from torch.utils.data.dataloader import DataLoader\n", 99 | "from torch.utils.data.sampler import Sampler\n", 100 | "from torch.utils.tensorboard import SummaryWriter\n", 101 | "from torchsummary import summary\n", 102 | "from torchsummaryX import summary as summaryx\n", 103 | "from torchvision import transforms, utils, datasets\n", 104 | "from tqdm.notebook import tqdm\n", 105 | "from unidecode import unidecode\n", 106 | "\n", 107 | "%matplotlib inline\n", 108 | "\n", 109 | "# Check assigned GPU\n", 110 | "gpu_info = !nvidia-smi\n", 111 | "gpu_info = '\\n'.join(gpu_info)\n", 112 | "if gpu_info.find('failed') >= 0:\n", 113 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n", 114 | " print('and then re-execute this cell.')\n", 115 | "else:\n", 116 | " print(gpu_info)\n", 117 | "\n", 118 | "# set manual seed for reproducibility\n", 119 | "RANDOM_SEED = 42\n", 120 | "\n", 121 | "# general reproducibility\n", 122 | "random.seed(RANDOM_SEED)\n", 123 | "np.random.seed(RANDOM_SEED)\n", 124 | "torch.manual_seed(RANDOM_SEED)\n", 125 | "torch.cuda.manual_seed(RANDOM_SEED)\n", 126 | "\n", 127 | "# gpu training specific\n", 128 | "torch.backends.cudnn.deterministic = True\n", 129 | "torch.backends.cudnn.benchmark = False" 130 | ], 131 | "execution_count": null, 132 | "outputs": [] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "id": "8Mfqv6uHfwh-" 138 | }, 139 | "source": [ 140 | "## Mount GDrive" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "metadata": { 146 | "id": "75qbJwxsGj-k" 147 | }, 148 | "source": [ 149 | "BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/FedProx'" 150 | ], 151 | "execution_count": null, 152 | "outputs": [] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "metadata": { 157 | "id": "Ost-6lXSfveC" 158 | }, 159 | "source": [ 160 | "try:\n", 161 | " from google.colab import drive\n", 162 | " drive.mount('/content/drive')\n", 163 | " os.makedirs(BASE_DIR, exist_ok=True)\n", 164 | "except:\n", 165 | " print(\"WARNING: Results won't be stored on GDrive\")\n", 166 | " BASE_DIR = './'\n", 167 | "\n" 168 | ], 169 | "execution_count": null, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "Y-ZegBArf4xQ" 176 | }, 177 | "source": [ 178 | "## Loading Dataset" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "hf03LRxof7Zj" 185 | }, 186 | "source": [ 187 | "!rm -Rf data\n", 188 | "!mkdir -p data scripts" 189 | ], 190 | "execution_count": null, 191 | "outputs": [] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "metadata": { 196 | "id": "ngygA4-Fgobx" 197 | }, 198 | "source": [ 199 | "GENERATE_DATASET = False # If False, download the dataset provided by the q-FFL paper\n", 200 | "DATA_DIR = 'data/'\n", 201 | "# Dataset generation params\n", 202 | "SAMPLES_FRACTION = 1. # If using an already generated dataset\n", 203 | "# SAMPLES_FRACTION = 0.2 # Fraction of total samples in the dataset - FedProx default script\n", 204 | "# SAMPLES_FRACTION = 0.05 # Fraction of total samples in the dataset - qFFL\n", 205 | "TRAIN_FRACTION = 0.8 # Train set size\n", 206 | "MIN_SAMPLES = 0 # Min samples per client (for filtering purposes) - FedProx\n", 207 | "# MIN_SAMPLES = 64 # Min samples per client (for filtering purposes) - qFFL" 208 | ], 209 | "execution_count": null, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "metadata": { 215 | "id": "nUmwJgJygoYD" 216 | }, 217 | "source": [ 218 | "# Download raw dataset\n", 219 | "# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt\n", 220 | "!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt" 221 | ], 222 | "execution_count": null, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "id": "4dCvx80BgoVr" 229 | }, 230 | "source": [ 231 | "if not GENERATE_DATASET:\n", 232 | " !rm -Rf data/train data/test\n", 233 | " !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD # Download Shakespeare dataset used by the FedProx paper\n", 234 | " !unzip shakespeare.zip\n", 235 | " !mv -f shakespeare_paper/train data/\n", 236 | " !mv -f shakespeare_paper/test data/\n", 237 | " !rm -R shakespeare_paper/ shakespeare.zip\n" 238 | ], 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "metadata": { 245 | "id": "a4pzFvPvhQhq" 246 | }, 247 | "source": [ 248 | "corpus = []\n", 249 | "with open('data/shakespeare.txt', 'r') as f:\n", 250 | " data = list(unidecode(f.read()))\n", 251 | " corpus = list(set(list(data)))\n", 252 | "print('Corpus Length:', len(corpus))" 253 | ], 254 | "execution_count": null, 255 | "outputs": [] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": { 260 | "id": "cce_-qnxhD4n" 261 | }, 262 | "source": [ 263 | "#### Dataset Preprocessing script" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "metadata": { 269 | "id": "Rt13M4IcgoTV" 270 | }, 271 | "source": [ 272 | "%%capture\n", 273 | "if GENERATE_DATASET:\n", 274 | " # Download dataset generation scripts\n", 275 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py\n", 276 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py\n", 277 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py\n", 278 | "\n", 279 | " # Download data preprocessing scripts\n", 280 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py\n", 281 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py" 282 | ], 283 | "execution_count": null, 284 | "outputs": [] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "metadata": { 289 | "id": "EIEyRW27goPo" 290 | }, 291 | "source": [ 292 | "# Running scripts\n", 293 | "if GENERATE_DATASET:\n", 294 | " !mkdir -p data/raw_data data/all_data data/train data/test\n", 295 | " !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data\n", 296 | " !python scripts/gen_all_data.py" 297 | ], 298 | "execution_count": null, 299 | "outputs": [] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": { 304 | "id": "mq8V6v_4hhhD" 305 | }, 306 | "source": [ 307 | "#### Dataset class" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "metadata": { 313 | "id": "H2SjEBKoWDxv" 314 | }, 315 | "source": [ 316 | "class ShakespeareDataset(Dataset):\n", 317 | " def __init__(self, x, y, corpus, seq_length):\n", 318 | " self.x = x\n", 319 | " self.y = y\n", 320 | " self.corpus = corpus\n", 321 | " self.corpus_size = len(self.corpus)\n", 322 | " super(ShakespeareDataset, self).__init__()\n", 323 | "\n", 324 | " def __len__(self):\n", 325 | " return len(self.x)\n", 326 | "\n", 327 | " def __repr__(self):\n", 328 | " return f'{self.__class__} - (length: {self.__len__()})'\n", 329 | "\n", 330 | " def __getitem__(self, i):\n", 331 | " input_seq = self.x[i]\n", 332 | " next_char = self.y[i]\n", 333 | " # print('\\tgetitem', i, input_seq, next_char)\n", 334 | " input_value = self.text2charindxs(input_seq)\n", 335 | " target_value = self.get_label_from_char(next_char)\n", 336 | " return input_value, target_value\n", 337 | "\n", 338 | " def text2charindxs(self, text):\n", 339 | " tensor = torch.zeros(len(text), dtype=torch.int32)\n", 340 | " for i, c in enumerate(text):\n", 341 | " tensor[i] = self.get_label_from_char(c)\n", 342 | " return tensor\n", 343 | "\n", 344 | " def get_label_from_char(self, c):\n", 345 | " return self.corpus.index(c)\n", 346 | "\n", 347 | " def get_char_from_label(self, l):\n", 348 | " return self.corpus[l]" 349 | ], 350 | "execution_count": null, 351 | "outputs": [] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "9fgJtS62lYAN" 357 | }, 358 | "source": [ 359 | "##### Federated Dataset" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "metadata": { 365 | "id": "5DqL5pTmgn5X" 366 | }, 367 | "source": [ 368 | "class ShakespeareFedDataset(ShakespeareDataset):\n", 369 | " def __init__(self, x, y, corpus, seq_length):\n", 370 | " super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)\n", 371 | "\n", 372 | " def dataloader(self, batch_size, shuffle=True):\n", 373 | " return DataLoader(self,\n", 374 | " batch_size=batch_size,\n", 375 | " shuffle=shuffle,\n", 376 | " num_workers=0)\n" 377 | ], 378 | "execution_count": null, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": { 384 | "id": "XelbyPsDlfgb" 385 | }, 386 | "source": [ 387 | "## Partitioning & Data Loaders" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": { 393 | "id": "IOBblyFGlwlU" 394 | }, 395 | "source": [ 396 | "### IID" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "metadata": { 402 | "id": "cSZFWKmsgn1p" 403 | }, 404 | "source": [ 405 | "def iid_partition_(dataset, clients):\n", 406 | " \"\"\"\n", 407 | " I.I.D paritioning of data over clients\n", 408 | " Shuffle the data\n", 409 | " Split it between clients\n", 410 | " \n", 411 | " params:\n", 412 | " - dataset (torch.utils.Dataset): Dataset\n", 413 | " - clients (int): Number of Clients to split the data between\n", 414 | "\n", 415 | " returns:\n", 416 | " - Dictionary of image indexes for each client\n", 417 | " \"\"\"\n", 418 | "\n", 419 | " num_items_per_client = int(len(dataset)/clients)\n", 420 | " client_dict = {}\n", 421 | " image_idxs = [i for i in range(len(dataset))]\n", 422 | "\n", 423 | " for i in range(clients):\n", 424 | " client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))\n", 425 | " image_idxs = list(set(image_idxs) - client_dict[i])\n", 426 | "\n", 427 | " return client_dict" 428 | ], 429 | "execution_count": null, 430 | "outputs": [] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "metadata": { 435 | "id": "-lGwDyhSll9h" 436 | }, 437 | "source": [ 438 | "def iid_partition(corpus, seq_length=80, val_split=False):\n", 439 | "\n", 440 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n", 441 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n", 442 | "\n", 443 | " with open(train_file, 'r') as file:\n", 444 | " data_train = json.loads(unidecode(file.read()))\n", 445 | "\n", 446 | " with open(test_file, 'r') as file:\n", 447 | " data_test = json.loads(unidecode(file.read()))\n", 448 | "\n", 449 | " \n", 450 | " total_samples_train = sum(data_train['num_samples'])\n", 451 | "\n", 452 | " data_dict = {}\n", 453 | "\n", 454 | " x_train, y_train = [], []\n", 455 | " x_test, y_test = [], []\n", 456 | " # x_val, y_val = [], []\n", 457 | "\n", 458 | " users = list(zip(data_train['users'], data_train['num_samples']))\n", 459 | " # random.shuffle(users)\n", 460 | "\n", 461 | "\n", 462 | "\n", 463 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n", 464 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n", 465 | " sample_count = 0\n", 466 | " \n", 467 | " for i, (author_id, samples) in enumerate(users):\n", 468 | "\n", 469 | " if sample_count >= total_samples:\n", 470 | " print('Max samples reached', sample_count, '/', total_samples)\n", 471 | " break\n", 472 | "\n", 473 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n", 474 | " print('SKIP', author_id, samples)\n", 475 | " continue\n", 476 | " else:\n", 477 | " udata_train = data_train['user_data'][author_id]\n", 478 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n", 479 | " \n", 480 | " sample_count += max_samples\n", 481 | " # print('sample_count', sample_count)\n", 482 | "\n", 483 | " x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])\n", 484 | " y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])\n", 485 | "\n", 486 | " author_data = data_test['user_data'][author_id]\n", 487 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n", 488 | "\n", 489 | " if val_split:\n", 490 | " x_test.extend(author_data['x'][:int(test_size / 2)])\n", 491 | " y_test.extend(author_data['y'][:int(test_size / 2)])\n", 492 | " # x_val.extend(author_data['x'][int(test_size / 2):])\n", 493 | " # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])\n", 494 | "\n", 495 | " else:\n", 496 | " x_test.extend(author_data['x'][:int(test_size)])\n", 497 | " y_test.extend(author_data['y'][:int(test_size)])\n", 498 | "\n", 499 | " train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)\n", 500 | " test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)\n", 501 | " # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)\n", 502 | "\n", 503 | " data_dict = iid_partition_(train_ds, clients=len(users))\n", 504 | "\n", 505 | " return train_ds, data_dict, test_ds" 506 | ], 507 | "execution_count": null, 508 | "outputs": [] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": { 513 | "id": "MFvc8mLoouKa" 514 | }, 515 | "source": [ 516 | "### Non-IID" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "metadata": { 522 | "id": "GZ76WsCZot9s" 523 | }, 524 | "source": [ 525 | "def noniid_partition(corpus, seq_length=80, val_split=False):\n", 526 | "\n", 527 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n", 528 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n", 529 | "\n", 530 | " with open(train_file, 'r') as file:\n", 531 | " data_train = json.loads(unidecode(file.read()))\n", 532 | "\n", 533 | " with open(test_file, 'r') as file:\n", 534 | " data_test = json.loads(unidecode(file.read()))\n", 535 | "\n", 536 | " \n", 537 | " total_samples_train = sum(data_train['num_samples'])\n", 538 | "\n", 539 | " data_dict = {}\n", 540 | "\n", 541 | " x_test, y_test = [], []\n", 542 | "\n", 543 | " users = list(zip(data_train['users'], data_train['num_samples']))\n", 544 | " # random.shuffle(users)\n", 545 | "\n", 546 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n", 547 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n", 548 | " sample_count = 0\n", 549 | " \n", 550 | " for i, (author_id, samples) in enumerate(users):\n", 551 | "\n", 552 | " if sample_count >= total_samples:\n", 553 | " print('Max samples reached', sample_count, '/', total_samples)\n", 554 | " break\n", 555 | "\n", 556 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n", 557 | " print('SKIP', author_id, samples)\n", 558 | " continue\n", 559 | " else:\n", 560 | " udata_train = data_train['user_data'][author_id]\n", 561 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n", 562 | " \n", 563 | " sample_count += max_samples\n", 564 | " # print('sample_count', sample_count)\n", 565 | "\n", 566 | " x_train = data_train['user_data'][author_id]['x'][:max_samples]\n", 567 | " y_train = data_train['user_data'][author_id]['y'][:max_samples]\n", 568 | "\n", 569 | " train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)\n", 570 | "\n", 571 | " x_val, y_val = None, None\n", 572 | " val_ds = None\n", 573 | " author_data = data_test['user_data'][author_id]\n", 574 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n", 575 | " if val_split:\n", 576 | " x_test += author_data['x'][:int(test_size / 2)]\n", 577 | " y_test += author_data['y'][:int(test_size / 2)]\n", 578 | " x_val = author_data['x'][int(test_size / 2):]\n", 579 | " y_val = author_data['y'][int(test_size / 2):int(test_size)]\n", 580 | "\n", 581 | " val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)\n", 582 | "\n", 583 | " else:\n", 584 | " x_test += author_data['x'][:int(test_size)]\n", 585 | " y_test += author_data['y'][:int(test_size)]\n", 586 | "\n", 587 | " data_dict[author_id] = {\n", 588 | " 'train_ds': train_ds,\n", 589 | " 'val_ds': val_ds\n", 590 | " }\n", 591 | "\n", 592 | " test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)\n", 593 | "\n", 594 | " return data_dict, test_ds" 595 | ], 596 | "execution_count": null, 597 | "outputs": [] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": { 602 | "id": "sWVOxcAao2_t" 603 | }, 604 | "source": [ 605 | "## Models" 606 | ] 607 | }, 608 | { 609 | "cell_type": "markdown", 610 | "metadata": { 611 | "id": "gQQQ2mLeo6EA" 612 | }, 613 | "source": [ 614 | "### Shakespeare LSTM" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "metadata": { 620 | "id": "2mGXTrXRot7R" 621 | }, 622 | "source": [ 623 | "class ShakespeareLSTM(nn.Module):\n", 624 | " \"\"\"\n", 625 | " \"\"\"\n", 626 | "\n", 627 | " def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):\n", 628 | " super(ShakespeareLSTM, self).__init__()\n", 629 | " self.input_dim = input_dim\n", 630 | " self.embedding_dim = embedding_dim\n", 631 | " self.hidden_dim = hidden_dim\n", 632 | " self.classes = classes\n", 633 | " self.no_layers = lstm_layers\n", 634 | " \n", 635 | " self.embedding = nn.Embedding(num_embeddings=self.classes,\n", 636 | " embedding_dim=self.embedding_dim)\n", 637 | " self.lstm = nn.LSTM(input_size=self.embedding_dim, \n", 638 | " hidden_size=self.hidden_dim,\n", 639 | " num_layers=self.no_layers,\n", 640 | " batch_first=batch_first, \n", 641 | " dropout=dropout if self.no_layers > 1 else 0.)\n", 642 | " self.fc = nn.Linear(hidden_dim, self.classes)\n", 643 | "\n", 644 | " def forward(self, x, hc=None):\n", 645 | " batch_size = x.size(0)\n", 646 | " x_emb = self.embedding(x)\n", 647 | " out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)\n", 648 | " dense = self.fc(ht[-1])\n", 649 | " return dense\n", 650 | " \n", 651 | " def init_hidden(self, batch_size):\n", 652 | " return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),\n", 653 | " Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))\n" 654 | ], 655 | "execution_count": null, 656 | "outputs": [] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "metadata": { 661 | "id": "5QsuJlVipMc8" 662 | }, 663 | "source": [ 664 | "#### Model Summary" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "metadata": { 670 | "id": "n_Vb0BYpot5I" 671 | }, 672 | "source": [ 673 | "batch_size = 10\n", 674 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n", 675 | "\n", 676 | "shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n", 677 | " embedding_dim=8, # mcmahan17a, fedprox, qFFL\n", 678 | " hidden_dim=256, # mcmahan17a, fedprox impl\n", 679 | " # hidden_dim=100, # fedprox paper\n", 680 | " classes=len(corpus),\n", 681 | " lstm_layers=2,\n", 682 | " dropout=0.1, # TODO:\n", 683 | " batch_first=True\n", 684 | " )\n", 685 | "\n", 686 | "if torch.cuda.is_available():\n", 687 | " shakespeare_lstm.cuda()\n", 688 | "\n", 689 | "\n", 690 | "\n", 691 | "hc = shakespeare_lstm.init_hidden(batch_size)\n", 692 | "\n", 693 | "x_sample = torch.zeros((batch_size, seq_length),\n", 694 | " dtype=torch.long,\n", 695 | " device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))\n", 696 | "\n", 697 | "x_sample[0][0] = 1\n", 698 | "x_sample\n", 699 | "\n", 700 | "print(\"\\nShakespeare LSTM SUMMARY\")\n", 701 | "print(summaryx(shakespeare_lstm, x_sample))" 702 | ], 703 | "execution_count": null, 704 | "outputs": [] 705 | }, 706 | { 707 | "cell_type": "markdown", 708 | "metadata": { 709 | "id": "qn7egnzTpeks" 710 | }, 711 | "source": [ 712 | "## FedProx Algorithm" 713 | ] 714 | }, 715 | { 716 | "cell_type": "markdown", 717 | "metadata": { 718 | "id": "vFFAfTOwpk4j" 719 | }, 720 | "source": [ 721 | "### Plot Utils" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "metadata": { 727 | "id": "oyYjWa6IpnTY" 728 | }, 729 | "source": [ 730 | "from sklearn.metrics import f1_score" 731 | ], 732 | "execution_count": null, 733 | "outputs": [] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "metadata": { 738 | "id": "367THsiTpo-C" 739 | }, 740 | "source": [ 741 | "def plot_scores(history, exp_id, title, suffix):\n", 742 | " accuracies = [x['accuracy'] for x in history]\n", 743 | " f1_macro = [x['f1_macro'] for x in history]\n", 744 | " f1_weighted = [x['f1_weighted'] for x in history]\n", 745 | "\n", 746 | " fig, ax = plt.subplots()\n", 747 | " ax.plot(accuracies, 'tab:orange')\n", 748 | " ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)\n", 749 | " ax.grid()\n", 750 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)\n", 751 | " plt.show()\n", 752 | "\n", 753 | " fig, ax = plt.subplots()\n", 754 | " ax.plot(f1_macro, 'tab:orange')\n", 755 | " ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)\n", 756 | " ax.grid()\n", 757 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')\n", 758 | " plt.show()\n", 759 | "\n", 760 | " fig, ax = plt.subplots()\n", 761 | " ax.plot(f1_weighted, 'tab:orange')\n", 762 | " ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)\n", 763 | " ax.grid()\n", 764 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')\n", 765 | " plt.show()\n", 766 | "\n", 767 | "\n", 768 | "def plot_losses(history, exp_id, title, suffix):\n", 769 | " val_losses = [x['loss'] for x in history]\n", 770 | " train_losses = [x['train_loss'] for x in history]\n", 771 | "\n", 772 | " fig, ax = plt.subplots()\n", 773 | " ax.plot(train_losses, 'tab:orange')\n", 774 | " ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)\n", 775 | " ax.grid()\n", 776 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')\n", 777 | " plt.show()\n", 778 | "\n", 779 | " fig, ax = plt.subplots()\n", 780 | " ax.plot(val_losses, 'tab:orange')\n", 781 | " ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)\n", 782 | " ax.grid()\n", 783 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')\n", 784 | " plt.show()\n" 785 | ], 786 | "execution_count": null, 787 | "outputs": [] 788 | }, 789 | { 790 | "cell_type": "markdown", 791 | "metadata": { 792 | "id": "c640e4NnpksE" 793 | }, 794 | "source": [ 795 | "### Systems Heterogeneity Simulations\n", 796 | "\n", 797 | "Generate epochs for selected clients based on percentage of devices that corresponds to heterogeneity. \n", 798 | "\n", 799 | "Assign x number of epochs (chosen unifirmly at random between [1, E]) to 0%, 50% or 90% of the selected devices, respectively. Settings where 0% devices perform fewer than E epochs of work correspond to the environments without system heterogeneity, while 90% of the devices sending their partial solutions corresponds to highly heterogenous system." 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "zuEZYnl5ot2m" 806 | }, 807 | "source": [ 808 | "def GenerateLocalEpochs(percentage, size, max_epochs):\n", 809 | " ''' Method generates list of epochs for selected clients\n", 810 | " to replicate system heteroggeneity\n", 811 | "\n", 812 | " Params:\n", 813 | " percentage: percentage of clients to have fewer than E epochs\n", 814 | " size: total size of the list\n", 815 | " max_epochs: maximum value for local epochs\n", 816 | " \n", 817 | " Returns:\n", 818 | " List of size epochs for each Client Update\n", 819 | "\n", 820 | " '''\n", 821 | "\n", 822 | " # if percentage is 0 then each client runs for E epochs\n", 823 | " if percentage == 0:\n", 824 | " return np.array([max_epochs]*size)\n", 825 | " else:\n", 826 | " # get the number of clients to have fewer than E epochs\n", 827 | " heterogenous_size = int((percentage/100) * size)\n", 828 | "\n", 829 | " # generate random uniform epochs of heterogenous size between 1 and E\n", 830 | " epoch_list = np.random.randint(1, max_epochs, heterogenous_size)\n", 831 | "\n", 832 | " # the rest of the clients will have E epochs\n", 833 | " remaining_size = size - heterogenous_size\n", 834 | " rem_list = [max_epochs]*remaining_size\n", 835 | "\n", 836 | " epoch_list = np.append(epoch_list, rem_list, axis=0)\n", 837 | " \n", 838 | " # shuffle the list and return\n", 839 | " np.random.shuffle(epoch_list)\n", 840 | "\n", 841 | " return epoch_list" 842 | ], 843 | "execution_count": null, 844 | "outputs": [] 845 | }, 846 | { 847 | "cell_type": "markdown", 848 | "metadata": { 849 | "id": "VQ9PZM0Gp9ve" 850 | }, 851 | "source": [ 852 | "### Local Training (Client Update)" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "metadata": { 858 | "id": "EDJFltwdotzZ" 859 | }, 860 | "source": [ 861 | "class CustomDataset(Dataset):\n", 862 | " def __init__(self, dataset, idxs):\n", 863 | " self.dataset = dataset\n", 864 | " self.idxs = list(idxs)\n", 865 | "\n", 866 | " def __len__(self):\n", 867 | " return len(self.idxs)\n", 868 | "\n", 869 | " def __getitem__(self, item):\n", 870 | " data, label = self.dataset[self.idxs[item]]\n", 871 | " return data, label" 872 | ], 873 | "execution_count": null, 874 | "outputs": [] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "metadata": { 879 | "id": "HtRzU5Yepddq" 880 | }, 881 | "source": [ 882 | "class ClientUpdate(object):\n", 883 | " def __init__(self, dataset, batchSize, learning_rate, epochs, idxs, mu, algorithm):\n", 884 | " # self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True)\n", 885 | " if hasattr(dataset, 'dataloader'):\n", 886 | " self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)\n", 887 | " else:\n", 888 | " self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)\n", 889 | "\n", 890 | " self.algorithm = algorithm\n", 891 | " self.learning_rate = learning_rate\n", 892 | " self.epochs = epochs\n", 893 | " self.mu = mu\n", 894 | "\n", 895 | " def train(self, model):\n", 896 | " # print(\"Client training for {} epochs.\".format(self.epochs))\n", 897 | " criterion = nn.CrossEntropyLoss()\n", 898 | " proximal_criterion = nn.MSELoss(reduction='mean')\n", 899 | " optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)\n", 900 | "\n", 901 | " # use the weights of global model for proximal term calculation\n", 902 | " global_model = copy.deepcopy(model)\n", 903 | "\n", 904 | " # calculate local training time\n", 905 | " start_time = time.time()\n", 906 | "\n", 907 | "\n", 908 | " e_loss = []\n", 909 | " for epoch in range(1, self.epochs+1):\n", 910 | "\n", 911 | " train_loss = 0.0\n", 912 | "\n", 913 | " model.train()\n", 914 | " for data, labels in self.train_loader:\n", 915 | "\n", 916 | " if torch.cuda.is_available():\n", 917 | " data, labels = data.cuda(), labels.cuda()\n", 918 | "\n", 919 | " # clear the gradients\n", 920 | " optimizer.zero_grad()\n", 921 | " # make a forward pass\n", 922 | " output = model(data)\n", 923 | "\n", 924 | " # calculate the loss + the proximal term\n", 925 | " _, pred = torch.max(output, 1)\n", 926 | "\n", 927 | " if self.algorithm == 'fedprox':\n", 928 | " proximal_term = 0.0\n", 929 | "\n", 930 | " # iterate through the current and global model parameters\n", 931 | " for w, w_t in zip(model.parameters(), global_model.parameters()) :\n", 932 | " # update the proximal term \n", 933 | " #proximal_term += torch.sum(torch.abs((w-w_t)**2))\n", 934 | " proximal_term += (w-w_t).norm(2)\n", 935 | "\n", 936 | " loss = criterion(output, labels) + (self.mu/2)*proximal_term\n", 937 | " else:\n", 938 | " loss = criterion(output, labels)\n", 939 | " \n", 940 | " # do a backwards pass\n", 941 | " loss.backward()\n", 942 | " # perform a single optimization step\n", 943 | " optimizer.step()\n", 944 | " # update training loss\n", 945 | " train_loss += loss.item()*data.size(0)\n", 946 | "\n", 947 | " # average losses\n", 948 | " train_loss = train_loss/len(self.train_loader.dataset)\n", 949 | " e_loss.append(train_loss)\n", 950 | "\n", 951 | " total_loss = sum(e_loss)/len(e_loss)\n", 952 | "\n", 953 | " return model.state_dict(), total_loss, (time.time() - start_time)" 954 | ], 955 | "execution_count": null, 956 | "outputs": [] 957 | }, 958 | { 959 | "cell_type": "markdown", 960 | "metadata": { 961 | "id": "3crFDN0xqGu6" 962 | }, 963 | "source": [ 964 | "### Server Side Training" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "metadata": { 970 | "id": "c085xSOoqEHk" 971 | }, 972 | "source": [ 973 | "def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, mu, percentage, plt_title, plt_color, target_test_accuracy,\n", 974 | " classes, algorithm=\"fedprox\", history=[], eval_every=1, tb_logger=None):\n", 975 | " \"\"\"\n", 976 | " Function implements the Federated Averaging Algorithm from the FedAvg paper.\n", 977 | " Specifically, this function is used for the server side training and weight update\n", 978 | "\n", 979 | " Params:\n", 980 | " - model: PyTorch model to train\n", 981 | " - rounds: Number of communication rounds for the client update\n", 982 | " - batch_size: Batch size for client update training\n", 983 | " - lr: Learning rate used for client update training\n", 984 | " - ds: Dataset used for training\n", 985 | " - data_dict: Type of data partition used for training (IID or non-IID)\n", 986 | " - test_data_dict: Data used for testing the model\n", 987 | " - C: Fraction of clients randomly chosen to perform computation on each round\n", 988 | " - K: Total number of clients\n", 989 | " - E: Number of training passes each client makes over its local dataset per round\n", 990 | " - mu: proximal term constant\n", 991 | " - percentage: percentage of selected client to have fewer than E epochs\n", 992 | " Returns:\n", 993 | " - model: Trained model on the server\n", 994 | " \"\"\"\n", 995 | "\n", 996 | " start = time.time()\n", 997 | "\n", 998 | " # global model weights\n", 999 | " global_weights = model.state_dict()\n", 1000 | "\n", 1001 | " # training loss\n", 1002 | " train_loss = []\n", 1003 | "\n", 1004 | " # test accuracy\n", 1005 | " test_acc = []\n", 1006 | "\n", 1007 | " # store last loss for convergence\n", 1008 | " last_loss = 0.0\n", 1009 | "\n", 1010 | " # total time taken \n", 1011 | " total_time = 0\n", 1012 | "\n", 1013 | " print(f\"System heterogeneity set to {percentage}% stragglers.\\n\")\n", 1014 | " print(f\"Picking {max(int(C*K),1 )} random clients per round.\\n\")\n", 1015 | "\n", 1016 | " users_id = list(data_dict.keys())\n", 1017 | "\n", 1018 | " for curr_round in range(1, rounds+1):\n", 1019 | " w, local_loss, lst_local_train_time = [], [], []\n", 1020 | "\n", 1021 | " m = max(int(C*K), 1)\n", 1022 | "\n", 1023 | " heterogenous_epoch_list = GenerateLocalEpochs(percentage, size=m, max_epochs=E)\n", 1024 | " heterogenous_epoch_list = np.array(heterogenous_epoch_list)\n", 1025 | " # print('heterogenous_epoch_list', len(heterogenous_epoch_list))\n", 1026 | "\n", 1027 | " S_t = np.random.choice(range(K), m, replace=False)\n", 1028 | " S_t = np.array(S_t)\n", 1029 | " print('Clients: {}/{} -> {}'.format(len(S_t), K, S_t))\n", 1030 | " \n", 1031 | " # For Federated Averaging, drop all the clients that are stragglers\n", 1032 | " if algorithm == 'fedavg':\n", 1033 | " stragglers_indices = np.argwhere(heterogenous_epoch_list < E)\n", 1034 | " heterogenous_epoch_list = np.delete(heterogenous_epoch_list, stragglers_indices)\n", 1035 | " S_t = np.delete(S_t, stragglers_indices)\n", 1036 | "\n", 1037 | " # for _, (k, epoch) in tqdm(enumerate(zip(S_t, heterogenous_epoch_list))):\n", 1038 | " for i in tqdm(range(len(S_t))):\n", 1039 | " # print('k', k)\n", 1040 | " k = S_t[i]\n", 1041 | " epoch = heterogenous_epoch_list[i]\n", 1042 | " key = users_id[k]\n", 1043 | " ds_ = ds if ds else data_dict[key]['train_ds']\n", 1044 | " idxs = data_dict[key] if ds else None\n", 1045 | " # print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')\n", 1046 | " local_update = ClientUpdate(dataset=ds_, batchSize=batch_size, learning_rate=lr, epochs=epoch, idxs=idxs, mu=mu, algorithm=algorithm)\n", 1047 | " weights, loss, local_train_time = local_update.train(model=copy.deepcopy(model))\n", 1048 | " # print(f'Local train time for {k} on {len(idxs) if idxs else len(ds_)} samples: {local_train_time}')\n", 1049 | " # print(f'Local train time: {local_train_time}')\n", 1050 | "\n", 1051 | " w.append(copy.deepcopy(weights))\n", 1052 | " local_loss.append(copy.deepcopy(loss))\n", 1053 | " lst_local_train_time.append(local_train_time)\n", 1054 | "\n", 1055 | " # calculate time to update the global weights\n", 1056 | " global_start_time = time.time()\n", 1057 | "\n", 1058 | " # updating the global weights\n", 1059 | " weights_avg = copy.deepcopy(w[0])\n", 1060 | " for k in weights_avg.keys():\n", 1061 | " for i in range(1, len(w)):\n", 1062 | " weights_avg[k] += w[i][k]\n", 1063 | "\n", 1064 | " weights_avg[k] = torch.div(weights_avg[k], len(w))\n", 1065 | "\n", 1066 | " global_weights = weights_avg\n", 1067 | "\n", 1068 | " global_end_time = time.time()\n", 1069 | "\n", 1070 | " # calculate total time \n", 1071 | " total_time += (global_end_time - global_start_time) + sum(lst_local_train_time)/len(lst_local_train_time)\n", 1072 | "\n", 1073 | " # move the updated weights to our model state dict\n", 1074 | " model.load_state_dict(global_weights)\n", 1075 | "\n", 1076 | " # loss\n", 1077 | " loss_avg = sum(local_loss) / len(local_loss)\n", 1078 | " print('Round: {}... \\tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))\n", 1079 | " train_loss.append(loss_avg)\n", 1080 | " if tb_logger:\n", 1081 | " tb_logger.add_scalar(f'Train/Loss', loss_avg, curr_round)\n", 1082 | "\n", 1083 | " # testing\n", 1084 | " # if curr_round % eval_every == 0:\n", 1085 | " test_scores = testing(model, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(classes), classes)\n", 1086 | " test_scores['train_loss'] = loss_avg\n", 1087 | " test_loss, test_accuracy = test_scores['loss'], test_scores['accuracy']\n", 1088 | " history.append(test_scores)\n", 1089 | " \n", 1090 | " # print('Round: {}... \\tAverage Loss: {} \\tTest Loss: {} \\tTest Acc: {}'.format(curr_round, round(loss_avg, 3), round(test_loss, 3), round(test_accuracy, 3)))\n", 1091 | "\n", 1092 | " if tb_logger:\n", 1093 | " tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)\n", 1094 | " tb_logger.add_scalars(f'Test/Scores', {\n", 1095 | " 'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']\n", 1096 | " }, curr_round)\n", 1097 | "\n", 1098 | " test_acc.append(test_accuracy)\n", 1099 | " # break if we achieve the target test accuracy\n", 1100 | " if test_accuracy >= target_test_accuracy:\n", 1101 | " rounds = curr_round\n", 1102 | " break\n", 1103 | "\n", 1104 | " # break if we achieve convergence, i.e., loss between two consecutive rounds is <0.0001\n", 1105 | " if algorithm == 'fedprox' and abs(loss_avg - last_loss) < 1e-5:\n", 1106 | " rounds = curr_round\n", 1107 | " break\n", 1108 | " \n", 1109 | " # update the last loss\n", 1110 | " last_loss = loss_avg\n", 1111 | "\n", 1112 | " end = time.time()\n", 1113 | " \n", 1114 | " # plot train loss\n", 1115 | " fig, ax = plt.subplots()\n", 1116 | " x_axis = np.arange(1, rounds+1)\n", 1117 | " y_axis = np.array(train_loss)\n", 1118 | " ax.plot(x_axis, y_axis)\n", 1119 | "\n", 1120 | " ax.set(xlabel='Number of Rounds', ylabel='Train Loss', title=plt_title)\n", 1121 | " ax.grid()\n", 1122 | " # fig.savefig(plt_title+'.jpg', format='jpg')\n", 1123 | "\n", 1124 | " # plot test accuracy\n", 1125 | " fig1, ax1 = plt.subplots()\n", 1126 | " x_axis1 = np.arange(1, rounds+1)\n", 1127 | " y_axis1 = np.array(test_acc)\n", 1128 | " ax1.plot(x_axis1, y_axis1)\n", 1129 | "\n", 1130 | " ax1.set(xlabel='Number of Rounds', ylabel='Test Accuracy', title=plt_title)\n", 1131 | " ax1.grid()\n", 1132 | " # fig1.savefig(plt_title+'-test.jpg', format='jpg')\n", 1133 | " \n", 1134 | " print(\"Training Done! Total time taken to Train: {}\".format(end-start))\n", 1135 | "\n", 1136 | " return model, history" 1137 | ], 1138 | "execution_count": null, 1139 | "outputs": [] 1140 | }, 1141 | { 1142 | "cell_type": "markdown", 1143 | "metadata": { 1144 | "id": "YXtGLkoAqLIW" 1145 | }, 1146 | "source": [ 1147 | "### Testing Loop" 1148 | ] 1149 | }, 1150 | { 1151 | "cell_type": "code", 1152 | "metadata": { 1153 | "id": "dQJIJno4qKvc" 1154 | }, 1155 | "source": [ 1156 | "def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):\n", 1157 | " #test loss \n", 1158 | " test_loss = 0.0\n", 1159 | " correct_class = list(0. for i in range(num_classes))\n", 1160 | " total_class = list(0. for i in range(num_classes))\n", 1161 | "\n", 1162 | " test_loader = DataLoader(dataset, batch_size=bs)\n", 1163 | " l = len(test_loader)\n", 1164 | " model.eval()\n", 1165 | " print('running validation...')\n", 1166 | " for i, (data, labels) in enumerate(tqdm(test_loader)):\n", 1167 | "\n", 1168 | " if torch.cuda.is_available():\n", 1169 | " data, labels = data.cuda(), labels.cuda()\n", 1170 | "\n", 1171 | " output = model(data)\n", 1172 | " loss = criterion(output, labels)\n", 1173 | " test_loss += loss.item()*data.size(0)\n", 1174 | "\n", 1175 | " _, pred = torch.max(output, 1)\n", 1176 | "\n", 1177 | " # For F1Score\n", 1178 | " y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()\n", 1179 | " y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()\n", 1180 | "\n", 1181 | " correct_tensor = pred.eq(labels.data.view_as(pred))\n", 1182 | " correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())\n", 1183 | "\n", 1184 | " #test accuracy for each object class\n", 1185 | " # for i in range(num_classes):\n", 1186 | " # label = labels.data[i]\n", 1187 | " # correct_class[label] += correct[i].item()\n", 1188 | " # total_class[label] += 1\n", 1189 | "\n", 1190 | " for i, lbl in enumerate(labels.data):\n", 1191 | " # print('lbl', i, lbl)\n", 1192 | " correct_class[lbl] += correct.data[i]\n", 1193 | " total_class[lbl] += 1\n", 1194 | " \n", 1195 | " # avg test loss\n", 1196 | " test_loss = test_loss/len(test_loader.dataset)\n", 1197 | " print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n", 1198 | "\n", 1199 | " # Avg F1 Score\n", 1200 | " f1_macro = f1_score(y_true, y_hat, average='macro')\n", 1201 | " # F1-Score -> weigthed to consider class imbalance\n", 1202 | " f1_weighted = f1_score(y_true, y_hat, average='weighted')\n", 1203 | " print(\"F1 Score: {:.6f} (macro) {:.6f} (weighted) %\\n\".format(f1_macro, f1_weighted))\n", 1204 | "\n", 1205 | " # print test accuracy\n", 1206 | " if print_all:\n", 1207 | " for i in range(num_classes):\n", 1208 | " if total_class[i]>0:\n", 1209 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % \n", 1210 | " (classes[i], 100 * correct_class[i] / total_class[i],\n", 1211 | " np.sum(correct_class[i]), np.sum(total_class[i])))\n", 1212 | " else:\n", 1213 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", 1214 | "\n", 1215 | " overall_accuracy = np.sum(correct_class) / np.sum(total_class)\n", 1216 | "\n", 1217 | " print('\\nFinal Test Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))\n", 1218 | "\n", 1219 | " return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}" 1220 | ], 1221 | "execution_count": null, 1222 | "outputs": [] 1223 | }, 1224 | { 1225 | "cell_type": "markdown", 1226 | "metadata": { 1227 | "id": "uxqXLBd8qbC2" 1228 | }, 1229 | "source": [ 1230 | "## Experiments" 1231 | ] 1232 | }, 1233 | { 1234 | "cell_type": "code", 1235 | "metadata": { 1236 | "id": "VRKlrkVHO8Na" 1237 | }, 1238 | "source": [ 1239 | "# FAIL-ON-PURPOSE" 1240 | ], 1241 | "execution_count": null, 1242 | "outputs": [] 1243 | }, 1244 | { 1245 | "cell_type": "code", 1246 | "metadata": { 1247 | "id": "E2CfSkNVqKtL" 1248 | }, 1249 | "source": [ 1250 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n", 1251 | "embedding_dim = 8 # mcmahan17a, fedprox, qFFL\n", 1252 | "# hidden_dim = 100 # fedprox paper\n", 1253 | "hidden_dim = 256 # mcmahan17a, fedprox impl\n", 1254 | "num_classes = len(corpus)\n", 1255 | "classes = list(range(num_classes))\n", 1256 | "lstm_layers = 2 # mcmahan17a, fedprox, qFFL\n", 1257 | "dropout = 0.1 # TODO\n" 1258 | ], 1259 | "execution_count": null, 1260 | "outputs": [] 1261 | }, 1262 | { 1263 | "cell_type": "code", 1264 | "metadata": { 1265 | "id": "OvCgT_qbmFAO" 1266 | }, 1267 | "source": [ 1268 | "class Hyperparameters():\n", 1269 | "\n", 1270 | " def __init__(self, total_clients):\n", 1271 | " # number of training rounds\n", 1272 | " self.rounds = 50\n", 1273 | " # client fraction\n", 1274 | " self.C = 0.5\n", 1275 | " # number of clients\n", 1276 | " self.K = total_clients\n", 1277 | " # number of training passes on local dataset for each roung\n", 1278 | " # self.E = 20\n", 1279 | " self.E = 1\n", 1280 | " # batch size\n", 1281 | " self.batch_size = 10\n", 1282 | " # learning Rate\n", 1283 | " self.lr = 0.8\n", 1284 | " # proximal term constant\n", 1285 | " # self.mu = 0.0\n", 1286 | " self.mu = 0.001\n", 1287 | " # percentage of clients to have fewer than E epochs\n", 1288 | " self.percentage = 0\n", 1289 | " # self.percentage = 50\n", 1290 | " # self.percentage = 90\n", 1291 | " # target test accuracy\n", 1292 | " self.target_test_accuracy= 99.0\n", 1293 | " # self.target_test_accuracy=96.0" 1294 | ], 1295 | "execution_count": null, 1296 | "outputs": [] 1297 | }, 1298 | { 1299 | "cell_type": "code", 1300 | "metadata": { 1301 | "id": "m_JVF83mfM3f" 1302 | }, 1303 | "source": [ 1304 | "exp_log = dict()" 1305 | ], 1306 | "execution_count": null, 1307 | "outputs": [] 1308 | }, 1309 | { 1310 | "cell_type": "markdown", 1311 | "metadata": { 1312 | "id": "rYOPtnYoqhWd" 1313 | }, 1314 | "source": [ 1315 | "### IID" 1316 | ] 1317 | }, 1318 | { 1319 | "cell_type": "code", 1320 | "metadata": { 1321 | "id": "FRKc7NrzqKpU" 1322 | }, 1323 | "source": [ 1324 | "train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True) # Not using val_ds but makes train eval periods faster\n", 1325 | "\n", 1326 | "total_clients = len(data_dict.keys())\n", 1327 | "'Total users:', total_clients" 1328 | ], 1329 | "execution_count": null, 1330 | "outputs": [] 1331 | }, 1332 | { 1333 | "cell_type": "code", 1334 | "metadata": { 1335 | "id": "eaKtpKT5q1q_" 1336 | }, 1337 | "source": [ 1338 | "hparams = Hyperparameters(total_clients)\n", 1339 | "hparams.__dict__" 1340 | ], 1341 | "execution_count": null, 1342 | "outputs": [] 1343 | }, 1344 | { 1345 | "cell_type": "code", 1346 | "metadata": { 1347 | "id": "5ilYMClTV_WR" 1348 | }, 1349 | "source": [ 1350 | "# Sweeping parameter\n", 1351 | "PARAM_NAME = 'clients_fraction'\n", 1352 | "PARAM_VALUE = hparams.C\n", 1353 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n", 1354 | "exp_id" 1355 | ], 1356 | "execution_count": null, 1357 | "outputs": [] 1358 | }, 1359 | { 1360 | "cell_type": "code", 1361 | "metadata": { 1362 | "id": "xAhy4CWVZy3F" 1363 | }, 1364 | "source": [ 1365 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n", 1366 | "os.makedirs(EXP_DIR, exist_ok=True)\n", 1367 | "\n", 1368 | "# tb_logger = SummaryWriter(log_dir)\n", 1369 | "# print(f'TBoard logger created at: {log_dir}')\n", 1370 | "\n", 1371 | "title = 'LSTM FedProx on IID'" 1372 | ], 1373 | "execution_count": null, 1374 | "outputs": [] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "metadata": { 1379 | "id": "LwTdeiv8q8_L" 1380 | }, 1381 | "source": [ 1382 | "def run_experiment(run_id):\n", 1383 | "\n", 1384 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n", 1385 | " embedding_dim=embedding_dim, \n", 1386 | " hidden_dim=hidden_dim,\n", 1387 | " classes=num_classes,\n", 1388 | " lstm_layers=lstm_layers,\n", 1389 | " dropout=dropout,\n", 1390 | " batch_first=True\n", 1391 | " )\n", 1392 | "\n", 1393 | " if torch.cuda.is_available():\n", 1394 | " shakespeare_lstm.cuda()\n", 1395 | " \n", 1396 | " test_history = []\n", 1397 | "\n", 1398 | " lstm_iid_trained, test_history = training(shakespeare_lstm,\n", 1399 | " hparams.rounds, hparams.batch_size, hparams.lr,\n", 1400 | " train_ds,\n", 1401 | " data_dict,\n", 1402 | " test_ds,\n", 1403 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n", 1404 | " title, \"green\",\n", 1405 | " hparams.target_test_accuracy,\n", 1406 | " corpus, # classes\n", 1407 | " history=test_history,\n", 1408 | " # tb_logger=tb_writer\n", 1409 | " )\n", 1410 | " \n", 1411 | "\n", 1412 | " final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n", 1413 | " print(f'\\n\\n========================================================\\n\\n')\n", 1414 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n", 1415 | "\n", 1416 | " log = {\n", 1417 | " 'history': test_history,\n", 1418 | " 'hyperparams': hparams.__dict__\n", 1419 | " }\n", 1420 | "\n", 1421 | " with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:\n", 1422 | " pickle.dump(log, file)\n", 1423 | "\n", 1424 | " return test_history\n" 1425 | ], 1426 | "execution_count": null, 1427 | "outputs": [] 1428 | }, 1429 | { 1430 | "cell_type": "code", 1431 | "metadata": { 1432 | "id": "gSU61KsSq87G" 1433 | }, 1434 | "source": [ 1435 | "exp_history = list()\n", 1436 | "for run_id in range(2): # TOTAL RUNS\n", 1437 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n", 1438 | " exp_history.append(run_experiment(run_id))\n", 1439 | " print(f'\\n\\n========================================================\\n\\n')" 1440 | ], 1441 | "execution_count": null, 1442 | "outputs": [] 1443 | }, 1444 | { 1445 | "cell_type": "code", 1446 | "metadata": { 1447 | "id": "us-HifGq3Uhf" 1448 | }, 1449 | "source": [ 1450 | "exp_log[title] = {\n", 1451 | " 'history': exp_history,\n", 1452 | " 'hyperparams': hparams.__dict__\n", 1453 | "}" 1454 | ], 1455 | "execution_count": null, 1456 | "outputs": [] 1457 | }, 1458 | { 1459 | "cell_type": "code", 1460 | "metadata": { 1461 | "id": "qDGpo4ug33dN" 1462 | }, 1463 | "source": [ 1464 | "df = None\n", 1465 | "for i, e in enumerate(exp_history):\n", 1466 | " if i == 0:\n", 1467 | " df = pd.json_normalize(e)\n", 1468 | " continue\n", 1469 | " df = df + pd.json_normalize(e)\n", 1470 | " \n", 1471 | "df_avg = df / len(exp_history)\n", 1472 | "avg_history = df_avg.to_dict(orient='records')" 1473 | ], 1474 | "execution_count": null, 1475 | "outputs": [] 1476 | }, 1477 | { 1478 | "cell_type": "code", 1479 | "metadata": { 1480 | "id": "Hf77BQAD36Eq" 1481 | }, 1482 | "source": [ 1483 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')" 1484 | ], 1485 | "execution_count": null, 1486 | "outputs": [] 1487 | }, 1488 | { 1489 | "cell_type": "code", 1490 | "metadata": { 1491 | "id": "wJClynRJ38Dh" 1492 | }, 1493 | "source": [ 1494 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')" 1495 | ], 1496 | "execution_count": null, 1497 | "outputs": [] 1498 | }, 1499 | { 1500 | "cell_type": "code", 1501 | "metadata": { 1502 | "id": "WKd0d7_a2e1C" 1503 | }, 1504 | "source": [ 1505 | "with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:\n", 1506 | " pickle.dump(exp_log, file)" 1507 | ], 1508 | "execution_count": null, 1509 | "outputs": [] 1510 | }, 1511 | { 1512 | "cell_type": "markdown", 1513 | "metadata": { 1514 | "id": "BaoYWkWgqvUQ" 1515 | }, 1516 | "source": [ 1517 | "### Non-IID" 1518 | ] 1519 | }, 1520 | { 1521 | "cell_type": "code", 1522 | "metadata": { 1523 | "id": "epuk1epg2jX3" 1524 | }, 1525 | "source": [ 1526 | "exp_log = dict()" 1527 | ], 1528 | "execution_count": null, 1529 | "outputs": [] 1530 | }, 1531 | { 1532 | "cell_type": "code", 1533 | "metadata": { 1534 | "id": "pILgaho8qKgF" 1535 | }, 1536 | "source": [ 1537 | "data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)\n", 1538 | "\n", 1539 | "total_clients = len(data_dict.keys())\n", 1540 | "'Total users:', total_clients" 1541 | ], 1542 | "execution_count": null, 1543 | "outputs": [] 1544 | }, 1545 | { 1546 | "cell_type": "code", 1547 | "metadata": { 1548 | "id": "Y3o7qgBcqKX_" 1549 | }, 1550 | "source": [ 1551 | "hparams = Hyperparameters(total_clients)\n", 1552 | "hparams.__dict__" 1553 | ], 1554 | "execution_count": null, 1555 | "outputs": [] 1556 | }, 1557 | { 1558 | "cell_type": "code", 1559 | "metadata": { 1560 | "id": "VANr1h0Pq51N" 1561 | }, 1562 | "source": [ 1563 | "# Sweeping parameter\n", 1564 | "PARAM_NAME = 'clients_fraction'\n", 1565 | "PARAM_VALUE = hparams.C\n", 1566 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n", 1567 | "exp_id" 1568 | ], 1569 | "execution_count": null, 1570 | "outputs": [] 1571 | }, 1572 | { 1573 | "cell_type": "code", 1574 | "metadata": { 1575 | "id": "yXgYFIyZ4ipm" 1576 | }, 1577 | "source": [ 1578 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n", 1579 | "os.makedirs(EXP_DIR, exist_ok=True)\n", 1580 | "\n", 1581 | "# tb_logger = SummaryWriter(log_dir)\n", 1582 | "# print(f'TBoard logger created at: {log_dir}')\n", 1583 | "\n", 1584 | "title = 'LSTM FedProx on Non-IID'" 1585 | ], 1586 | "execution_count": null, 1587 | "outputs": [] 1588 | }, 1589 | { 1590 | "cell_type": "code", 1591 | "metadata": { 1592 | "id": "Vnv7UaE0q6dG" 1593 | }, 1594 | "source": [ 1595 | "def run_experiment(run_id):\n", 1596 | "\n", 1597 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,\n", 1598 | " embedding_dim=embedding_dim,\n", 1599 | " hidden_dim=hidden_dim,\n", 1600 | " classes=num_classes,\n", 1601 | " lstm_layers=lstm_layers,\n", 1602 | " dropout=dropout,\n", 1603 | " batch_first=True\n", 1604 | " )\n", 1605 | "\n", 1606 | " if torch.cuda.is_available():\n", 1607 | " shakespeare_lstm.cuda()\n", 1608 | "\n", 1609 | " test_history = []\n", 1610 | "\n", 1611 | " lstm_non_iid_trained, test_history = training(shakespeare_lstm,\n", 1612 | " hparams.rounds, hparams.batch_size, hparams.lr,\n", 1613 | " None, # ds empty as it is included in data_dict\n", 1614 | " data_dict,\n", 1615 | " test_ds,\n", 1616 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n", 1617 | " title, \"green\",\n", 1618 | " hparams.target_test_accuracy,\n", 1619 | " corpus, # classes\n", 1620 | " history=test_history,\n", 1621 | " # tb_logger=tb_writer\n", 1622 | " )\n", 1623 | "\n", 1624 | " \n", 1625 | "\n", 1626 | " final_scores = testing(lstm_non_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n", 1627 | " print(f'\\n\\n========================================================\\n\\n')\n", 1628 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n", 1629 | "\n", 1630 | " log = {\n", 1631 | " 'history': test_history,\n", 1632 | " 'hyperparams': hparams.__dict__\n", 1633 | " }\n", 1634 | "\n", 1635 | " with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:\n", 1636 | " pickle.dump(log, file)\n", 1637 | "\n", 1638 | " return test_history\n" 1639 | ], 1640 | "execution_count": null, 1641 | "outputs": [] 1642 | }, 1643 | { 1644 | "cell_type": "code", 1645 | "metadata": { 1646 | "id": "0pLbVBwVq6Uw" 1647 | }, 1648 | "source": [ 1649 | "exp_history = list()\n", 1650 | "for run_id in range(2): # TOTAL RUNS\n", 1651 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n", 1652 | " exp_history.append(run_experiment(run_id))\n", 1653 | " print(f'\\n\\n========================================================\\n\\n')" 1654 | ], 1655 | "execution_count": null, 1656 | "outputs": [] 1657 | }, 1658 | { 1659 | "cell_type": "code", 1660 | "metadata": { 1661 | "id": "n5F38z5C4qw9" 1662 | }, 1663 | "source": [ 1664 | "exp_log[title] = {\n", 1665 | " 'history': exp_history,\n", 1666 | " 'hyperparams': hparams.__dict__\n", 1667 | "}" 1668 | ], 1669 | "execution_count": null, 1670 | "outputs": [] 1671 | }, 1672 | { 1673 | "cell_type": "code", 1674 | "metadata": { 1675 | "id": "inIGn3Mh4qpO" 1676 | }, 1677 | "source": [ 1678 | "df = None\n", 1679 | "for i, e in enumerate(exp_history):\n", 1680 | " if i == 0:\n", 1681 | " df = pd.json_normalize(e)\n", 1682 | " continue\n", 1683 | " df = df + pd.json_normalize(e)\n", 1684 | " \n", 1685 | "df_avg = df / len(exp_history)\n", 1686 | "avg_history = df_avg.to_dict(orient='records')" 1687 | ], 1688 | "execution_count": null, 1689 | "outputs": [] 1690 | }, 1691 | { 1692 | "cell_type": "code", 1693 | "metadata": { 1694 | "id": "z8ngcls64qjc" 1695 | }, 1696 | "source": [ 1697 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')" 1698 | ], 1699 | "execution_count": null, 1700 | "outputs": [] 1701 | }, 1702 | { 1703 | "cell_type": "code", 1704 | "metadata": { 1705 | "id": "GR9vjtYs4qBX" 1706 | }, 1707 | "source": [ 1708 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')" 1709 | ], 1710 | "execution_count": null, 1711 | "outputs": [] 1712 | }, 1713 | { 1714 | "cell_type": "markdown", 1715 | "metadata": { 1716 | "id": "adK1OTS-40Z8" 1717 | }, 1718 | "source": [ 1719 | "### Pickle Experiment Results" 1720 | ] 1721 | }, 1722 | { 1723 | "cell_type": "code", 1724 | "metadata": { 1725 | "id": "i5nl-hsa4zqw" 1726 | }, 1727 | "source": [ 1728 | "with open(f'{EXP_DIR}/results.pkl', 'wb') as file:\n", 1729 | " pickle.dump(exp_log, file)" 1730 | ], 1731 | "execution_count": null, 1732 | "outputs": [] 1733 | } 1734 | ] 1735 | } -------------------------------------------------------------------------------- /Shakespeare/qFedAvg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "qFedAvg.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "x5UTD8mqIDWl", 10 | "MAtkg7ME8PQn", 11 | "GpokHKXO8gmS", 12 | "391xbZuU8vEn", 13 | "rxAlW1YK9B_b", 14 | "JtAFxz5h9MRy", 15 | "QSDdUZjs9W-g", 16 | "rhN9isJD9Z8f", 17 | "w8OsHzAt9kDc", 18 | "p-8XOZlR9wbV", 19 | "0vQVNPmM93M7", 20 | "JrTGp6vd-DKa" 21 | ], 22 | "toc_visible": true, 23 | "authorship_tag": "ABX9TyMR+JaDud0PZxZhcegRNqzN", 24 | "include_colab_link": true 25 | }, 26 | "kernelspec": { 27 | "name": "python3", 28 | "display_name": "Python 3" 29 | }, 30 | "language_info": { 31 | "name": "python" 32 | }, 33 | "accelerator": "GPU" 34 | }, 35 | "cells": [ 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "view-in-github", 40 | "colab_type": "text" 41 | }, 42 | "source": [ 43 | "\"Open" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "4v_R7FXY7_fi" 50 | }, 51 | "source": [ 52 | "# FedPerf - Shakespeare + qFedAvg algorithm" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": { 58 | "id": "x5UTD8mqIDWl" 59 | }, 60 | "source": [ 61 | "## Setup & Dependencies Installation" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "id": "1vTQbtlX731V" 68 | }, 69 | "source": [ 70 | "%%capture\n", 71 | "!pip install torchsummaryX unidecode" 72 | ], 73 | "execution_count": null, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "metadata": { 79 | "id": "FHlNfr2R8EvM" 80 | }, 81 | "source": [ 82 | "%load_ext tensorboard\n", 83 | "\n", 84 | "import copy\n", 85 | "from functools import reduce\n", 86 | "import json\n", 87 | "import matplotlib.pyplot as plt\n", 88 | "import numpy as np\n", 89 | "import os\n", 90 | "import pandas as pd\n", 91 | "import pickle\n", 92 | "import random\n", 93 | "from sklearn.model_selection import train_test_split\n", 94 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n", 95 | "import time\n", 96 | "import torch\n", 97 | "from torch.autograd import Variable\n", 98 | "import torch.nn as nn\n", 99 | "import torch.nn.functional as F\n", 100 | "from torch.utils.data import Dataset\n", 101 | "from torch.utils.data.dataloader import DataLoader\n", 102 | "from torch.utils.data.sampler import Sampler\n", 103 | "from torch.utils.tensorboard import SummaryWriter\n", 104 | "from torchsummary import summary\n", 105 | "from torchsummaryX import summary as summaryx\n", 106 | "from torchvision import transforms, utils, datasets\n", 107 | "from tqdm.notebook import tqdm\n", 108 | "from unidecode import unidecode\n", 109 | "\n", 110 | "%matplotlib inline\n", 111 | "\n", 112 | "# Check assigned GPU\n", 113 | "gpu_info = !nvidia-smi\n", 114 | "gpu_info = '\\n'.join(gpu_info)\n", 115 | "if gpu_info.find('failed') >= 0:\n", 116 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n", 117 | " print('and then re-execute this cell.')\n", 118 | "else:\n", 119 | " print(gpu_info)\n", 120 | "\n", 121 | "# set manual seed for reproducibility\n", 122 | "RANDOM_SEED = 42\n", 123 | "\n", 124 | "# general reproducibility\n", 125 | "random.seed(RANDOM_SEED)\n", 126 | "np.random.seed(RANDOM_SEED)\n", 127 | "torch.manual_seed(RANDOM_SEED)\n", 128 | "torch.cuda.manual_seed(RANDOM_SEED)\n", 129 | "\n", 130 | "# gpu training specific\n", 131 | "torch.backends.cudnn.deterministic = True\n", 132 | "torch.backends.cudnn.benchmark = False" 133 | ], 134 | "execution_count": null, 135 | "outputs": [] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "id": "AvWSM6mh8LZr" 141 | }, 142 | "source": [ 143 | "## Mount GDrive" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "metadata": { 149 | "id": "oh9Lj2-lIIfY" 150 | }, 151 | "source": [ 152 | "BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/qFedAvg'" 153 | ], 154 | "execution_count": null, 155 | "outputs": [] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "metadata": { 160 | "id": "518ez9Oz8LNE" 161 | }, 162 | "source": [ 163 | "try:\n", 164 | " from google.colab import drive\n", 165 | " drive.mount('/content/drive')\n", 166 | " os.makedirs(BASE_DIR, exist_ok=True)\n", 167 | "except:\n", 168 | " print(\"WARNING: Results won't be stored on GDrive\")\n", 169 | " BASE_DIR = './'\n", 170 | "\n" 171 | ], 172 | "execution_count": null, 173 | "outputs": [] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": { 178 | "id": "MAtkg7ME8PQn" 179 | }, 180 | "source": [ 181 | "## Loading Dataset" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "metadata": { 187 | "id": "-6ZYQ2SB8LLo" 188 | }, 189 | "source": [ 190 | "!rm -Rf data\n", 191 | "!mkdir -p data scripts" 192 | ], 193 | "execution_count": null, 194 | "outputs": [] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "metadata": { 199 | "id": "skkRzhu98LIx" 200 | }, 201 | "source": [ 202 | "GENERATE_DATASET = False # If False, download the dataset provided by the q-FFL paper\n", 203 | "DATA_DIR = 'data/'\n", 204 | "# Dataset generation params\n", 205 | "SAMPLES_FRACTION = 1. # If using an already generated dataset\n", 206 | "# SAMPLES_FRACTION = 0.2 # Fraction of total samples in the dataset - FedProx default script\n", 207 | "# SAMPLES_FRACTION = 0.05 # Fraction of total samples in the dataset - qFFL\n", 208 | "TRAIN_FRACTION = 0.8 # Train set size\n", 209 | "MIN_SAMPLES = 0 # Min samples per client (for filtering purposes) - FedProx\n", 210 | "# MIN_SAMPLES = 64 # Min samples per client (for filtering purposes) - qFFL" 211 | ], 212 | "execution_count": null, 213 | "outputs": [] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "metadata": { 218 | "id": "c9nCIGlB8LGA" 219 | }, 220 | "source": [ 221 | "# Download raw dataset\n", 222 | "# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt\n", 223 | "!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt" 224 | ], 225 | "execution_count": null, 226 | "outputs": [] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "metadata": { 231 | "id": "jKQNm6cC8LDp" 232 | }, 233 | "source": [ 234 | "if not GENERATE_DATASET:\n", 235 | " !rm -Rf data/train data/test\n", 236 | " !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD # Download Shakespeare dataset used by the FedProx paper\n", 237 | " !unzip shakespeare.zip\n", 238 | " !mv -f shakespeare_paper/train data/\n", 239 | " !mv -f shakespeare_paper/test data/\n", 240 | " !rm -R shakespeare_paper/ shakespeare.zip\n" 241 | ], 242 | "execution_count": null, 243 | "outputs": [] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "metadata": { 248 | "id": "5ewNwmpP8LBg" 249 | }, 250 | "source": [ 251 | "corpus = []\n", 252 | "with open('data/shakespeare.txt', 'r') as f:\n", 253 | " data = list(unidecode(f.read()))\n", 254 | " corpus = list(set(list(data)))\n", 255 | "print('Corpus Length:', len(corpus))" 256 | ], 257 | "execution_count": null, 258 | "outputs": [] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": { 263 | "id": "GpokHKXO8gmS" 264 | }, 265 | "source": [ 266 | "#### Dataset Preprocessing script" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "metadata": { 272 | "id": "1YEeiGCT8K--" 273 | }, 274 | "source": [ 275 | "%%capture\n", 276 | "if GENERATE_DATASET:\n", 277 | " # Download dataset generation scripts\n", 278 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py\n", 279 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py\n", 280 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py\n", 281 | "\n", 282 | " # Download data preprocessing scripts\n", 283 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py\n", 284 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py" 285 | ], 286 | "execution_count": null, 287 | "outputs": [] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "metadata": { 292 | "id": "KE4LT4z48K8u" 293 | }, 294 | "source": [ 295 | "# Running scripts\n", 296 | "if GENERATE_DATASET:\n", 297 | " !mkdir -p data/raw_data data/all_data data/train data/test\n", 298 | " !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data\n", 299 | " !python scripts/gen_all_data.py" 300 | ], 301 | "execution_count": null, 302 | "outputs": [] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": { 307 | "id": "391xbZuU8vEn" 308 | }, 309 | "source": [ 310 | "#### Dataset class" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "metadata": { 316 | "id": "5aS5pJenZ8Ow" 317 | }, 318 | "source": [ 319 | "class ShakespeareDataset(Dataset):\n", 320 | " def __init__(self, x, y, corpus, seq_length):\n", 321 | " self.x = x\n", 322 | " self.y = y\n", 323 | " self.corpus = corpus\n", 324 | " self.corpus_size = len(self.corpus)\n", 325 | " super(ShakespeareDataset, self).__init__()\n", 326 | "\n", 327 | " def __len__(self):\n", 328 | " return len(self.x)\n", 329 | "\n", 330 | " def __repr__(self):\n", 331 | " return f'{self.__class__} - (length: {self.__len__()})'\n", 332 | "\n", 333 | " def __getitem__(self, i):\n", 334 | " input_seq = self.x[i]\n", 335 | " next_char = self.y[i]\n", 336 | " # print('\\tgetitem', i, input_seq, next_char)\n", 337 | " input_value = self.text2charindxs(input_seq)\n", 338 | " target_value = self.get_label_from_char(next_char)\n", 339 | " return input_value, target_value\n", 340 | "\n", 341 | " def text2charindxs(self, text):\n", 342 | " tensor = torch.zeros(len(text), dtype=torch.int32)\n", 343 | " for i, c in enumerate(text):\n", 344 | " tensor[i] = self.get_label_from_char(c)\n", 345 | " return tensor\n", 346 | "\n", 347 | " def get_label_from_char(self, c):\n", 348 | " return self.corpus.index(c)\n", 349 | "\n", 350 | " def get_char_from_label(self, l):\n", 351 | " return self.corpus[l]" 352 | ], 353 | "execution_count": null, 354 | "outputs": [] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": { 359 | "id": "0r68yAYr8xwJ" 360 | }, 361 | "source": [ 362 | "##### Federated Dataset" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "metadata": { 368 | "id": "CUTWxbt98K6W" 369 | }, 370 | "source": [ 371 | "class ShakespeareFedDataset(ShakespeareDataset):\n", 372 | " def __init__(self, x, y, corpus, seq_length):\n", 373 | " super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)\n", 374 | "\n", 375 | " def dataloader(self, batch_size, shuffle=True):\n", 376 | " return DataLoader(self,\n", 377 | " batch_size=batch_size,\n", 378 | " shuffle=shuffle,\n", 379 | " num_workers=0)\n" 380 | ], 381 | "execution_count": null, 382 | "outputs": [] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": { 387 | "id": "awOBv7tN8_ec" 388 | }, 389 | "source": [ 390 | "## Partitioning & Data Loaders" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": { 396 | "id": "rxAlW1YK9B_b" 397 | }, 398 | "source": [ 399 | "### IID" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "metadata": { 405 | "id": "OqAvrJxm8K4D" 406 | }, 407 | "source": [ 408 | "def iid_partition_(dataset, clients):\n", 409 | " \"\"\"\n", 410 | " I.I.D paritioning of data over clients\n", 411 | " Shuffle the data\n", 412 | " Split it between clients\n", 413 | " \n", 414 | " params:\n", 415 | " - dataset (torch.utils.Dataset): Dataset containing the MNIST Images\n", 416 | " - clients (int): Number of Clients to split the data between\n", 417 | "\n", 418 | " returns:\n", 419 | " - Dictionary of image indexes for each client\n", 420 | " \"\"\"\n", 421 | "\n", 422 | " num_items_per_client = int(len(dataset)/clients)\n", 423 | " client_dict = {}\n", 424 | " image_idxs = [i for i in range(len(dataset))]\n", 425 | "\n", 426 | " for i in range(clients):\n", 427 | " client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))\n", 428 | " image_idxs = list(set(image_idxs) - client_dict[i])\n", 429 | "\n", 430 | " return client_dict" 431 | ], 432 | "execution_count": null, 433 | "outputs": [] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "metadata": { 438 | "id": "0T-NUxsq8K1g" 439 | }, 440 | "source": [ 441 | "def iid_partition(corpus, seq_length=80, val_split=False):\n", 442 | "\n", 443 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n", 444 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n", 445 | "\n", 446 | " with open(train_file, 'r') as file:\n", 447 | " data_train = json.loads(unidecode(file.read()))\n", 448 | "\n", 449 | " with open(test_file, 'r') as file:\n", 450 | " data_test = json.loads(unidecode(file.read()))\n", 451 | "\n", 452 | " \n", 453 | " total_samples_train = sum(data_train['num_samples'])\n", 454 | "\n", 455 | " data_dict = {}\n", 456 | "\n", 457 | " x_train, y_train = [], []\n", 458 | " x_test, y_test = [], []\n", 459 | " # x_val, y_val = [], []\n", 460 | "\n", 461 | " users = list(zip(data_train['users'], data_train['num_samples']))\n", 462 | " # random.shuffle(users)\n", 463 | "\n", 464 | "\n", 465 | "\n", 466 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n", 467 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n", 468 | " sample_count = 0\n", 469 | " \n", 470 | " for i, (author_id, samples) in enumerate(users):\n", 471 | "\n", 472 | " if sample_count >= total_samples:\n", 473 | " print('Max samples reached', sample_count, '/', total_samples)\n", 474 | " break\n", 475 | "\n", 476 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n", 477 | " print('SKIP', author_id, samples)\n", 478 | " continue\n", 479 | " else:\n", 480 | " udata_train = data_train['user_data'][author_id]\n", 481 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n", 482 | " \n", 483 | " sample_count += max_samples\n", 484 | " # print('sample_count', sample_count)\n", 485 | "\n", 486 | " x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])\n", 487 | " y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])\n", 488 | "\n", 489 | " author_data = data_test['user_data'][author_id]\n", 490 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n", 491 | "\n", 492 | " if val_split:\n", 493 | " x_test.extend(author_data['x'][:int(test_size / 2)])\n", 494 | " y_test.extend(author_data['y'][:int(test_size / 2)])\n", 495 | " # x_val.extend(author_data['x'][int(test_size / 2):])\n", 496 | " # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])\n", 497 | "\n", 498 | " else:\n", 499 | " x_test.extend(author_data['x'][:int(test_size)])\n", 500 | " y_test.extend(author_data['y'][:int(test_size)])\n", 501 | "\n", 502 | " train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)\n", 503 | " test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)\n", 504 | " # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)\n", 505 | "\n", 506 | " data_dict = iid_partition_(train_ds, clients=len(users))\n", 507 | "\n", 508 | " return train_ds, data_dict, test_ds" 509 | ], 510 | "execution_count": null, 511 | "outputs": [] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": { 516 | "id": "JtAFxz5h9MRy" 517 | }, 518 | "source": [ 519 | "### Non-IID" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "metadata": { 525 | "id": "0w3yiODz8KzT" 526 | }, 527 | "source": [ 528 | "def noniid_partition(corpus, seq_length=80, val_split=False):\n", 529 | "\n", 530 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n", 531 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n", 532 | "\n", 533 | " with open(train_file, 'r') as file:\n", 534 | " data_train = json.loads(unidecode(file.read()))\n", 535 | "\n", 536 | " with open(test_file, 'r') as file:\n", 537 | " data_test = json.loads(unidecode(file.read()))\n", 538 | "\n", 539 | " \n", 540 | " total_samples_train = sum(data_train['num_samples'])\n", 541 | "\n", 542 | " data_dict = {}\n", 543 | "\n", 544 | " x_test, y_test = [], []\n", 545 | "\n", 546 | " users = list(zip(data_train['users'], data_train['num_samples']))\n", 547 | " # random.shuffle(users)\n", 548 | "\n", 549 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n", 550 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n", 551 | " sample_count = 0\n", 552 | " \n", 553 | " for i, (author_id, samples) in enumerate(users):\n", 554 | "\n", 555 | " if sample_count >= total_samples:\n", 556 | " print('Max samples reached', sample_count, '/', total_samples)\n", 557 | " break\n", 558 | "\n", 559 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n", 560 | " print('SKIP', author_id, samples)\n", 561 | " continue\n", 562 | " else:\n", 563 | " udata_train = data_train['user_data'][author_id]\n", 564 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n", 565 | " \n", 566 | " sample_count += max_samples\n", 567 | " # print('sample_count', sample_count)\n", 568 | "\n", 569 | " x_train = data_train['user_data'][author_id]['x'][:max_samples]\n", 570 | " y_train = data_train['user_data'][author_id]['y'][:max_samples]\n", 571 | "\n", 572 | " train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)\n", 573 | "\n", 574 | " x_val, y_val = None, None\n", 575 | " val_ds = None\n", 576 | " author_data = data_test['user_data'][author_id]\n", 577 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n", 578 | " if val_split:\n", 579 | " x_test += author_data['x'][:int(test_size / 2)]\n", 580 | " y_test += author_data['y'][:int(test_size / 2)]\n", 581 | " x_val = author_data['x'][int(test_size / 2):]\n", 582 | " y_val = author_data['y'][int(test_size / 2):int(test_size)]\n", 583 | "\n", 584 | " val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)\n", 585 | "\n", 586 | " else:\n", 587 | " x_test += author_data['x'][:int(test_size)]\n", 588 | " y_test += author_data['y'][:int(test_size)]\n", 589 | "\n", 590 | " data_dict[author_id] = {\n", 591 | " 'train_ds': train_ds,\n", 592 | " 'val_ds': val_ds\n", 593 | " }\n", 594 | "\n", 595 | " test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)\n", 596 | "\n", 597 | " return data_dict, test_ds" 598 | ], 599 | "execution_count": null, 600 | "outputs": [] 601 | }, 602 | { 603 | "cell_type": "markdown", 604 | "metadata": { 605 | "id": "QSDdUZjs9W-g" 606 | }, 607 | "source": [ 608 | "## Models" 609 | ] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": { 614 | "id": "rhN9isJD9Z8f" 615 | }, 616 | "source": [ 617 | "### Shakespeare LSTM" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "metadata": { 623 | "id": "GKXxN3HO8Kw3" 624 | }, 625 | "source": [ 626 | "class ShakespeareLSTM(nn.Module):\n", 627 | " \"\"\"\n", 628 | " \"\"\"\n", 629 | "\n", 630 | " def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):\n", 631 | " super(ShakespeareLSTM, self).__init__()\n", 632 | " self.input_dim = input_dim\n", 633 | " self.embedding_dim = embedding_dim\n", 634 | " self.hidden_dim = hidden_dim\n", 635 | " self.classes = classes\n", 636 | " self.no_layers = lstm_layers\n", 637 | " \n", 638 | " self.embedding = nn.Embedding(num_embeddings=self.classes,\n", 639 | " embedding_dim=self.embedding_dim)\n", 640 | " self.lstm = nn.LSTM(input_size=self.embedding_dim, \n", 641 | " hidden_size=self.hidden_dim,\n", 642 | " num_layers=self.no_layers,\n", 643 | " batch_first=batch_first, \n", 644 | " dropout=dropout if self.no_layers > 1 else 0.)\n", 645 | " self.fc = nn.Linear(hidden_dim, self.classes)\n", 646 | "\n", 647 | " def forward(self, x, hc=None):\n", 648 | " batch_size = x.size(0)\n", 649 | " x_emb = self.embedding(x)\n", 650 | " out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)\n", 651 | " dense = self.fc(ht[-1])\n", 652 | " return dense\n", 653 | " \n", 654 | " def init_hidden(self, batch_size):\n", 655 | " return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),\n", 656 | " Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))\n" 657 | ], 658 | "execution_count": null, 659 | "outputs": [] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "metadata": { 664 | "id": "lwrRH5yD9eZB" 665 | }, 666 | "source": [ 667 | "#### Model Summary" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "metadata": { 673 | "id": "Zsl5z8CS8Kul" 674 | }, 675 | "source": [ 676 | "batch_size = 10\n", 677 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n", 678 | "\n", 679 | "shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n", 680 | " embedding_dim=8, # mcmahan17a, fedprox, qFFL\n", 681 | " hidden_dim=256, # mcmahan17a, fedprox impl\n", 682 | " # hidden_dim=100, # fedprox paper\n", 683 | " classes=len(corpus),\n", 684 | " lstm_layers=2,\n", 685 | " dropout=0.1, # TODO:\n", 686 | " batch_first=True\n", 687 | " )\n", 688 | "\n", 689 | "if torch.cuda.is_available():\n", 690 | " shakespeare_lstm.cuda()\n", 691 | "\n", 692 | "\n", 693 | "\n", 694 | "hc = shakespeare_lstm.init_hidden(batch_size)\n", 695 | "\n", 696 | "x_sample = torch.zeros((batch_size, seq_length),\n", 697 | " dtype=torch.long,\n", 698 | " device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))\n", 699 | "\n", 700 | "x_sample[0][0] = 1\n", 701 | "x_sample\n", 702 | "\n", 703 | "print(\"\\nShakespeare LSTM SUMMARY\")\n", 704 | "print(summaryx(shakespeare_lstm, x_sample))" 705 | ], 706 | "execution_count": null, 707 | "outputs": [] 708 | }, 709 | { 710 | "cell_type": "markdown", 711 | "metadata": { 712 | "id": "bbJUc6X89jFw" 713 | }, 714 | "source": [ 715 | "## qFedAvg Algorithm" 716 | ] 717 | }, 718 | { 719 | "cell_type": "markdown", 720 | "metadata": { 721 | "id": "w8OsHzAt9kDc" 722 | }, 723 | "source": [ 724 | "### Plot Utils" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "metadata": { 730 | "id": "jn1JNhaI8KsX" 731 | }, 732 | "source": [ 733 | "from sklearn.metrics import f1_score" 734 | ], 735 | "execution_count": null, 736 | "outputs": [] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "metadata": { 741 | "id": "oyYu5s4d8KqH" 742 | }, 743 | "source": [ 744 | "def plot_scores(history, exp_id, title, suffix):\n", 745 | " accuracies = [x['accuracy'] for x in history]\n", 746 | " f1_macro = [x['f1_macro'] for x in history]\n", 747 | " f1_weighted = [x['f1_weighted'] for x in history]\n", 748 | "\n", 749 | " fig, ax = plt.subplots()\n", 750 | " ax.plot(accuracies, 'tab:orange')\n", 751 | " ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)\n", 752 | " ax.grid()\n", 753 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)\n", 754 | " plt.show()\n", 755 | "\n", 756 | " fig, ax = plt.subplots()\n", 757 | " ax.plot(f1_macro, 'tab:orange')\n", 758 | " ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)\n", 759 | " ax.grid()\n", 760 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')\n", 761 | " plt.show()\n", 762 | "\n", 763 | " fig, ax = plt.subplots()\n", 764 | " ax.plot(f1_weighted, 'tab:orange')\n", 765 | " ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)\n", 766 | " ax.grid()\n", 767 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')\n", 768 | " plt.show()\n", 769 | "\n", 770 | "\n", 771 | "def plot_losses(history, exp_id, title, suffix):\n", 772 | " val_losses = [x['loss'] for x in history]\n", 773 | " train_losses = [x['train_loss'] for x in history]\n", 774 | "\n", 775 | " fig, ax = plt.subplots()\n", 776 | " ax.plot(train_losses, 'tab:orange')\n", 777 | " ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)\n", 778 | " ax.grid()\n", 779 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')\n", 780 | " plt.show()\n", 781 | "\n", 782 | " fig, ax = plt.subplots()\n", 783 | " ax.plot(val_losses, 'tab:orange')\n", 784 | " ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)\n", 785 | " ax.grid()\n", 786 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')\n", 787 | " plt.show()\n" 788 | ], 789 | "execution_count": null, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "markdown", 794 | "metadata": { 795 | "id": "p-8XOZlR9wbV" 796 | }, 797 | "source": [ 798 | "### Local Training (Client Update)" 799 | ] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "metadata": { 804 | "id": "83Jn60Gt8KlY" 805 | }, 806 | "source": [ 807 | "class CustomDataset(Dataset):\n", 808 | " def __init__(self, dataset, idxs):\n", 809 | " self.dataset = dataset\n", 810 | " self.idxs = list(idxs)\n", 811 | "\n", 812 | " def __len__(self):\n", 813 | " return len(self.idxs)\n", 814 | "\n", 815 | " def __getitem__(self, item):\n", 816 | " data, label = self.dataset[self.idxs[item]]\n", 817 | " return data, label" 818 | ], 819 | "execution_count": null, 820 | "outputs": [] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "metadata": { 825 | "id": "vOn7VSei8KjB" 826 | }, 827 | "source": [ 828 | "class ClientUpdate(object):\n", 829 | " def __init__(self, dataset, batch_size, learning_rate, epochs, idxs, q=None):\n", 830 | " \"\"\"\n", 831 | "\n", 832 | " \"\"\"\n", 833 | " if hasattr(dataset, 'dataloader'):\n", 834 | " self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)\n", 835 | " else:\n", 836 | " self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)\n", 837 | "\n", 838 | " self.learning_rate = learning_rate\n", 839 | " self.epochs = epochs\n", 840 | " self.q = q\n", 841 | " if not self.q:\n", 842 | " # TODO: Client itself adjust fairness \n", 843 | " pass\n", 844 | " self.mu = 1e-10\n", 845 | "\n", 846 | " def train(self, model):\n", 847 | "\n", 848 | " criterion = nn.CrossEntropyLoss()\n", 849 | " optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)\n", 850 | " # optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)\n", 851 | "\n", 852 | " e_loss = []\n", 853 | " model_weights = copy.deepcopy(model.state_dict())\n", 854 | " for epoch in range(1, self.epochs+1):\n", 855 | "\n", 856 | " train_loss = 0.0\n", 857 | "\n", 858 | " model.train()\n", 859 | " # for data, labels in tqdm(self.train_loader):\n", 860 | " for data, labels in self.train_loader:\n", 861 | "\n", 862 | " if torch.cuda.is_available():\n", 863 | " data, labels = data.cuda(), labels.cuda()\n", 864 | "\n", 865 | " # clear the gradients\n", 866 | " optimizer.zero_grad()\n", 867 | " # make a forward pass\n", 868 | " # print('input', data.size())\n", 869 | " output = model(data)\n", 870 | " # print('output', output.size())\n", 871 | " # print('labels', labels.size())\n", 872 | " # calculate the loss\n", 873 | " loss = criterion(output, labels)\n", 874 | " # do a backwards pass\n", 875 | " loss.backward()\n", 876 | " # perform a single optimization step\n", 877 | " optimizer.step()\n", 878 | " # update training loss\n", 879 | " train_loss += loss.item()*data.size(0)\n", 880 | "\n", 881 | " # average losses\n", 882 | " train_loss = train_loss/len(self.train_loader.dataset)\n", 883 | " e_loss.append(train_loss)\n", 884 | "\n", 885 | "\n", 886 | " total_loss = sum(e_loss)/len(e_loss)\n", 887 | "\n", 888 | " # delta weights\n", 889 | " model_weights_new = copy.deepcopy(model.state_dict())\n", 890 | " L = 1.0 / self.learning_rate\n", 891 | "\n", 892 | " delta_weights, delta, h = {}, {}, {}\n", 893 | " loss_q = np.float_power(total_loss + self.mu, self.q)\n", 894 | " # updating the global weights\n", 895 | " for k in model_weights_new.keys():\n", 896 | " delta_weights[k] = (model_weights[k] - model_weights_new[k]) * L\n", 897 | " delta[k] = loss_q * delta_weights[k]\n", 898 | " # Estimation of the local Lipchitz constant\n", 899 | " h[k] = (self.q * np.float_power(total_loss + self.mu, self.q - 1) * torch.pow(torch.norm(delta_weights[k]), 2)) + (L * loss_q)\n", 900 | "\n", 901 | " return delta, h, total_loss" 902 | ], 903 | "execution_count": null, 904 | "outputs": [] 905 | }, 906 | { 907 | "cell_type": "markdown", 908 | "metadata": { 909 | "id": "0vQVNPmM93M7" 910 | }, 911 | "source": [ 912 | "### Server Side Training" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "metadata": { 918 | "id": "bhOfx0HSCX35" 919 | }, 920 | "source": [ 921 | "def client_sampling(n, m, weights=None, with_replace=False):\n", 922 | " pk = None\n", 923 | " if weights:\n", 924 | " total_weights = np.sum(np.asarray(weights))\n", 925 | " pk = [w * 1.0 / total_weights for w in weights]\n", 926 | "\n", 927 | " return np.random.choice(range(n), m, replace=with_replace, p=pk)" 928 | ], 929 | "execution_count": null, 930 | "outputs": [] 931 | }, 932 | { 933 | "cell_type": "code", 934 | "metadata": { 935 | "id": "MCJZnkY08Kgs" 936 | }, 937 | "source": [ 938 | "def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, q=0, sampling='uniform', tb_logger=None, test_history=[], perf_fig_file='loss.jpg'):\n", 939 | " \"\"\"\n", 940 | " Function implements the Federated Averaging Algorithm from the FedAvg paper.\n", 941 | " Specifically, this function is used for the server side training and weight update\n", 942 | "\n", 943 | " Params:\n", 944 | " - model: PyTorch model to train\n", 945 | " - rounds: Number of communication rounds for the client update\n", 946 | " - batch_size: Batch size for client update training\n", 947 | " - lr: Learning rate used for client update training\n", 948 | " - ds: Dataset used for training\n", 949 | " - data_dict: Type of data partition used for training (IID or non-IID)\n", 950 | " - test_ds Dataset used for global testing\n", 951 | " - C: Fraction of clients randomly chosen to perform computation on each round\n", 952 | " - K: Total number of clients\n", 953 | " - E: Number of training passes each client makes over its local dataset per round\n", 954 | " - q: Parameter that tunes the amount of fairness we wish to impose (default: 0 -> vanilla FedAvg objective)\n", 955 | " - sampling Uniform or weighted (default: uniform)\n", 956 | " - tb_logger: Tensorboard SummaryWriter\n", 957 | " - test_history: Test Scores history log\n", 958 | " - perf_fig_file File for storing final performance plot (loss)\n", 959 | " Returns:\n", 960 | " - model: Trained model on the server\n", 961 | " \"\"\"\n", 962 | "\n", 963 | " # global model weights\n", 964 | " global_weights = model.state_dict()\n", 965 | "\n", 966 | " # training loss\n", 967 | " train_loss = []\n", 968 | "\n", 969 | " # client weights by total samples\n", 970 | " p_k = None\n", 971 | " if sampling == 'weighted':\n", 972 | " p_k = [len(data_dict[c]) for c in data_dict] if ds else [len(data_dict[c]['train_ds']) for c in data_dict]\n", 973 | "\n", 974 | " # Time log\n", 975 | " start_time = time.time()\n", 976 | "\n", 977 | " users_id = list(data_dict.keys())\n", 978 | "\n", 979 | " # Orchestrate training\n", 980 | " for curr_round in range(1, rounds+1):\n", 981 | " deltas, hs, local_loss = [], [], []\n", 982 | "\n", 983 | " m = max(int(C*K), 1) \n", 984 | " S_t = client_sampling(K, m, weights=p_k, with_replace=False)\n", 985 | "\n", 986 | " print('Round: {} Picking {}/{} clients: {}'.format(curr_round, m, K, S_t))\n", 987 | "\n", 988 | " global_weights = model.state_dict()\n", 989 | "\n", 990 | " for k in tqdm(S_t):\n", 991 | " key = users_id[k]\n", 992 | " ds_ = ds if ds else data_dict[key]['train_ds']\n", 993 | " idxs = data_dict[key] if ds else None\n", 994 | " # print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')\n", 995 | " local_update = ClientUpdate(dataset=ds_, batch_size=batch_size, learning_rate=lr, epochs=E, idxs=idxs, q=q)\n", 996 | " # weights, loss = local_update.train(model=copy.deepcopy(model))\n", 997 | " delta_k, h_k, loss = local_update.train(model=copy.deepcopy(model))\n", 998 | "\n", 999 | " deltas.append(copy.deepcopy(delta_k))\n", 1000 | " hs.append(copy.deepcopy(h_k))\n", 1001 | " local_loss.append(copy.deepcopy(loss))\n", 1002 | "\n", 1003 | " if tb_logger:\n", 1004 | " tb_logger.add_scalar(f'Round/S{k}', loss, curr_round)\n", 1005 | "\n", 1006 | " # Perform qFedAvg\n", 1007 | " h_sum = copy.deepcopy(hs[0])\n", 1008 | " delta_sum = copy.deepcopy(deltas[0])\n", 1009 | " \n", 1010 | " for k in h_sum.keys():\n", 1011 | " for i in range(1, len(hs)):\n", 1012 | " h_sum[k] += hs[i][k]\n", 1013 | " delta_sum[k] += deltas[i][k]\n", 1014 | "\n", 1015 | " new_weights = {}\n", 1016 | " for k in delta_sum.keys():\n", 1017 | " for i in range(len(deltas)):\n", 1018 | " new_weights[k] = delta_sum[k] / h_sum[k]\n", 1019 | "\n", 1020 | " # Updating global model weights\n", 1021 | " for k in global_weights.keys():\n", 1022 | " global_weights[k] -= new_weights[k]\n", 1023 | "\n", 1024 | " # move the updated weights to our model state dict\n", 1025 | " model.load_state_dict(global_weights)\n", 1026 | "\n", 1027 | " # loss\n", 1028 | " loss_avg = sum(local_loss) / len(local_loss)\n", 1029 | " print('Round: {}... \\tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))\n", 1030 | " train_loss.append(loss_avg)\n", 1031 | "\n", 1032 | " if tb_logger:\n", 1033 | " tb_logger.add_scalar('Train/Loss', loss_avg, curr_round)\n", 1034 | " # tb_logger.add_scalar(f'Train/Datapoints', total_datapoints, curr_round)\n", 1035 | " \n", 1036 | " # if curr_round % eval_every == 0:\n", 1037 | " test_scores = testing(model, test_ds, batch_size, nn.CrossEntropyLoss(), num_classes, list(range(num_classes)))\n", 1038 | " test_scores['train_loss'] = loss_avg\n", 1039 | " test_history.append(test_scores)\n", 1040 | " if tb_logger:\n", 1041 | " tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)\n", 1042 | " tb_logger.add_scalars(f'Test/Scores', {\n", 1043 | " 'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']\n", 1044 | " }, curr_round)\n", 1045 | " \n", 1046 | " end_time = time.time()\n", 1047 | " \n", 1048 | " fig, ax = plt.subplots()\n", 1049 | " x_axis = np.arange(1, rounds+1)\n", 1050 | " y_axis = np.array(train_loss)\n", 1051 | " ax.plot(x_axis, y_axis, 'tab:orange')\n", 1052 | "\n", 1053 | " ax.set(xlabel='# Rounds', ylabel='Train Loss',\n", 1054 | " title=\"Model's Performance with q: {}\".format(q))\n", 1055 | " ax.grid()\n", 1056 | " #fig.savefig(perf_fig_file, format='jpg')\n", 1057 | "\n", 1058 | " print(\"Training Done! Total time: {}\".format(round(end_time - start_time, 3)))\n", 1059 | " return model, test_history" 1060 | ], 1061 | "execution_count": null, 1062 | "outputs": [] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": { 1067 | "id": "JrTGp6vd-DKa" 1068 | }, 1069 | "source": [ 1070 | "### Testing Loop" 1071 | ] 1072 | }, 1073 | { 1074 | "cell_type": "code", 1075 | "metadata": { 1076 | "id": "yEnPyGYb8KeO" 1077 | }, 1078 | "source": [ 1079 | "def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):\n", 1080 | " #test loss \n", 1081 | " test_loss = 0.0\n", 1082 | " y_true, y_hat = None, None\n", 1083 | "\n", 1084 | " correct_class = list(0 for i in range(num_classes))\n", 1085 | " total_class = list(0 for i in range(num_classes))\n", 1086 | "\n", 1087 | " if hasattr(dataset, 'dataloader'):\n", 1088 | " test_loader = dataset.dataloader(batch_size=bs, shuffle=False)\n", 1089 | " else:\n", 1090 | " test_loader = DataLoader(dataset, batch_size=bs, shuffle=False)\n", 1091 | "\n", 1092 | " l = len(test_loader)\n", 1093 | "\n", 1094 | " model.eval()\n", 1095 | " for i, (data, labels) in enumerate(tqdm(test_loader)):\n", 1096 | "\n", 1097 | " if torch.cuda.is_available():\n", 1098 | " data, labels = data.cuda(), labels.cuda()\n", 1099 | "\n", 1100 | " output = model(data)\n", 1101 | " loss = criterion(output, labels)\n", 1102 | " test_loss += loss.item()*data.size(0)\n", 1103 | "\n", 1104 | " _, pred = torch.max(output, dim=1)\n", 1105 | "\n", 1106 | " # For F1Score\n", 1107 | " y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()\n", 1108 | " y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()\n", 1109 | "\n", 1110 | " correct_tensor = pred.eq(labels.data.view_as(pred))\n", 1111 | " correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())\n", 1112 | "\n", 1113 | " #test accuracy for each object class\n", 1114 | " # for i in range(num_classes):\n", 1115 | " # label = labels.data[i]\n", 1116 | " # correct_class[label] += correct[i].item()\n", 1117 | " # total_class[label] += 1\n", 1118 | "\n", 1119 | " for i, lbl in enumerate(labels.data):\n", 1120 | " try:\n", 1121 | " # print(type(lbl))\n", 1122 | " # correct_class[lbl.data[0]] += correct.data[i]\n", 1123 | " correct_class[lbl.item()] += correct[i]\n", 1124 | " total_class[lbl.item()] += 1\n", 1125 | " except:\n", 1126 | " print('Error', lbl, i)\n", 1127 | " \n", 1128 | " # avg test loss\n", 1129 | " test_loss = test_loss/len(test_loader.dataset)\n", 1130 | " print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n", 1131 | "\n", 1132 | " # Avg F1 Score\n", 1133 | " f1_macro = f1_score(y_true, y_hat, average='macro')\n", 1134 | " # F1-Score -> weigthed to consider class imbalance\n", 1135 | " f1_weighted = f1_score(y_true, y_hat, average='weighted')\n", 1136 | " print(\"F1 Score: {:.6f} (macro) {:.6f} (weighted) %\\n\".format(f1_macro, f1_weighted))\n", 1137 | "\n", 1138 | " # print test accuracy\n", 1139 | " if print_all:\n", 1140 | " for i in range(num_classes):\n", 1141 | " if total_class[i] > 0:\n", 1142 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % \n", 1143 | " (classes[i], 100 * correct_class[i] / total_class[i],\n", 1144 | " np.sum(correct_class[i]), np.sum(total_class[i])))\n", 1145 | " else:\n", 1146 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n", 1147 | "\n", 1148 | " overall_accuracy = np.sum(correct_class) / np.sum(total_class)\n", 1149 | "\n", 1150 | " print('\\nFinal Test Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))\n", 1151 | "\n", 1152 | " return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}" 1153 | ], 1154 | "execution_count": null, 1155 | "outputs": [] 1156 | }, 1157 | { 1158 | "cell_type": "markdown", 1159 | "metadata": { 1160 | "id": "Evpa7UZO-Hii" 1161 | }, 1162 | "source": [ 1163 | "## Experiments" 1164 | ] 1165 | }, 1166 | { 1167 | "cell_type": "code", 1168 | "metadata": { 1169 | "id": "HhZ2iUQ2am-_" 1170 | }, 1171 | "source": [ 1172 | "# FAIL-ON-PURPOSE" 1173 | ], 1174 | "execution_count": null, 1175 | "outputs": [] 1176 | }, 1177 | { 1178 | "cell_type": "code", 1179 | "metadata": { 1180 | "id": "QXNvhZAw8Kbx" 1181 | }, 1182 | "source": [ 1183 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n", 1184 | "embedding_dim = 8 # mcmahan17a, fedprox, qFFL\n", 1185 | "# hidden_dim = 100 # fedprox paper\n", 1186 | "hidden_dim = 256 # mcmahan17a, fedprox impl\n", 1187 | "num_classes = len(corpus)\n", 1188 | "classes = list(range(num_classes))\n", 1189 | "lstm_layers = 2 # mcmahan17a, fedprox, qFFL\n", 1190 | "dropout = 0.1 # TODO" 1191 | ], 1192 | "execution_count": null, 1193 | "outputs": [] 1194 | }, 1195 | { 1196 | "cell_type": "code", 1197 | "metadata": { 1198 | "id": "IadhhL4knbOc" 1199 | }, 1200 | "source": [ 1201 | "class Hyperparameters():\n", 1202 | "\n", 1203 | " def __init__(self, total_clients):\n", 1204 | " # number of training rounds\n", 1205 | " self.rounds = 50\n", 1206 | " # client fraction\n", 1207 | " self.C = 0.5\n", 1208 | " # number of clients\n", 1209 | " self.K = total_clients\n", 1210 | " # number of training passes on local dataset for each roung\n", 1211 | " self.E = 1 # qFFL\n", 1212 | " # batch size\n", 1213 | " self.batch_size = 10 # fedprox\n", 1214 | " # learning Rate\n", 1215 | " self.lr = 0.8 # fedprox, qFFL\n", 1216 | " # fairness\n", 1217 | " self.q = 0.001 # qFFL\n", 1218 | " # sampling\n", 1219 | " # self.sampling = 'uniform'\n", 1220 | " self.sampling = 'weighted'" 1221 | ], 1222 | "execution_count": null, 1223 | "outputs": [] 1224 | }, 1225 | { 1226 | "cell_type": "code", 1227 | "metadata": { 1228 | "id": "VtSTy1GencJ4" 1229 | }, 1230 | "source": [ 1231 | "exp_log = dict()" 1232 | ], 1233 | "execution_count": null, 1234 | "outputs": [] 1235 | }, 1236 | { 1237 | "cell_type": "markdown", 1238 | "metadata": { 1239 | "id": "hGoeL-HpDv7L" 1240 | }, 1241 | "source": [ 1242 | "### IID" 1243 | ] 1244 | }, 1245 | { 1246 | "cell_type": "code", 1247 | "metadata": { 1248 | "id": "61pZ-UY6D8nB" 1249 | }, 1250 | "source": [ 1251 | "train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True)\n", 1252 | "\n", 1253 | "total_clients = len(data_dict.keys())\n", 1254 | "'Total users:', total_clients" 1255 | ], 1256 | "execution_count": null, 1257 | "outputs": [] 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "metadata": { 1262 | "id": "nUmhoN-zGZp_" 1263 | }, 1264 | "source": [ 1265 | "hparams = Hyperparameters(total_clients)\n", 1266 | "hparams.__dict__" 1267 | ], 1268 | "execution_count": null, 1269 | "outputs": [] 1270 | }, 1271 | { 1272 | "cell_type": "code", 1273 | "metadata": { 1274 | "id": "XfNgps35XnxY" 1275 | }, 1276 | "source": [ 1277 | "# Sweeping parameter\n", 1278 | "PARAM_NAME = 'clients_fraction'\n", 1279 | "PARAM_VALUE = hparams.C\n", 1280 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n", 1281 | "exp_id" 1282 | ], 1283 | "execution_count": null, 1284 | "outputs": [] 1285 | }, 1286 | { 1287 | "cell_type": "code", 1288 | "metadata": { 1289 | "id": "PmXNlFlYcl2H" 1290 | }, 1291 | "source": [ 1292 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n", 1293 | "os.makedirs(EXP_DIR, exist_ok=True)\n", 1294 | "\n", 1295 | "# tb_logger = SummaryWriter(log_dir)\n", 1296 | "# print(f'TBoard logger created at: {log_dir}')\n", 1297 | "\n", 1298 | "title = 'LSTM qFedAvg on IID'" 1299 | ], 1300 | "execution_count": null, 1301 | "outputs": [] 1302 | }, 1303 | { 1304 | "cell_type": "code", 1305 | "metadata": { 1306 | "id": "ymcomht5GdBO" 1307 | }, 1308 | "source": [ 1309 | "def run_experiment(run_id):\n", 1310 | "\n", 1311 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n", 1312 | " embedding_dim=embedding_dim, \n", 1313 | " hidden_dim=hidden_dim,\n", 1314 | " classes=num_classes,\n", 1315 | " lstm_layers=lstm_layers,\n", 1316 | " dropout=dropout,\n", 1317 | " batch_first=True\n", 1318 | " )\n", 1319 | "\n", 1320 | " if torch.cuda.is_available():\n", 1321 | " shakespeare_lstm.cuda()\n", 1322 | " \n", 1323 | " test_history = []\n", 1324 | "\n", 1325 | " lstm_iid_trained, test_history = training(shakespeare_lstm,\n", 1326 | " hparams.rounds, hparams.batch_size, hparams.lr,\n", 1327 | " train_ds, data_dict, test_ds,\n", 1328 | " hparams.C, hparams.K, hparams.E, hparams.q,\n", 1329 | " sampling=hparams.sampling,\n", 1330 | " test_history=test_history,\n", 1331 | " # tb_logger=tb_logger,\n", 1332 | " # perf_fig_file=f'{BASE_DIR}/loss.jpg'\n", 1333 | " )\n", 1334 | " \n", 1335 | " final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n", 1336 | " print(f'\\n\\n========================================================\\n\\n')\n", 1337 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n", 1338 | "\n", 1339 | " log = {\n", 1340 | " 'history': test_history,\n", 1341 | " 'hyperparams': hparams.__dict__\n", 1342 | " }\n", 1343 | "\n", 1344 | " with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:\n", 1345 | " pickle.dump(log, file)\n", 1346 | "\n", 1347 | " return test_history\n", 1348 | " " 1349 | ], 1350 | "execution_count": null, 1351 | "outputs": [] 1352 | }, 1353 | { 1354 | "cell_type": "code", 1355 | "metadata": { 1356 | "id": "xhPK_WL9GZhV" 1357 | }, 1358 | "source": [ 1359 | "exp_history = list()\n", 1360 | "for run_id in range(2): # TOTAL RUNS\n", 1361 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n", 1362 | " exp_history.append(run_experiment(run_id))\n", 1363 | " print(f'\\n\\n========================================================\\n\\n')" 1364 | ], 1365 | "execution_count": null, 1366 | "outputs": [] 1367 | }, 1368 | { 1369 | "cell_type": "code", 1370 | "metadata": { 1371 | "id": "a4Rxo_HZn10G" 1372 | }, 1373 | "source": [ 1374 | "exp_log[title] = {\n", 1375 | " 'history': exp_history,\n", 1376 | " 'hyperparams': hparams.__dict__\n", 1377 | "}" 1378 | ], 1379 | "execution_count": null, 1380 | "outputs": [] 1381 | }, 1382 | { 1383 | "cell_type": "code", 1384 | "metadata": { 1385 | "id": "YQGqrZ1_n-e3" 1386 | }, 1387 | "source": [ 1388 | "df = None\n", 1389 | "for i, e in enumerate(exp_history):\n", 1390 | " if i == 0:\n", 1391 | " df = pd.json_normalize(e)\n", 1392 | " continue\n", 1393 | " df = df + pd.json_normalize(e)\n", 1394 | " \n", 1395 | "df_avg = df / len(exp_history)\n", 1396 | "avg_history = df_avg.to_dict(orient='records')" 1397 | ], 1398 | "execution_count": null, 1399 | "outputs": [] 1400 | }, 1401 | { 1402 | "cell_type": "code", 1403 | "metadata": { 1404 | "id": "H3SO1Cga5tLE" 1405 | }, 1406 | "source": [ 1407 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')" 1408 | ], 1409 | "execution_count": null, 1410 | "outputs": [] 1411 | }, 1412 | { 1413 | "cell_type": "code", 1414 | "metadata": { 1415 | "id": "Pwo7EWRh5uQx" 1416 | }, 1417 | "source": [ 1418 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')" 1419 | ], 1420 | "execution_count": null, 1421 | "outputs": [] 1422 | }, 1423 | { 1424 | "cell_type": "code", 1425 | "metadata": { 1426 | "id": "9KPWHl8128He" 1427 | }, 1428 | "source": [ 1429 | "with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:\n", 1430 | " pickle.dump(exp_log, file)" 1431 | ], 1432 | "execution_count": null, 1433 | "outputs": [] 1434 | }, 1435 | { 1436 | "cell_type": "markdown", 1437 | "metadata": { 1438 | "id": "ttnrJ9FnDxFR" 1439 | }, 1440 | "source": [ 1441 | "### Non-IID" 1442 | ] 1443 | }, 1444 | { 1445 | "cell_type": "code", 1446 | "metadata": { 1447 | "id": "9T9hXMvT2_Ud" 1448 | }, 1449 | "source": [ 1450 | "exp_log = dict()" 1451 | ], 1452 | "execution_count": null, 1453 | "outputs": [] 1454 | }, 1455 | { 1456 | "cell_type": "code", 1457 | "metadata": { 1458 | "id": "yQlDBOdJEAaX" 1459 | }, 1460 | "source": [ 1461 | "data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)\n", 1462 | "\n", 1463 | "total_clients = len(data_dict.keys())\n", 1464 | "'Total users:', total_clients" 1465 | ], 1466 | "execution_count": null, 1467 | "outputs": [] 1468 | }, 1469 | { 1470 | "cell_type": "code", 1471 | "metadata": { 1472 | "id": "f5vIc19G8KUc" 1473 | }, 1474 | "source": [ 1475 | "hparams = Hyperparameters(total_clients)\n", 1476 | "hparams.__dict__" 1477 | ], 1478 | "execution_count": null, 1479 | "outputs": [] 1480 | }, 1481 | { 1482 | "cell_type": "code", 1483 | "metadata": { 1484 | "id": "WeGGOeN3rFDJ" 1485 | }, 1486 | "source": [ 1487 | "# Sweeping parameter\n", 1488 | "PARAM_NAME = 'clients_fraction'\n", 1489 | "PARAM_VALUE = hparams.C\n", 1490 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n", 1491 | "exp_id" 1492 | ], 1493 | "execution_count": null, 1494 | "outputs": [] 1495 | }, 1496 | { 1497 | "cell_type": "code", 1498 | "metadata": { 1499 | "id": "i3cGeO20ofEZ" 1500 | }, 1501 | "source": [ 1502 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n", 1503 | "os.makedirs(EXP_DIR, exist_ok=True)\n", 1504 | "\n", 1505 | "# tb_logger = SummaryWriter(log_dir)\n", 1506 | "# print(f'TBoard logger created at: {log_dir}')\n", 1507 | "\n", 1508 | "title = 'LSTM qFedAvg on Non-IID'" 1509 | ], 1510 | "execution_count": null, 1511 | "outputs": [] 1512 | }, 1513 | { 1514 | "cell_type": "code", 1515 | "metadata": { 1516 | "id": "bu1A3GUjGHWy" 1517 | }, 1518 | "source": [ 1519 | "def run_experiment(run_id):\n", 1520 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,\n", 1521 | " embedding_dim=embedding_dim,\n", 1522 | " hidden_dim=hidden_dim,\n", 1523 | " classes=num_classes,\n", 1524 | " lstm_layers=lstm_layers,\n", 1525 | " dropout=dropout,\n", 1526 | " batch_first=True\n", 1527 | " )\n", 1528 | "\n", 1529 | " if torch.cuda.is_available():\n", 1530 | " shakespeare_lstm.cuda()\n", 1531 | "\n", 1532 | " test_history = []\n", 1533 | "\n", 1534 | " lstm_noniid_trained, test_history = training(shakespeare_lstm,\n", 1535 | " hparams.rounds, hparams.batch_size, hparams.lr,\n", 1536 | " None, data_dict, test_ds,\n", 1537 | " hparams.C, hparams.K, hparams.E, hparams.q,\n", 1538 | " sampling=hparams.sampling,\n", 1539 | " test_history=test_history,\n", 1540 | " # tb_logger=tb_logger,\n", 1541 | " # perf_fig_file=f'{BASE_DIR}/loss.jpg'\n", 1542 | " )\n", 1543 | " \n", 1544 | " final_scores = testing(lstm_noniid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n", 1545 | " print(f'\\n\\n========================================================\\n\\n')\n", 1546 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n", 1547 | "\n", 1548 | " log = {\n", 1549 | " 'history': test_history,\n", 1550 | " 'hyperparams': hparams.__dict__\n", 1551 | " }\n", 1552 | "\n", 1553 | " with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:\n", 1554 | " pickle.dump(log, file)\n", 1555 | "\n", 1556 | " return test_history" 1557 | ], 1558 | "execution_count": null, 1559 | "outputs": [] 1560 | }, 1561 | { 1562 | "cell_type": "code", 1563 | "metadata": { 1564 | "id": "AaaamxrWGKct" 1565 | }, 1566 | "source": [ 1567 | "exp_history = list()\n", 1568 | "for run_id in range(2): # TOTAL RUNS\n", 1569 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n", 1570 | " exp_history.append(run_experiment(run_id))\n", 1571 | " print(f'\\n\\n========================================================\\n\\n')" 1572 | ], 1573 | "execution_count": null, 1574 | "outputs": [] 1575 | }, 1576 | { 1577 | "cell_type": "code", 1578 | "metadata": { 1579 | "id": "ftL-MoHxwe5C" 1580 | }, 1581 | "source": [ 1582 | "exp_log[title] = {\n", 1583 | " 'history': exp_history,\n", 1584 | " 'hyperparams': hparams.__dict__\n", 1585 | "}" 1586 | ], 1587 | "execution_count": null, 1588 | "outputs": [] 1589 | }, 1590 | { 1591 | "cell_type": "code", 1592 | "metadata": { 1593 | "id": "jceXDZXOwezj" 1594 | }, 1595 | "source": [ 1596 | "df = None\n", 1597 | "for i, e in enumerate(exp_history):\n", 1598 | " if i == 0:\n", 1599 | " df = pd.json_normalize(e)\n", 1600 | " continue\n", 1601 | " df = df + pd.json_normalize(e)\n", 1602 | " \n", 1603 | "df_avg = df / len(exp_history)\n", 1604 | "avg_history = df_avg.to_dict(orient='records')" 1605 | ], 1606 | "execution_count": null, 1607 | "outputs": [] 1608 | }, 1609 | { 1610 | "cell_type": "code", 1611 | "metadata": { 1612 | "id": "1u7uTHwJ6KXE" 1613 | }, 1614 | "source": [ 1615 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')" 1616 | ], 1617 | "execution_count": null, 1618 | "outputs": [] 1619 | }, 1620 | { 1621 | "cell_type": "code", 1622 | "metadata": { 1623 | "id": "um4eBO8O6KLX" 1624 | }, 1625 | "source": [ 1626 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')" 1627 | ], 1628 | "execution_count": null, 1629 | "outputs": [] 1630 | }, 1631 | { 1632 | "cell_type": "markdown", 1633 | "metadata": { 1634 | "id": "N80BpTFy6aR7" 1635 | }, 1636 | "source": [ 1637 | "### Pickle Experiment Results" 1638 | ] 1639 | }, 1640 | { 1641 | "cell_type": "code", 1642 | "metadata": { 1643 | "id": "tGlR8COy6aCN" 1644 | }, 1645 | "source": [ 1646 | "with open(f'{EXP_DIR}/results_niid.pkl', 'wb') as file:\n", 1647 | " pickle.dump(exp_log, file)" 1648 | ], 1649 | "execution_count": null, 1650 | "outputs": [] 1651 | } 1652 | ] 1653 | } --------------------------------------------------------------------------------