├── .gitignore
├── LICENSE
├── README.md
└── Shakespeare
├── FedAvg.ipynb
├── FedMed.ipynb
├── FedProx.ipynb
└── qFedAvg.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Santiago Gonzalez Toral
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fl-algorithms
2 |
3 | A repository containing notebooks that implement different Federated Learning algorithms using PyTorch
4 |
5 | # Experiments
6 |
7 | ### FL using the Shakespeare dataset
8 | - **Federated Averaging** (FedAvg)
9 | - [Notebook](Shakespeare/FedAvg.ipynb)
10 | - [Paper: Communication-Efficient Learning of Deep Networks from Decentralized Data](http://proceedings.mlr.press/v54/mcmahan17a.html)
11 | - **Federated Prox** (FedProx)
12 | - [Notebook](Shakespeare/FedProx.ipynb)
13 | - [Paper: Federated Optimization in Heterogenous Networks](https://arxiv.org/abs/1812.06127)
14 | - **Federated Median** (FedMed)
15 | - [Notebook](Shakespeare/FedMed.ipynb)
16 | - [Paper: Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](http://proceedings.mlr.press/v80/yin18a.html)
17 | - **q-Federated Averaging** (qFedAvg) a.k.a. **q-Fair Federated Learning** (qFFL)
18 | - [Notebook](Shakespeare/qFedAvg.ipynb)
19 | - [Paper: Fair Resource Allocation In Federated Learning](https://arxiv.org/abs/1905.10497)
20 |
21 | # License
22 |
23 | [MIT](LICENSE)
24 |
--------------------------------------------------------------------------------
/Shakespeare/FedAvg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "FedAvg.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [
9 | "Y-ZegBArf4xQ",
10 | "cce_-qnxhD4n",
11 | "XelbyPsDlfgb",
12 | "IOBblyFGlwlU",
13 | "sWVOxcAao2_t",
14 | "vFFAfTOwpk4j",
15 | "c640e4NnpksE",
16 | "3crFDN0xqGu6",
17 | "YXtGLkoAqLIW"
18 | ],
19 | "toc_visible": true,
20 | "authorship_tag": "ABX9TyODmfkWrZZ1tNBNYaII6os2",
21 | "include_colab_link": true
22 | },
23 | "kernelspec": {
24 | "name": "python3",
25 | "display_name": "Python 3"
26 | },
27 | "language_info": {
28 | "name": "python"
29 | },
30 | "accelerator": "GPU"
31 | },
32 | "cells": [
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "view-in-github",
37 | "colab_type": "text"
38 | },
39 | "source": [
40 | "
"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {
46 | "id": "ZXnVpPCFfQkm"
47 | },
48 | "source": [
49 | "# FedPerf - Shakespeare + FedAvg algorithm"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {
55 | "id": "TWDjzbCzfUFt"
56 | },
57 | "source": [
58 | "## Setup & Dependencies Installation"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "metadata": {
64 | "id": "z_f-rWudW0Fo"
65 | },
66 | "source": [
67 | "%%capture\n",
68 | "!pip install torchsummaryX unidecode"
69 | ],
70 | "execution_count": null,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "code",
75 | "metadata": {
76 | "id": "3TmMInb_fnw_"
77 | },
78 | "source": [
79 | "%load_ext tensorboard\n",
80 | "\n",
81 | "import copy\n",
82 | "from functools import reduce\n",
83 | "import json\n",
84 | "import matplotlib.pyplot as plt\n",
85 | "import numpy as np\n",
86 | "import os\n",
87 | "import pandas as pd\n",
88 | "import pickle\n",
89 | "import random\n",
90 | "from sklearn.model_selection import train_test_split\n",
91 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
92 | "import time\n",
93 | "import torch\n",
94 | "from torch.autograd import Variable\n",
95 | "import torch.nn as nn\n",
96 | "import torch.nn.functional as F\n",
97 | "from torch.utils.data import Dataset\n",
98 | "from torch.utils.data.dataloader import DataLoader\n",
99 | "from torch.utils.data.sampler import Sampler\n",
100 | "from torch.utils.tensorboard import SummaryWriter\n",
101 | "from torchsummary import summary\n",
102 | "from torchsummaryX import summary as summaryx\n",
103 | "from torchvision import transforms, utils, datasets\n",
104 | "from tqdm.notebook import tqdm\n",
105 | "from unidecode import unidecode\n",
106 | "\n",
107 | "%matplotlib inline\n",
108 | "\n",
109 | "# Check assigned GPU\n",
110 | "gpu_info = !nvidia-smi\n",
111 | "gpu_info = '\\n'.join(gpu_info)\n",
112 | "if gpu_info.find('failed') >= 0:\n",
113 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n",
114 | " print('and then re-execute this cell.')\n",
115 | "else:\n",
116 | " print(gpu_info)\n",
117 | "\n",
118 | "# set manual seed for reproducibility\n",
119 | "RANDOM_SEED = 42\n",
120 | "\n",
121 | "# general reproducibility\n",
122 | "random.seed(RANDOM_SEED)\n",
123 | "np.random.seed(RANDOM_SEED)\n",
124 | "torch.manual_seed(RANDOM_SEED)\n",
125 | "torch.cuda.manual_seed(RANDOM_SEED)\n",
126 | "\n",
127 | "# gpu training specific\n",
128 | "torch.backends.cudnn.deterministic = True\n",
129 | "torch.backends.cudnn.benchmark = False"
130 | ],
131 | "execution_count": null,
132 | "outputs": []
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {
137 | "id": "8Mfqv6uHfwh-"
138 | },
139 | "source": [
140 | "## Mount GDrive"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "metadata": {
146 | "id": "75qbJwxsGj-k"
147 | },
148 | "source": [
149 | "BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/FedAvg'"
150 | ],
151 | "execution_count": null,
152 | "outputs": []
153 | },
154 | {
155 | "cell_type": "code",
156 | "metadata": {
157 | "id": "Ost-6lXSfveC"
158 | },
159 | "source": [
160 | "try:\n",
161 | " from google.colab import drive\n",
162 | " drive.mount('/content/drive')\n",
163 | " os.makedirs(BASE_DIR, exist_ok=True)\n",
164 | "except:\n",
165 | " print(\"WARNING: Results won't be stored on GDrive\")\n",
166 | " BASE_DIR = './'\n",
167 | "\n"
168 | ],
169 | "execution_count": null,
170 | "outputs": []
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "id": "Y-ZegBArf4xQ"
176 | },
177 | "source": [
178 | "## Loading Dataset"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "metadata": {
184 | "id": "hf03LRxof7Zj"
185 | },
186 | "source": [
187 | "!rm -Rf data\n",
188 | "!mkdir -p data scripts"
189 | ],
190 | "execution_count": null,
191 | "outputs": []
192 | },
193 | {
194 | "cell_type": "code",
195 | "metadata": {
196 | "id": "ngygA4-Fgobx"
197 | },
198 | "source": [
199 | "GENERATE_DATASET = False # If False, download the dataset provided by the q-FFL paper\n",
200 | "DATA_DIR = 'data/'\n",
201 | "# Dataset generation params\n",
202 | "SAMPLES_FRACTION = 1. # If using an already generated dataset\n",
203 | "# SAMPLES_FRACTION = 0.2 # Fraction of total samples in the dataset - FedProx default script\n",
204 | "# SAMPLES_FRACTION = 0.05 # Fraction of total samples in the dataset - qFFL\n",
205 | "TRAIN_FRACTION = 0.8 # Train set size\n",
206 | "MIN_SAMPLES = 0 # Min samples per client (for filtering purposes) - FedProx\n",
207 | "# MIN_SAMPLES = 64 # Min samples per client (for filtering purposes) - qFFL"
208 | ],
209 | "execution_count": null,
210 | "outputs": []
211 | },
212 | {
213 | "cell_type": "code",
214 | "metadata": {
215 | "id": "nUmwJgJygoYD"
216 | },
217 | "source": [
218 | "# Download raw dataset\n",
219 | "# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt\n",
220 | "!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt"
221 | ],
222 | "execution_count": null,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "code",
227 | "metadata": {
228 | "id": "4dCvx80BgoVr"
229 | },
230 | "source": [
231 | "if not GENERATE_DATASET:\n",
232 | " !rm -Rf data/train data/test\n",
233 | " !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD # Download Shakespeare dataset used by the FedProx paper\n",
234 | " !unzip shakespeare.zip\n",
235 | " !mv -f shakespeare_paper/train data/\n",
236 | " !mv -f shakespeare_paper/test data/\n",
237 | " !rm -R shakespeare_paper/ shakespeare.zip\n"
238 | ],
239 | "execution_count": null,
240 | "outputs": []
241 | },
242 | {
243 | "cell_type": "code",
244 | "metadata": {
245 | "id": "a4pzFvPvhQhq"
246 | },
247 | "source": [
248 | "corpus = []\n",
249 | "with open('data/shakespeare.txt', 'r') as f:\n",
250 | " data = list(unidecode(f.read()))\n",
251 | " corpus = list(set(list(data)))\n",
252 | "print('Corpus Length:', len(corpus))"
253 | ],
254 | "execution_count": null,
255 | "outputs": []
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {
260 | "id": "cce_-qnxhD4n"
261 | },
262 | "source": [
263 | "#### Dataset Preprocessing script"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "metadata": {
269 | "id": "Rt13M4IcgoTV"
270 | },
271 | "source": [
272 | "%%capture\n",
273 | "if GENERATE_DATASET:\n",
274 | " # Download dataset generation scripts\n",
275 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py\n",
276 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py\n",
277 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py\n",
278 | "\n",
279 | " # Download data preprocessing scripts\n",
280 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py\n",
281 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py"
282 | ],
283 | "execution_count": null,
284 | "outputs": []
285 | },
286 | {
287 | "cell_type": "code",
288 | "metadata": {
289 | "id": "EIEyRW27goPo"
290 | },
291 | "source": [
292 | "# Running scripts\n",
293 | "if GENERATE_DATASET:\n",
294 | " !mkdir -p data/raw_data data/all_data data/train data/test\n",
295 | " !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data\n",
296 | " !python scripts/gen_all_data.py"
297 | ],
298 | "execution_count": null,
299 | "outputs": []
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {
304 | "id": "mq8V6v_4hhhD"
305 | },
306 | "source": [
307 | "#### Dataset class"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "metadata": {
313 | "id": "H2SjEBKoWDxv"
314 | },
315 | "source": [
316 | "class ShakespeareDataset(Dataset):\n",
317 | " def __init__(self, x, y, corpus, seq_length):\n",
318 | " self.x = x\n",
319 | " self.y = y\n",
320 | " self.corpus = corpus\n",
321 | " self.corpus_size = len(self.corpus)\n",
322 | " super(ShakespeareDataset, self).__init__()\n",
323 | "\n",
324 | " def __len__(self):\n",
325 | " return len(self.x)\n",
326 | "\n",
327 | " def __repr__(self):\n",
328 | " return f'{self.__class__} - (length: {self.__len__()})'\n",
329 | "\n",
330 | " def __getitem__(self, i):\n",
331 | " input_seq = self.x[i]\n",
332 | " next_char = self.y[i]\n",
333 | " # print('\\tgetitem', i, input_seq, next_char)\n",
334 | " input_value = self.text2charindxs(input_seq)\n",
335 | " target_value = self.get_label_from_char(next_char)\n",
336 | " return input_value, target_value\n",
337 | "\n",
338 | " def text2charindxs(self, text):\n",
339 | " tensor = torch.zeros(len(text), dtype=torch.int32)\n",
340 | " for i, c in enumerate(text):\n",
341 | " tensor[i] = self.get_label_from_char(c)\n",
342 | " return tensor\n",
343 | "\n",
344 | " def get_label_from_char(self, c):\n",
345 | " return self.corpus.index(c)\n",
346 | "\n",
347 | " def get_char_from_label(self, l):\n",
348 | " return self.corpus[l]"
349 | ],
350 | "execution_count": null,
351 | "outputs": []
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {
356 | "id": "9fgJtS62lYAN"
357 | },
358 | "source": [
359 | "##### Federated Dataset"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "metadata": {
365 | "id": "5DqL5pTmgn5X"
366 | },
367 | "source": [
368 | "class ShakespeareFedDataset(ShakespeareDataset):\n",
369 | " def __init__(self, x, y, corpus, seq_length):\n",
370 | " super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)\n",
371 | "\n",
372 | " def dataloader(self, batch_size, shuffle=True):\n",
373 | " return DataLoader(self,\n",
374 | " batch_size=batch_size,\n",
375 | " shuffle=shuffle,\n",
376 | " num_workers=0)\n"
377 | ],
378 | "execution_count": null,
379 | "outputs": []
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {
384 | "id": "XelbyPsDlfgb"
385 | },
386 | "source": [
387 | "## Partitioning & Data Loaders"
388 | ]
389 | },
390 | {
391 | "cell_type": "markdown",
392 | "metadata": {
393 | "id": "IOBblyFGlwlU"
394 | },
395 | "source": [
396 | "### IID"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "metadata": {
402 | "id": "cSZFWKmsgn1p"
403 | },
404 | "source": [
405 | "def iid_partition_(dataset, clients):\n",
406 | " \"\"\"\n",
407 | " I.I.D paritioning of data over clients\n",
408 | " Shuffle the data\n",
409 | " Split it between clients\n",
410 | " \n",
411 | " params:\n",
412 | " - dataset (torch.utils.Dataset): Dataset\n",
413 | " - clients (int): Number of Clients to split the data between\n",
414 | "\n",
415 | " returns:\n",
416 | " - Dictionary of image indexes for each client\n",
417 | " \"\"\"\n",
418 | "\n",
419 | " num_items_per_client = int(len(dataset)/clients)\n",
420 | " client_dict = {}\n",
421 | " image_idxs = [i for i in range(len(dataset))]\n",
422 | "\n",
423 | " for i in range(clients):\n",
424 | " client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))\n",
425 | " image_idxs = list(set(image_idxs) - client_dict[i])\n",
426 | "\n",
427 | " return client_dict"
428 | ],
429 | "execution_count": null,
430 | "outputs": []
431 | },
432 | {
433 | "cell_type": "code",
434 | "metadata": {
435 | "id": "-lGwDyhSll9h"
436 | },
437 | "source": [
438 | "def iid_partition(corpus, seq_length=80, val_split=False):\n",
439 | "\n",
440 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n",
441 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n",
442 | "\n",
443 | " with open(train_file, 'r') as file:\n",
444 | " data_train = json.loads(unidecode(file.read()))\n",
445 | "\n",
446 | " with open(test_file, 'r') as file:\n",
447 | " data_test = json.loads(unidecode(file.read()))\n",
448 | "\n",
449 | " \n",
450 | " total_samples_train = sum(data_train['num_samples'])\n",
451 | "\n",
452 | " data_dict = {}\n",
453 | "\n",
454 | " x_train, y_train = [], []\n",
455 | " x_test, y_test = [], []\n",
456 | " # x_val, y_val = [], []\n",
457 | "\n",
458 | " users = list(zip(data_train['users'], data_train['num_samples']))\n",
459 | " # random.shuffle(users)\n",
460 | "\n",
461 | "\n",
462 | "\n",
463 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n",
464 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n",
465 | " sample_count = 0\n",
466 | " \n",
467 | " for i, (author_id, samples) in enumerate(users):\n",
468 | "\n",
469 | " if sample_count >= total_samples:\n",
470 | " print('Max samples reached', sample_count, '/', total_samples)\n",
471 | " break\n",
472 | "\n",
473 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n",
474 | " print('SKIP', author_id, samples)\n",
475 | " continue\n",
476 | " else:\n",
477 | " udata_train = data_train['user_data'][author_id]\n",
478 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n",
479 | " \n",
480 | " sample_count += max_samples\n",
481 | " # print('sample_count', sample_count)\n",
482 | "\n",
483 | " x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])\n",
484 | " y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])\n",
485 | "\n",
486 | " author_data = data_test['user_data'][author_id]\n",
487 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n",
488 | "\n",
489 | " if val_split:\n",
490 | " x_test.extend(author_data['x'][:int(test_size / 2)])\n",
491 | " y_test.extend(author_data['y'][:int(test_size / 2)])\n",
492 | " # x_val.extend(author_data['x'][int(test_size / 2):])\n",
493 | " # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])\n",
494 | "\n",
495 | " else:\n",
496 | " x_test.extend(author_data['x'][:int(test_size)])\n",
497 | " y_test.extend(author_data['y'][:int(test_size)])\n",
498 | "\n",
499 | " train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)\n",
500 | " test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)\n",
501 | " # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)\n",
502 | "\n",
503 | " data_dict = iid_partition_(train_ds, clients=len(users))\n",
504 | "\n",
505 | " return train_ds, data_dict, test_ds"
506 | ],
507 | "execution_count": null,
508 | "outputs": []
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {
513 | "id": "MFvc8mLoouKa"
514 | },
515 | "source": [
516 | "### Non-IID"
517 | ]
518 | },
519 | {
520 | "cell_type": "code",
521 | "metadata": {
522 | "id": "GZ76WsCZot9s"
523 | },
524 | "source": [
525 | "def noniid_partition(corpus, seq_length=80, val_split=False):\n",
526 | "\n",
527 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n",
528 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n",
529 | "\n",
530 | " with open(train_file, 'r') as file:\n",
531 | " data_train = json.loads(unidecode(file.read()))\n",
532 | "\n",
533 | " with open(test_file, 'r') as file:\n",
534 | " data_test = json.loads(unidecode(file.read()))\n",
535 | "\n",
536 | " \n",
537 | " total_samples_train = sum(data_train['num_samples'])\n",
538 | "\n",
539 | " data_dict = {}\n",
540 | "\n",
541 | " x_test, y_test = [], []\n",
542 | "\n",
543 | " users = list(zip(data_train['users'], data_train['num_samples']))\n",
544 | " # random.shuffle(users)\n",
545 | "\n",
546 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n",
547 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n",
548 | " sample_count = 0\n",
549 | " \n",
550 | " for i, (author_id, samples) in enumerate(users):\n",
551 | "\n",
552 | " if sample_count >= total_samples:\n",
553 | " print('Max samples reached', sample_count, '/', total_samples)\n",
554 | " break\n",
555 | "\n",
556 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n",
557 | " print('SKIP', author_id, samples)\n",
558 | " continue\n",
559 | " else:\n",
560 | " udata_train = data_train['user_data'][author_id]\n",
561 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n",
562 | " \n",
563 | " sample_count += max_samples\n",
564 | " # print('sample_count', sample_count)\n",
565 | "\n",
566 | " x_train = data_train['user_data'][author_id]['x'][:max_samples]\n",
567 | " y_train = data_train['user_data'][author_id]['y'][:max_samples]\n",
568 | "\n",
569 | " train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)\n",
570 | "\n",
571 | " x_val, y_val = None, None\n",
572 | " val_ds = None\n",
573 | " author_data = data_test['user_data'][author_id]\n",
574 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n",
575 | " if val_split:\n",
576 | " x_test += author_data['x'][:int(test_size / 2)]\n",
577 | " y_test += author_data['y'][:int(test_size / 2)]\n",
578 | " x_val = author_data['x'][int(test_size / 2):]\n",
579 | " y_val = author_data['y'][int(test_size / 2):int(test_size)]\n",
580 | "\n",
581 | " val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)\n",
582 | "\n",
583 | " else:\n",
584 | " x_test += author_data['x'][:int(test_size)]\n",
585 | " y_test += author_data['y'][:int(test_size)]\n",
586 | "\n",
587 | " data_dict[author_id] = {\n",
588 | " 'train_ds': train_ds,\n",
589 | " 'val_ds': val_ds\n",
590 | " }\n",
591 | "\n",
592 | " test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)\n",
593 | "\n",
594 | " return data_dict, test_ds"
595 | ],
596 | "execution_count": null,
597 | "outputs": []
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "metadata": {
602 | "id": "sWVOxcAao2_t"
603 | },
604 | "source": [
605 | "## Models"
606 | ]
607 | },
608 | {
609 | "cell_type": "markdown",
610 | "metadata": {
611 | "id": "gQQQ2mLeo6EA"
612 | },
613 | "source": [
614 | "### Shakespeare LSTM"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "metadata": {
620 | "id": "2mGXTrXRot7R"
621 | },
622 | "source": [
623 | "class ShakespeareLSTM(nn.Module):\n",
624 | " \"\"\"\n",
625 | " \"\"\"\n",
626 | "\n",
627 | " def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):\n",
628 | " super(ShakespeareLSTM, self).__init__()\n",
629 | " self.input_dim = input_dim\n",
630 | " self.embedding_dim = embedding_dim\n",
631 | " self.hidden_dim = hidden_dim\n",
632 | " self.classes = classes\n",
633 | " self.no_layers = lstm_layers\n",
634 | " \n",
635 | " self.embedding = nn.Embedding(num_embeddings=self.classes,\n",
636 | " embedding_dim=self.embedding_dim)\n",
637 | " self.lstm = nn.LSTM(input_size=self.embedding_dim, \n",
638 | " hidden_size=self.hidden_dim,\n",
639 | " num_layers=self.no_layers,\n",
640 | " batch_first=batch_first, \n",
641 | " dropout=dropout if self.no_layers > 1 else 0.)\n",
642 | " self.fc = nn.Linear(hidden_dim, self.classes)\n",
643 | "\n",
644 | " def forward(self, x, hc=None):\n",
645 | " batch_size = x.size(0)\n",
646 | " x_emb = self.embedding(x)\n",
647 | " out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)\n",
648 | " dense = self.fc(ht[-1])\n",
649 | " return dense\n",
650 | " \n",
651 | " def init_hidden(self, batch_size):\n",
652 | " return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),\n",
653 | " Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))\n"
654 | ],
655 | "execution_count": null,
656 | "outputs": []
657 | },
658 | {
659 | "cell_type": "markdown",
660 | "metadata": {
661 | "id": "5QsuJlVipMc8"
662 | },
663 | "source": [
664 | "#### Model Summary"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "metadata": {
670 | "id": "n_Vb0BYpot5I"
671 | },
672 | "source": [
673 | "batch_size = 10\n",
674 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n",
675 | "\n",
676 | "shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n",
677 | " embedding_dim=8, # mcmahan17a, fedprox, qFFL\n",
678 | " hidden_dim=256, # mcmahan17a, fedprox impl\n",
679 | " # hidden_dim=100, # fedprox paper\n",
680 | " classes=len(corpus),\n",
681 | " lstm_layers=2,\n",
682 | " dropout=0.1, # TODO:\n",
683 | " batch_first=True\n",
684 | " )\n",
685 | "\n",
686 | "if torch.cuda.is_available():\n",
687 | " shakespeare_lstm.cuda()\n",
688 | "\n",
689 | "\n",
690 | "\n",
691 | "hc = shakespeare_lstm.init_hidden(batch_size)\n",
692 | "\n",
693 | "x_sample = torch.zeros((batch_size, seq_length),\n",
694 | " dtype=torch.long,\n",
695 | " device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))\n",
696 | "\n",
697 | "x_sample[0][0] = 1\n",
698 | "x_sample\n",
699 | "\n",
700 | "print(\"\\nShakespeare LSTM SUMMARY\")\n",
701 | "print(summaryx(shakespeare_lstm, x_sample))"
702 | ],
703 | "execution_count": null,
704 | "outputs": []
705 | },
706 | {
707 | "cell_type": "markdown",
708 | "metadata": {
709 | "id": "qn7egnzTpeks"
710 | },
711 | "source": [
712 | "## FedAvg Algorithm"
713 | ]
714 | },
715 | {
716 | "cell_type": "markdown",
717 | "metadata": {
718 | "id": "vFFAfTOwpk4j"
719 | },
720 | "source": [
721 | "### Plot Utils"
722 | ]
723 | },
724 | {
725 | "cell_type": "code",
726 | "metadata": {
727 | "id": "oyYjWa6IpnTY"
728 | },
729 | "source": [
730 | "from sklearn.metrics import f1_score"
731 | ],
732 | "execution_count": null,
733 | "outputs": []
734 | },
735 | {
736 | "cell_type": "code",
737 | "metadata": {
738 | "id": "367THsiTpo-C"
739 | },
740 | "source": [
741 | "def plot_scores(history, exp_id, title, suffix):\n",
742 | " accuracies = [x['accuracy'] for x in history]\n",
743 | " f1_macro = [x['f1_macro'] for x in history]\n",
744 | " f1_weighted = [x['f1_weighted'] for x in history]\n",
745 | "\n",
746 | " fig, ax = plt.subplots()\n",
747 | " ax.plot(accuracies, 'tab:orange')\n",
748 | " ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)\n",
749 | " ax.grid()\n",
750 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)\n",
751 | " plt.show()\n",
752 | "\n",
753 | " fig, ax = plt.subplots()\n",
754 | " ax.plot(f1_macro, 'tab:orange')\n",
755 | " ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)\n",
756 | " ax.grid()\n",
757 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')\n",
758 | " plt.show()\n",
759 | "\n",
760 | " fig, ax = plt.subplots()\n",
761 | " ax.plot(f1_weighted, 'tab:orange')\n",
762 | " ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)\n",
763 | " ax.grid()\n",
764 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')\n",
765 | " plt.show()\n",
766 | "\n",
767 | "\n",
768 | "def plot_losses(history, exp_id, title, suffix):\n",
769 | " val_losses = [x['loss'] for x in history]\n",
770 | " train_losses = [x['train_loss'] for x in history]\n",
771 | "\n",
772 | " fig, ax = plt.subplots()\n",
773 | " ax.plot(train_losses, 'tab:orange')\n",
774 | " ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)\n",
775 | " ax.grid()\n",
776 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')\n",
777 | " plt.show()\n",
778 | "\n",
779 | " fig, ax = plt.subplots()\n",
780 | " ax.plot(val_losses, 'tab:orange')\n",
781 | " ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)\n",
782 | " ax.grid()\n",
783 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')\n",
784 | " plt.show()\n"
785 | ],
786 | "execution_count": null,
787 | "outputs": []
788 | },
789 | {
790 | "cell_type": "markdown",
791 | "metadata": {
792 | "id": "c640e4NnpksE"
793 | },
794 | "source": [
795 | "### Systems Heterogeneity Simulations\n",
796 | "\n",
797 | "Generate epochs for selected clients based on percentage of devices that corresponds to heterogeneity. \n",
798 | "\n",
799 | "Assign x number of epochs (chosen unifirmly at random between [1, E]) to 0%, 50% or 90% of the selected devices, respectively. Settings where 0% devices perform fewer than E epochs of work correspond to the environments without system heterogeneity, while 90% of the devices sending their partial solutions corresponds to highly heterogenous system."
800 | ]
801 | },
802 | {
803 | "cell_type": "code",
804 | "metadata": {
805 | "id": "zuEZYnl5ot2m"
806 | },
807 | "source": [
808 | "def GenerateLocalEpochs(percentage, size, max_epochs):\n",
809 | " ''' Method generates list of epochs for selected clients\n",
810 | " to replicate system heteroggeneity\n",
811 | "\n",
812 | " Params:\n",
813 | " percentage: percentage of clients to have fewer than E epochs\n",
814 | " size: total size of the list\n",
815 | " max_epochs: maximum value for local epochs\n",
816 | " \n",
817 | " Returns:\n",
818 | " List of size epochs for each Client Update\n",
819 | "\n",
820 | " '''\n",
821 | "\n",
822 | " # if percentage is 0 then each client runs for E epochs\n",
823 | " if percentage == 0:\n",
824 | " return np.array([max_epochs]*size)\n",
825 | " else:\n",
826 | " # get the number of clients to have fewer than E epochs\n",
827 | " heterogenous_size = int((percentage/100) * size)\n",
828 | "\n",
829 | " # generate random uniform epochs of heterogenous size between 1 and E\n",
830 | " epoch_list = np.random.randint(1, max_epochs, heterogenous_size)\n",
831 | "\n",
832 | " # the rest of the clients will have E epochs\n",
833 | " remaining_size = size - heterogenous_size\n",
834 | " rem_list = [max_epochs]*remaining_size\n",
835 | "\n",
836 | " epoch_list = np.append(epoch_list, rem_list, axis=0)\n",
837 | " \n",
838 | " # shuffle the list and return\n",
839 | " np.random.shuffle(epoch_list)\n",
840 | "\n",
841 | " return epoch_list"
842 | ],
843 | "execution_count": null,
844 | "outputs": []
845 | },
846 | {
847 | "cell_type": "markdown",
848 | "metadata": {
849 | "id": "VQ9PZM0Gp9ve"
850 | },
851 | "source": [
852 | "### Local Training (Client Update)"
853 | ]
854 | },
855 | {
856 | "cell_type": "code",
857 | "metadata": {
858 | "id": "EDJFltwdotzZ"
859 | },
860 | "source": [
861 | "class CustomDataset(Dataset):\n",
862 | " def __init__(self, dataset, idxs):\n",
863 | " self.dataset = dataset\n",
864 | " self.idxs = list(idxs)\n",
865 | "\n",
866 | " def __len__(self):\n",
867 | " return len(self.idxs)\n",
868 | "\n",
869 | " def __getitem__(self, item):\n",
870 | " data, label = self.dataset[self.idxs[item]]\n",
871 | " return data, label"
872 | ],
873 | "execution_count": null,
874 | "outputs": []
875 | },
876 | {
877 | "cell_type": "code",
878 | "metadata": {
879 | "id": "HtRzU5Yepddq"
880 | },
881 | "source": [
882 | "class ClientUpdate(object):\n",
883 | " def __init__(self, dataset, batchSize, learning_rate, epochs, idxs, mu, algorithm):\n",
884 | " # self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True)\n",
885 | " if hasattr(dataset, 'dataloader'):\n",
886 | " self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)\n",
887 | " else:\n",
888 | " self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)\n",
889 | "\n",
890 | " self.algorithm = algorithm\n",
891 | " self.learning_rate = learning_rate\n",
892 | " self.epochs = epochs\n",
893 | " self.mu = mu\n",
894 | "\n",
895 | " def train(self, model):\n",
896 | " # print(\"Client training for {} epochs.\".format(self.epochs))\n",
897 | " criterion = nn.CrossEntropyLoss()\n",
898 | " proximal_criterion = nn.MSELoss(reduction='mean')\n",
899 | " optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)\n",
900 | "\n",
901 | " # use the weights of global model for proximal term calculation\n",
902 | " global_model = copy.deepcopy(model)\n",
903 | "\n",
904 | " # calculate local training time\n",
905 | " start_time = time.time()\n",
906 | "\n",
907 | "\n",
908 | " e_loss = []\n",
909 | " for epoch in range(1, self.epochs+1):\n",
910 | "\n",
911 | " train_loss = 0.0\n",
912 | "\n",
913 | " model.train()\n",
914 | " for data, labels in self.train_loader:\n",
915 | "\n",
916 | " if torch.cuda.is_available():\n",
917 | " data, labels = data.cuda(), labels.cuda()\n",
918 | "\n",
919 | " # clear the gradients\n",
920 | " optimizer.zero_grad()\n",
921 | " # make a forward pass\n",
922 | " output = model(data)\n",
923 | "\n",
924 | " # calculate the loss + the proximal term\n",
925 | " _, pred = torch.max(output, 1)\n",
926 | "\n",
927 | " if self.algorithm == 'fedprox':\n",
928 | " proximal_term = 0.0\n",
929 | "\n",
930 | " # iterate through the current and global model parameters\n",
931 | " for w, w_t in zip(model.parameters(), global_model.parameters()) :\n",
932 | " # update the proximal term \n",
933 | " #proximal_term += torch.sum(torch.abs((w-w_t)**2))\n",
934 | " proximal_term += (w-w_t).norm(2)\n",
935 | "\n",
936 | " loss = criterion(output, labels) + (self.mu/2)*proximal_term\n",
937 | " else:\n",
938 | " loss = criterion(output, labels)\n",
939 | " \n",
940 | " # do a backwards pass\n",
941 | " loss.backward()\n",
942 | " # perform a single optimization step\n",
943 | " optimizer.step()\n",
944 | " # update training loss\n",
945 | " train_loss += loss.item()*data.size(0)\n",
946 | "\n",
947 | " # average losses\n",
948 | " train_loss = train_loss/len(self.train_loader.dataset)\n",
949 | " e_loss.append(train_loss)\n",
950 | "\n",
951 | " total_loss = sum(e_loss)/len(e_loss)\n",
952 | "\n",
953 | " return model.state_dict(), total_loss, (time.time() - start_time)"
954 | ],
955 | "execution_count": null,
956 | "outputs": []
957 | },
958 | {
959 | "cell_type": "markdown",
960 | "metadata": {
961 | "id": "3crFDN0xqGu6"
962 | },
963 | "source": [
964 | "### Server Side Training"
965 | ]
966 | },
967 | {
968 | "cell_type": "code",
969 | "metadata": {
970 | "id": "c085xSOoqEHk"
971 | },
972 | "source": [
973 | "def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, mu, percentage, plt_title, plt_color, target_test_accuracy,\n",
974 | " classes, algorithm=\"fedprox\", history=[], eval_every=1, tb_logger=None):\n",
975 | " \"\"\"\n",
976 | " Function implements the Federated Averaging Algorithm from the FedAvg paper.\n",
977 | " Specifically, this function is used for the server side training and weight update\n",
978 | "\n",
979 | " Params:\n",
980 | " - model: PyTorch model to train\n",
981 | " - rounds: Number of communication rounds for the client update\n",
982 | " - batch_size: Batch size for client update training\n",
983 | " - lr: Learning rate used for client update training\n",
984 | " - ds: Dataset used for training\n",
985 | " - data_dict: Type of data partition used for training (IID or non-IID)\n",
986 | " - test_data_dict: Data used for testing the model\n",
987 | " - C: Fraction of clients randomly chosen to perform computation on each round\n",
988 | " - K: Total number of clients\n",
989 | " - E: Number of training passes each client makes over its local dataset per round\n",
990 | " - mu: proximal term constant\n",
991 | " - percentage: percentage of selected client to have fewer than E epochs\n",
992 | " Returns:\n",
993 | " - model: Trained model on the server\n",
994 | " \"\"\"\n",
995 | "\n",
996 | " start = time.time()\n",
997 | "\n",
998 | " # global model weights\n",
999 | " global_weights = model.state_dict()\n",
1000 | "\n",
1001 | " # training loss\n",
1002 | " train_loss = []\n",
1003 | "\n",
1004 | " # test accuracy\n",
1005 | " test_acc = []\n",
1006 | "\n",
1007 | " # store last loss for convergence\n",
1008 | " last_loss = 0.0\n",
1009 | "\n",
1010 | " # total time taken \n",
1011 | " total_time = 0\n",
1012 | "\n",
1013 | " print(f\"System heterogeneity set to {percentage}% stragglers.\\n\")\n",
1014 | " print(f\"Picking {max(int(C*K),1 )} random clients per round.\\n\")\n",
1015 | "\n",
1016 | " users_id = list(data_dict.keys())\n",
1017 | "\n",
1018 | " for curr_round in range(1, rounds+1):\n",
1019 | " w, local_loss, lst_local_train_time = [], [], []\n",
1020 | "\n",
1021 | " m = max(int(C*K), 1)\n",
1022 | "\n",
1023 | " heterogenous_epoch_list = GenerateLocalEpochs(percentage, size=m, max_epochs=E)\n",
1024 | " heterogenous_epoch_list = np.array(heterogenous_epoch_list)\n",
1025 | " # print('heterogenous_epoch_list', len(heterogenous_epoch_list))\n",
1026 | "\n",
1027 | " S_t = np.random.choice(range(K), m, replace=False)\n",
1028 | " S_t = np.array(S_t)\n",
1029 | " print('Clients: {}/{} -> {}'.format(len(S_t), K, S_t))\n",
1030 | " \n",
1031 | " # For Federated Averaging, drop all the clients that are stragglers\n",
1032 | " if algorithm == 'fedavg':\n",
1033 | " stragglers_indices = np.argwhere(heterogenous_epoch_list < E)\n",
1034 | " heterogenous_epoch_list = np.delete(heterogenous_epoch_list, stragglers_indices)\n",
1035 | " S_t = np.delete(S_t, stragglers_indices)\n",
1036 | "\n",
1037 | " # for _, (k, epoch) in tqdm(enumerate(zip(S_t, heterogenous_epoch_list))):\n",
1038 | " for i in tqdm(range(len(S_t))):\n",
1039 | " # print('k', k)\n",
1040 | " k = S_t[i]\n",
1041 | " epoch = heterogenous_epoch_list[i]\n",
1042 | " key = users_id[k]\n",
1043 | " ds_ = ds if ds else data_dict[key]['train_ds']\n",
1044 | " idxs = data_dict[key] if ds else None\n",
1045 | " # print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')\n",
1046 | " local_update = ClientUpdate(dataset=ds_, batchSize=batch_size, learning_rate=lr, epochs=epoch, idxs=idxs, mu=mu, algorithm=algorithm)\n",
1047 | " weights, loss, local_train_time = local_update.train(model=copy.deepcopy(model))\n",
1048 | " # print(f'Local train time for {k} on {len(idxs) if idxs else len(ds_)} samples: {local_train_time}')\n",
1049 | " # print(f'Local train time: {local_train_time}')\n",
1050 | "\n",
1051 | " w.append(copy.deepcopy(weights))\n",
1052 | " local_loss.append(copy.deepcopy(loss))\n",
1053 | " lst_local_train_time.append(local_train_time)\n",
1054 | "\n",
1055 | " # calculate time to update the global weights\n",
1056 | " global_start_time = time.time()\n",
1057 | "\n",
1058 | " # updating the global weights\n",
1059 | " weights_avg = copy.deepcopy(w[0])\n",
1060 | " for k in weights_avg.keys():\n",
1061 | " for i in range(1, len(w)):\n",
1062 | " weights_avg[k] += w[i][k]\n",
1063 | "\n",
1064 | " weights_avg[k] = torch.div(weights_avg[k], len(w))\n",
1065 | "\n",
1066 | " global_weights = weights_avg\n",
1067 | "\n",
1068 | " global_end_time = time.time()\n",
1069 | "\n",
1070 | " # calculate total time \n",
1071 | " total_time += (global_end_time - global_start_time) + sum(lst_local_train_time)/len(lst_local_train_time)\n",
1072 | "\n",
1073 | " # move the updated weights to our model state dict\n",
1074 | " model.load_state_dict(global_weights)\n",
1075 | "\n",
1076 | " # loss\n",
1077 | " loss_avg = sum(local_loss) / len(local_loss)\n",
1078 | " print('Round: {}... \\tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))\n",
1079 | " train_loss.append(loss_avg)\n",
1080 | " if tb_logger:\n",
1081 | " tb_logger.add_scalar(f'Train/Loss', loss_avg, curr_round)\n",
1082 | "\n",
1083 | " # testing\n",
1084 | " # if curr_round % eval_every == 0:\n",
1085 | " test_scores = testing(model, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(classes), classes)\n",
1086 | " test_scores['train_loss'] = loss_avg\n",
1087 | " test_loss, test_accuracy = test_scores['loss'], test_scores['accuracy']\n",
1088 | " history.append(test_scores)\n",
1089 | " \n",
1090 | " # print('Round: {}... \\tAverage Loss: {} \\tTest Loss: {} \\tTest Acc: {}'.format(curr_round, round(loss_avg, 3), round(test_loss, 3), round(test_accuracy, 3)))\n",
1091 | "\n",
1092 | " if tb_logger:\n",
1093 | " tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)\n",
1094 | " tb_logger.add_scalars(f'Test/Scores', {\n",
1095 | " 'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']\n",
1096 | " }, curr_round)\n",
1097 | "\n",
1098 | " test_acc.append(test_accuracy)\n",
1099 | " # break if we achieve the target test accuracy\n",
1100 | " if test_accuracy >= target_test_accuracy:\n",
1101 | " rounds = curr_round\n",
1102 | " break\n",
1103 | "\n",
1104 | " # break if we achieve convergence, i.e., loss between two consecutive rounds is <0.0001\n",
1105 | " if algorithm == 'fedprox' and abs(loss_avg - last_loss) < 1e-5:\n",
1106 | " rounds = curr_round\n",
1107 | " break\n",
1108 | " \n",
1109 | " # update the last loss\n",
1110 | " last_loss = loss_avg\n",
1111 | "\n",
1112 | " end = time.time()\n",
1113 | " \n",
1114 | " # plot train loss\n",
1115 | " fig, ax = plt.subplots()\n",
1116 | " x_axis = np.arange(1, rounds+1)\n",
1117 | " y_axis = np.array(train_loss)\n",
1118 | " ax.plot(x_axis, y_axis)\n",
1119 | "\n",
1120 | " ax.set(xlabel='Number of Rounds', ylabel='Train Loss', title=plt_title)\n",
1121 | " ax.grid()\n",
1122 | " # fig.savefig(plt_title+'.jpg', format='jpg')\n",
1123 | "\n",
1124 | " # plot test accuracy\n",
1125 | " fig1, ax1 = plt.subplots()\n",
1126 | " x_axis1 = np.arange(1, rounds+1)\n",
1127 | " y_axis1 = np.array(test_acc)\n",
1128 | " ax1.plot(x_axis1, y_axis1)\n",
1129 | "\n",
1130 | " ax1.set(xlabel='Number of Rounds', ylabel='Test Accuracy', title=plt_title)\n",
1131 | " ax1.grid()\n",
1132 | " # fig1.savefig(plt_title+'-test.jpg', format='jpg')\n",
1133 | " \n",
1134 | " print(\"Training Done! Total time taken to Train: {}\".format(end-start))\n",
1135 | "\n",
1136 | " return model, history"
1137 | ],
1138 | "execution_count": null,
1139 | "outputs": []
1140 | },
1141 | {
1142 | "cell_type": "markdown",
1143 | "metadata": {
1144 | "id": "YXtGLkoAqLIW"
1145 | },
1146 | "source": [
1147 | "### Testing Loop"
1148 | ]
1149 | },
1150 | {
1151 | "cell_type": "code",
1152 | "metadata": {
1153 | "id": "dQJIJno4qKvc"
1154 | },
1155 | "source": [
1156 | "def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):\n",
1157 | " #test loss \n",
1158 | " test_loss = 0.0\n",
1159 | " correct_class = list(0. for i in range(num_classes))\n",
1160 | " total_class = list(0. for i in range(num_classes))\n",
1161 | "\n",
1162 | " test_loader = DataLoader(dataset, batch_size=bs)\n",
1163 | " l = len(test_loader)\n",
1164 | " model.eval()\n",
1165 | " print('running validation...')\n",
1166 | " for i, (data, labels) in enumerate(tqdm(test_loader)):\n",
1167 | "\n",
1168 | " if torch.cuda.is_available():\n",
1169 | " data, labels = data.cuda(), labels.cuda()\n",
1170 | "\n",
1171 | " output = model(data)\n",
1172 | " loss = criterion(output, labels)\n",
1173 | " test_loss += loss.item()*data.size(0)\n",
1174 | "\n",
1175 | " _, pred = torch.max(output, 1)\n",
1176 | "\n",
1177 | " # For F1Score\n",
1178 | " y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()\n",
1179 | " y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()\n",
1180 | "\n",
1181 | " correct_tensor = pred.eq(labels.data.view_as(pred))\n",
1182 | " correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())\n",
1183 | "\n",
1184 | " #test accuracy for each object class\n",
1185 | " # for i in range(num_classes):\n",
1186 | " # label = labels.data[i]\n",
1187 | " # correct_class[label] += correct[i].item()\n",
1188 | " # total_class[label] += 1\n",
1189 | "\n",
1190 | " for i, lbl in enumerate(labels.data):\n",
1191 | " # print('lbl', i, lbl)\n",
1192 | " correct_class[lbl] += correct.data[i]\n",
1193 | " total_class[lbl] += 1\n",
1194 | " \n",
1195 | " # avg test loss\n",
1196 | " test_loss = test_loss/len(test_loader.dataset)\n",
1197 | " print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n",
1198 | "\n",
1199 | " # Avg F1 Score\n",
1200 | " f1_macro = f1_score(y_true, y_hat, average='macro')\n",
1201 | " # F1-Score -> weigthed to consider class imbalance\n",
1202 | " f1_weighted = f1_score(y_true, y_hat, average='weighted')\n",
1203 | " print(\"F1 Score: {:.6f} (macro) {:.6f} (weighted) %\\n\".format(f1_macro, f1_weighted))\n",
1204 | "\n",
1205 | " # print test accuracy\n",
1206 | " if print_all:\n",
1207 | " for i in range(num_classes):\n",
1208 | " if total_class[i]>0:\n",
1209 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % \n",
1210 | " (classes[i], 100 * correct_class[i] / total_class[i],\n",
1211 | " np.sum(correct_class[i]), np.sum(total_class[i])))\n",
1212 | " else:\n",
1213 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n",
1214 | "\n",
1215 | " overall_accuracy = np.sum(correct_class) / np.sum(total_class)\n",
1216 | "\n",
1217 | " print('\\nFinal Test Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))\n",
1218 | "\n",
1219 | " return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}"
1220 | ],
1221 | "execution_count": null,
1222 | "outputs": []
1223 | },
1224 | {
1225 | "cell_type": "markdown",
1226 | "metadata": {
1227 | "id": "uxqXLBd8qbC2"
1228 | },
1229 | "source": [
1230 | "## Experiments"
1231 | ]
1232 | },
1233 | {
1234 | "cell_type": "code",
1235 | "metadata": {
1236 | "id": "VRKlrkVHO8Na"
1237 | },
1238 | "source": [
1239 | "FAIL-ON-PURPOSE"
1240 | ],
1241 | "execution_count": null,
1242 | "outputs": []
1243 | },
1244 | {
1245 | "cell_type": "code",
1246 | "metadata": {
1247 | "id": "E2CfSkNVqKtL"
1248 | },
1249 | "source": [
1250 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n",
1251 | "embedding_dim = 8 # mcmahan17a, fedprox, qFFL\n",
1252 | "# hidden_dim = 100 # fedprox paper\n",
1253 | "hidden_dim = 256 # mcmahan17a, fedprox impl\n",
1254 | "num_classes = len(corpus)\n",
1255 | "classes = list(range(num_classes))\n",
1256 | "lstm_layers = 2 # mcmahan17a, fedprox, qFFL\n",
1257 | "dropout = 0.1 # TODO\n"
1258 | ],
1259 | "execution_count": null,
1260 | "outputs": []
1261 | },
1262 | {
1263 | "cell_type": "code",
1264 | "metadata": {
1265 | "id": "OvCgT_qbmFAO"
1266 | },
1267 | "source": [
1268 | "class Hyperparameters():\n",
1269 | "\n",
1270 | " def __init__(self, total_clients):\n",
1271 | " # number of training rounds\n",
1272 | " self.rounds = 50\n",
1273 | " # client fraction\n",
1274 | " self.C = 0.5\n",
1275 | " # number of clients\n",
1276 | " self.K = total_clients\n",
1277 | " # number of training passes on local dataset for each roung\n",
1278 | " # self.E = 20\n",
1279 | " self.E = 1\n",
1280 | " # batch size\n",
1281 | " self.batch_size = 10\n",
1282 | " # learning Rate\n",
1283 | " self.lr = 0.8\n",
1284 | " # proximal term constant\n",
1285 | " # self.mu = 0.0\n",
1286 | " self.mu = 0.001\n",
1287 | " # percentage of clients to have fewer than E epochs\n",
1288 | " self.percentage = 0\n",
1289 | " # self.percentage = 50\n",
1290 | " # self.percentage = 90\n",
1291 | " # target test accuracy\n",
1292 | " self.target_test_accuracy= 99.0\n",
1293 | " # self.target_test_accuracy=96.0"
1294 | ],
1295 | "execution_count": null,
1296 | "outputs": []
1297 | },
1298 | {
1299 | "cell_type": "code",
1300 | "metadata": {
1301 | "id": "m_JVF83mfM3f"
1302 | },
1303 | "source": [
1304 | "exp_log = dict()"
1305 | ],
1306 | "execution_count": null,
1307 | "outputs": []
1308 | },
1309 | {
1310 | "cell_type": "markdown",
1311 | "metadata": {
1312 | "id": "rYOPtnYoqhWd"
1313 | },
1314 | "source": [
1315 | "### IID"
1316 | ]
1317 | },
1318 | {
1319 | "cell_type": "code",
1320 | "metadata": {
1321 | "id": "FRKc7NrzqKpU"
1322 | },
1323 | "source": [
1324 | "train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True) # Not using val_ds but makes train eval periods faster\n",
1325 | "\n",
1326 | "total_clients = len(data_dict.keys())\n",
1327 | "'Total users:', total_clients"
1328 | ],
1329 | "execution_count": null,
1330 | "outputs": []
1331 | },
1332 | {
1333 | "cell_type": "code",
1334 | "metadata": {
1335 | "id": "eaKtpKT5q1q_"
1336 | },
1337 | "source": [
1338 | "hparams = Hyperparameters(total_clients)\n",
1339 | "hparams.__dict__"
1340 | ],
1341 | "execution_count": null,
1342 | "outputs": []
1343 | },
1344 | {
1345 | "cell_type": "code",
1346 | "metadata": {
1347 | "id": "5ilYMClTV_WR"
1348 | },
1349 | "source": [
1350 | "# Sweeping parameter\n",
1351 | "PARAM_NAME = 'clients_fraction'\n",
1352 | "PARAM_VALUE = hparams.C\n",
1353 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n",
1354 | "exp_id"
1355 | ],
1356 | "execution_count": null,
1357 | "outputs": []
1358 | },
1359 | {
1360 | "cell_type": "code",
1361 | "metadata": {
1362 | "id": "xAhy4CWVZy3F"
1363 | },
1364 | "source": [
1365 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n",
1366 | "os.makedirs(EXP_DIR, exist_ok=True)\n",
1367 | "\n",
1368 | "# tb_logger = SummaryWriter(log_dir)\n",
1369 | "# print(f'TBoard logger created at: {log_dir}')\n",
1370 | "\n",
1371 | "title = 'LSTM FedProx on IID'"
1372 | ],
1373 | "execution_count": null,
1374 | "outputs": []
1375 | },
1376 | {
1377 | "cell_type": "code",
1378 | "metadata": {
1379 | "id": "LwTdeiv8q8_L"
1380 | },
1381 | "source": [
1382 | "def run_experiment(run_id):\n",
1383 | "\n",
1384 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n",
1385 | " embedding_dim=embedding_dim, \n",
1386 | " hidden_dim=hidden_dim,\n",
1387 | " classes=num_classes,\n",
1388 | " lstm_layers=lstm_layers,\n",
1389 | " dropout=dropout,\n",
1390 | " batch_first=True\n",
1391 | " )\n",
1392 | "\n",
1393 | " if torch.cuda.is_available():\n",
1394 | " shakespeare_lstm.cuda()\n",
1395 | " \n",
1396 | " test_history = []\n",
1397 | "\n",
1398 | " lstm_iid_trained, test_history = training(shakespeare_lstm,\n",
1399 | " hparams.rounds, hparams.batch_size, hparams.lr,\n",
1400 | " train_ds,\n",
1401 | " data_dict,\n",
1402 | " test_ds,\n",
1403 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n",
1404 | " title, \"green\",\n",
1405 | " hparams.target_test_accuracy,\n",
1406 | " corpus, # classes\n",
1407 | " history=test_history,\n",
1408 | " algorithm='fedavg',\n",
1409 | " # tb_logger=tb_writer\n",
1410 | " )\n",
1411 | " \n",
1412 | "\n",
1413 | " final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n",
1414 | " print(f'\\n\\n========================================================\\n\\n')\n",
1415 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n",
1416 | "\n",
1417 | " log = {\n",
1418 | " 'history': test_history,\n",
1419 | " 'hyperparams': hparams.__dict__\n",
1420 | " }\n",
1421 | "\n",
1422 | " with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:\n",
1423 | " pickle.dump(log, file)\n",
1424 | "\n",
1425 | " return test_history\n"
1426 | ],
1427 | "execution_count": null,
1428 | "outputs": []
1429 | },
1430 | {
1431 | "cell_type": "code",
1432 | "metadata": {
1433 | "id": "gSU61KsSq87G"
1434 | },
1435 | "source": [
1436 | "exp_history = list()\n",
1437 | "for run_id in range(2): # TOTAL RUNS\n",
1438 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n",
1439 | " exp_history.append(run_experiment(run_id))\n",
1440 | " print(f'\\n\\n========================================================\\n\\n')"
1441 | ],
1442 | "execution_count": null,
1443 | "outputs": []
1444 | },
1445 | {
1446 | "cell_type": "code",
1447 | "metadata": {
1448 | "id": "us-HifGq3Uhf"
1449 | },
1450 | "source": [
1451 | "exp_log[title] = {\n",
1452 | " 'history': exp_history,\n",
1453 | " 'hyperparams': hparams.__dict__\n",
1454 | "}"
1455 | ],
1456 | "execution_count": null,
1457 | "outputs": []
1458 | },
1459 | {
1460 | "cell_type": "code",
1461 | "metadata": {
1462 | "id": "qDGpo4ug33dN"
1463 | },
1464 | "source": [
1465 | "df = None\n",
1466 | "for i, e in enumerate(exp_history):\n",
1467 | " if i == 0:\n",
1468 | " df = pd.json_normalize(e)\n",
1469 | " continue\n",
1470 | " df = df + pd.json_normalize(e)\n",
1471 | " \n",
1472 | "df_avg = df / len(exp_history)\n",
1473 | "avg_history = df_avg.to_dict(orient='records')"
1474 | ],
1475 | "execution_count": null,
1476 | "outputs": []
1477 | },
1478 | {
1479 | "cell_type": "code",
1480 | "metadata": {
1481 | "id": "Hf77BQAD36Eq"
1482 | },
1483 | "source": [
1484 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')"
1485 | ],
1486 | "execution_count": null,
1487 | "outputs": []
1488 | },
1489 | {
1490 | "cell_type": "code",
1491 | "metadata": {
1492 | "id": "wJClynRJ38Dh"
1493 | },
1494 | "source": [
1495 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')"
1496 | ],
1497 | "execution_count": null,
1498 | "outputs": []
1499 | },
1500 | {
1501 | "cell_type": "code",
1502 | "metadata": {
1503 | "id": "N36h40Zs1h_f"
1504 | },
1505 | "source": [
1506 | "with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:\n",
1507 | " pickle.dump(exp_log, file)"
1508 | ],
1509 | "execution_count": null,
1510 | "outputs": []
1511 | },
1512 | {
1513 | "cell_type": "markdown",
1514 | "metadata": {
1515 | "id": "BaoYWkWgqvUQ"
1516 | },
1517 | "source": [
1518 | "### Non-IID"
1519 | ]
1520 | },
1521 | {
1522 | "cell_type": "code",
1523 | "metadata": {
1524 | "id": "lVmhd_791lk7"
1525 | },
1526 | "source": [
1527 | "exp_log = dict()"
1528 | ],
1529 | "execution_count": null,
1530 | "outputs": []
1531 | },
1532 | {
1533 | "cell_type": "code",
1534 | "metadata": {
1535 | "id": "pILgaho8qKgF"
1536 | },
1537 | "source": [
1538 | "data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)\n",
1539 | "\n",
1540 | "total_clients = len(data_dict.keys())\n",
1541 | "'Total users:', total_clients"
1542 | ],
1543 | "execution_count": null,
1544 | "outputs": []
1545 | },
1546 | {
1547 | "cell_type": "code",
1548 | "metadata": {
1549 | "id": "Y3o7qgBcqKX_"
1550 | },
1551 | "source": [
1552 | "hparams = Hyperparameters(total_clients)\n",
1553 | "hparams.__dict__"
1554 | ],
1555 | "execution_count": null,
1556 | "outputs": []
1557 | },
1558 | {
1559 | "cell_type": "code",
1560 | "metadata": {
1561 | "id": "VANr1h0Pq51N"
1562 | },
1563 | "source": [
1564 | "# Sweeping parameter\n",
1565 | "PARAM_NAME = 'clients_fraction'\n",
1566 | "PARAM_VALUE = hparams.C\n",
1567 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n",
1568 | "exp_id"
1569 | ],
1570 | "execution_count": null,
1571 | "outputs": []
1572 | },
1573 | {
1574 | "cell_type": "code",
1575 | "metadata": {
1576 | "id": "yXgYFIyZ4ipm"
1577 | },
1578 | "source": [
1579 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n",
1580 | "os.makedirs(EXP_DIR, exist_ok=True)\n",
1581 | "\n",
1582 | "# tb_logger = SummaryWriter(log_dir)\n",
1583 | "# print(f'TBoard logger created at: {log_dir}')\n",
1584 | "\n",
1585 | "title = 'LSTM FedProx on Non-IID'"
1586 | ],
1587 | "execution_count": null,
1588 | "outputs": []
1589 | },
1590 | {
1591 | "cell_type": "code",
1592 | "metadata": {
1593 | "id": "Vnv7UaE0q6dG"
1594 | },
1595 | "source": [
1596 | "def run_experiment(run_id):\n",
1597 | "\n",
1598 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,\n",
1599 | " embedding_dim=embedding_dim,\n",
1600 | " hidden_dim=hidden_dim,\n",
1601 | " classes=num_classes,\n",
1602 | " lstm_layers=lstm_layers,\n",
1603 | " dropout=dropout,\n",
1604 | " batch_first=True\n",
1605 | " )\n",
1606 | "\n",
1607 | " if torch.cuda.is_available():\n",
1608 | " shakespeare_lstm.cuda()\n",
1609 | "\n",
1610 | " test_history = []\n",
1611 | "\n",
1612 | " lstm_non_iid_trained, test_history = training(shakespeare_lstm,\n",
1613 | " hparams.rounds, hparams.batch_size, hparams.lr,\n",
1614 | " None, # ds empty as it is included in data_dict\n",
1615 | " data_dict,\n",
1616 | " test_ds,\n",
1617 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n",
1618 | " title, \"green\",\n",
1619 | " hparams.target_test_accuracy,\n",
1620 | " corpus, # classes\n",
1621 | " history=test_history,\n",
1622 | " algorithm='fedavg',\n",
1623 | " # tb_logger=tb_writer\n",
1624 | " )\n",
1625 | "\n",
1626 | " \n",
1627 | "\n",
1628 | " final_scores = testing(lstm_non_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n",
1629 | " print(f'\\n\\n========================================================\\n\\n')\n",
1630 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n",
1631 | "\n",
1632 | " log = {\n",
1633 | " 'history': test_history,\n",
1634 | " 'hyperparams': hparams.__dict__\n",
1635 | " }\n",
1636 | "\n",
1637 | " with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:\n",
1638 | " pickle.dump(log, file)\n",
1639 | "\n",
1640 | " return test_history\n"
1641 | ],
1642 | "execution_count": null,
1643 | "outputs": []
1644 | },
1645 | {
1646 | "cell_type": "code",
1647 | "metadata": {
1648 | "id": "0pLbVBwVq6Uw"
1649 | },
1650 | "source": [
1651 | "exp_history = list()\n",
1652 | "for run_id in range(2): # TOTAL RUNS\n",
1653 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n",
1654 | " exp_history.append(run_experiment(run_id))\n",
1655 | " print(f'\\n\\n========================================================\\n\\n')"
1656 | ],
1657 | "execution_count": null,
1658 | "outputs": []
1659 | },
1660 | {
1661 | "cell_type": "code",
1662 | "metadata": {
1663 | "id": "n5F38z5C4qw9"
1664 | },
1665 | "source": [
1666 | "exp_log[title] = {\n",
1667 | " 'history': exp_history,\n",
1668 | " 'hyperparams': hparams.__dict__\n",
1669 | "}"
1670 | ],
1671 | "execution_count": null,
1672 | "outputs": []
1673 | },
1674 | {
1675 | "cell_type": "code",
1676 | "metadata": {
1677 | "id": "inIGn3Mh4qpO"
1678 | },
1679 | "source": [
1680 | "df = None\n",
1681 | "for i, e in enumerate(exp_history):\n",
1682 | " if i == 0:\n",
1683 | " df = pd.json_normalize(e)\n",
1684 | " continue\n",
1685 | " df = df + pd.json_normalize(e)\n",
1686 | " \n",
1687 | "df_avg = df / len(exp_history)\n",
1688 | "avg_history = df_avg.to_dict(orient='records')"
1689 | ],
1690 | "execution_count": null,
1691 | "outputs": []
1692 | },
1693 | {
1694 | "cell_type": "code",
1695 | "metadata": {
1696 | "id": "z8ngcls64qjc"
1697 | },
1698 | "source": [
1699 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')"
1700 | ],
1701 | "execution_count": null,
1702 | "outputs": []
1703 | },
1704 | {
1705 | "cell_type": "code",
1706 | "metadata": {
1707 | "id": "GR9vjtYs4qBX"
1708 | },
1709 | "source": [
1710 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')"
1711 | ],
1712 | "execution_count": null,
1713 | "outputs": []
1714 | },
1715 | {
1716 | "cell_type": "markdown",
1717 | "metadata": {
1718 | "id": "adK1OTS-40Z8"
1719 | },
1720 | "source": [
1721 | "### Pickle Experiment Results"
1722 | ]
1723 | },
1724 | {
1725 | "cell_type": "code",
1726 | "metadata": {
1727 | "id": "i5nl-hsa4zqw"
1728 | },
1729 | "source": [
1730 | "with open(f'{EXP_DIR}/results_niid.pkl', 'wb') as file:\n",
1731 | " pickle.dump(exp_log, file)"
1732 | ],
1733 | "execution_count": null,
1734 | "outputs": []
1735 | }
1736 | ]
1737 | }
--------------------------------------------------------------------------------
/Shakespeare/FedProx.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "FedProx.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [
9 | "TWDjzbCzfUFt",
10 | "cce_-qnxhD4n",
11 | "MFvc8mLoouKa",
12 | "sWVOxcAao2_t",
13 | "vFFAfTOwpk4j",
14 | "c640e4NnpksE",
15 | "VQ9PZM0Gp9ve",
16 | "3crFDN0xqGu6",
17 | "YXtGLkoAqLIW"
18 | ],
19 | "toc_visible": true,
20 | "authorship_tag": "ABX9TyOeIyifUjZfud7miIF289x8",
21 | "include_colab_link": true
22 | },
23 | "kernelspec": {
24 | "name": "python3",
25 | "display_name": "Python 3"
26 | },
27 | "language_info": {
28 | "name": "python"
29 | },
30 | "accelerator": "GPU"
31 | },
32 | "cells": [
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {
36 | "id": "view-in-github",
37 | "colab_type": "text"
38 | },
39 | "source": [
40 | "
"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {
46 | "id": "ZXnVpPCFfQkm"
47 | },
48 | "source": [
49 | "# FedPerf - Shakespeare + FedProx algorithm"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {
55 | "id": "TWDjzbCzfUFt"
56 | },
57 | "source": [
58 | "## Setup & Dependencies Installation"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "metadata": {
64 | "id": "z_f-rWudW0Fo"
65 | },
66 | "source": [
67 | "%%capture\n",
68 | "!pip install torchsummaryX unidecode"
69 | ],
70 | "execution_count": null,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "code",
75 | "metadata": {
76 | "id": "3TmMInb_fnw_"
77 | },
78 | "source": [
79 | "%load_ext tensorboard\n",
80 | "\n",
81 | "import copy\n",
82 | "from functools import reduce\n",
83 | "import json\n",
84 | "import matplotlib.pyplot as plt\n",
85 | "import numpy as np\n",
86 | "import os\n",
87 | "import pickle\n",
88 | "import pandas as pd\n",
89 | "import random\n",
90 | "from sklearn.model_selection import train_test_split\n",
91 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
92 | "import time\n",
93 | "import torch\n",
94 | "from torch.autograd import Variable\n",
95 | "import torch.nn as nn\n",
96 | "import torch.nn.functional as F\n",
97 | "from torch.utils.data import Dataset\n",
98 | "from torch.utils.data.dataloader import DataLoader\n",
99 | "from torch.utils.data.sampler import Sampler\n",
100 | "from torch.utils.tensorboard import SummaryWriter\n",
101 | "from torchsummary import summary\n",
102 | "from torchsummaryX import summary as summaryx\n",
103 | "from torchvision import transforms, utils, datasets\n",
104 | "from tqdm.notebook import tqdm\n",
105 | "from unidecode import unidecode\n",
106 | "\n",
107 | "%matplotlib inline\n",
108 | "\n",
109 | "# Check assigned GPU\n",
110 | "gpu_info = !nvidia-smi\n",
111 | "gpu_info = '\\n'.join(gpu_info)\n",
112 | "if gpu_info.find('failed') >= 0:\n",
113 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n",
114 | " print('and then re-execute this cell.')\n",
115 | "else:\n",
116 | " print(gpu_info)\n",
117 | "\n",
118 | "# set manual seed for reproducibility\n",
119 | "RANDOM_SEED = 42\n",
120 | "\n",
121 | "# general reproducibility\n",
122 | "random.seed(RANDOM_SEED)\n",
123 | "np.random.seed(RANDOM_SEED)\n",
124 | "torch.manual_seed(RANDOM_SEED)\n",
125 | "torch.cuda.manual_seed(RANDOM_SEED)\n",
126 | "\n",
127 | "# gpu training specific\n",
128 | "torch.backends.cudnn.deterministic = True\n",
129 | "torch.backends.cudnn.benchmark = False"
130 | ],
131 | "execution_count": null,
132 | "outputs": []
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {
137 | "id": "8Mfqv6uHfwh-"
138 | },
139 | "source": [
140 | "## Mount GDrive"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "metadata": {
146 | "id": "75qbJwxsGj-k"
147 | },
148 | "source": [
149 | "BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/FedProx'"
150 | ],
151 | "execution_count": null,
152 | "outputs": []
153 | },
154 | {
155 | "cell_type": "code",
156 | "metadata": {
157 | "id": "Ost-6lXSfveC"
158 | },
159 | "source": [
160 | "try:\n",
161 | " from google.colab import drive\n",
162 | " drive.mount('/content/drive')\n",
163 | " os.makedirs(BASE_DIR, exist_ok=True)\n",
164 | "except:\n",
165 | " print(\"WARNING: Results won't be stored on GDrive\")\n",
166 | " BASE_DIR = './'\n",
167 | "\n"
168 | ],
169 | "execution_count": null,
170 | "outputs": []
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "id": "Y-ZegBArf4xQ"
176 | },
177 | "source": [
178 | "## Loading Dataset"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "metadata": {
184 | "id": "hf03LRxof7Zj"
185 | },
186 | "source": [
187 | "!rm -Rf data\n",
188 | "!mkdir -p data scripts"
189 | ],
190 | "execution_count": null,
191 | "outputs": []
192 | },
193 | {
194 | "cell_type": "code",
195 | "metadata": {
196 | "id": "ngygA4-Fgobx"
197 | },
198 | "source": [
199 | "GENERATE_DATASET = False # If False, download the dataset provided by the q-FFL paper\n",
200 | "DATA_DIR = 'data/'\n",
201 | "# Dataset generation params\n",
202 | "SAMPLES_FRACTION = 1. # If using an already generated dataset\n",
203 | "# SAMPLES_FRACTION = 0.2 # Fraction of total samples in the dataset - FedProx default script\n",
204 | "# SAMPLES_FRACTION = 0.05 # Fraction of total samples in the dataset - qFFL\n",
205 | "TRAIN_FRACTION = 0.8 # Train set size\n",
206 | "MIN_SAMPLES = 0 # Min samples per client (for filtering purposes) - FedProx\n",
207 | "# MIN_SAMPLES = 64 # Min samples per client (for filtering purposes) - qFFL"
208 | ],
209 | "execution_count": null,
210 | "outputs": []
211 | },
212 | {
213 | "cell_type": "code",
214 | "metadata": {
215 | "id": "nUmwJgJygoYD"
216 | },
217 | "source": [
218 | "# Download raw dataset\n",
219 | "# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt\n",
220 | "!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt"
221 | ],
222 | "execution_count": null,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "code",
227 | "metadata": {
228 | "id": "4dCvx80BgoVr"
229 | },
230 | "source": [
231 | "if not GENERATE_DATASET:\n",
232 | " !rm -Rf data/train data/test\n",
233 | " !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD # Download Shakespeare dataset used by the FedProx paper\n",
234 | " !unzip shakespeare.zip\n",
235 | " !mv -f shakespeare_paper/train data/\n",
236 | " !mv -f shakespeare_paper/test data/\n",
237 | " !rm -R shakespeare_paper/ shakespeare.zip\n"
238 | ],
239 | "execution_count": null,
240 | "outputs": []
241 | },
242 | {
243 | "cell_type": "code",
244 | "metadata": {
245 | "id": "a4pzFvPvhQhq"
246 | },
247 | "source": [
248 | "corpus = []\n",
249 | "with open('data/shakespeare.txt', 'r') as f:\n",
250 | " data = list(unidecode(f.read()))\n",
251 | " corpus = list(set(list(data)))\n",
252 | "print('Corpus Length:', len(corpus))"
253 | ],
254 | "execution_count": null,
255 | "outputs": []
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {
260 | "id": "cce_-qnxhD4n"
261 | },
262 | "source": [
263 | "#### Dataset Preprocessing script"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "metadata": {
269 | "id": "Rt13M4IcgoTV"
270 | },
271 | "source": [
272 | "%%capture\n",
273 | "if GENERATE_DATASET:\n",
274 | " # Download dataset generation scripts\n",
275 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py\n",
276 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py\n",
277 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py\n",
278 | "\n",
279 | " # Download data preprocessing scripts\n",
280 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py\n",
281 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py"
282 | ],
283 | "execution_count": null,
284 | "outputs": []
285 | },
286 | {
287 | "cell_type": "code",
288 | "metadata": {
289 | "id": "EIEyRW27goPo"
290 | },
291 | "source": [
292 | "# Running scripts\n",
293 | "if GENERATE_DATASET:\n",
294 | " !mkdir -p data/raw_data data/all_data data/train data/test\n",
295 | " !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data\n",
296 | " !python scripts/gen_all_data.py"
297 | ],
298 | "execution_count": null,
299 | "outputs": []
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {
304 | "id": "mq8V6v_4hhhD"
305 | },
306 | "source": [
307 | "#### Dataset class"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "metadata": {
313 | "id": "H2SjEBKoWDxv"
314 | },
315 | "source": [
316 | "class ShakespeareDataset(Dataset):\n",
317 | " def __init__(self, x, y, corpus, seq_length):\n",
318 | " self.x = x\n",
319 | " self.y = y\n",
320 | " self.corpus = corpus\n",
321 | " self.corpus_size = len(self.corpus)\n",
322 | " super(ShakespeareDataset, self).__init__()\n",
323 | "\n",
324 | " def __len__(self):\n",
325 | " return len(self.x)\n",
326 | "\n",
327 | " def __repr__(self):\n",
328 | " return f'{self.__class__} - (length: {self.__len__()})'\n",
329 | "\n",
330 | " def __getitem__(self, i):\n",
331 | " input_seq = self.x[i]\n",
332 | " next_char = self.y[i]\n",
333 | " # print('\\tgetitem', i, input_seq, next_char)\n",
334 | " input_value = self.text2charindxs(input_seq)\n",
335 | " target_value = self.get_label_from_char(next_char)\n",
336 | " return input_value, target_value\n",
337 | "\n",
338 | " def text2charindxs(self, text):\n",
339 | " tensor = torch.zeros(len(text), dtype=torch.int32)\n",
340 | " for i, c in enumerate(text):\n",
341 | " tensor[i] = self.get_label_from_char(c)\n",
342 | " return tensor\n",
343 | "\n",
344 | " def get_label_from_char(self, c):\n",
345 | " return self.corpus.index(c)\n",
346 | "\n",
347 | " def get_char_from_label(self, l):\n",
348 | " return self.corpus[l]"
349 | ],
350 | "execution_count": null,
351 | "outputs": []
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {
356 | "id": "9fgJtS62lYAN"
357 | },
358 | "source": [
359 | "##### Federated Dataset"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "metadata": {
365 | "id": "5DqL5pTmgn5X"
366 | },
367 | "source": [
368 | "class ShakespeareFedDataset(ShakespeareDataset):\n",
369 | " def __init__(self, x, y, corpus, seq_length):\n",
370 | " super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)\n",
371 | "\n",
372 | " def dataloader(self, batch_size, shuffle=True):\n",
373 | " return DataLoader(self,\n",
374 | " batch_size=batch_size,\n",
375 | " shuffle=shuffle,\n",
376 | " num_workers=0)\n"
377 | ],
378 | "execution_count": null,
379 | "outputs": []
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {
384 | "id": "XelbyPsDlfgb"
385 | },
386 | "source": [
387 | "## Partitioning & Data Loaders"
388 | ]
389 | },
390 | {
391 | "cell_type": "markdown",
392 | "metadata": {
393 | "id": "IOBblyFGlwlU"
394 | },
395 | "source": [
396 | "### IID"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "metadata": {
402 | "id": "cSZFWKmsgn1p"
403 | },
404 | "source": [
405 | "def iid_partition_(dataset, clients):\n",
406 | " \"\"\"\n",
407 | " I.I.D paritioning of data over clients\n",
408 | " Shuffle the data\n",
409 | " Split it between clients\n",
410 | " \n",
411 | " params:\n",
412 | " - dataset (torch.utils.Dataset): Dataset\n",
413 | " - clients (int): Number of Clients to split the data between\n",
414 | "\n",
415 | " returns:\n",
416 | " - Dictionary of image indexes for each client\n",
417 | " \"\"\"\n",
418 | "\n",
419 | " num_items_per_client = int(len(dataset)/clients)\n",
420 | " client_dict = {}\n",
421 | " image_idxs = [i for i in range(len(dataset))]\n",
422 | "\n",
423 | " for i in range(clients):\n",
424 | " client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))\n",
425 | " image_idxs = list(set(image_idxs) - client_dict[i])\n",
426 | "\n",
427 | " return client_dict"
428 | ],
429 | "execution_count": null,
430 | "outputs": []
431 | },
432 | {
433 | "cell_type": "code",
434 | "metadata": {
435 | "id": "-lGwDyhSll9h"
436 | },
437 | "source": [
438 | "def iid_partition(corpus, seq_length=80, val_split=False):\n",
439 | "\n",
440 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n",
441 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n",
442 | "\n",
443 | " with open(train_file, 'r') as file:\n",
444 | " data_train = json.loads(unidecode(file.read()))\n",
445 | "\n",
446 | " with open(test_file, 'r') as file:\n",
447 | " data_test = json.loads(unidecode(file.read()))\n",
448 | "\n",
449 | " \n",
450 | " total_samples_train = sum(data_train['num_samples'])\n",
451 | "\n",
452 | " data_dict = {}\n",
453 | "\n",
454 | " x_train, y_train = [], []\n",
455 | " x_test, y_test = [], []\n",
456 | " # x_val, y_val = [], []\n",
457 | "\n",
458 | " users = list(zip(data_train['users'], data_train['num_samples']))\n",
459 | " # random.shuffle(users)\n",
460 | "\n",
461 | "\n",
462 | "\n",
463 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n",
464 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n",
465 | " sample_count = 0\n",
466 | " \n",
467 | " for i, (author_id, samples) in enumerate(users):\n",
468 | "\n",
469 | " if sample_count >= total_samples:\n",
470 | " print('Max samples reached', sample_count, '/', total_samples)\n",
471 | " break\n",
472 | "\n",
473 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n",
474 | " print('SKIP', author_id, samples)\n",
475 | " continue\n",
476 | " else:\n",
477 | " udata_train = data_train['user_data'][author_id]\n",
478 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n",
479 | " \n",
480 | " sample_count += max_samples\n",
481 | " # print('sample_count', sample_count)\n",
482 | "\n",
483 | " x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])\n",
484 | " y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])\n",
485 | "\n",
486 | " author_data = data_test['user_data'][author_id]\n",
487 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n",
488 | "\n",
489 | " if val_split:\n",
490 | " x_test.extend(author_data['x'][:int(test_size / 2)])\n",
491 | " y_test.extend(author_data['y'][:int(test_size / 2)])\n",
492 | " # x_val.extend(author_data['x'][int(test_size / 2):])\n",
493 | " # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])\n",
494 | "\n",
495 | " else:\n",
496 | " x_test.extend(author_data['x'][:int(test_size)])\n",
497 | " y_test.extend(author_data['y'][:int(test_size)])\n",
498 | "\n",
499 | " train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)\n",
500 | " test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)\n",
501 | " # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)\n",
502 | "\n",
503 | " data_dict = iid_partition_(train_ds, clients=len(users))\n",
504 | "\n",
505 | " return train_ds, data_dict, test_ds"
506 | ],
507 | "execution_count": null,
508 | "outputs": []
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {
513 | "id": "MFvc8mLoouKa"
514 | },
515 | "source": [
516 | "### Non-IID"
517 | ]
518 | },
519 | {
520 | "cell_type": "code",
521 | "metadata": {
522 | "id": "GZ76WsCZot9s"
523 | },
524 | "source": [
525 | "def noniid_partition(corpus, seq_length=80, val_split=False):\n",
526 | "\n",
527 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n",
528 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n",
529 | "\n",
530 | " with open(train_file, 'r') as file:\n",
531 | " data_train = json.loads(unidecode(file.read()))\n",
532 | "\n",
533 | " with open(test_file, 'r') as file:\n",
534 | " data_test = json.loads(unidecode(file.read()))\n",
535 | "\n",
536 | " \n",
537 | " total_samples_train = sum(data_train['num_samples'])\n",
538 | "\n",
539 | " data_dict = {}\n",
540 | "\n",
541 | " x_test, y_test = [], []\n",
542 | "\n",
543 | " users = list(zip(data_train['users'], data_train['num_samples']))\n",
544 | " # random.shuffle(users)\n",
545 | "\n",
546 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n",
547 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n",
548 | " sample_count = 0\n",
549 | " \n",
550 | " for i, (author_id, samples) in enumerate(users):\n",
551 | "\n",
552 | " if sample_count >= total_samples:\n",
553 | " print('Max samples reached', sample_count, '/', total_samples)\n",
554 | " break\n",
555 | "\n",
556 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n",
557 | " print('SKIP', author_id, samples)\n",
558 | " continue\n",
559 | " else:\n",
560 | " udata_train = data_train['user_data'][author_id]\n",
561 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n",
562 | " \n",
563 | " sample_count += max_samples\n",
564 | " # print('sample_count', sample_count)\n",
565 | "\n",
566 | " x_train = data_train['user_data'][author_id]['x'][:max_samples]\n",
567 | " y_train = data_train['user_data'][author_id]['y'][:max_samples]\n",
568 | "\n",
569 | " train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)\n",
570 | "\n",
571 | " x_val, y_val = None, None\n",
572 | " val_ds = None\n",
573 | " author_data = data_test['user_data'][author_id]\n",
574 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n",
575 | " if val_split:\n",
576 | " x_test += author_data['x'][:int(test_size / 2)]\n",
577 | " y_test += author_data['y'][:int(test_size / 2)]\n",
578 | " x_val = author_data['x'][int(test_size / 2):]\n",
579 | " y_val = author_data['y'][int(test_size / 2):int(test_size)]\n",
580 | "\n",
581 | " val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)\n",
582 | "\n",
583 | " else:\n",
584 | " x_test += author_data['x'][:int(test_size)]\n",
585 | " y_test += author_data['y'][:int(test_size)]\n",
586 | "\n",
587 | " data_dict[author_id] = {\n",
588 | " 'train_ds': train_ds,\n",
589 | " 'val_ds': val_ds\n",
590 | " }\n",
591 | "\n",
592 | " test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)\n",
593 | "\n",
594 | " return data_dict, test_ds"
595 | ],
596 | "execution_count": null,
597 | "outputs": []
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "metadata": {
602 | "id": "sWVOxcAao2_t"
603 | },
604 | "source": [
605 | "## Models"
606 | ]
607 | },
608 | {
609 | "cell_type": "markdown",
610 | "metadata": {
611 | "id": "gQQQ2mLeo6EA"
612 | },
613 | "source": [
614 | "### Shakespeare LSTM"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "metadata": {
620 | "id": "2mGXTrXRot7R"
621 | },
622 | "source": [
623 | "class ShakespeareLSTM(nn.Module):\n",
624 | " \"\"\"\n",
625 | " \"\"\"\n",
626 | "\n",
627 | " def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):\n",
628 | " super(ShakespeareLSTM, self).__init__()\n",
629 | " self.input_dim = input_dim\n",
630 | " self.embedding_dim = embedding_dim\n",
631 | " self.hidden_dim = hidden_dim\n",
632 | " self.classes = classes\n",
633 | " self.no_layers = lstm_layers\n",
634 | " \n",
635 | " self.embedding = nn.Embedding(num_embeddings=self.classes,\n",
636 | " embedding_dim=self.embedding_dim)\n",
637 | " self.lstm = nn.LSTM(input_size=self.embedding_dim, \n",
638 | " hidden_size=self.hidden_dim,\n",
639 | " num_layers=self.no_layers,\n",
640 | " batch_first=batch_first, \n",
641 | " dropout=dropout if self.no_layers > 1 else 0.)\n",
642 | " self.fc = nn.Linear(hidden_dim, self.classes)\n",
643 | "\n",
644 | " def forward(self, x, hc=None):\n",
645 | " batch_size = x.size(0)\n",
646 | " x_emb = self.embedding(x)\n",
647 | " out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)\n",
648 | " dense = self.fc(ht[-1])\n",
649 | " return dense\n",
650 | " \n",
651 | " def init_hidden(self, batch_size):\n",
652 | " return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),\n",
653 | " Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))\n"
654 | ],
655 | "execution_count": null,
656 | "outputs": []
657 | },
658 | {
659 | "cell_type": "markdown",
660 | "metadata": {
661 | "id": "5QsuJlVipMc8"
662 | },
663 | "source": [
664 | "#### Model Summary"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "metadata": {
670 | "id": "n_Vb0BYpot5I"
671 | },
672 | "source": [
673 | "batch_size = 10\n",
674 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n",
675 | "\n",
676 | "shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n",
677 | " embedding_dim=8, # mcmahan17a, fedprox, qFFL\n",
678 | " hidden_dim=256, # mcmahan17a, fedprox impl\n",
679 | " # hidden_dim=100, # fedprox paper\n",
680 | " classes=len(corpus),\n",
681 | " lstm_layers=2,\n",
682 | " dropout=0.1, # TODO:\n",
683 | " batch_first=True\n",
684 | " )\n",
685 | "\n",
686 | "if torch.cuda.is_available():\n",
687 | " shakespeare_lstm.cuda()\n",
688 | "\n",
689 | "\n",
690 | "\n",
691 | "hc = shakespeare_lstm.init_hidden(batch_size)\n",
692 | "\n",
693 | "x_sample = torch.zeros((batch_size, seq_length),\n",
694 | " dtype=torch.long,\n",
695 | " device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))\n",
696 | "\n",
697 | "x_sample[0][0] = 1\n",
698 | "x_sample\n",
699 | "\n",
700 | "print(\"\\nShakespeare LSTM SUMMARY\")\n",
701 | "print(summaryx(shakespeare_lstm, x_sample))"
702 | ],
703 | "execution_count": null,
704 | "outputs": []
705 | },
706 | {
707 | "cell_type": "markdown",
708 | "metadata": {
709 | "id": "qn7egnzTpeks"
710 | },
711 | "source": [
712 | "## FedProx Algorithm"
713 | ]
714 | },
715 | {
716 | "cell_type": "markdown",
717 | "metadata": {
718 | "id": "vFFAfTOwpk4j"
719 | },
720 | "source": [
721 | "### Plot Utils"
722 | ]
723 | },
724 | {
725 | "cell_type": "code",
726 | "metadata": {
727 | "id": "oyYjWa6IpnTY"
728 | },
729 | "source": [
730 | "from sklearn.metrics import f1_score"
731 | ],
732 | "execution_count": null,
733 | "outputs": []
734 | },
735 | {
736 | "cell_type": "code",
737 | "metadata": {
738 | "id": "367THsiTpo-C"
739 | },
740 | "source": [
741 | "def plot_scores(history, exp_id, title, suffix):\n",
742 | " accuracies = [x['accuracy'] for x in history]\n",
743 | " f1_macro = [x['f1_macro'] for x in history]\n",
744 | " f1_weighted = [x['f1_weighted'] for x in history]\n",
745 | "\n",
746 | " fig, ax = plt.subplots()\n",
747 | " ax.plot(accuracies, 'tab:orange')\n",
748 | " ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)\n",
749 | " ax.grid()\n",
750 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)\n",
751 | " plt.show()\n",
752 | "\n",
753 | " fig, ax = plt.subplots()\n",
754 | " ax.plot(f1_macro, 'tab:orange')\n",
755 | " ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)\n",
756 | " ax.grid()\n",
757 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')\n",
758 | " plt.show()\n",
759 | "\n",
760 | " fig, ax = plt.subplots()\n",
761 | " ax.plot(f1_weighted, 'tab:orange')\n",
762 | " ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)\n",
763 | " ax.grid()\n",
764 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')\n",
765 | " plt.show()\n",
766 | "\n",
767 | "\n",
768 | "def plot_losses(history, exp_id, title, suffix):\n",
769 | " val_losses = [x['loss'] for x in history]\n",
770 | " train_losses = [x['train_loss'] for x in history]\n",
771 | "\n",
772 | " fig, ax = plt.subplots()\n",
773 | " ax.plot(train_losses, 'tab:orange')\n",
774 | " ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)\n",
775 | " ax.grid()\n",
776 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')\n",
777 | " plt.show()\n",
778 | "\n",
779 | " fig, ax = plt.subplots()\n",
780 | " ax.plot(val_losses, 'tab:orange')\n",
781 | " ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)\n",
782 | " ax.grid()\n",
783 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')\n",
784 | " plt.show()\n"
785 | ],
786 | "execution_count": null,
787 | "outputs": []
788 | },
789 | {
790 | "cell_type": "markdown",
791 | "metadata": {
792 | "id": "c640e4NnpksE"
793 | },
794 | "source": [
795 | "### Systems Heterogeneity Simulations\n",
796 | "\n",
797 | "Generate epochs for selected clients based on percentage of devices that corresponds to heterogeneity. \n",
798 | "\n",
799 | "Assign x number of epochs (chosen unifirmly at random between [1, E]) to 0%, 50% or 90% of the selected devices, respectively. Settings where 0% devices perform fewer than E epochs of work correspond to the environments without system heterogeneity, while 90% of the devices sending their partial solutions corresponds to highly heterogenous system."
800 | ]
801 | },
802 | {
803 | "cell_type": "code",
804 | "metadata": {
805 | "id": "zuEZYnl5ot2m"
806 | },
807 | "source": [
808 | "def GenerateLocalEpochs(percentage, size, max_epochs):\n",
809 | " ''' Method generates list of epochs for selected clients\n",
810 | " to replicate system heteroggeneity\n",
811 | "\n",
812 | " Params:\n",
813 | " percentage: percentage of clients to have fewer than E epochs\n",
814 | " size: total size of the list\n",
815 | " max_epochs: maximum value for local epochs\n",
816 | " \n",
817 | " Returns:\n",
818 | " List of size epochs for each Client Update\n",
819 | "\n",
820 | " '''\n",
821 | "\n",
822 | " # if percentage is 0 then each client runs for E epochs\n",
823 | " if percentage == 0:\n",
824 | " return np.array([max_epochs]*size)\n",
825 | " else:\n",
826 | " # get the number of clients to have fewer than E epochs\n",
827 | " heterogenous_size = int((percentage/100) * size)\n",
828 | "\n",
829 | " # generate random uniform epochs of heterogenous size between 1 and E\n",
830 | " epoch_list = np.random.randint(1, max_epochs, heterogenous_size)\n",
831 | "\n",
832 | " # the rest of the clients will have E epochs\n",
833 | " remaining_size = size - heterogenous_size\n",
834 | " rem_list = [max_epochs]*remaining_size\n",
835 | "\n",
836 | " epoch_list = np.append(epoch_list, rem_list, axis=0)\n",
837 | " \n",
838 | " # shuffle the list and return\n",
839 | " np.random.shuffle(epoch_list)\n",
840 | "\n",
841 | " return epoch_list"
842 | ],
843 | "execution_count": null,
844 | "outputs": []
845 | },
846 | {
847 | "cell_type": "markdown",
848 | "metadata": {
849 | "id": "VQ9PZM0Gp9ve"
850 | },
851 | "source": [
852 | "### Local Training (Client Update)"
853 | ]
854 | },
855 | {
856 | "cell_type": "code",
857 | "metadata": {
858 | "id": "EDJFltwdotzZ"
859 | },
860 | "source": [
861 | "class CustomDataset(Dataset):\n",
862 | " def __init__(self, dataset, idxs):\n",
863 | " self.dataset = dataset\n",
864 | " self.idxs = list(idxs)\n",
865 | "\n",
866 | " def __len__(self):\n",
867 | " return len(self.idxs)\n",
868 | "\n",
869 | " def __getitem__(self, item):\n",
870 | " data, label = self.dataset[self.idxs[item]]\n",
871 | " return data, label"
872 | ],
873 | "execution_count": null,
874 | "outputs": []
875 | },
876 | {
877 | "cell_type": "code",
878 | "metadata": {
879 | "id": "HtRzU5Yepddq"
880 | },
881 | "source": [
882 | "class ClientUpdate(object):\n",
883 | " def __init__(self, dataset, batchSize, learning_rate, epochs, idxs, mu, algorithm):\n",
884 | " # self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batchSize, shuffle=True)\n",
885 | " if hasattr(dataset, 'dataloader'):\n",
886 | " self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)\n",
887 | " else:\n",
888 | " self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)\n",
889 | "\n",
890 | " self.algorithm = algorithm\n",
891 | " self.learning_rate = learning_rate\n",
892 | " self.epochs = epochs\n",
893 | " self.mu = mu\n",
894 | "\n",
895 | " def train(self, model):\n",
896 | " # print(\"Client training for {} epochs.\".format(self.epochs))\n",
897 | " criterion = nn.CrossEntropyLoss()\n",
898 | " proximal_criterion = nn.MSELoss(reduction='mean')\n",
899 | " optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)\n",
900 | "\n",
901 | " # use the weights of global model for proximal term calculation\n",
902 | " global_model = copy.deepcopy(model)\n",
903 | "\n",
904 | " # calculate local training time\n",
905 | " start_time = time.time()\n",
906 | "\n",
907 | "\n",
908 | " e_loss = []\n",
909 | " for epoch in range(1, self.epochs+1):\n",
910 | "\n",
911 | " train_loss = 0.0\n",
912 | "\n",
913 | " model.train()\n",
914 | " for data, labels in self.train_loader:\n",
915 | "\n",
916 | " if torch.cuda.is_available():\n",
917 | " data, labels = data.cuda(), labels.cuda()\n",
918 | "\n",
919 | " # clear the gradients\n",
920 | " optimizer.zero_grad()\n",
921 | " # make a forward pass\n",
922 | " output = model(data)\n",
923 | "\n",
924 | " # calculate the loss + the proximal term\n",
925 | " _, pred = torch.max(output, 1)\n",
926 | "\n",
927 | " if self.algorithm == 'fedprox':\n",
928 | " proximal_term = 0.0\n",
929 | "\n",
930 | " # iterate through the current and global model parameters\n",
931 | " for w, w_t in zip(model.parameters(), global_model.parameters()) :\n",
932 | " # update the proximal term \n",
933 | " #proximal_term += torch.sum(torch.abs((w-w_t)**2))\n",
934 | " proximal_term += (w-w_t).norm(2)\n",
935 | "\n",
936 | " loss = criterion(output, labels) + (self.mu/2)*proximal_term\n",
937 | " else:\n",
938 | " loss = criterion(output, labels)\n",
939 | " \n",
940 | " # do a backwards pass\n",
941 | " loss.backward()\n",
942 | " # perform a single optimization step\n",
943 | " optimizer.step()\n",
944 | " # update training loss\n",
945 | " train_loss += loss.item()*data.size(0)\n",
946 | "\n",
947 | " # average losses\n",
948 | " train_loss = train_loss/len(self.train_loader.dataset)\n",
949 | " e_loss.append(train_loss)\n",
950 | "\n",
951 | " total_loss = sum(e_loss)/len(e_loss)\n",
952 | "\n",
953 | " return model.state_dict(), total_loss, (time.time() - start_time)"
954 | ],
955 | "execution_count": null,
956 | "outputs": []
957 | },
958 | {
959 | "cell_type": "markdown",
960 | "metadata": {
961 | "id": "3crFDN0xqGu6"
962 | },
963 | "source": [
964 | "### Server Side Training"
965 | ]
966 | },
967 | {
968 | "cell_type": "code",
969 | "metadata": {
970 | "id": "c085xSOoqEHk"
971 | },
972 | "source": [
973 | "def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, mu, percentage, plt_title, plt_color, target_test_accuracy,\n",
974 | " classes, algorithm=\"fedprox\", history=[], eval_every=1, tb_logger=None):\n",
975 | " \"\"\"\n",
976 | " Function implements the Federated Averaging Algorithm from the FedAvg paper.\n",
977 | " Specifically, this function is used for the server side training and weight update\n",
978 | "\n",
979 | " Params:\n",
980 | " - model: PyTorch model to train\n",
981 | " - rounds: Number of communication rounds for the client update\n",
982 | " - batch_size: Batch size for client update training\n",
983 | " - lr: Learning rate used for client update training\n",
984 | " - ds: Dataset used for training\n",
985 | " - data_dict: Type of data partition used for training (IID or non-IID)\n",
986 | " - test_data_dict: Data used for testing the model\n",
987 | " - C: Fraction of clients randomly chosen to perform computation on each round\n",
988 | " - K: Total number of clients\n",
989 | " - E: Number of training passes each client makes over its local dataset per round\n",
990 | " - mu: proximal term constant\n",
991 | " - percentage: percentage of selected client to have fewer than E epochs\n",
992 | " Returns:\n",
993 | " - model: Trained model on the server\n",
994 | " \"\"\"\n",
995 | "\n",
996 | " start = time.time()\n",
997 | "\n",
998 | " # global model weights\n",
999 | " global_weights = model.state_dict()\n",
1000 | "\n",
1001 | " # training loss\n",
1002 | " train_loss = []\n",
1003 | "\n",
1004 | " # test accuracy\n",
1005 | " test_acc = []\n",
1006 | "\n",
1007 | " # store last loss for convergence\n",
1008 | " last_loss = 0.0\n",
1009 | "\n",
1010 | " # total time taken \n",
1011 | " total_time = 0\n",
1012 | "\n",
1013 | " print(f\"System heterogeneity set to {percentage}% stragglers.\\n\")\n",
1014 | " print(f\"Picking {max(int(C*K),1 )} random clients per round.\\n\")\n",
1015 | "\n",
1016 | " users_id = list(data_dict.keys())\n",
1017 | "\n",
1018 | " for curr_round in range(1, rounds+1):\n",
1019 | " w, local_loss, lst_local_train_time = [], [], []\n",
1020 | "\n",
1021 | " m = max(int(C*K), 1)\n",
1022 | "\n",
1023 | " heterogenous_epoch_list = GenerateLocalEpochs(percentage, size=m, max_epochs=E)\n",
1024 | " heterogenous_epoch_list = np.array(heterogenous_epoch_list)\n",
1025 | " # print('heterogenous_epoch_list', len(heterogenous_epoch_list))\n",
1026 | "\n",
1027 | " S_t = np.random.choice(range(K), m, replace=False)\n",
1028 | " S_t = np.array(S_t)\n",
1029 | " print('Clients: {}/{} -> {}'.format(len(S_t), K, S_t))\n",
1030 | " \n",
1031 | " # For Federated Averaging, drop all the clients that are stragglers\n",
1032 | " if algorithm == 'fedavg':\n",
1033 | " stragglers_indices = np.argwhere(heterogenous_epoch_list < E)\n",
1034 | " heterogenous_epoch_list = np.delete(heterogenous_epoch_list, stragglers_indices)\n",
1035 | " S_t = np.delete(S_t, stragglers_indices)\n",
1036 | "\n",
1037 | " # for _, (k, epoch) in tqdm(enumerate(zip(S_t, heterogenous_epoch_list))):\n",
1038 | " for i in tqdm(range(len(S_t))):\n",
1039 | " # print('k', k)\n",
1040 | " k = S_t[i]\n",
1041 | " epoch = heterogenous_epoch_list[i]\n",
1042 | " key = users_id[k]\n",
1043 | " ds_ = ds if ds else data_dict[key]['train_ds']\n",
1044 | " idxs = data_dict[key] if ds else None\n",
1045 | " # print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')\n",
1046 | " local_update = ClientUpdate(dataset=ds_, batchSize=batch_size, learning_rate=lr, epochs=epoch, idxs=idxs, mu=mu, algorithm=algorithm)\n",
1047 | " weights, loss, local_train_time = local_update.train(model=copy.deepcopy(model))\n",
1048 | " # print(f'Local train time for {k} on {len(idxs) if idxs else len(ds_)} samples: {local_train_time}')\n",
1049 | " # print(f'Local train time: {local_train_time}')\n",
1050 | "\n",
1051 | " w.append(copy.deepcopy(weights))\n",
1052 | " local_loss.append(copy.deepcopy(loss))\n",
1053 | " lst_local_train_time.append(local_train_time)\n",
1054 | "\n",
1055 | " # calculate time to update the global weights\n",
1056 | " global_start_time = time.time()\n",
1057 | "\n",
1058 | " # updating the global weights\n",
1059 | " weights_avg = copy.deepcopy(w[0])\n",
1060 | " for k in weights_avg.keys():\n",
1061 | " for i in range(1, len(w)):\n",
1062 | " weights_avg[k] += w[i][k]\n",
1063 | "\n",
1064 | " weights_avg[k] = torch.div(weights_avg[k], len(w))\n",
1065 | "\n",
1066 | " global_weights = weights_avg\n",
1067 | "\n",
1068 | " global_end_time = time.time()\n",
1069 | "\n",
1070 | " # calculate total time \n",
1071 | " total_time += (global_end_time - global_start_time) + sum(lst_local_train_time)/len(lst_local_train_time)\n",
1072 | "\n",
1073 | " # move the updated weights to our model state dict\n",
1074 | " model.load_state_dict(global_weights)\n",
1075 | "\n",
1076 | " # loss\n",
1077 | " loss_avg = sum(local_loss) / len(local_loss)\n",
1078 | " print('Round: {}... \\tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))\n",
1079 | " train_loss.append(loss_avg)\n",
1080 | " if tb_logger:\n",
1081 | " tb_logger.add_scalar(f'Train/Loss', loss_avg, curr_round)\n",
1082 | "\n",
1083 | " # testing\n",
1084 | " # if curr_round % eval_every == 0:\n",
1085 | " test_scores = testing(model, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(classes), classes)\n",
1086 | " test_scores['train_loss'] = loss_avg\n",
1087 | " test_loss, test_accuracy = test_scores['loss'], test_scores['accuracy']\n",
1088 | " history.append(test_scores)\n",
1089 | " \n",
1090 | " # print('Round: {}... \\tAverage Loss: {} \\tTest Loss: {} \\tTest Acc: {}'.format(curr_round, round(loss_avg, 3), round(test_loss, 3), round(test_accuracy, 3)))\n",
1091 | "\n",
1092 | " if tb_logger:\n",
1093 | " tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)\n",
1094 | " tb_logger.add_scalars(f'Test/Scores', {\n",
1095 | " 'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']\n",
1096 | " }, curr_round)\n",
1097 | "\n",
1098 | " test_acc.append(test_accuracy)\n",
1099 | " # break if we achieve the target test accuracy\n",
1100 | " if test_accuracy >= target_test_accuracy:\n",
1101 | " rounds = curr_round\n",
1102 | " break\n",
1103 | "\n",
1104 | " # break if we achieve convergence, i.e., loss between two consecutive rounds is <0.0001\n",
1105 | " if algorithm == 'fedprox' and abs(loss_avg - last_loss) < 1e-5:\n",
1106 | " rounds = curr_round\n",
1107 | " break\n",
1108 | " \n",
1109 | " # update the last loss\n",
1110 | " last_loss = loss_avg\n",
1111 | "\n",
1112 | " end = time.time()\n",
1113 | " \n",
1114 | " # plot train loss\n",
1115 | " fig, ax = plt.subplots()\n",
1116 | " x_axis = np.arange(1, rounds+1)\n",
1117 | " y_axis = np.array(train_loss)\n",
1118 | " ax.plot(x_axis, y_axis)\n",
1119 | "\n",
1120 | " ax.set(xlabel='Number of Rounds', ylabel='Train Loss', title=plt_title)\n",
1121 | " ax.grid()\n",
1122 | " # fig.savefig(plt_title+'.jpg', format='jpg')\n",
1123 | "\n",
1124 | " # plot test accuracy\n",
1125 | " fig1, ax1 = plt.subplots()\n",
1126 | " x_axis1 = np.arange(1, rounds+1)\n",
1127 | " y_axis1 = np.array(test_acc)\n",
1128 | " ax1.plot(x_axis1, y_axis1)\n",
1129 | "\n",
1130 | " ax1.set(xlabel='Number of Rounds', ylabel='Test Accuracy', title=plt_title)\n",
1131 | " ax1.grid()\n",
1132 | " # fig1.savefig(plt_title+'-test.jpg', format='jpg')\n",
1133 | " \n",
1134 | " print(\"Training Done! Total time taken to Train: {}\".format(end-start))\n",
1135 | "\n",
1136 | " return model, history"
1137 | ],
1138 | "execution_count": null,
1139 | "outputs": []
1140 | },
1141 | {
1142 | "cell_type": "markdown",
1143 | "metadata": {
1144 | "id": "YXtGLkoAqLIW"
1145 | },
1146 | "source": [
1147 | "### Testing Loop"
1148 | ]
1149 | },
1150 | {
1151 | "cell_type": "code",
1152 | "metadata": {
1153 | "id": "dQJIJno4qKvc"
1154 | },
1155 | "source": [
1156 | "def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):\n",
1157 | " #test loss \n",
1158 | " test_loss = 0.0\n",
1159 | " correct_class = list(0. for i in range(num_classes))\n",
1160 | " total_class = list(0. for i in range(num_classes))\n",
1161 | "\n",
1162 | " test_loader = DataLoader(dataset, batch_size=bs)\n",
1163 | " l = len(test_loader)\n",
1164 | " model.eval()\n",
1165 | " print('running validation...')\n",
1166 | " for i, (data, labels) in enumerate(tqdm(test_loader)):\n",
1167 | "\n",
1168 | " if torch.cuda.is_available():\n",
1169 | " data, labels = data.cuda(), labels.cuda()\n",
1170 | "\n",
1171 | " output = model(data)\n",
1172 | " loss = criterion(output, labels)\n",
1173 | " test_loss += loss.item()*data.size(0)\n",
1174 | "\n",
1175 | " _, pred = torch.max(output, 1)\n",
1176 | "\n",
1177 | " # For F1Score\n",
1178 | " y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()\n",
1179 | " y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()\n",
1180 | "\n",
1181 | " correct_tensor = pred.eq(labels.data.view_as(pred))\n",
1182 | " correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())\n",
1183 | "\n",
1184 | " #test accuracy for each object class\n",
1185 | " # for i in range(num_classes):\n",
1186 | " # label = labels.data[i]\n",
1187 | " # correct_class[label] += correct[i].item()\n",
1188 | " # total_class[label] += 1\n",
1189 | "\n",
1190 | " for i, lbl in enumerate(labels.data):\n",
1191 | " # print('lbl', i, lbl)\n",
1192 | " correct_class[lbl] += correct.data[i]\n",
1193 | " total_class[lbl] += 1\n",
1194 | " \n",
1195 | " # avg test loss\n",
1196 | " test_loss = test_loss/len(test_loader.dataset)\n",
1197 | " print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n",
1198 | "\n",
1199 | " # Avg F1 Score\n",
1200 | " f1_macro = f1_score(y_true, y_hat, average='macro')\n",
1201 | " # F1-Score -> weigthed to consider class imbalance\n",
1202 | " f1_weighted = f1_score(y_true, y_hat, average='weighted')\n",
1203 | " print(\"F1 Score: {:.6f} (macro) {:.6f} (weighted) %\\n\".format(f1_macro, f1_weighted))\n",
1204 | "\n",
1205 | " # print test accuracy\n",
1206 | " if print_all:\n",
1207 | " for i in range(num_classes):\n",
1208 | " if total_class[i]>0:\n",
1209 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % \n",
1210 | " (classes[i], 100 * correct_class[i] / total_class[i],\n",
1211 | " np.sum(correct_class[i]), np.sum(total_class[i])))\n",
1212 | " else:\n",
1213 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n",
1214 | "\n",
1215 | " overall_accuracy = np.sum(correct_class) / np.sum(total_class)\n",
1216 | "\n",
1217 | " print('\\nFinal Test Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))\n",
1218 | "\n",
1219 | " return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}"
1220 | ],
1221 | "execution_count": null,
1222 | "outputs": []
1223 | },
1224 | {
1225 | "cell_type": "markdown",
1226 | "metadata": {
1227 | "id": "uxqXLBd8qbC2"
1228 | },
1229 | "source": [
1230 | "## Experiments"
1231 | ]
1232 | },
1233 | {
1234 | "cell_type": "code",
1235 | "metadata": {
1236 | "id": "VRKlrkVHO8Na"
1237 | },
1238 | "source": [
1239 | "# FAIL-ON-PURPOSE"
1240 | ],
1241 | "execution_count": null,
1242 | "outputs": []
1243 | },
1244 | {
1245 | "cell_type": "code",
1246 | "metadata": {
1247 | "id": "E2CfSkNVqKtL"
1248 | },
1249 | "source": [
1250 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n",
1251 | "embedding_dim = 8 # mcmahan17a, fedprox, qFFL\n",
1252 | "# hidden_dim = 100 # fedprox paper\n",
1253 | "hidden_dim = 256 # mcmahan17a, fedprox impl\n",
1254 | "num_classes = len(corpus)\n",
1255 | "classes = list(range(num_classes))\n",
1256 | "lstm_layers = 2 # mcmahan17a, fedprox, qFFL\n",
1257 | "dropout = 0.1 # TODO\n"
1258 | ],
1259 | "execution_count": null,
1260 | "outputs": []
1261 | },
1262 | {
1263 | "cell_type": "code",
1264 | "metadata": {
1265 | "id": "OvCgT_qbmFAO"
1266 | },
1267 | "source": [
1268 | "class Hyperparameters():\n",
1269 | "\n",
1270 | " def __init__(self, total_clients):\n",
1271 | " # number of training rounds\n",
1272 | " self.rounds = 50\n",
1273 | " # client fraction\n",
1274 | " self.C = 0.5\n",
1275 | " # number of clients\n",
1276 | " self.K = total_clients\n",
1277 | " # number of training passes on local dataset for each roung\n",
1278 | " # self.E = 20\n",
1279 | " self.E = 1\n",
1280 | " # batch size\n",
1281 | " self.batch_size = 10\n",
1282 | " # learning Rate\n",
1283 | " self.lr = 0.8\n",
1284 | " # proximal term constant\n",
1285 | " # self.mu = 0.0\n",
1286 | " self.mu = 0.001\n",
1287 | " # percentage of clients to have fewer than E epochs\n",
1288 | " self.percentage = 0\n",
1289 | " # self.percentage = 50\n",
1290 | " # self.percentage = 90\n",
1291 | " # target test accuracy\n",
1292 | " self.target_test_accuracy= 99.0\n",
1293 | " # self.target_test_accuracy=96.0"
1294 | ],
1295 | "execution_count": null,
1296 | "outputs": []
1297 | },
1298 | {
1299 | "cell_type": "code",
1300 | "metadata": {
1301 | "id": "m_JVF83mfM3f"
1302 | },
1303 | "source": [
1304 | "exp_log = dict()"
1305 | ],
1306 | "execution_count": null,
1307 | "outputs": []
1308 | },
1309 | {
1310 | "cell_type": "markdown",
1311 | "metadata": {
1312 | "id": "rYOPtnYoqhWd"
1313 | },
1314 | "source": [
1315 | "### IID"
1316 | ]
1317 | },
1318 | {
1319 | "cell_type": "code",
1320 | "metadata": {
1321 | "id": "FRKc7NrzqKpU"
1322 | },
1323 | "source": [
1324 | "train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True) # Not using val_ds but makes train eval periods faster\n",
1325 | "\n",
1326 | "total_clients = len(data_dict.keys())\n",
1327 | "'Total users:', total_clients"
1328 | ],
1329 | "execution_count": null,
1330 | "outputs": []
1331 | },
1332 | {
1333 | "cell_type": "code",
1334 | "metadata": {
1335 | "id": "eaKtpKT5q1q_"
1336 | },
1337 | "source": [
1338 | "hparams = Hyperparameters(total_clients)\n",
1339 | "hparams.__dict__"
1340 | ],
1341 | "execution_count": null,
1342 | "outputs": []
1343 | },
1344 | {
1345 | "cell_type": "code",
1346 | "metadata": {
1347 | "id": "5ilYMClTV_WR"
1348 | },
1349 | "source": [
1350 | "# Sweeping parameter\n",
1351 | "PARAM_NAME = 'clients_fraction'\n",
1352 | "PARAM_VALUE = hparams.C\n",
1353 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n",
1354 | "exp_id"
1355 | ],
1356 | "execution_count": null,
1357 | "outputs": []
1358 | },
1359 | {
1360 | "cell_type": "code",
1361 | "metadata": {
1362 | "id": "xAhy4CWVZy3F"
1363 | },
1364 | "source": [
1365 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n",
1366 | "os.makedirs(EXP_DIR, exist_ok=True)\n",
1367 | "\n",
1368 | "# tb_logger = SummaryWriter(log_dir)\n",
1369 | "# print(f'TBoard logger created at: {log_dir}')\n",
1370 | "\n",
1371 | "title = 'LSTM FedProx on IID'"
1372 | ],
1373 | "execution_count": null,
1374 | "outputs": []
1375 | },
1376 | {
1377 | "cell_type": "code",
1378 | "metadata": {
1379 | "id": "LwTdeiv8q8_L"
1380 | },
1381 | "source": [
1382 | "def run_experiment(run_id):\n",
1383 | "\n",
1384 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n",
1385 | " embedding_dim=embedding_dim, \n",
1386 | " hidden_dim=hidden_dim,\n",
1387 | " classes=num_classes,\n",
1388 | " lstm_layers=lstm_layers,\n",
1389 | " dropout=dropout,\n",
1390 | " batch_first=True\n",
1391 | " )\n",
1392 | "\n",
1393 | " if torch.cuda.is_available():\n",
1394 | " shakespeare_lstm.cuda()\n",
1395 | " \n",
1396 | " test_history = []\n",
1397 | "\n",
1398 | " lstm_iid_trained, test_history = training(shakespeare_lstm,\n",
1399 | " hparams.rounds, hparams.batch_size, hparams.lr,\n",
1400 | " train_ds,\n",
1401 | " data_dict,\n",
1402 | " test_ds,\n",
1403 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n",
1404 | " title, \"green\",\n",
1405 | " hparams.target_test_accuracy,\n",
1406 | " corpus, # classes\n",
1407 | " history=test_history,\n",
1408 | " # tb_logger=tb_writer\n",
1409 | " )\n",
1410 | " \n",
1411 | "\n",
1412 | " final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n",
1413 | " print(f'\\n\\n========================================================\\n\\n')\n",
1414 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n",
1415 | "\n",
1416 | " log = {\n",
1417 | " 'history': test_history,\n",
1418 | " 'hyperparams': hparams.__dict__\n",
1419 | " }\n",
1420 | "\n",
1421 | " with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:\n",
1422 | " pickle.dump(log, file)\n",
1423 | "\n",
1424 | " return test_history\n"
1425 | ],
1426 | "execution_count": null,
1427 | "outputs": []
1428 | },
1429 | {
1430 | "cell_type": "code",
1431 | "metadata": {
1432 | "id": "gSU61KsSq87G"
1433 | },
1434 | "source": [
1435 | "exp_history = list()\n",
1436 | "for run_id in range(2): # TOTAL RUNS\n",
1437 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n",
1438 | " exp_history.append(run_experiment(run_id))\n",
1439 | " print(f'\\n\\n========================================================\\n\\n')"
1440 | ],
1441 | "execution_count": null,
1442 | "outputs": []
1443 | },
1444 | {
1445 | "cell_type": "code",
1446 | "metadata": {
1447 | "id": "us-HifGq3Uhf"
1448 | },
1449 | "source": [
1450 | "exp_log[title] = {\n",
1451 | " 'history': exp_history,\n",
1452 | " 'hyperparams': hparams.__dict__\n",
1453 | "}"
1454 | ],
1455 | "execution_count": null,
1456 | "outputs": []
1457 | },
1458 | {
1459 | "cell_type": "code",
1460 | "metadata": {
1461 | "id": "qDGpo4ug33dN"
1462 | },
1463 | "source": [
1464 | "df = None\n",
1465 | "for i, e in enumerate(exp_history):\n",
1466 | " if i == 0:\n",
1467 | " df = pd.json_normalize(e)\n",
1468 | " continue\n",
1469 | " df = df + pd.json_normalize(e)\n",
1470 | " \n",
1471 | "df_avg = df / len(exp_history)\n",
1472 | "avg_history = df_avg.to_dict(orient='records')"
1473 | ],
1474 | "execution_count": null,
1475 | "outputs": []
1476 | },
1477 | {
1478 | "cell_type": "code",
1479 | "metadata": {
1480 | "id": "Hf77BQAD36Eq"
1481 | },
1482 | "source": [
1483 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')"
1484 | ],
1485 | "execution_count": null,
1486 | "outputs": []
1487 | },
1488 | {
1489 | "cell_type": "code",
1490 | "metadata": {
1491 | "id": "wJClynRJ38Dh"
1492 | },
1493 | "source": [
1494 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')"
1495 | ],
1496 | "execution_count": null,
1497 | "outputs": []
1498 | },
1499 | {
1500 | "cell_type": "code",
1501 | "metadata": {
1502 | "id": "WKd0d7_a2e1C"
1503 | },
1504 | "source": [
1505 | "with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:\n",
1506 | " pickle.dump(exp_log, file)"
1507 | ],
1508 | "execution_count": null,
1509 | "outputs": []
1510 | },
1511 | {
1512 | "cell_type": "markdown",
1513 | "metadata": {
1514 | "id": "BaoYWkWgqvUQ"
1515 | },
1516 | "source": [
1517 | "### Non-IID"
1518 | ]
1519 | },
1520 | {
1521 | "cell_type": "code",
1522 | "metadata": {
1523 | "id": "epuk1epg2jX3"
1524 | },
1525 | "source": [
1526 | "exp_log = dict()"
1527 | ],
1528 | "execution_count": null,
1529 | "outputs": []
1530 | },
1531 | {
1532 | "cell_type": "code",
1533 | "metadata": {
1534 | "id": "pILgaho8qKgF"
1535 | },
1536 | "source": [
1537 | "data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)\n",
1538 | "\n",
1539 | "total_clients = len(data_dict.keys())\n",
1540 | "'Total users:', total_clients"
1541 | ],
1542 | "execution_count": null,
1543 | "outputs": []
1544 | },
1545 | {
1546 | "cell_type": "code",
1547 | "metadata": {
1548 | "id": "Y3o7qgBcqKX_"
1549 | },
1550 | "source": [
1551 | "hparams = Hyperparameters(total_clients)\n",
1552 | "hparams.__dict__"
1553 | ],
1554 | "execution_count": null,
1555 | "outputs": []
1556 | },
1557 | {
1558 | "cell_type": "code",
1559 | "metadata": {
1560 | "id": "VANr1h0Pq51N"
1561 | },
1562 | "source": [
1563 | "# Sweeping parameter\n",
1564 | "PARAM_NAME = 'clients_fraction'\n",
1565 | "PARAM_VALUE = hparams.C\n",
1566 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n",
1567 | "exp_id"
1568 | ],
1569 | "execution_count": null,
1570 | "outputs": []
1571 | },
1572 | {
1573 | "cell_type": "code",
1574 | "metadata": {
1575 | "id": "yXgYFIyZ4ipm"
1576 | },
1577 | "source": [
1578 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n",
1579 | "os.makedirs(EXP_DIR, exist_ok=True)\n",
1580 | "\n",
1581 | "# tb_logger = SummaryWriter(log_dir)\n",
1582 | "# print(f'TBoard logger created at: {log_dir}')\n",
1583 | "\n",
1584 | "title = 'LSTM FedProx on Non-IID'"
1585 | ],
1586 | "execution_count": null,
1587 | "outputs": []
1588 | },
1589 | {
1590 | "cell_type": "code",
1591 | "metadata": {
1592 | "id": "Vnv7UaE0q6dG"
1593 | },
1594 | "source": [
1595 | "def run_experiment(run_id):\n",
1596 | "\n",
1597 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,\n",
1598 | " embedding_dim=embedding_dim,\n",
1599 | " hidden_dim=hidden_dim,\n",
1600 | " classes=num_classes,\n",
1601 | " lstm_layers=lstm_layers,\n",
1602 | " dropout=dropout,\n",
1603 | " batch_first=True\n",
1604 | " )\n",
1605 | "\n",
1606 | " if torch.cuda.is_available():\n",
1607 | " shakespeare_lstm.cuda()\n",
1608 | "\n",
1609 | " test_history = []\n",
1610 | "\n",
1611 | " lstm_non_iid_trained, test_history = training(shakespeare_lstm,\n",
1612 | " hparams.rounds, hparams.batch_size, hparams.lr,\n",
1613 | " None, # ds empty as it is included in data_dict\n",
1614 | " data_dict,\n",
1615 | " test_ds,\n",
1616 | " hparams.C, hparams.K, hparams.E, hparams.mu, hparams.percentage,\n",
1617 | " title, \"green\",\n",
1618 | " hparams.target_test_accuracy,\n",
1619 | " corpus, # classes\n",
1620 | " history=test_history,\n",
1621 | " # tb_logger=tb_writer\n",
1622 | " )\n",
1623 | "\n",
1624 | " \n",
1625 | "\n",
1626 | " final_scores = testing(lstm_non_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n",
1627 | " print(f'\\n\\n========================================================\\n\\n')\n",
1628 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n",
1629 | "\n",
1630 | " log = {\n",
1631 | " 'history': test_history,\n",
1632 | " 'hyperparams': hparams.__dict__\n",
1633 | " }\n",
1634 | "\n",
1635 | " with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:\n",
1636 | " pickle.dump(log, file)\n",
1637 | "\n",
1638 | " return test_history\n"
1639 | ],
1640 | "execution_count": null,
1641 | "outputs": []
1642 | },
1643 | {
1644 | "cell_type": "code",
1645 | "metadata": {
1646 | "id": "0pLbVBwVq6Uw"
1647 | },
1648 | "source": [
1649 | "exp_history = list()\n",
1650 | "for run_id in range(2): # TOTAL RUNS\n",
1651 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n",
1652 | " exp_history.append(run_experiment(run_id))\n",
1653 | " print(f'\\n\\n========================================================\\n\\n')"
1654 | ],
1655 | "execution_count": null,
1656 | "outputs": []
1657 | },
1658 | {
1659 | "cell_type": "code",
1660 | "metadata": {
1661 | "id": "n5F38z5C4qw9"
1662 | },
1663 | "source": [
1664 | "exp_log[title] = {\n",
1665 | " 'history': exp_history,\n",
1666 | " 'hyperparams': hparams.__dict__\n",
1667 | "}"
1668 | ],
1669 | "execution_count": null,
1670 | "outputs": []
1671 | },
1672 | {
1673 | "cell_type": "code",
1674 | "metadata": {
1675 | "id": "inIGn3Mh4qpO"
1676 | },
1677 | "source": [
1678 | "df = None\n",
1679 | "for i, e in enumerate(exp_history):\n",
1680 | " if i == 0:\n",
1681 | " df = pd.json_normalize(e)\n",
1682 | " continue\n",
1683 | " df = df + pd.json_normalize(e)\n",
1684 | " \n",
1685 | "df_avg = df / len(exp_history)\n",
1686 | "avg_history = df_avg.to_dict(orient='records')"
1687 | ],
1688 | "execution_count": null,
1689 | "outputs": []
1690 | },
1691 | {
1692 | "cell_type": "code",
1693 | "metadata": {
1694 | "id": "z8ngcls64qjc"
1695 | },
1696 | "source": [
1697 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')"
1698 | ],
1699 | "execution_count": null,
1700 | "outputs": []
1701 | },
1702 | {
1703 | "cell_type": "code",
1704 | "metadata": {
1705 | "id": "GR9vjtYs4qBX"
1706 | },
1707 | "source": [
1708 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')"
1709 | ],
1710 | "execution_count": null,
1711 | "outputs": []
1712 | },
1713 | {
1714 | "cell_type": "markdown",
1715 | "metadata": {
1716 | "id": "adK1OTS-40Z8"
1717 | },
1718 | "source": [
1719 | "### Pickle Experiment Results"
1720 | ]
1721 | },
1722 | {
1723 | "cell_type": "code",
1724 | "metadata": {
1725 | "id": "i5nl-hsa4zqw"
1726 | },
1727 | "source": [
1728 | "with open(f'{EXP_DIR}/results.pkl', 'wb') as file:\n",
1729 | " pickle.dump(exp_log, file)"
1730 | ],
1731 | "execution_count": null,
1732 | "outputs": []
1733 | }
1734 | ]
1735 | }
--------------------------------------------------------------------------------
/Shakespeare/qFedAvg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "qFedAvg.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [
9 | "x5UTD8mqIDWl",
10 | "MAtkg7ME8PQn",
11 | "GpokHKXO8gmS",
12 | "391xbZuU8vEn",
13 | "rxAlW1YK9B_b",
14 | "JtAFxz5h9MRy",
15 | "QSDdUZjs9W-g",
16 | "rhN9isJD9Z8f",
17 | "w8OsHzAt9kDc",
18 | "p-8XOZlR9wbV",
19 | "0vQVNPmM93M7",
20 | "JrTGp6vd-DKa"
21 | ],
22 | "toc_visible": true,
23 | "authorship_tag": "ABX9TyMR+JaDud0PZxZhcegRNqzN",
24 | "include_colab_link": true
25 | },
26 | "kernelspec": {
27 | "name": "python3",
28 | "display_name": "Python 3"
29 | },
30 | "language_info": {
31 | "name": "python"
32 | },
33 | "accelerator": "GPU"
34 | },
35 | "cells": [
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {
39 | "id": "view-in-github",
40 | "colab_type": "text"
41 | },
42 | "source": [
43 | "
"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {
49 | "id": "4v_R7FXY7_fi"
50 | },
51 | "source": [
52 | "# FedPerf - Shakespeare + qFedAvg algorithm"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {
58 | "id": "x5UTD8mqIDWl"
59 | },
60 | "source": [
61 | "## Setup & Dependencies Installation"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "metadata": {
67 | "id": "1vTQbtlX731V"
68 | },
69 | "source": [
70 | "%%capture\n",
71 | "!pip install torchsummaryX unidecode"
72 | ],
73 | "execution_count": null,
74 | "outputs": []
75 | },
76 | {
77 | "cell_type": "code",
78 | "metadata": {
79 | "id": "FHlNfr2R8EvM"
80 | },
81 | "source": [
82 | "%load_ext tensorboard\n",
83 | "\n",
84 | "import copy\n",
85 | "from functools import reduce\n",
86 | "import json\n",
87 | "import matplotlib.pyplot as plt\n",
88 | "import numpy as np\n",
89 | "import os\n",
90 | "import pandas as pd\n",
91 | "import pickle\n",
92 | "import random\n",
93 | "from sklearn.model_selection import train_test_split\n",
94 | "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
95 | "import time\n",
96 | "import torch\n",
97 | "from torch.autograd import Variable\n",
98 | "import torch.nn as nn\n",
99 | "import torch.nn.functional as F\n",
100 | "from torch.utils.data import Dataset\n",
101 | "from torch.utils.data.dataloader import DataLoader\n",
102 | "from torch.utils.data.sampler import Sampler\n",
103 | "from torch.utils.tensorboard import SummaryWriter\n",
104 | "from torchsummary import summary\n",
105 | "from torchsummaryX import summary as summaryx\n",
106 | "from torchvision import transforms, utils, datasets\n",
107 | "from tqdm.notebook import tqdm\n",
108 | "from unidecode import unidecode\n",
109 | "\n",
110 | "%matplotlib inline\n",
111 | "\n",
112 | "# Check assigned GPU\n",
113 | "gpu_info = !nvidia-smi\n",
114 | "gpu_info = '\\n'.join(gpu_info)\n",
115 | "if gpu_info.find('failed') >= 0:\n",
116 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n",
117 | " print('and then re-execute this cell.')\n",
118 | "else:\n",
119 | " print(gpu_info)\n",
120 | "\n",
121 | "# set manual seed for reproducibility\n",
122 | "RANDOM_SEED = 42\n",
123 | "\n",
124 | "# general reproducibility\n",
125 | "random.seed(RANDOM_SEED)\n",
126 | "np.random.seed(RANDOM_SEED)\n",
127 | "torch.manual_seed(RANDOM_SEED)\n",
128 | "torch.cuda.manual_seed(RANDOM_SEED)\n",
129 | "\n",
130 | "# gpu training specific\n",
131 | "torch.backends.cudnn.deterministic = True\n",
132 | "torch.backends.cudnn.benchmark = False"
133 | ],
134 | "execution_count": null,
135 | "outputs": []
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {
140 | "id": "AvWSM6mh8LZr"
141 | },
142 | "source": [
143 | "## Mount GDrive"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "metadata": {
149 | "id": "oh9Lj2-lIIfY"
150 | },
151 | "source": [
152 | "BASE_DIR = '/content/drive/MyDrive/FedPerf/shakespeare/qFedAvg'"
153 | ],
154 | "execution_count": null,
155 | "outputs": []
156 | },
157 | {
158 | "cell_type": "code",
159 | "metadata": {
160 | "id": "518ez9Oz8LNE"
161 | },
162 | "source": [
163 | "try:\n",
164 | " from google.colab import drive\n",
165 | " drive.mount('/content/drive')\n",
166 | " os.makedirs(BASE_DIR, exist_ok=True)\n",
167 | "except:\n",
168 | " print(\"WARNING: Results won't be stored on GDrive\")\n",
169 | " BASE_DIR = './'\n",
170 | "\n"
171 | ],
172 | "execution_count": null,
173 | "outputs": []
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {
178 | "id": "MAtkg7ME8PQn"
179 | },
180 | "source": [
181 | "## Loading Dataset"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "metadata": {
187 | "id": "-6ZYQ2SB8LLo"
188 | },
189 | "source": [
190 | "!rm -Rf data\n",
191 | "!mkdir -p data scripts"
192 | ],
193 | "execution_count": null,
194 | "outputs": []
195 | },
196 | {
197 | "cell_type": "code",
198 | "metadata": {
199 | "id": "skkRzhu98LIx"
200 | },
201 | "source": [
202 | "GENERATE_DATASET = False # If False, download the dataset provided by the q-FFL paper\n",
203 | "DATA_DIR = 'data/'\n",
204 | "# Dataset generation params\n",
205 | "SAMPLES_FRACTION = 1. # If using an already generated dataset\n",
206 | "# SAMPLES_FRACTION = 0.2 # Fraction of total samples in the dataset - FedProx default script\n",
207 | "# SAMPLES_FRACTION = 0.05 # Fraction of total samples in the dataset - qFFL\n",
208 | "TRAIN_FRACTION = 0.8 # Train set size\n",
209 | "MIN_SAMPLES = 0 # Min samples per client (for filtering purposes) - FedProx\n",
210 | "# MIN_SAMPLES = 64 # Min samples per client (for filtering purposes) - qFFL"
211 | ],
212 | "execution_count": null,
213 | "outputs": []
214 | },
215 | {
216 | "cell_type": "code",
217 | "metadata": {
218 | "id": "c9nCIGlB8LGA"
219 | },
220 | "source": [
221 | "# Download raw dataset\n",
222 | "# !wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt -O data/shakespeare.txt\n",
223 | "!wget --adjust-extension http://www.gutenberg.org/files/100/100-0.txt -O data/shakespeare.txt"
224 | ],
225 | "execution_count": null,
226 | "outputs": []
227 | },
228 | {
229 | "cell_type": "code",
230 | "metadata": {
231 | "id": "jKQNm6cC8LDp"
232 | },
233 | "source": [
234 | "if not GENERATE_DATASET:\n",
235 | " !rm -Rf data/train data/test\n",
236 | " !gdown --id 1n46Mftp3_ahRi1Z6jYhEriyLtdRDS1tD # Download Shakespeare dataset used by the FedProx paper\n",
237 | " !unzip shakespeare.zip\n",
238 | " !mv -f shakespeare_paper/train data/\n",
239 | " !mv -f shakespeare_paper/test data/\n",
240 | " !rm -R shakespeare_paper/ shakespeare.zip\n"
241 | ],
242 | "execution_count": null,
243 | "outputs": []
244 | },
245 | {
246 | "cell_type": "code",
247 | "metadata": {
248 | "id": "5ewNwmpP8LBg"
249 | },
250 | "source": [
251 | "corpus = []\n",
252 | "with open('data/shakespeare.txt', 'r') as f:\n",
253 | " data = list(unidecode(f.read()))\n",
254 | " corpus = list(set(list(data)))\n",
255 | "print('Corpus Length:', len(corpus))"
256 | ],
257 | "execution_count": null,
258 | "outputs": []
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "GpokHKXO8gmS"
264 | },
265 | "source": [
266 | "#### Dataset Preprocessing script"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "metadata": {
272 | "id": "1YEeiGCT8K--"
273 | },
274 | "source": [
275 | "%%capture\n",
276 | "if GENERATE_DATASET:\n",
277 | " # Download dataset generation scripts\n",
278 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/preprocess_shakespeare.py -O scripts/preprocess_shakespeare.py\n",
279 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/shake_utils.py -O scripts/shake_utils.py\n",
280 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/data/shakespeare/preprocess/gen_all_data.py -O scripts/gen_all_data.py\n",
281 | "\n",
282 | " # Download data preprocessing scripts\n",
283 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/sample.py -O scripts/sample.py\n",
284 | " !wget https://raw.githubusercontent.com/ml-lab/FedProx/master/utils/remove_users.py -O scripts/remove_users.py"
285 | ],
286 | "execution_count": null,
287 | "outputs": []
288 | },
289 | {
290 | "cell_type": "code",
291 | "metadata": {
292 | "id": "KE4LT4z48K8u"
293 | },
294 | "source": [
295 | "# Running scripts\n",
296 | "if GENERATE_DATASET:\n",
297 | " !mkdir -p data/raw_data data/all_data data/train data/test\n",
298 | " !python scripts/preprocess_shakespeare.py data/shakespeare.txt data/raw_data\n",
299 | " !python scripts/gen_all_data.py"
300 | ],
301 | "execution_count": null,
302 | "outputs": []
303 | },
304 | {
305 | "cell_type": "markdown",
306 | "metadata": {
307 | "id": "391xbZuU8vEn"
308 | },
309 | "source": [
310 | "#### Dataset class"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "metadata": {
316 | "id": "5aS5pJenZ8Ow"
317 | },
318 | "source": [
319 | "class ShakespeareDataset(Dataset):\n",
320 | " def __init__(self, x, y, corpus, seq_length):\n",
321 | " self.x = x\n",
322 | " self.y = y\n",
323 | " self.corpus = corpus\n",
324 | " self.corpus_size = len(self.corpus)\n",
325 | " super(ShakespeareDataset, self).__init__()\n",
326 | "\n",
327 | " def __len__(self):\n",
328 | " return len(self.x)\n",
329 | "\n",
330 | " def __repr__(self):\n",
331 | " return f'{self.__class__} - (length: {self.__len__()})'\n",
332 | "\n",
333 | " def __getitem__(self, i):\n",
334 | " input_seq = self.x[i]\n",
335 | " next_char = self.y[i]\n",
336 | " # print('\\tgetitem', i, input_seq, next_char)\n",
337 | " input_value = self.text2charindxs(input_seq)\n",
338 | " target_value = self.get_label_from_char(next_char)\n",
339 | " return input_value, target_value\n",
340 | "\n",
341 | " def text2charindxs(self, text):\n",
342 | " tensor = torch.zeros(len(text), dtype=torch.int32)\n",
343 | " for i, c in enumerate(text):\n",
344 | " tensor[i] = self.get_label_from_char(c)\n",
345 | " return tensor\n",
346 | "\n",
347 | " def get_label_from_char(self, c):\n",
348 | " return self.corpus.index(c)\n",
349 | "\n",
350 | " def get_char_from_label(self, l):\n",
351 | " return self.corpus[l]"
352 | ],
353 | "execution_count": null,
354 | "outputs": []
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {
359 | "id": "0r68yAYr8xwJ"
360 | },
361 | "source": [
362 | "##### Federated Dataset"
363 | ]
364 | },
365 | {
366 | "cell_type": "code",
367 | "metadata": {
368 | "id": "CUTWxbt98K6W"
369 | },
370 | "source": [
371 | "class ShakespeareFedDataset(ShakespeareDataset):\n",
372 | " def __init__(self, x, y, corpus, seq_length):\n",
373 | " super(ShakespeareFedDataset, self).__init__(x, y, corpus, seq_length)\n",
374 | "\n",
375 | " def dataloader(self, batch_size, shuffle=True):\n",
376 | " return DataLoader(self,\n",
377 | " batch_size=batch_size,\n",
378 | " shuffle=shuffle,\n",
379 | " num_workers=0)\n"
380 | ],
381 | "execution_count": null,
382 | "outputs": []
383 | },
384 | {
385 | "cell_type": "markdown",
386 | "metadata": {
387 | "id": "awOBv7tN8_ec"
388 | },
389 | "source": [
390 | "## Partitioning & Data Loaders"
391 | ]
392 | },
393 | {
394 | "cell_type": "markdown",
395 | "metadata": {
396 | "id": "rxAlW1YK9B_b"
397 | },
398 | "source": [
399 | "### IID"
400 | ]
401 | },
402 | {
403 | "cell_type": "code",
404 | "metadata": {
405 | "id": "OqAvrJxm8K4D"
406 | },
407 | "source": [
408 | "def iid_partition_(dataset, clients):\n",
409 | " \"\"\"\n",
410 | " I.I.D paritioning of data over clients\n",
411 | " Shuffle the data\n",
412 | " Split it between clients\n",
413 | " \n",
414 | " params:\n",
415 | " - dataset (torch.utils.Dataset): Dataset containing the MNIST Images\n",
416 | " - clients (int): Number of Clients to split the data between\n",
417 | "\n",
418 | " returns:\n",
419 | " - Dictionary of image indexes for each client\n",
420 | " \"\"\"\n",
421 | "\n",
422 | " num_items_per_client = int(len(dataset)/clients)\n",
423 | " client_dict = {}\n",
424 | " image_idxs = [i for i in range(len(dataset))]\n",
425 | "\n",
426 | " for i in range(clients):\n",
427 | " client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False))\n",
428 | " image_idxs = list(set(image_idxs) - client_dict[i])\n",
429 | "\n",
430 | " return client_dict"
431 | ],
432 | "execution_count": null,
433 | "outputs": []
434 | },
435 | {
436 | "cell_type": "code",
437 | "metadata": {
438 | "id": "0T-NUxsq8K1g"
439 | },
440 | "source": [
441 | "def iid_partition(corpus, seq_length=80, val_split=False):\n",
442 | "\n",
443 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n",
444 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n",
445 | "\n",
446 | " with open(train_file, 'r') as file:\n",
447 | " data_train = json.loads(unidecode(file.read()))\n",
448 | "\n",
449 | " with open(test_file, 'r') as file:\n",
450 | " data_test = json.loads(unidecode(file.read()))\n",
451 | "\n",
452 | " \n",
453 | " total_samples_train = sum(data_train['num_samples'])\n",
454 | "\n",
455 | " data_dict = {}\n",
456 | "\n",
457 | " x_train, y_train = [], []\n",
458 | " x_test, y_test = [], []\n",
459 | " # x_val, y_val = [], []\n",
460 | "\n",
461 | " users = list(zip(data_train['users'], data_train['num_samples']))\n",
462 | " # random.shuffle(users)\n",
463 | "\n",
464 | "\n",
465 | "\n",
466 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n",
467 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n",
468 | " sample_count = 0\n",
469 | " \n",
470 | " for i, (author_id, samples) in enumerate(users):\n",
471 | "\n",
472 | " if sample_count >= total_samples:\n",
473 | " print('Max samples reached', sample_count, '/', total_samples)\n",
474 | " break\n",
475 | "\n",
476 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n",
477 | " print('SKIP', author_id, samples)\n",
478 | " continue\n",
479 | " else:\n",
480 | " udata_train = data_train['user_data'][author_id]\n",
481 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n",
482 | " \n",
483 | " sample_count += max_samples\n",
484 | " # print('sample_count', sample_count)\n",
485 | "\n",
486 | " x_train.extend(data_train['user_data'][author_id]['x'][:max_samples])\n",
487 | " y_train.extend(data_train['user_data'][author_id]['y'][:max_samples])\n",
488 | "\n",
489 | " author_data = data_test['user_data'][author_id]\n",
490 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n",
491 | "\n",
492 | " if val_split:\n",
493 | " x_test.extend(author_data['x'][:int(test_size / 2)])\n",
494 | " y_test.extend(author_data['y'][:int(test_size / 2)])\n",
495 | " # x_val.extend(author_data['x'][int(test_size / 2):])\n",
496 | " # y_val.extend(author_data['y'][int(test_size / 2):int(test_size)])\n",
497 | "\n",
498 | " else:\n",
499 | " x_test.extend(author_data['x'][:int(test_size)])\n",
500 | " y_test.extend(author_data['y'][:int(test_size)])\n",
501 | "\n",
502 | " train_ds = ShakespeareDataset(x_train, y_train, corpus, seq_length)\n",
503 | " test_ds = ShakespeareDataset(x_test, y_test, corpus, seq_length)\n",
504 | " # val_ds = ShakespeareDataset(x_val, y_val, corpus, seq_length)\n",
505 | "\n",
506 | " data_dict = iid_partition_(train_ds, clients=len(users))\n",
507 | "\n",
508 | " return train_ds, data_dict, test_ds"
509 | ],
510 | "execution_count": null,
511 | "outputs": []
512 | },
513 | {
514 | "cell_type": "markdown",
515 | "metadata": {
516 | "id": "JtAFxz5h9MRy"
517 | },
518 | "source": [
519 | "### Non-IID"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "metadata": {
525 | "id": "0w3yiODz8KzT"
526 | },
527 | "source": [
528 | "def noniid_partition(corpus, seq_length=80, val_split=False):\n",
529 | "\n",
530 | " train_file = [os.path.join(DATA_DIR, 'train', f) for f in os.listdir(f'{DATA_DIR}/train') if f.endswith('.json')][0]\n",
531 | " test_file = [os.path.join(DATA_DIR, 'test', f) for f in os.listdir(f'{DATA_DIR}/test') if f.endswith('.json')][0]\n",
532 | "\n",
533 | " with open(train_file, 'r') as file:\n",
534 | " data_train = json.loads(unidecode(file.read()))\n",
535 | "\n",
536 | " with open(test_file, 'r') as file:\n",
537 | " data_test = json.loads(unidecode(file.read()))\n",
538 | "\n",
539 | " \n",
540 | " total_samples_train = sum(data_train['num_samples'])\n",
541 | "\n",
542 | " data_dict = {}\n",
543 | "\n",
544 | " x_test, y_test = [], []\n",
545 | "\n",
546 | " users = list(zip(data_train['users'], data_train['num_samples']))\n",
547 | " # random.shuffle(users)\n",
548 | "\n",
549 | " total_samples = int(sum(data_train['num_samples']) * SAMPLES_FRACTION)\n",
550 | " print('Objective', total_samples, '/', sum(data_train['num_samples']))\n",
551 | " sample_count = 0\n",
552 | " \n",
553 | " for i, (author_id, samples) in enumerate(users):\n",
554 | "\n",
555 | " if sample_count >= total_samples:\n",
556 | " print('Max samples reached', sample_count, '/', total_samples)\n",
557 | " break\n",
558 | "\n",
559 | " if samples < MIN_SAMPLES: # or data_train['num_samples'][i] > 10000:\n",
560 | " print('SKIP', author_id, samples)\n",
561 | " continue\n",
562 | " else:\n",
563 | " udata_train = data_train['user_data'][author_id]\n",
564 | " max_samples = samples if (sample_count + samples) <= total_samples else (sample_count + samples - total_samples) \n",
565 | " \n",
566 | " sample_count += max_samples\n",
567 | " # print('sample_count', sample_count)\n",
568 | "\n",
569 | " x_train = data_train['user_data'][author_id]['x'][:max_samples]\n",
570 | " y_train = data_train['user_data'][author_id]['y'][:max_samples]\n",
571 | "\n",
572 | " train_ds = ShakespeareFedDataset(x_train, y_train, corpus, seq_length)\n",
573 | "\n",
574 | " x_val, y_val = None, None\n",
575 | " val_ds = None\n",
576 | " author_data = data_test['user_data'][author_id]\n",
577 | " test_size = int(len(author_data['x']) * SAMPLES_FRACTION)\n",
578 | " if val_split:\n",
579 | " x_test += author_data['x'][:int(test_size / 2)]\n",
580 | " y_test += author_data['y'][:int(test_size / 2)]\n",
581 | " x_val = author_data['x'][int(test_size / 2):]\n",
582 | " y_val = author_data['y'][int(test_size / 2):int(test_size)]\n",
583 | "\n",
584 | " val_ds = ShakespeareFedDataset(x_val, y_val, corpus, seq_length)\n",
585 | "\n",
586 | " else:\n",
587 | " x_test += author_data['x'][:int(test_size)]\n",
588 | " y_test += author_data['y'][:int(test_size)]\n",
589 | "\n",
590 | " data_dict[author_id] = {\n",
591 | " 'train_ds': train_ds,\n",
592 | " 'val_ds': val_ds\n",
593 | " }\n",
594 | "\n",
595 | " test_ds = ShakespeareFedDataset(x_test, y_test, corpus, seq_length)\n",
596 | "\n",
597 | " return data_dict, test_ds"
598 | ],
599 | "execution_count": null,
600 | "outputs": []
601 | },
602 | {
603 | "cell_type": "markdown",
604 | "metadata": {
605 | "id": "QSDdUZjs9W-g"
606 | },
607 | "source": [
608 | "## Models"
609 | ]
610 | },
611 | {
612 | "cell_type": "markdown",
613 | "metadata": {
614 | "id": "rhN9isJD9Z8f"
615 | },
616 | "source": [
617 | "### Shakespeare LSTM"
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "metadata": {
623 | "id": "GKXxN3HO8Kw3"
624 | },
625 | "source": [
626 | "class ShakespeareLSTM(nn.Module):\n",
627 | " \"\"\"\n",
628 | " \"\"\"\n",
629 | "\n",
630 | " def __init__(self, input_dim, embedding_dim, hidden_dim, classes, lstm_layers=2, dropout=0.1, batch_first=True):\n",
631 | " super(ShakespeareLSTM, self).__init__()\n",
632 | " self.input_dim = input_dim\n",
633 | " self.embedding_dim = embedding_dim\n",
634 | " self.hidden_dim = hidden_dim\n",
635 | " self.classes = classes\n",
636 | " self.no_layers = lstm_layers\n",
637 | " \n",
638 | " self.embedding = nn.Embedding(num_embeddings=self.classes,\n",
639 | " embedding_dim=self.embedding_dim)\n",
640 | " self.lstm = nn.LSTM(input_size=self.embedding_dim, \n",
641 | " hidden_size=self.hidden_dim,\n",
642 | " num_layers=self.no_layers,\n",
643 | " batch_first=batch_first, \n",
644 | " dropout=dropout if self.no_layers > 1 else 0.)\n",
645 | " self.fc = nn.Linear(hidden_dim, self.classes)\n",
646 | "\n",
647 | " def forward(self, x, hc=None):\n",
648 | " batch_size = x.size(0)\n",
649 | " x_emb = self.embedding(x)\n",
650 | " out, (ht, ct) = self.lstm(x_emb.view(batch_size, -1, self.embedding_dim), hc)\n",
651 | " dense = self.fc(ht[-1])\n",
652 | " return dense\n",
653 | " \n",
654 | " def init_hidden(self, batch_size):\n",
655 | " return (Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)),\n",
656 | " Variable(torch.zeros(self.no_layers, batch_size, self.hidden_dim)))\n"
657 | ],
658 | "execution_count": null,
659 | "outputs": []
660 | },
661 | {
662 | "cell_type": "markdown",
663 | "metadata": {
664 | "id": "lwrRH5yD9eZB"
665 | },
666 | "source": [
667 | "#### Model Summary"
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "metadata": {
673 | "id": "Zsl5z8CS8Kul"
674 | },
675 | "source": [
676 | "batch_size = 10\n",
677 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n",
678 | "\n",
679 | "shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n",
680 | " embedding_dim=8, # mcmahan17a, fedprox, qFFL\n",
681 | " hidden_dim=256, # mcmahan17a, fedprox impl\n",
682 | " # hidden_dim=100, # fedprox paper\n",
683 | " classes=len(corpus),\n",
684 | " lstm_layers=2,\n",
685 | " dropout=0.1, # TODO:\n",
686 | " batch_first=True\n",
687 | " )\n",
688 | "\n",
689 | "if torch.cuda.is_available():\n",
690 | " shakespeare_lstm.cuda()\n",
691 | "\n",
692 | "\n",
693 | "\n",
694 | "hc = shakespeare_lstm.init_hidden(batch_size)\n",
695 | "\n",
696 | "x_sample = torch.zeros((batch_size, seq_length),\n",
697 | " dtype=torch.long,\n",
698 | " device=(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')))\n",
699 | "\n",
700 | "x_sample[0][0] = 1\n",
701 | "x_sample\n",
702 | "\n",
703 | "print(\"\\nShakespeare LSTM SUMMARY\")\n",
704 | "print(summaryx(shakespeare_lstm, x_sample))"
705 | ],
706 | "execution_count": null,
707 | "outputs": []
708 | },
709 | {
710 | "cell_type": "markdown",
711 | "metadata": {
712 | "id": "bbJUc6X89jFw"
713 | },
714 | "source": [
715 | "## qFedAvg Algorithm"
716 | ]
717 | },
718 | {
719 | "cell_type": "markdown",
720 | "metadata": {
721 | "id": "w8OsHzAt9kDc"
722 | },
723 | "source": [
724 | "### Plot Utils"
725 | ]
726 | },
727 | {
728 | "cell_type": "code",
729 | "metadata": {
730 | "id": "jn1JNhaI8KsX"
731 | },
732 | "source": [
733 | "from sklearn.metrics import f1_score"
734 | ],
735 | "execution_count": null,
736 | "outputs": []
737 | },
738 | {
739 | "cell_type": "code",
740 | "metadata": {
741 | "id": "oyYu5s4d8KqH"
742 | },
743 | "source": [
744 | "def plot_scores(history, exp_id, title, suffix):\n",
745 | " accuracies = [x['accuracy'] for x in history]\n",
746 | " f1_macro = [x['f1_macro'] for x in history]\n",
747 | " f1_weighted = [x['f1_weighted'] for x in history]\n",
748 | "\n",
749 | " fig, ax = plt.subplots()\n",
750 | " ax.plot(accuracies, 'tab:orange')\n",
751 | " ax.set(xlabel='Rounds', ylabel='Test Accuracy', title=title)\n",
752 | " ax.grid()\n",
753 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Accuracy_{suffix}.jpg', format='jpg', dpi=300)\n",
754 | " plt.show()\n",
755 | "\n",
756 | " fig, ax = plt.subplots()\n",
757 | " ax.plot(f1_macro, 'tab:orange')\n",
758 | " ax.set(xlabel='Rounds', ylabel='Test F1 (macro)', title=title)\n",
759 | " ax.grid()\n",
760 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Macro_{suffix}.jpg', format='jpg')\n",
761 | " plt.show()\n",
762 | "\n",
763 | " fig, ax = plt.subplots()\n",
764 | " ax.plot(f1_weighted, 'tab:orange')\n",
765 | " ax.set(xlabel='Rounds', ylabel='Test F1 (weighted)', title=title)\n",
766 | " ax.grid()\n",
767 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_F1_Weighted_{suffix}.jpg', format='jpg')\n",
768 | " plt.show()\n",
769 | "\n",
770 | "\n",
771 | "def plot_losses(history, exp_id, title, suffix):\n",
772 | " val_losses = [x['loss'] for x in history]\n",
773 | " train_losses = [x['train_loss'] for x in history]\n",
774 | "\n",
775 | " fig, ax = plt.subplots()\n",
776 | " ax.plot(train_losses, 'tab:orange')\n",
777 | " ax.set(xlabel='Rounds', ylabel='Train Loss', title=title)\n",
778 | " ax.grid()\n",
779 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Train_Loss_{suffix}.jpg', format='jpg')\n",
780 | " plt.show()\n",
781 | "\n",
782 | " fig, ax = plt.subplots()\n",
783 | " ax.plot(val_losses, 'tab:orange')\n",
784 | " ax.set(xlabel='Rounds', ylabel='Test Loss', title=title)\n",
785 | " ax.grid()\n",
786 | " fig.savefig(f'{BASE_DIR}/{exp_id}/Test_Loss_{suffix}.jpg', format='jpg')\n",
787 | " plt.show()\n"
788 | ],
789 | "execution_count": null,
790 | "outputs": []
791 | },
792 | {
793 | "cell_type": "markdown",
794 | "metadata": {
795 | "id": "p-8XOZlR9wbV"
796 | },
797 | "source": [
798 | "### Local Training (Client Update)"
799 | ]
800 | },
801 | {
802 | "cell_type": "code",
803 | "metadata": {
804 | "id": "83Jn60Gt8KlY"
805 | },
806 | "source": [
807 | "class CustomDataset(Dataset):\n",
808 | " def __init__(self, dataset, idxs):\n",
809 | " self.dataset = dataset\n",
810 | " self.idxs = list(idxs)\n",
811 | "\n",
812 | " def __len__(self):\n",
813 | " return len(self.idxs)\n",
814 | "\n",
815 | " def __getitem__(self, item):\n",
816 | " data, label = self.dataset[self.idxs[item]]\n",
817 | " return data, label"
818 | ],
819 | "execution_count": null,
820 | "outputs": []
821 | },
822 | {
823 | "cell_type": "code",
824 | "metadata": {
825 | "id": "vOn7VSei8KjB"
826 | },
827 | "source": [
828 | "class ClientUpdate(object):\n",
829 | " def __init__(self, dataset, batch_size, learning_rate, epochs, idxs, q=None):\n",
830 | " \"\"\"\n",
831 | "\n",
832 | " \"\"\"\n",
833 | " if hasattr(dataset, 'dataloader'):\n",
834 | " self.train_loader = dataset.dataloader(batch_size=batch_size, shuffle=True)\n",
835 | " else:\n",
836 | " self.train_loader = DataLoader(CustomDataset(dataset, idxs), batch_size=batch_size, shuffle=True)\n",
837 | "\n",
838 | " self.learning_rate = learning_rate\n",
839 | " self.epochs = epochs\n",
840 | " self.q = q\n",
841 | " if not self.q:\n",
842 | " # TODO: Client itself adjust fairness \n",
843 | " pass\n",
844 | " self.mu = 1e-10\n",
845 | "\n",
846 | " def train(self, model):\n",
847 | "\n",
848 | " criterion = nn.CrossEntropyLoss()\n",
849 | " optimizer = torch.optim.SGD(model.parameters(), lr=self.learning_rate, momentum=0.5)\n",
850 | " # optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)\n",
851 | "\n",
852 | " e_loss = []\n",
853 | " model_weights = copy.deepcopy(model.state_dict())\n",
854 | " for epoch in range(1, self.epochs+1):\n",
855 | "\n",
856 | " train_loss = 0.0\n",
857 | "\n",
858 | " model.train()\n",
859 | " # for data, labels in tqdm(self.train_loader):\n",
860 | " for data, labels in self.train_loader:\n",
861 | "\n",
862 | " if torch.cuda.is_available():\n",
863 | " data, labels = data.cuda(), labels.cuda()\n",
864 | "\n",
865 | " # clear the gradients\n",
866 | " optimizer.zero_grad()\n",
867 | " # make a forward pass\n",
868 | " # print('input', data.size())\n",
869 | " output = model(data)\n",
870 | " # print('output', output.size())\n",
871 | " # print('labels', labels.size())\n",
872 | " # calculate the loss\n",
873 | " loss = criterion(output, labels)\n",
874 | " # do a backwards pass\n",
875 | " loss.backward()\n",
876 | " # perform a single optimization step\n",
877 | " optimizer.step()\n",
878 | " # update training loss\n",
879 | " train_loss += loss.item()*data.size(0)\n",
880 | "\n",
881 | " # average losses\n",
882 | " train_loss = train_loss/len(self.train_loader.dataset)\n",
883 | " e_loss.append(train_loss)\n",
884 | "\n",
885 | "\n",
886 | " total_loss = sum(e_loss)/len(e_loss)\n",
887 | "\n",
888 | " # delta weights\n",
889 | " model_weights_new = copy.deepcopy(model.state_dict())\n",
890 | " L = 1.0 / self.learning_rate\n",
891 | "\n",
892 | " delta_weights, delta, h = {}, {}, {}\n",
893 | " loss_q = np.float_power(total_loss + self.mu, self.q)\n",
894 | " # updating the global weights\n",
895 | " for k in model_weights_new.keys():\n",
896 | " delta_weights[k] = (model_weights[k] - model_weights_new[k]) * L\n",
897 | " delta[k] = loss_q * delta_weights[k]\n",
898 | " # Estimation of the local Lipchitz constant\n",
899 | " h[k] = (self.q * np.float_power(total_loss + self.mu, self.q - 1) * torch.pow(torch.norm(delta_weights[k]), 2)) + (L * loss_q)\n",
900 | "\n",
901 | " return delta, h, total_loss"
902 | ],
903 | "execution_count": null,
904 | "outputs": []
905 | },
906 | {
907 | "cell_type": "markdown",
908 | "metadata": {
909 | "id": "0vQVNPmM93M7"
910 | },
911 | "source": [
912 | "### Server Side Training"
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "metadata": {
918 | "id": "bhOfx0HSCX35"
919 | },
920 | "source": [
921 | "def client_sampling(n, m, weights=None, with_replace=False):\n",
922 | " pk = None\n",
923 | " if weights:\n",
924 | " total_weights = np.sum(np.asarray(weights))\n",
925 | " pk = [w * 1.0 / total_weights for w in weights]\n",
926 | "\n",
927 | " return np.random.choice(range(n), m, replace=with_replace, p=pk)"
928 | ],
929 | "execution_count": null,
930 | "outputs": []
931 | },
932 | {
933 | "cell_type": "code",
934 | "metadata": {
935 | "id": "MCJZnkY08Kgs"
936 | },
937 | "source": [
938 | "def training(model, rounds, batch_size, lr, ds, data_dict, test_ds, C, K, E, q=0, sampling='uniform', tb_logger=None, test_history=[], perf_fig_file='loss.jpg'):\n",
939 | " \"\"\"\n",
940 | " Function implements the Federated Averaging Algorithm from the FedAvg paper.\n",
941 | " Specifically, this function is used for the server side training and weight update\n",
942 | "\n",
943 | " Params:\n",
944 | " - model: PyTorch model to train\n",
945 | " - rounds: Number of communication rounds for the client update\n",
946 | " - batch_size: Batch size for client update training\n",
947 | " - lr: Learning rate used for client update training\n",
948 | " - ds: Dataset used for training\n",
949 | " - data_dict: Type of data partition used for training (IID or non-IID)\n",
950 | " - test_ds Dataset used for global testing\n",
951 | " - C: Fraction of clients randomly chosen to perform computation on each round\n",
952 | " - K: Total number of clients\n",
953 | " - E: Number of training passes each client makes over its local dataset per round\n",
954 | " - q: Parameter that tunes the amount of fairness we wish to impose (default: 0 -> vanilla FedAvg objective)\n",
955 | " - sampling Uniform or weighted (default: uniform)\n",
956 | " - tb_logger: Tensorboard SummaryWriter\n",
957 | " - test_history: Test Scores history log\n",
958 | " - perf_fig_file File for storing final performance plot (loss)\n",
959 | " Returns:\n",
960 | " - model: Trained model on the server\n",
961 | " \"\"\"\n",
962 | "\n",
963 | " # global model weights\n",
964 | " global_weights = model.state_dict()\n",
965 | "\n",
966 | " # training loss\n",
967 | " train_loss = []\n",
968 | "\n",
969 | " # client weights by total samples\n",
970 | " p_k = None\n",
971 | " if sampling == 'weighted':\n",
972 | " p_k = [len(data_dict[c]) for c in data_dict] if ds else [len(data_dict[c]['train_ds']) for c in data_dict]\n",
973 | "\n",
974 | " # Time log\n",
975 | " start_time = time.time()\n",
976 | "\n",
977 | " users_id = list(data_dict.keys())\n",
978 | "\n",
979 | " # Orchestrate training\n",
980 | " for curr_round in range(1, rounds+1):\n",
981 | " deltas, hs, local_loss = [], [], []\n",
982 | "\n",
983 | " m = max(int(C*K), 1) \n",
984 | " S_t = client_sampling(K, m, weights=p_k, with_replace=False)\n",
985 | "\n",
986 | " print('Round: {} Picking {}/{} clients: {}'.format(curr_round, m, K, S_t))\n",
987 | "\n",
988 | " global_weights = model.state_dict()\n",
989 | "\n",
990 | " for k in tqdm(S_t):\n",
991 | " key = users_id[k]\n",
992 | " ds_ = ds if ds else data_dict[key]['train_ds']\n",
993 | " idxs = data_dict[key] if ds else None\n",
994 | " # print(f'Client {k}: {len(idxs) if idxs else len(ds_)} samples')\n",
995 | " local_update = ClientUpdate(dataset=ds_, batch_size=batch_size, learning_rate=lr, epochs=E, idxs=idxs, q=q)\n",
996 | " # weights, loss = local_update.train(model=copy.deepcopy(model))\n",
997 | " delta_k, h_k, loss = local_update.train(model=copy.deepcopy(model))\n",
998 | "\n",
999 | " deltas.append(copy.deepcopy(delta_k))\n",
1000 | " hs.append(copy.deepcopy(h_k))\n",
1001 | " local_loss.append(copy.deepcopy(loss))\n",
1002 | "\n",
1003 | " if tb_logger:\n",
1004 | " tb_logger.add_scalar(f'Round/S{k}', loss, curr_round)\n",
1005 | "\n",
1006 | " # Perform qFedAvg\n",
1007 | " h_sum = copy.deepcopy(hs[0])\n",
1008 | " delta_sum = copy.deepcopy(deltas[0])\n",
1009 | " \n",
1010 | " for k in h_sum.keys():\n",
1011 | " for i in range(1, len(hs)):\n",
1012 | " h_sum[k] += hs[i][k]\n",
1013 | " delta_sum[k] += deltas[i][k]\n",
1014 | "\n",
1015 | " new_weights = {}\n",
1016 | " for k in delta_sum.keys():\n",
1017 | " for i in range(len(deltas)):\n",
1018 | " new_weights[k] = delta_sum[k] / h_sum[k]\n",
1019 | "\n",
1020 | " # Updating global model weights\n",
1021 | " for k in global_weights.keys():\n",
1022 | " global_weights[k] -= new_weights[k]\n",
1023 | "\n",
1024 | " # move the updated weights to our model state dict\n",
1025 | " model.load_state_dict(global_weights)\n",
1026 | "\n",
1027 | " # loss\n",
1028 | " loss_avg = sum(local_loss) / len(local_loss)\n",
1029 | " print('Round: {}... \\tAverage Loss: {}'.format(curr_round, round(loss_avg, 3)))\n",
1030 | " train_loss.append(loss_avg)\n",
1031 | "\n",
1032 | " if tb_logger:\n",
1033 | " tb_logger.add_scalar('Train/Loss', loss_avg, curr_round)\n",
1034 | " # tb_logger.add_scalar(f'Train/Datapoints', total_datapoints, curr_round)\n",
1035 | " \n",
1036 | " # if curr_round % eval_every == 0:\n",
1037 | " test_scores = testing(model, test_ds, batch_size, nn.CrossEntropyLoss(), num_classes, list(range(num_classes)))\n",
1038 | " test_scores['train_loss'] = loss_avg\n",
1039 | " test_history.append(test_scores)\n",
1040 | " if tb_logger:\n",
1041 | " tb_logger.add_scalar(f'Test/Loss', test_scores['loss'], curr_round)\n",
1042 | " tb_logger.add_scalars(f'Test/Scores', {\n",
1043 | " 'accuracy': test_scores['accuracy'], 'f1_macro': test_scores['f1_macro'], 'f1_weighted': test_scores['f1_weighted']\n",
1044 | " }, curr_round)\n",
1045 | " \n",
1046 | " end_time = time.time()\n",
1047 | " \n",
1048 | " fig, ax = plt.subplots()\n",
1049 | " x_axis = np.arange(1, rounds+1)\n",
1050 | " y_axis = np.array(train_loss)\n",
1051 | " ax.plot(x_axis, y_axis, 'tab:orange')\n",
1052 | "\n",
1053 | " ax.set(xlabel='# Rounds', ylabel='Train Loss',\n",
1054 | " title=\"Model's Performance with q: {}\".format(q))\n",
1055 | " ax.grid()\n",
1056 | " #fig.savefig(perf_fig_file, format='jpg')\n",
1057 | "\n",
1058 | " print(\"Training Done! Total time: {}\".format(round(end_time - start_time, 3)))\n",
1059 | " return model, test_history"
1060 | ],
1061 | "execution_count": null,
1062 | "outputs": []
1063 | },
1064 | {
1065 | "cell_type": "markdown",
1066 | "metadata": {
1067 | "id": "JrTGp6vd-DKa"
1068 | },
1069 | "source": [
1070 | "### Testing Loop"
1071 | ]
1072 | },
1073 | {
1074 | "cell_type": "code",
1075 | "metadata": {
1076 | "id": "yEnPyGYb8KeO"
1077 | },
1078 | "source": [
1079 | "def testing(model, dataset, bs, criterion, num_classes, classes, print_all=False):\n",
1080 | " #test loss \n",
1081 | " test_loss = 0.0\n",
1082 | " y_true, y_hat = None, None\n",
1083 | "\n",
1084 | " correct_class = list(0 for i in range(num_classes))\n",
1085 | " total_class = list(0 for i in range(num_classes))\n",
1086 | "\n",
1087 | " if hasattr(dataset, 'dataloader'):\n",
1088 | " test_loader = dataset.dataloader(batch_size=bs, shuffle=False)\n",
1089 | " else:\n",
1090 | " test_loader = DataLoader(dataset, batch_size=bs, shuffle=False)\n",
1091 | "\n",
1092 | " l = len(test_loader)\n",
1093 | "\n",
1094 | " model.eval()\n",
1095 | " for i, (data, labels) in enumerate(tqdm(test_loader)):\n",
1096 | "\n",
1097 | " if torch.cuda.is_available():\n",
1098 | " data, labels = data.cuda(), labels.cuda()\n",
1099 | "\n",
1100 | " output = model(data)\n",
1101 | " loss = criterion(output, labels)\n",
1102 | " test_loss += loss.item()*data.size(0)\n",
1103 | "\n",
1104 | " _, pred = torch.max(output, dim=1)\n",
1105 | "\n",
1106 | " # For F1Score\n",
1107 | " y_true = np.append(y_true, labels.data.view_as(pred).cpu().numpy()) if i != 0 else labels.data.view_as(pred).cpu().numpy()\n",
1108 | " y_hat = np.append(y_hat, pred.cpu().numpy()) if i != 0 else pred.cpu().numpy()\n",
1109 | "\n",
1110 | " correct_tensor = pred.eq(labels.data.view_as(pred))\n",
1111 | " correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())\n",
1112 | "\n",
1113 | " #test accuracy for each object class\n",
1114 | " # for i in range(num_classes):\n",
1115 | " # label = labels.data[i]\n",
1116 | " # correct_class[label] += correct[i].item()\n",
1117 | " # total_class[label] += 1\n",
1118 | "\n",
1119 | " for i, lbl in enumerate(labels.data):\n",
1120 | " try:\n",
1121 | " # print(type(lbl))\n",
1122 | " # correct_class[lbl.data[0]] += correct.data[i]\n",
1123 | " correct_class[lbl.item()] += correct[i]\n",
1124 | " total_class[lbl.item()] += 1\n",
1125 | " except:\n",
1126 | " print('Error', lbl, i)\n",
1127 | " \n",
1128 | " # avg test loss\n",
1129 | " test_loss = test_loss/len(test_loader.dataset)\n",
1130 | " print(\"Test Loss: {:.6f}\\n\".format(test_loss))\n",
1131 | "\n",
1132 | " # Avg F1 Score\n",
1133 | " f1_macro = f1_score(y_true, y_hat, average='macro')\n",
1134 | " # F1-Score -> weigthed to consider class imbalance\n",
1135 | " f1_weighted = f1_score(y_true, y_hat, average='weighted')\n",
1136 | " print(\"F1 Score: {:.6f} (macro) {:.6f} (weighted) %\\n\".format(f1_macro, f1_weighted))\n",
1137 | "\n",
1138 | " # print test accuracy\n",
1139 | " if print_all:\n",
1140 | " for i in range(num_classes):\n",
1141 | " if total_class[i] > 0:\n",
1142 | " print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % \n",
1143 | " (classes[i], 100 * correct_class[i] / total_class[i],\n",
1144 | " np.sum(correct_class[i]), np.sum(total_class[i])))\n",
1145 | " else:\n",
1146 | " print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))\n",
1147 | "\n",
1148 | " overall_accuracy = np.sum(correct_class) / np.sum(total_class)\n",
1149 | "\n",
1150 | " print('\\nFinal Test Accuracy: {:.3f} ({}/{})'.format(overall_accuracy, np.sum(correct_class), np.sum(total_class)))\n",
1151 | "\n",
1152 | " return {'loss': test_loss, 'accuracy': overall_accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted}"
1153 | ],
1154 | "execution_count": null,
1155 | "outputs": []
1156 | },
1157 | {
1158 | "cell_type": "markdown",
1159 | "metadata": {
1160 | "id": "Evpa7UZO-Hii"
1161 | },
1162 | "source": [
1163 | "## Experiments"
1164 | ]
1165 | },
1166 | {
1167 | "cell_type": "code",
1168 | "metadata": {
1169 | "id": "HhZ2iUQ2am-_"
1170 | },
1171 | "source": [
1172 | "# FAIL-ON-PURPOSE"
1173 | ],
1174 | "execution_count": null,
1175 | "outputs": []
1176 | },
1177 | {
1178 | "cell_type": "code",
1179 | "metadata": {
1180 | "id": "QXNvhZAw8Kbx"
1181 | },
1182 | "source": [
1183 | "seq_length = 80 # mcmahan17a, fedprox, qFFL\n",
1184 | "embedding_dim = 8 # mcmahan17a, fedprox, qFFL\n",
1185 | "# hidden_dim = 100 # fedprox paper\n",
1186 | "hidden_dim = 256 # mcmahan17a, fedprox impl\n",
1187 | "num_classes = len(corpus)\n",
1188 | "classes = list(range(num_classes))\n",
1189 | "lstm_layers = 2 # mcmahan17a, fedprox, qFFL\n",
1190 | "dropout = 0.1 # TODO"
1191 | ],
1192 | "execution_count": null,
1193 | "outputs": []
1194 | },
1195 | {
1196 | "cell_type": "code",
1197 | "metadata": {
1198 | "id": "IadhhL4knbOc"
1199 | },
1200 | "source": [
1201 | "class Hyperparameters():\n",
1202 | "\n",
1203 | " def __init__(self, total_clients):\n",
1204 | " # number of training rounds\n",
1205 | " self.rounds = 50\n",
1206 | " # client fraction\n",
1207 | " self.C = 0.5\n",
1208 | " # number of clients\n",
1209 | " self.K = total_clients\n",
1210 | " # number of training passes on local dataset for each roung\n",
1211 | " self.E = 1 # qFFL\n",
1212 | " # batch size\n",
1213 | " self.batch_size = 10 # fedprox\n",
1214 | " # learning Rate\n",
1215 | " self.lr = 0.8 # fedprox, qFFL\n",
1216 | " # fairness\n",
1217 | " self.q = 0.001 # qFFL\n",
1218 | " # sampling\n",
1219 | " # self.sampling = 'uniform'\n",
1220 | " self.sampling = 'weighted'"
1221 | ],
1222 | "execution_count": null,
1223 | "outputs": []
1224 | },
1225 | {
1226 | "cell_type": "code",
1227 | "metadata": {
1228 | "id": "VtSTy1GencJ4"
1229 | },
1230 | "source": [
1231 | "exp_log = dict()"
1232 | ],
1233 | "execution_count": null,
1234 | "outputs": []
1235 | },
1236 | {
1237 | "cell_type": "markdown",
1238 | "metadata": {
1239 | "id": "hGoeL-HpDv7L"
1240 | },
1241 | "source": [
1242 | "### IID"
1243 | ]
1244 | },
1245 | {
1246 | "cell_type": "code",
1247 | "metadata": {
1248 | "id": "61pZ-UY6D8nB"
1249 | },
1250 | "source": [
1251 | "train_ds, data_dict, test_ds = iid_partition(corpus, seq_length, val_split=True)\n",
1252 | "\n",
1253 | "total_clients = len(data_dict.keys())\n",
1254 | "'Total users:', total_clients"
1255 | ],
1256 | "execution_count": null,
1257 | "outputs": []
1258 | },
1259 | {
1260 | "cell_type": "code",
1261 | "metadata": {
1262 | "id": "nUmhoN-zGZp_"
1263 | },
1264 | "source": [
1265 | "hparams = Hyperparameters(total_clients)\n",
1266 | "hparams.__dict__"
1267 | ],
1268 | "execution_count": null,
1269 | "outputs": []
1270 | },
1271 | {
1272 | "cell_type": "code",
1273 | "metadata": {
1274 | "id": "XfNgps35XnxY"
1275 | },
1276 | "source": [
1277 | "# Sweeping parameter\n",
1278 | "PARAM_NAME = 'clients_fraction'\n",
1279 | "PARAM_VALUE = hparams.C\n",
1280 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n",
1281 | "exp_id"
1282 | ],
1283 | "execution_count": null,
1284 | "outputs": []
1285 | },
1286 | {
1287 | "cell_type": "code",
1288 | "metadata": {
1289 | "id": "PmXNlFlYcl2H"
1290 | },
1291 | "source": [
1292 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n",
1293 | "os.makedirs(EXP_DIR, exist_ok=True)\n",
1294 | "\n",
1295 | "# tb_logger = SummaryWriter(log_dir)\n",
1296 | "# print(f'TBoard logger created at: {log_dir}')\n",
1297 | "\n",
1298 | "title = 'LSTM qFedAvg on IID'"
1299 | ],
1300 | "execution_count": null,
1301 | "outputs": []
1302 | },
1303 | {
1304 | "cell_type": "code",
1305 | "metadata": {
1306 | "id": "ymcomht5GdBO"
1307 | },
1308 | "source": [
1309 | "def run_experiment(run_id):\n",
1310 | "\n",
1311 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length, \n",
1312 | " embedding_dim=embedding_dim, \n",
1313 | " hidden_dim=hidden_dim,\n",
1314 | " classes=num_classes,\n",
1315 | " lstm_layers=lstm_layers,\n",
1316 | " dropout=dropout,\n",
1317 | " batch_first=True\n",
1318 | " )\n",
1319 | "\n",
1320 | " if torch.cuda.is_available():\n",
1321 | " shakespeare_lstm.cuda()\n",
1322 | " \n",
1323 | " test_history = []\n",
1324 | "\n",
1325 | " lstm_iid_trained, test_history = training(shakespeare_lstm,\n",
1326 | " hparams.rounds, hparams.batch_size, hparams.lr,\n",
1327 | " train_ds, data_dict, test_ds,\n",
1328 | " hparams.C, hparams.K, hparams.E, hparams.q,\n",
1329 | " sampling=hparams.sampling,\n",
1330 | " test_history=test_history,\n",
1331 | " # tb_logger=tb_logger,\n",
1332 | " # perf_fig_file=f'{BASE_DIR}/loss.jpg'\n",
1333 | " )\n",
1334 | " \n",
1335 | " final_scores = testing(lstm_iid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n",
1336 | " print(f'\\n\\n========================================================\\n\\n')\n",
1337 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n",
1338 | "\n",
1339 | " log = {\n",
1340 | " 'history': test_history,\n",
1341 | " 'hyperparams': hparams.__dict__\n",
1342 | " }\n",
1343 | "\n",
1344 | " with open(f'{EXP_DIR}/results_iid_{run_id}.pkl', 'wb') as file:\n",
1345 | " pickle.dump(log, file)\n",
1346 | "\n",
1347 | " return test_history\n",
1348 | " "
1349 | ],
1350 | "execution_count": null,
1351 | "outputs": []
1352 | },
1353 | {
1354 | "cell_type": "code",
1355 | "metadata": {
1356 | "id": "xhPK_WL9GZhV"
1357 | },
1358 | "source": [
1359 | "exp_history = list()\n",
1360 | "for run_id in range(2): # TOTAL RUNS\n",
1361 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n",
1362 | " exp_history.append(run_experiment(run_id))\n",
1363 | " print(f'\\n\\n========================================================\\n\\n')"
1364 | ],
1365 | "execution_count": null,
1366 | "outputs": []
1367 | },
1368 | {
1369 | "cell_type": "code",
1370 | "metadata": {
1371 | "id": "a4Rxo_HZn10G"
1372 | },
1373 | "source": [
1374 | "exp_log[title] = {\n",
1375 | " 'history': exp_history,\n",
1376 | " 'hyperparams': hparams.__dict__\n",
1377 | "}"
1378 | ],
1379 | "execution_count": null,
1380 | "outputs": []
1381 | },
1382 | {
1383 | "cell_type": "code",
1384 | "metadata": {
1385 | "id": "YQGqrZ1_n-e3"
1386 | },
1387 | "source": [
1388 | "df = None\n",
1389 | "for i, e in enumerate(exp_history):\n",
1390 | " if i == 0:\n",
1391 | " df = pd.json_normalize(e)\n",
1392 | " continue\n",
1393 | " df = df + pd.json_normalize(e)\n",
1394 | " \n",
1395 | "df_avg = df / len(exp_history)\n",
1396 | "avg_history = df_avg.to_dict(orient='records')"
1397 | ],
1398 | "execution_count": null,
1399 | "outputs": []
1400 | },
1401 | {
1402 | "cell_type": "code",
1403 | "metadata": {
1404 | "id": "H3SO1Cga5tLE"
1405 | },
1406 | "source": [
1407 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='IID')"
1408 | ],
1409 | "execution_count": null,
1410 | "outputs": []
1411 | },
1412 | {
1413 | "cell_type": "code",
1414 | "metadata": {
1415 | "id": "Pwo7EWRh5uQx"
1416 | },
1417 | "source": [
1418 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='IID')"
1419 | ],
1420 | "execution_count": null,
1421 | "outputs": []
1422 | },
1423 | {
1424 | "cell_type": "code",
1425 | "metadata": {
1426 | "id": "9KPWHl8128He"
1427 | },
1428 | "source": [
1429 | "with open(f'{EXP_DIR}/results_iid.pkl', 'wb') as file:\n",
1430 | " pickle.dump(exp_log, file)"
1431 | ],
1432 | "execution_count": null,
1433 | "outputs": []
1434 | },
1435 | {
1436 | "cell_type": "markdown",
1437 | "metadata": {
1438 | "id": "ttnrJ9FnDxFR"
1439 | },
1440 | "source": [
1441 | "### Non-IID"
1442 | ]
1443 | },
1444 | {
1445 | "cell_type": "code",
1446 | "metadata": {
1447 | "id": "9T9hXMvT2_Ud"
1448 | },
1449 | "source": [
1450 | "exp_log = dict()"
1451 | ],
1452 | "execution_count": null,
1453 | "outputs": []
1454 | },
1455 | {
1456 | "cell_type": "code",
1457 | "metadata": {
1458 | "id": "yQlDBOdJEAaX"
1459 | },
1460 | "source": [
1461 | "data_dict, test_ds = noniid_partition(corpus, seq_length=seq_length, val_split=True)\n",
1462 | "\n",
1463 | "total_clients = len(data_dict.keys())\n",
1464 | "'Total users:', total_clients"
1465 | ],
1466 | "execution_count": null,
1467 | "outputs": []
1468 | },
1469 | {
1470 | "cell_type": "code",
1471 | "metadata": {
1472 | "id": "f5vIc19G8KUc"
1473 | },
1474 | "source": [
1475 | "hparams = Hyperparameters(total_clients)\n",
1476 | "hparams.__dict__"
1477 | ],
1478 | "execution_count": null,
1479 | "outputs": []
1480 | },
1481 | {
1482 | "cell_type": "code",
1483 | "metadata": {
1484 | "id": "WeGGOeN3rFDJ"
1485 | },
1486 | "source": [
1487 | "# Sweeping parameter\n",
1488 | "PARAM_NAME = 'clients_fraction'\n",
1489 | "PARAM_VALUE = hparams.C\n",
1490 | "exp_id = f'{PARAM_NAME}/{PARAM_VALUE}'\n",
1491 | "exp_id"
1492 | ],
1493 | "execution_count": null,
1494 | "outputs": []
1495 | },
1496 | {
1497 | "cell_type": "code",
1498 | "metadata": {
1499 | "id": "i3cGeO20ofEZ"
1500 | },
1501 | "source": [
1502 | "EXP_DIR = f'{BASE_DIR}/{exp_id}'\n",
1503 | "os.makedirs(EXP_DIR, exist_ok=True)\n",
1504 | "\n",
1505 | "# tb_logger = SummaryWriter(log_dir)\n",
1506 | "# print(f'TBoard logger created at: {log_dir}')\n",
1507 | "\n",
1508 | "title = 'LSTM qFedAvg on Non-IID'"
1509 | ],
1510 | "execution_count": null,
1511 | "outputs": []
1512 | },
1513 | {
1514 | "cell_type": "code",
1515 | "metadata": {
1516 | "id": "bu1A3GUjGHWy"
1517 | },
1518 | "source": [
1519 | "def run_experiment(run_id):\n",
1520 | " shakespeare_lstm = ShakespeareLSTM(input_dim=seq_length,\n",
1521 | " embedding_dim=embedding_dim,\n",
1522 | " hidden_dim=hidden_dim,\n",
1523 | " classes=num_classes,\n",
1524 | " lstm_layers=lstm_layers,\n",
1525 | " dropout=dropout,\n",
1526 | " batch_first=True\n",
1527 | " )\n",
1528 | "\n",
1529 | " if torch.cuda.is_available():\n",
1530 | " shakespeare_lstm.cuda()\n",
1531 | "\n",
1532 | " test_history = []\n",
1533 | "\n",
1534 | " lstm_noniid_trained, test_history = training(shakespeare_lstm,\n",
1535 | " hparams.rounds, hparams.batch_size, hparams.lr,\n",
1536 | " None, data_dict, test_ds,\n",
1537 | " hparams.C, hparams.K, hparams.E, hparams.q,\n",
1538 | " sampling=hparams.sampling,\n",
1539 | " test_history=test_history,\n",
1540 | " # tb_logger=tb_logger,\n",
1541 | " # perf_fig_file=f'{BASE_DIR}/loss.jpg'\n",
1542 | " )\n",
1543 | " \n",
1544 | " final_scores = testing(lstm_noniid_trained, test_ds, batch_size * 2, nn.CrossEntropyLoss(), len(corpus), corpus)\n",
1545 | " print(f'\\n\\n========================================================\\n\\n')\n",
1546 | " print(f'Final scores for Exp {run_id} \\n {final_scores}')\n",
1547 | "\n",
1548 | " log = {\n",
1549 | " 'history': test_history,\n",
1550 | " 'hyperparams': hparams.__dict__\n",
1551 | " }\n",
1552 | "\n",
1553 | " with open(f'{EXP_DIR}/results_niid_{run_id}.pkl', 'wb') as file:\n",
1554 | " pickle.dump(log, file)\n",
1555 | "\n",
1556 | " return test_history"
1557 | ],
1558 | "execution_count": null,
1559 | "outputs": []
1560 | },
1561 | {
1562 | "cell_type": "code",
1563 | "metadata": {
1564 | "id": "AaaamxrWGKct"
1565 | },
1566 | "source": [
1567 | "exp_history = list()\n",
1568 | "for run_id in range(2): # TOTAL RUNS\n",
1569 | " print(f'============== RUNNING EXPERIMENT #{run_id} ==============')\n",
1570 | " exp_history.append(run_experiment(run_id))\n",
1571 | " print(f'\\n\\n========================================================\\n\\n')"
1572 | ],
1573 | "execution_count": null,
1574 | "outputs": []
1575 | },
1576 | {
1577 | "cell_type": "code",
1578 | "metadata": {
1579 | "id": "ftL-MoHxwe5C"
1580 | },
1581 | "source": [
1582 | "exp_log[title] = {\n",
1583 | " 'history': exp_history,\n",
1584 | " 'hyperparams': hparams.__dict__\n",
1585 | "}"
1586 | ],
1587 | "execution_count": null,
1588 | "outputs": []
1589 | },
1590 | {
1591 | "cell_type": "code",
1592 | "metadata": {
1593 | "id": "jceXDZXOwezj"
1594 | },
1595 | "source": [
1596 | "df = None\n",
1597 | "for i, e in enumerate(exp_history):\n",
1598 | " if i == 0:\n",
1599 | " df = pd.json_normalize(e)\n",
1600 | " continue\n",
1601 | " df = df + pd.json_normalize(e)\n",
1602 | " \n",
1603 | "df_avg = df / len(exp_history)\n",
1604 | "avg_history = df_avg.to_dict(orient='records')"
1605 | ],
1606 | "execution_count": null,
1607 | "outputs": []
1608 | },
1609 | {
1610 | "cell_type": "code",
1611 | "metadata": {
1612 | "id": "1u7uTHwJ6KXE"
1613 | },
1614 | "source": [
1615 | "plot_scores(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')"
1616 | ],
1617 | "execution_count": null,
1618 | "outputs": []
1619 | },
1620 | {
1621 | "cell_type": "code",
1622 | "metadata": {
1623 | "id": "um4eBO8O6KLX"
1624 | },
1625 | "source": [
1626 | "plot_losses(history=avg_history, exp_id=exp_id, title=title, suffix='nonIID')"
1627 | ],
1628 | "execution_count": null,
1629 | "outputs": []
1630 | },
1631 | {
1632 | "cell_type": "markdown",
1633 | "metadata": {
1634 | "id": "N80BpTFy6aR7"
1635 | },
1636 | "source": [
1637 | "### Pickle Experiment Results"
1638 | ]
1639 | },
1640 | {
1641 | "cell_type": "code",
1642 | "metadata": {
1643 | "id": "tGlR8COy6aCN"
1644 | },
1645 | "source": [
1646 | "with open(f'{EXP_DIR}/results_niid.pkl', 'wb') as file:\n",
1647 | " pickle.dump(exp_log, file)"
1648 | ],
1649 | "execution_count": null,
1650 | "outputs": []
1651 | }
1652 | ]
1653 | }
--------------------------------------------------------------------------------