├── README.md
├── code_for_plots
├── Robustness_and_Privacy.ipynb
└── print_privacy_leakage.ipynb
├── diagrams
└── FbFTL_diagram.png
├── history.txt
├── main_GLOBECOM_CIFAR10_VGG16.py
├── main_TMLCN_CIFAR10_VGG16.py
└── main_TMLCN_FLANT5_SAMSUM.py
/README.md:
--------------------------------------------------------------------------------
1 | # FbFTL: Communication-Efficient Feature-based Federated Transfer Learning
2 |
3 | This is the offical implementation for Python simulation of Feature-based Federated Transfer Learning (FbFTL), from the following conference paper and journal paper:
4 |
5 | Communication-Efficient Feature-based Federated Transfer Learning.([Globecom2022](https://ieeexplore.ieee.org/abstract/document/10000612), [arXiv](https://arxiv.org/abs/2209.05395))
6 | Feng Wang, M. Cenk Gursoy and Senem Velipasalar
7 | Department of Electrical Engineering and Computer Science, Syracuse University
8 |
9 | Feature-based Federated Transfer Learning: Communication Efficiency, Robustness and Privacy.([TMLCN](https://ieeexplore.ieee.org/abstract/document/10542971), [arXiv](https://arxiv.org/abs/2405.09014))
10 | Feng Wang, M. Cenk Gursoy and Senem Velipasalar
11 | Department of Electrical Engineering and Computer Science, Syracuse University
12 |
13 | ---
14 |
15 |
16 |
17 | We propose the FbFTL as an innovative federated learning approach that upload features and outputs instead of gradients to reduce the uplink payload by more than five orders of magnitude. Please refer to the journal paper for explicit explaination on learning structure, system design, robustness analysis, and privacy analysis.
18 |
19 |
20 | # Results on CIFAR-10 Dataset with VGG16 Model
21 | In the following table, we provide comparison between federated learning with [FedAvg](http://proceedings.mlr.press/v54/mcmahan17a.html) (FL), federated transfer learning with FedAvg that updating full model (FTLf), federated transfer learning with FedAvg that updating task-specific sub-model(FTLc), and FbFTL. All of them learn [VGG16](https://arxiv.org/abs/1409.1556) model on [CIFAR-10](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.222.9220&rep=rep1&type=pdf) dataset. For transfer learning approaches, the source models are trained on [ImageNet](https://ieeexplore.ieee.org/abstract/document/5206848?casa_token=QncCRBM1tzAAAAAA:QuoJhjJAHRplmLJ4jcFw5JWdfASjmbIVlvpCrHgTPIFu63gpSUlBeACB78S0AH34qqQnsBOdoQ) dataset. Compared to all other methods, FbFTL reduces the uplink payload by up to five orders of magnitude.
22 |
23 | | | FL | FTLf | FTLc | FbFTL |
24 | | ---- | ----- | ---- | ---- | ---- |
25 | | upload batches | 656250 | 193750 | 525000 | 50000 |
26 | | upload parameters per batch | 153144650 | 153144650 | 35665418 | 4096 |
27 | | uplink payload per batch | **4.9 Gb** | **4.9 Gb** | **1.1 Gb** | **131 Kb** |
28 | | total uplink payload | **3216 Tb** | **949 Tb** | **599 Tb** | **6.6 Gb** |
29 | | total downlink payload | 402 Tb | 253 Tb | 322 Tb | 3.8 Gb |
30 | | test accuracy | 89.42\% | 93.75\% | 86.51\% | 86.51\% |
31 |
32 | # Results on SAMSum summary task with FLAN-T5-small language model
33 | In the following table, we consider [FLAN-T5-small](https://www.jmlr.org/papers/volume25/23-0870/23-0870.pdf) as a pre-trained language model, and fine-tune on [SAMSum](https://www.aclweb.org/anthology/D19-5409) summary task. As a fine-tuning task, this experiment does not include an FL setting, and we provide comparison between federated transfer learning with FedAvg that updating full model (FTLf), federated transfer learning with FedAvg that updating task-specific sub-model(FTLc), and FbFTL. Compared to all other methods, FbFTL reduces the uplink payload by up to five orders of magnitude.
34 |
35 | | | FTLf | FTLc | FbFTL | FTLc | FbFTL | FTLc | FbFTL |
36 | | ---- | ----- | ---- | ---- | ---- | ---- | ---- | ---- |
37 | | number of trained encoders | 8 | 8 | 8 | 4 | 4 | 2 | 2 |
38 | | number of upload batches | 132588 | 36830 | 7366 | 88392 | 7366 | 103124 | 7366 |
39 | | upload parameters per batch | 109860224 | 60511616 | 1024 | 51070144 | 1024 | 46349504 | 1024 |
40 | | uplink payload per batch | **3.5 Gb** | **1.9 Gb** | **32.7 Kb** | **1.6 Gb** | **32.7 Kb** | **1.5 Gb** | **32.7 Kb** |
41 | | total uplink payload | **466.1 Tb** | **71.3 Tb** | **241.4 Mb** | **144.5 Tb** | **241.4 Mb** | **152.9 Tb** | **241.4 Mb** |
42 | | total downlink payload | 116.0 Tb | 32.2 Tb | 1.58 Gb | 77.3 Tb | 1.88 Gb | 90.2 Tb | 2.03 Gb |
43 | | test ROUGE-1 | 45.9249 | 45.4680 | 45.4680 | 45.2827 | 45.2827 | 44.9862 | 44.9862 |
44 |
45 | # Required packages installation
46 | We use python==3.6.9, numpy==1.19.5, torch==1.4.0, torchvision==0.5.0, and CUDA version 11.6 for the experiments on CIFAR-10 with VGG16. The dataset and the source model will be automatically downloaded.
47 |
48 | Additionally, for the experiments on SAMSUM with FLAN-T5, we use transformers==4.30.2, torchinfo==1.8.0, datasets==2.13.2, nltk==3.8.1, evaluate==0.4.1, huggingface_hub==0.16.4
49 |
50 | # Citation
51 | If you find our work useful in your research, please consider citing:
52 | ```
53 | @inproceedings{wang2022communication,
54 | title={Communication-Efficient and Privacy-Preserving Feature-based Federated Transfer Learning},
55 | author={Wang, Feng and Gursoy, M Cenk and Velipasalar, Senem},
56 | booktitle={GLOBECOM 2022-2022 IEEE Global Communications Conference},
57 | pages={3875--3880},
58 | year={2022},
59 | organization={IEEE}
60 | }
61 | ```
62 |
63 | ```
64 | @article{wang2024feature,
65 | title={Feature-based Federated Transfer Learning: Communication Efficiency, Robustness and Privacy},
66 | author={Wang, Feng and Gursoy, M Cenk and Velipasalar, Senem},
67 | journal={IEEE Transactions on Machine Learning in Communications and Networking},
68 | year={2024},
69 | publisher={IEEE}
70 | }
71 | ```
72 |
--------------------------------------------------------------------------------
/code_for_plots/print_privacy_leakage.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "source": [
20 | "# CIFAR-10"
21 | ],
22 | "metadata": {
23 | "id": "RnC64R4j-se4"
24 | }
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {
30 | "id": "Ogn9Z9JF0EHG"
31 | },
32 | "outputs": [],
33 | "source": [
34 | "import numpy as np\n",
35 | "import matplotlib.pyplot as plt\n",
36 | "import math\n",
37 | "from math import factorial as f\n",
38 | "from time import sleep\n",
39 | "import itertools\n",
40 | "import time"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "source": [
46 | "K = 8\n",
47 | "N = 10\n",
48 | "B = f(N+K-1) // f(K) // f(N-1)\n",
49 | "\n",
50 | "# calculate H(B|Y=y), y is uniform\n",
51 | "p_y = 0\n",
52 | "P_ys = []\n",
53 | "b_count = 0\n",
54 | "H_y = 0\n",
55 | "\n",
56 | "for c0 in range(0, N+K-1):\n",
57 | " for c1 in range(c0+1, N+K-1):\n",
58 | " for c2 in range(c1+1, N+K-1):\n",
59 | " for c3 in range(c2+1, N+K-1):\n",
60 | " for c4 in range(c3+1, N+K-1):\n",
61 | " for c5 in range(c4+1, N+K-1):\n",
62 | " for c6 in range(c5+1, N+K-1):\n",
63 | " for c7 in range(c6+1, N+K-1):\n",
64 | " b_count += 1\n",
65 | " Ns = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
66 | " Ns[c0-0] += 1\n",
67 | " Ns[c1-1] += 1\n",
68 | " Ns[c2-2] += 1\n",
69 | " Ns[c3-3] += 1\n",
70 | " Ns[c4-4] += 1\n",
71 | " Ns[c5-5] += 1\n",
72 | " Ns[c6-6] += 1\n",
73 | " Ns[c7-7] += 1\n",
74 | " p_y = f(K) / (N**K) / (f(Ns[0])*f(Ns[1])*f(Ns[2])*f(Ns[3])*f(Ns[4])*f(Ns[5])*f(Ns[6])*f(Ns[7])*f(Ns[8])*f(Ns[9]))\n",
75 | " P_ys.append(p_y)\n",
76 | " H_y += -p_y * math.log2(p_y)\n",
77 | "\n",
78 | "print(len(P_ys))\n",
79 | "print(sum(P_ys))\n",
80 | "print(b_count)\n",
81 | "print('B', B)\n",
82 | "print(math.log2(B))\n",
83 | "print('H_y', H_y)\n",
84 | "print(2**H_y)"
85 | ],
86 | "metadata": {
87 | "colab": {
88 | "base_uri": "https://localhost:8080/"
89 | },
90 | "id": "7qZfigBQGzVQ",
91 | "outputId": "1d0a0bbf-fdb0-4fc7-aee6-2248cb12cd81"
92 | },
93 | "execution_count": null,
94 | "outputs": [
95 | {
96 | "output_type": "stream",
97 | "name": "stdout",
98 | "text": [
99 | "24310\n",
100 | "1.0000000000001252\n",
101 | "24310\n",
102 | "B 24310\n",
103 | "14.569262272916092\n",
104 | "H_y 13.860045564098357\n",
105 | "14869.263443691536\n"
106 | ]
107 | }
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "source": [
113 | "# calculate H(B|S=s)\n",
114 | "U_range = np.arange(1, 6252, 1250)\n",
115 | "# U_range = np.arange(1, B, 3000)\n",
116 | "H_ss = np.zeros_like(U_range).astype(float)\n",
117 | "\n",
118 | "for i, U in enumerate(U_range):\n",
119 | " H_s = 0\n",
120 | " p_s = 0\n",
121 | " choices = np.random.choice(B, U, p=P_ys)\n",
122 | " # print(len(choices), choices)\n",
123 | " for u in range(U):\n",
124 | " p_s = sum([1 for c in choices if c == u]) / U\n",
125 | " # print(p_s)\n",
126 | " if p_s != 0:\n",
127 | " H_s += -p_s * math.log2(p_s)\n",
128 | " print('U:', U, 'H_s:', H_s)\n",
129 | " H_ss[i] = H_s\n",
130 | "\n",
131 | "print('H_ss', H_ss)"
132 | ],
133 | "metadata": {
134 | "colab": {
135 | "base_uri": "https://localhost:8080/"
136 | },
137 | "id": "ePqcA5Ua7aZk",
138 | "outputId": "7f86a441-d4ff-4127-8e16-58b339d148b9"
139 | },
140 | "execution_count": null,
141 | "outputs": [
142 | {
143 | "output_type": "stream",
144 | "name": "stdout",
145 | "text": [
146 | "U: 1 H_s: 0\n",
147 | "U: 1251 H_s: 0.19578959694642661\n",
148 | "U: 2501 H_s: 0.527494304228384\n",
149 | "U: 3751 H_s: 1.6176505965077759\n",
150 | "U: 5001 H_s: 2.240630027446816\n",
151 | "U: 6251 H_s: 3.09074958685801\n",
152 | "H_ss [0. 0.1957896 0.5274943 1.6176506 2.24063003 3.09074959]\n"
153 | ]
154 | }
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "source": [
160 | "H_ss = np.array([ 0., 1.14445591, 2.92142532, 5.66463935, 7.32508148, 8.84656221, 10.8259317, 12.2532801, 13.15547667])\n",
161 | "H_y = 13.860045564098357\n",
162 | "plt.plot(np.arange(9)*6250/9, H_y-H_ss)\n",
163 | "# H_ss = np.array([ 0., 0.1891638, 0.78579461, 1.54249295, 2.22972449, 3.19595329])\n",
164 | "# H_y = 13.860045564098357\n",
165 | "# plt.plot(np.arange(6)*6250/6, H_y-H_ss)\n",
166 | "plt.xlabel('Number of clients')\n",
167 | "plt.ylabel('Privacy leakage (in bits)')\n",
168 | "plt.show()"
169 | ],
170 | "metadata": {
171 | "colab": {
172 | "base_uri": "https://localhost:8080/",
173 | "height": 279
174 | },
175 | "id": "bbZ_N7hK87yL",
176 | "outputId": "b63f85ad-1ffd-4ae2-ab25-3659484863df"
177 | },
178 | "execution_count": null,
179 | "outputs": [
180 | {
181 | "output_type": "display_data",
182 | "data": {
183 | "text/plain": [
184 | ""
185 | ],
186 | "image/png": "\n"
187 | },
188 | "metadata": {
189 | "needs_background": "light"
190 | }
191 | }
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "source": [
197 | "# Fixed UK=50000"
198 | ],
199 | "metadata": {
200 | "id": "2eQs7BvRBM9Z"
201 | }
202 | },
203 | {
204 | "cell_type": "code",
205 | "source": [
206 | "UK = 50000 # 50000\n",
207 | "N = 10\n",
208 | "\n",
209 | "for K in [9]: # [8]\n",
210 | " U = int(UK / K)\n",
211 | " B = f(N+K-1) // f(K) // f(N-1)\n",
212 | " print(\"K, U, B:\", K, U, B)\n",
213 | "\n",
214 | " # calculate H(B|Y=y), y is uniform\n",
215 | " p_y = 0\n",
216 | " P_ys = []\n",
217 | " b_count = 0\n",
218 | " H_y = 0\n",
219 | " for comb in itertools.combinations(np.arange(N+K-1), N-1):\n",
220 | " b_count += 1\n",
221 | " p_y = f(K) / (N**K) / f(comb[0]) / f(N+K-2-comb[-1])\n",
222 | " for n in range(1, N-1):\n",
223 | " p_y = p_y / f(comb[n]-comb[n-1]-1)\n",
224 | " P_ys.append(p_y)\n",
225 | " H_y += -p_y * math.log2(p_y)\n",
226 | "\n",
227 | " # print('len(P_ys)', len(P_ys))\n",
228 | " # print('sum(P_ys)', sum(P_ys))\n",
229 | " # print('b_count', b_count)\n",
230 | " # print('B', B)\n",
231 | " # print('math.log2(B)', math.log2(B))\n",
232 | " print('H_y', H_y)\n",
233 | " # print('2**H_y', 2**H_y)\n",
234 | "\n",
235 | " # calculate H(B|S=s)\n",
236 | " H_s = 0\n",
237 | " p_s = 0\n",
238 | " choices = np.random.choice(B, U, p=P_ys)\n",
239 | " # print(len(choices), choices)\n",
240 | " start_time = time.time()\n",
241 | " for u in range(U):\n",
242 | " if u%10000==0:\n",
243 | " end_time = time.time()\n",
244 | " print('u:', u, '/', U, ', H_s:', H_s, ', time:', round((end_time-start_time)/60))\n",
245 | " p_s = sum([1 for c in choices if c == u]) / U\n",
246 | " # print(p_s)\n",
247 | " if p_s != 0:\n",
248 | " H_s += -p_s * math.log2(p_s)\n",
249 | " print('H_s', H_s)"
250 | ],
251 | "metadata": {
252 | "id": "I26xJAKI87vq",
253 | "colab": {
254 | "base_uri": "https://localhost:8080/"
255 | },
256 | "outputId": "abc10e1d-2476-4108-f9ad-883f2e6fbaca"
257 | },
258 | "execution_count": null,
259 | "outputs": [
260 | {
261 | "output_type": "stream",
262 | "name": "stdout",
263 | "text": [
264 | "K, U, B: 9 5555 48620\n",
265 | "H_y 14.7084540600815\n",
266 | "u: 0 / 5555 , H_s: 0 , time: 0\n",
267 | "H_s 0.5174473607079045\n"
268 | ]
269 | }
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "source": [
275 | "Ks = [1, 2, 4, 6, 7, 8, 9, 10, 16]\n",
276 | "Ks_H_y = [3.321928, 5.743856, 9.286393, 11.8602849, 12.91726, 13.860046, 14.708454, 15.4776598, 18.95812]\n",
277 | "K_s_H_s_UK100000 = [3.321874, 5.74309, 9.25941, 11.64429, 12.33949, 5.448149, 1.7305256, 0.505497, 0]\n",
278 | "K_s_H_s_UK50000 = [3.321928, 5.738995, 9.2478531, 11.399456, 6.4285309, 1.970087, 0.517447, 0.17900, 0]\n",
279 | "K_s_H_s_UK5000 = [3.318592, 5.73230, 8.81008, 1.091727, 0.146047, 0.056241, 0, 0, 0]\n",
280 | "\n",
281 | "# plt.plot(Ks, Ks_H_y, label='$\\mathregular{H^a(B|Y=y_0)}$')\n",
282 | "# plt.plot(Ks, K_s_H_s_UK5000, label='$\\mathregular{H^a(B|S=s), UK=5000}$')\n",
283 | "# plt.plot(Ks, K_s_H_s_UK50000, label='$\\mathregular{H^a(B|S=s), UK=50000}$')\n",
284 | "# plt.plot(Ks, K_s_H_s_UK100000, label='$\\mathregular{H^a(B|S=s), UK=100000}$')\n",
285 | "plt.plot(Ks, Ks_H_y, label='$\\mathregular{H^a(B|P_{y,uni})}$')\n",
286 | "plt.plot(Ks, K_s_H_s_UK5000, label='$\\mathregular{H^a(B|\\mathcal{C}_U=c_U), UK=5000}$')\n",
287 | "plt.plot(Ks, K_s_H_s_UK50000, label='$\\mathregular{H^a(B|\\mathcal{C}_U=c_U), UK=50000}$')\n",
288 | "plt.plot(Ks, K_s_H_s_UK100000, label='$\\mathregular{H^a(B|\\mathcal{C}_U=c_U), UK=100000}$')\n",
289 | "plt.xlabel('Number of samples K from each client')\n",
290 | "plt.ylabel('Lable Privacy information (in bits)')\n",
291 | "plt.legend()\n",
292 | "plt.show()"
293 | ],
294 | "metadata": {
295 | "id": "befatQox87ql",
296 | "colab": {
297 | "base_uri": "https://localhost:8080/",
298 | "height": 279
299 | },
300 | "outputId": "b202ec23-d599-42a4-a941-0ee1804dbff2"
301 | },
302 | "execution_count": null,
303 | "outputs": [
304 | {
305 | "output_type": "display_data",
306 | "data": {
307 | "text/plain": [
308 | ""
309 | ],
310 | "image/png": "\n"
311 | },
312 | "metadata": {
313 | "needs_background": "light"
314 | }
315 | }
316 | ]
317 | },
318 | {
319 | "cell_type": "markdown",
320 | "source": [
321 | "# Dry bean"
322 | ],
323 | "metadata": {
324 | "id": "DP0fdL00-wYq"
325 | }
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "metadata": {
331 | "id": "xBPJ8vmAEiJQ"
332 | },
333 | "outputs": [],
334 | "source": [
335 | "K = 4\n",
336 | "N = 4\n",
337 | "B = f(N+K-1) // f(K) // f(N-1)"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "source": [
343 | "# calculate H(B|Y=y), y is uniform\n",
344 | "p_y = 0\n",
345 | "P_ys = []\n",
346 | "b_count = 0\n",
347 | "H_y = 0\n",
348 | "\n",
349 | "for c0 in range(0, N+K-1):\n",
350 | " for c1 in range(c0+1, N+K-1):\n",
351 | " for c2 in range(c1+1, N+K-1):\n",
352 | " for c3 in range(c2+1, N+K-1):\n",
353 | " b_count += 1\n",
354 | " Ns = [0, 0, 0, 0, 0, 0]\n",
355 | " Ns[c0-0] += 1\n",
356 | " Ns[c1-1] += 1\n",
357 | " Ns[c2-2] += 1\n",
358 | " Ns[c3-3] += 1\n",
359 | " p_y = f(K) / (N**K) / (f(Ns[0])*f(Ns[1])*f(Ns[2])*f(Ns[3])*f(Ns[4])*f(Ns[5]))\n",
360 | " P_ys.append(p_y)\n",
361 | " H_y += -p_y * math.log2(p_y)\n",
362 | "\n",
363 | "# print(sum(P_ys))\n",
364 | "# print(b_count)\n",
365 | "print('B', B)\n",
366 | "print(math.log2(B))\n",
367 | "print('H_y', H_y)\n",
368 | "print(2**H_y)"
369 | ],
370 | "metadata": {
371 | "colab": {
372 | "base_uri": "https://localhost:8080/"
373 | },
374 | "outputId": "7c0c2d06-1666-4b67-8c66-130dfece1472",
375 | "id": "ObFHHIEsEiJa"
376 | },
377 | "execution_count": null,
378 | "outputs": [
379 | {
380 | "output_type": "stream",
381 | "name": "stdout",
382 | "text": [
383 | "B 35\n",
384 | "5.129283016944966\n",
385 | "H_y 4.81510800723783\n",
386 | "28.150877863635262\n"
387 | ]
388 | }
389 | ]
390 | },
391 | {
392 | "cell_type": "code",
393 | "source": [
394 | "# calculate H(B|S=s)\n",
395 | "U_range = np.arange(1, 1176, 110)\n",
396 | "# U_range = np.arange(1, B, 3000)\n",
397 | "H_ss = np.zeros_like(U_range).astype(float)\n",
398 | "\n",
399 | "for i, U in enumerate(U_range):\n",
400 | " H_s = 0\n",
401 | " p_s = 0\n",
402 | " choices = np.random.choice(B, U, p=P_ys)\n",
403 | " # print(len(choices), choices)\n",
404 | " for u in range(U):\n",
405 | " p_s = sum([1 for c in choices if c == u]) / U\n",
406 | " # print(p_s)\n",
407 | " if p_s != 0:\n",
408 | " H_s += -p_s * math.log2(p_s)\n",
409 | " print(H_s)\n",
410 | " H_ss[i] = H_s\n",
411 | "\n",
412 | "print('H_ss', H_ss)"
413 | ],
414 | "metadata": {
415 | "colab": {
416 | "base_uri": "https://localhost:8080/"
417 | },
418 | "outputId": "d4fe9fbc-70dd-4ccb-9947-13b5786c4743",
419 | "id": "z8SJ6CCVEiJb"
420 | },
421 | "execution_count": null,
422 | "outputs": [
423 | {
424 | "output_type": "stream",
425 | "name": "stdout",
426 | "text": [
427 | "0\n",
428 | "4.488025774313082\n",
429 | "4.702617634802406\n",
430 | "4.802929804553897\n",
431 | "4.750954940629122\n",
432 | "4.744101917925756\n",
433 | "4.8117979838104485\n",
434 | "4.768437327163887\n",
435 | "4.815896365695055\n",
436 | "4.789852490143935\n",
437 | "4.782352207555449\n",
438 | "H_ss [0. 4.48802577 4.70261763 4.8029298 4.75095494 4.74410192\n",
439 | " 4.81179798 4.76843733 4.81589637 4.78985249 4.78235221]\n"
440 | ]
441 | }
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "source": [
447 | "H_ss = np.array([0., 4.48802577, 4.70261763, 4.8029298, 4.75095494, 4.74410192, 4.81179798, 4.76843733, 4.81589637, 4.78985249, 4.78235221])\n",
448 | "H_y = 4.81510800723783\n",
449 | "plt.plot(np.arange(11)*1176/11, H_y-H_ss)\n",
450 | "plt.xlabel('Number of clients')\n",
451 | "plt.ylabel('Privacy leakage (in bits)')\n",
452 | "plt.show()"
453 | ],
454 | "metadata": {
455 | "colab": {
456 | "base_uri": "https://localhost:8080/",
457 | "height": 281
458 | },
459 | "outputId": "c9ab6df1-93a5-481a-a404-dfc8767a4395",
460 | "id": "7YxzWs4OEiJb"
461 | },
462 | "execution_count": null,
463 | "outputs": [
464 | {
465 | "output_type": "display_data",
466 | "data": {
467 | "text/plain": [
468 | ""
469 | ],
470 | "image/png": "\n"
471 | },
472 | "metadata": {
473 | "needs_background": "light"
474 | }
475 | }
476 | ]
477 | }
478 | ]
479 | }
--------------------------------------------------------------------------------
/diagrams/FbFTL_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wfwf10/Feature-based-Federated-Transfer-Learning/6e332076b157d3d7dfcfc46b3170c966747e5510/diagrams/FbFTL_diagram.png
--------------------------------------------------------------------------------
/history.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/main_GLOBECOM_CIFAR10_VGG16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torchvision import datasets
8 | from torchvision import transforms
9 | from torch.utils.data import DataLoader
10 | import datetime
11 |
12 | #######################################
13 | ### PRE-TRAINED MODELS AVAILABLE HERE
14 | ## https://pytorch.org/docs/stable/torchvision/models.html
15 | from torchvision import models
16 | #######################################
17 |
18 | now = datetime.datetime.now()
19 |
20 | # Device
21 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22 | print('Device:', DEVICE)
23 |
24 | ##########################
25 | ### SETTINGS
26 | ##########################
27 |
28 | FL_type = 'FbFTL'
29 | train_set_denominator = 'full' # 'full', int <= 50000 # pick a subset with 50000/int training samples
30 |
31 | # Hyperparameters
32 | NUM_CLASSES = 10
33 | random_seed = 1
34 | learning_rate = 1e-2
35 | num_epochs = 200
36 | batch_size = 64
37 | momentum = 0.9
38 | lr_decay = 5e-4
39 |
40 | write_hist = True
41 |
42 | def adjust_learning_rate(optimizer, epoch):
43 | """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
44 | lr = learning_rate * (0.5 ** ((epoch * 10) // num_epochs))
45 | for param_group in optimizer.param_groups:
46 | param_group['lr'] = lr
47 |
48 | if torch.cuda.is_available():
49 | torch.backends.cudnn.deterministic = True
50 | torch.cuda.manual_seed(random_seed) # Sets the seed for generating random numbers for the current GPU.
51 | torch.manual_seed(random_seed) # sets the seed for generating random numbers.
52 |
53 | if write_hist:
54 | file1 = open('history.txt', 'a')
55 | file1.write('\n \n \n Time:')
56 | file1.write(str(now.year) + ' ' + str(now.month) + ' ' + str(now.day) + ' '
57 | + str(now.hour) + ' ' + str(now.minute) + ' ' + str(now.second)
58 | + ' ' + FL_type + ', train_deno:' + str(train_set_denominator))
59 | file1.close()
60 |
61 | ##########################
62 | ### CIFAR10 DATASET
63 | ##########################
64 |
65 | custom_transform = transforms.Compose([
66 | transforms.Resize((224, 224)),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
69 | std=[0.229, 0.224, 0.225])
70 | ])
71 |
72 | train_dataset = datasets.CIFAR10(root='data', train=True, transform=custom_transform,download=True)
73 | test_dataset = datasets.CIFAR10(root='data', train=False, transform=custom_transform)
74 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=8, shuffle=False)
75 |
76 | if train_set_denominator == 'full':
77 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=8, shuffle=True)
78 | else:
79 | selected_list = list(range(0, len(train_dataset), train_set_denominator))
80 | trainset_1 = torch.utils.data.Subset(train_dataset, selected_list)
81 | train_loader = torch.utils.data.DataLoader(dataset=trainset_1, batch_size=batch_size, num_workers=8, shuffle=True)
82 |
83 | ##########################
84 | ### LOAD MODEL
85 | ##########################
86 |
87 | model = models.vgg16(pretrained=True)
88 | for param in model.parameters():
89 | param.requires_grad = False
90 |
91 | model.classifier[3].requires_grad = True
92 | model.classifier[6] = nn.Sequential(nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5), nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5), nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, NUM_CLASSES))
93 |
94 | ##########################
95 | ### TRAIN MODEL
96 | ##########################
97 |
98 | model = model.to(DEVICE)
99 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=lr_decay, momentum=momentum)
100 |
101 | def compute_accuracy(model, data_loader):
102 | model.eval()
103 | correct_pred, num_examples = 0, 0
104 | with torch.no_grad():
105 | for i, (features, targets) in enumerate(data_loader):
106 | features = features.to(DEVICE)
107 | targets = targets.to(DEVICE)
108 |
109 | logits = model(features)
110 | _, predicted_labels = torch.max(logits, 1)
111 | num_examples += targets.size(0)
112 | correct_pred += (predicted_labels == targets).sum()
113 | return correct_pred.float()/num_examples * 100
114 |
115 |
116 | def compute_epoch_loss(model, data_loader):
117 | model.eval()
118 | curr_loss, num_examples = 0., 0
119 | with torch.no_grad():
120 | for batch_idx, (features, targets) in enumerate(data_loader):
121 | features = features.to(DEVICE)
122 | targets = targets.to(DEVICE)
123 | logits = model(features)
124 | loss = F.cross_entropy(logits, targets, reduction='sum')
125 | num_examples += targets.size(0)
126 | curr_loss += loss
127 |
128 | curr_loss = curr_loss / num_examples
129 | return curr_loss
130 |
131 |
132 |
133 | start_time = time.time()
134 | for epoch in range(num_epochs):
135 | model.train()
136 | for batch_idx, (features, targets) in enumerate(train_loader):
137 |
138 | features = features.to(DEVICE)
139 | targets = targets.to(DEVICE)
140 |
141 | ### FORWARD AND BACK PROP
142 | logits = model(features)
143 | cost = F.cross_entropy(logits, targets)
144 | optimizer.zero_grad()
145 |
146 | cost.backward()
147 |
148 | ### UPDATE MODEL PARAMETERS
149 | optimizer.step()
150 |
151 | ### LOGGING
152 | if not batch_idx % 50:
153 | print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'
154 | %(epoch+1, num_epochs, batch_idx,
155 | len(train_loader), cost))
156 |
157 | model.eval()
158 | accuracy = compute_accuracy(model, test_loader)
159 | loss = compute_epoch_loss(model, test_loader)
160 | with torch.set_grad_enabled(False): # save memory during inference
161 | print('Epoch: %03d/%03d | Test: %.3f%% | Loss: %.3f' % (epoch+1, num_epochs, accuracy, loss))
162 | if write_hist:
163 | file1 = open('history.txt', 'a')
164 | file1.write('\n Epoch: %03d/%03d | Test: %.3f%% | Loss: %.3f' % (epoch+1, num_epochs, accuracy, loss))
165 | file1.close()
166 |
167 |
168 | print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
169 |
170 | print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
171 |
172 | with torch.set_grad_enabled(False): # save memory during inference
173 | accuracy = compute_accuracy(model, test_loader)
174 | print('Test accuracy: %.2f%%' % (accuracy))
175 | if write_hist:
176 | file1 = open('history.txt', 'a')
177 | file1.write('\n Test accuracy: %.2f%%' % (accuracy))
178 | file1.close()
179 |
180 |
181 |
--------------------------------------------------------------------------------
/main_TMLCN_CIFAR10_VGG16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torchvision import datasets
8 | from torchvision import transforms
9 | from torch.utils.data import DataLoader
10 | import matplotlib.pyplot as plt
11 | import datetime
12 | import math
13 | from collections import deque
14 | import copy
15 |
16 | #######################################
17 | ### PRE-TRAINED MODELS AVAILABLE HERE
18 | ## https://pytorch.org/docs/stable/torchvision/models.html
19 | from torchvision import models
20 | #######################################
21 |
22 | now = datetime.datetime.now()
23 |
24 | # Device
25 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 | print('Device:', DEVICE)
27 |
28 | ##########################
29 | ### SETTINGS
30 | ##########################
31 |
32 | FL_type = 'FbFTL' # 'FL', 'FTLf', 'FTLc', 'FbFTL'
33 | train_set_denominator = 'full' # 'full', int <= 50000 # pick a subset with 50000/int training samples
34 |
35 | # Hyperparameters
36 | NUM_CLASSES = 10
37 | U_clients = 6250 # number of clients, 50000/8
38 | random_seed = 1 # transfer = True / False
39 | learning_rate = 1e-2 # 1e-3, 0.05, 1e-2
40 | num_epochs = 200 # 10, 300, 200
41 | batch_size = 64 # 128, 128, (out of memory:64)
42 | momentum = 0.9 # None, 0.9
43 | lr_decay = 5e-4 # 1e-6, 5e-4
44 | if FL_type == 'FL':
45 | transfer, full = False, True # transfer or train model from scratch # train whole model or last few layers
46 | sigma = 0. # 0.5 relative std for addtive gaussian noise on gradients
47 | elif FL_type == 'FTLf':
48 | transfer, full = True, True
49 | sigma = 0. # 0.3305 relative std for addtive gaussian noise on gradients
50 | elif FL_type == 'FTLc':
51 | transfer, full = True, False
52 | sigma = 0. # 0.285 relative std for addtive gaussian noise on gradients
53 | elif FL_type == 'FbFTL':
54 | transfer, full = True, False
55 | sigma = 0 # 0.8? relative std for addtive gaussian noise on features
56 | saved_noise = True # save noise at beginning
57 | else:
58 | raise ValueError('Unknown FL_type: ' + FL_type)
59 | relative_noise_type = 'all_std' # 'individual', 'all_std'
60 | packet_loss_rate = 0. # 0, 0.05, 0.1, 0.15
61 | quan_digit = 32 # digits kept after feature quantization: None (max:(12~18)(6~8), min=0, std~0.8) or int
62 | sparse_rate = 0.9 # ratio of uplink elements kept after sparsification: None or (0,1]
63 | class ErrorFeedback(object):
64 | queue = deque(maxlen=U_clients)
65 | temp = deque()
66 | if (quan_digit or sparse_rate) and FL_type != 'FbFTL':
67 | errfdbk = ErrorFeedback()
68 | # print(errfdbk.queue, errfdbk.temp)
69 |
70 | write_hist = True
71 |
72 | def adjust_learning_rate(optimizer, epoch):
73 | """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
74 | lr = learning_rate * (0.5 ** ((epoch * 10) // num_epochs))
75 | for param_group in optimizer.param_groups:
76 | param_group['lr'] = lr
77 |
78 | if torch.cuda.is_available():
79 | torch.backends.cudnn.deterministic = True
80 | torch.cuda.manual_seed(random_seed) # Sets the seed for generating random numbers for the current GPU.
81 | torch.manual_seed(random_seed) # sets the seed for generating random numbers.
82 |
83 | if write_hist:
84 | file1 = open('/data1/feng/LaTFL/history.txt', 'a')
85 | file1.write('\n \n \n Time:')
86 | file1.write(str(now.year) + ' ' + str(now.month) + ' ' + str(now.day) + ' ' + str(now.hour) + ' '
87 | + str(now.minute) + ' ' + str(now.second) + ' ' + FL_type
88 | # + ', train_deno:' + str(train_set_denominator)
89 | # + ', sigma:' + str(sigma)
90 | # + ', packet_loss_rate:' + str(packet_loss_rate)
91 | + ', quantization digits:' + str(quan_digit)
92 | + ', sparsification rate:' + str(sparse_rate)
93 | )
94 | file1.close()
95 |
96 | ##########################
97 | ### CIFAR10 DATASET
98 | ##########################
99 |
100 | custom_transform = transforms.Compose([
101 | transforms.Resize((224, 224)),
102 | transforms.ToTensor(),
103 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
104 | std=[0.229, 0.224, 0.225])
105 | ])
106 |
107 | ## Note that this particular normalization scheme is necessary since it was used for pre-training the network on ImageNet.
108 | ## These are the channel-means and standard deviations for z-score normalization.
109 |
110 | train_dataset = datasets.CIFAR10(root='data', train=True, transform=custom_transform,download=True)
111 | test_dataset = datasets.CIFAR10(root='data', train=False, transform=custom_transform)
112 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=8, shuffle=False)
113 |
114 | if train_set_denominator == 'full':
115 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=8, shuffle=True)
116 | train_set_len = len(train_dataset)
117 | else:
118 | # print('len(train_dataset)', len(train_dataset)) # 50000
119 | # selected_list = list(range(0, len(train_dataset), 2))
120 | selected_list = list(range(0, len(train_dataset), train_set_denominator))
121 | trainset_1 = torch.utils.data.Subset(train_dataset, selected_list)
122 | train_loader = torch.utils.data.DataLoader(dataset=trainset_1, batch_size=batch_size, num_workers=8, shuffle=True)
123 | train_set_len = len(selected_list)
124 |
125 | # # Checking the dataset
126 | # for images, labels in train_loader:
127 | # print('Image batch dimensions:', images.shape)
128 | # print('Image label dimensions:', labels.shape)
129 | # break
130 |
131 | # labels = torch.zeros(10, dtype=torch.long)
132 | # for batch_idx, (features, targets) in enumerate(train_loader):
133 | # for t in targets:
134 | # labels[t] += 1
135 | # print('labels', labels)
136 |
137 | ##########################
138 | ### LOAD MODEL
139 | ##########################
140 |
141 | class GaussianNoise(nn.Module):
142 | """Gaussian noise regularizer.
143 | Args:
144 | sigma (float, optional): relative standard deviation used to generate the
145 | noise. Relative means that it will be multiplied by the magnitude of
146 | the value your are adding the noise to. This means that sigma can be
147 | the same regardless of the scale of the vector.
148 | is_relative_detach (bool, optional): whether to detach the variable before
149 | computing the scale of the noise. If `False` then the scale of the noise
150 | won't be seen as a constant but something to optimize: this will bias the
151 | network to generate vectors with smaller values.
152 | """
153 | def __init__(self, sigma=0, is_relative_detach=False):
154 | super().__init__()
155 | self.sigma = sigma
156 | self.is_relative_detach = is_relative_detach
157 | if saved_noise:
158 | self.register_buffer('noise', torch.empty(train_set_len*4096).normal_(mean=0,std=1))
159 | self.i = 0
160 | else:
161 | self.register_buffer('noise', torch.tensor(0))
162 |
163 | def forward(self, x):
164 | if self.training and quan_digit:
165 | # print(x)
166 | # print(x.dtype, torch.max(x), torch.min(x), torch.std(x))
167 | x = torch.round((2**quan_digit-1) / torch.max(x) * x) * torch.max(x) / (2**quan_digit-1)
168 | # print(x.dtype, torch.max(x), torch.min(x), torch.std(x))
169 | # print(x)
170 | # quit()
171 | if self.training and self.sigma != 0:
172 | if torch.cuda.is_available():
173 | torch.backends.cudnn.deterministic = True
174 | torch.cuda.manual_seed(random_seed) # Sets the seed for generating random numbers for the current GPU.
175 | torch.manual_seed(random_seed) # sets the seed for generating random numbers.
176 | if relative_noise_type == 'individual':
177 | scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
178 | sampled_noise = self.noise.expand(*x.size()).float().normal_() * scale
179 | elif relative_noise_type == 'all_std':
180 | x_std = torch.std(x.detach()) if self.is_relative_detach else torch.std(x)
181 | # print(*x.size()) # 64 4096
182 | if saved_noise:
183 | sampled_noise = torch.reshape(self.noise[self.i*batch_size*4096 : (self.i+1)*batch_size*4096],(-1, 4096)
184 | ).detach().float() * x_std * self.sigma
185 | self.i = self.i + 1 if (self.i+1)*batch_size*4096 256:
221 | m.reset_parameters()
222 | for i in range(21, 31):
223 | model.features[i].apply(weight_reset)
224 | for i in range(7):
225 | model.classifier[i].apply(weight_reset)
226 |
227 | for param in model.parameters():
228 | param.requires_grad = True
229 |
230 | # model.classifier[3].requires_grad = True # (4096, 4096, relu, dropout(0.5))
231 | # model.classifier[6] = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, NUM_CLASSES))
232 | model.classifier[6] = nn.Sequential(nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5), nn.Linear(4096, 512),
233 | nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, NUM_CLASSES))
234 |
235 |
236 | def Gaussian_noise_to_weights(m):
237 | if sigma!=0:
238 | with torch.no_grad():
239 | for param in m.parameters():
240 | if param.requires_grad:
241 | if relative_noise_type == 'individual':
242 | # print(param.grad.view(-1))
243 | # print(param)
244 | # scale = sigma * param.detach()
245 | scale = sigma * param.grad.detach() # todo: * math.sqrt(batch_size/8)
246 | noise = torch.tensor(0).to(DEVICE)
247 | sampled_noise = noise.expand(*param.size()).float().normal_() * scale
248 | elif relative_noise_type == 'all_std':
249 | param_grad_std = torch.std(param.grad.detach())
250 | noise = torch.tensor(0).to(DEVICE)
251 | sampled_noise = noise.expand(*param.size()).float().normal_(std=param_grad_std*sigma) # todo: * math.sqrt(batch_size/8)
252 | # param = param + sampled_noise
253 | param.add_(sampled_noise)
254 |
255 | def Errfdbk_to_weights(m):
256 | print("inner model.apply")
257 | with torch.no_grad():
258 | for param in m.parameters():
259 | print('len(param)', len(param))
260 | if param.requires_grad:
261 | print('len(errfdbk.queue)', len(errfdbk.queue))
262 | # print(errfdbk.temp)
263 | print('len(errfdbk.temp)', len(errfdbk.temp))
264 | p_grad = param.grad.detach()
265 | if err_flag:
266 | p_grad += errfdbk.temp.popleft()
267 | p_grad_qs = copy.deepcopy(p_grad)
268 | if sparse_rate:
269 | pass
270 | if quan_digit:
271 | pass
272 | err = p_grad_qs - p_grad
273 | param.add_(err)
274 | errfdbk.temp.append(-err)
275 | print('seems good')
276 |
277 | if FL_type == 'FbFTL':
278 | received_batches_FbFTL = np.ones(len(train_loader))
279 | received_batches_FbFTL[:int(len(train_loader)*packet_loss_rate)] = 0
280 | np.random.shuffle(received_batches_FbFTL)
281 |
282 | def Packet_Received(batch_idx):
283 | if FL_type == 'FbFTL':
284 | return received_batches_FbFTL[batch_idx]
285 | else:
286 | return np.random.choice(2, p=[packet_loss_rate, 1-packet_loss_rate])
287 |
288 | # for a in range(3):
289 | # print('round', a)
290 | # for i in range(50):
291 | # print(Packet_Received(i))
292 |
293 | # print(model)
294 |
295 | # for name, param in model.named_parameters():
296 | # print(name, torch.numel(param), param.requires_grad)
297 | # quit()
298 |
299 | # model(torch.randn(1, 3, 224, 224)).mean().backward()
300 | # for name, param in model.named_parameters():
301 | # print(name, param.grad)
302 | # print('value', param.data)
303 | # wait = input('next layer')
304 |
305 | ##########################
306 | ### TRAIN MODEL
307 | ##########################
308 |
309 | model = model.to(DEVICE)
310 | # if transfer:
311 | # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
312 | # else:
313 | # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=lr_decay, momentum=momentum) # , nesterov=True)
314 | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=lr_decay, momentum=momentum)
315 |
316 |
317 | def compute_accuracy(model, data_loader):
318 | model.eval()
319 | correct_pred, num_examples = 0, 0
320 | with torch.no_grad():
321 | for i, (features, targets) in enumerate(data_loader):
322 | features = features.to(DEVICE)
323 | targets = targets.to(DEVICE)
324 |
325 | logits = model(features)
326 | _, predicted_labels = torch.max(logits, 1)
327 | num_examples += targets.size(0)
328 | correct_pred += (predicted_labels == targets).sum()
329 | return correct_pred.float()/num_examples * 100
330 |
331 |
332 | def compute_epoch_loss(model, data_loader):
333 | model.eval()
334 | curr_loss, num_examples = 0., 0
335 | with torch.no_grad():
336 | for batch_idx, (features, targets) in enumerate(data_loader):
337 | features = features.to(DEVICE)
338 | targets = targets.to(DEVICE)
339 | logits = model(features)
340 | loss = F.cross_entropy(logits, targets, reduction='sum')
341 | num_examples += targets.size(0)
342 | curr_loss += loss
343 |
344 | curr_loss = curr_loss / num_examples
345 | return curr_loss
346 |
347 |
348 |
349 | start_time = time.time()
350 | for epoch in range(num_epochs):
351 | # adjust_learning_rate(optimizer, epoch)
352 |
353 | model.train()
354 | for batch_idx, (features, targets) in enumerate(train_loader):
355 |
356 | if Packet_Received(batch_idx):
357 | features = features.to(DEVICE)
358 | targets = targets.to(DEVICE)
359 |
360 | ### FORWARD AND BACK PROP
361 | logits = model(features)
362 | cost = F.cross_entropy(logits, targets)
363 | optimizer.zero_grad()
364 |
365 | cost.backward()
366 |
367 | ### UPDATE MODEL PARAMETERS
368 | optimizer.step()
369 |
370 | ### PRIVACY NOISE
371 | if FL_type != 'FbFTL':
372 | # Gaussian_noise_to_weights(model, sigma * math.sqrt(batch_size/8))
373 | model.apply(Gaussian_noise_to_weights)
374 |
375 | ### Sparsification/Quantization with Error Feedback # errfdbk.queue = deque(maxlen=U_clients)
376 | if (quan_digit or sparse_rate) and FL_type != 'FbFTL':
377 | print('main loop: len(errfdbk.queue)', len(errfdbk.queue))
378 | if len(errfdbk.queue) < U_clients:
379 | errfdbk.temp = deque()
380 | err_flag = False
381 | else:
382 | errfdbk.temp = errfdbk.queue.popleft()
383 | err_flag = True
384 | model.apply(Errfdbk_to_weights)
385 | print('completed one cycle!')
386 | errfdbk.queue.append(errfdbk.temp)
387 |
388 | ### LOGGING
389 | if not batch_idx % 50:
390 | print('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' % (epoch+1, num_epochs, batch_idx, len(train_loader), cost))
391 | # if FL_type == 'FbFTL':
392 | # print(model.classifier[6][0].sigma)
393 | # model.classifier[6][0].set_sigma(sigma=0.4)
394 |
395 | model.eval()
396 | accuracy = compute_accuracy(model, test_loader)
397 | loss = compute_epoch_loss(model, test_loader)
398 | with torch.set_grad_enabled(False): # save memory during inference
399 | print('Epoch: %03d/%03d | Test: %.3f%% | Loss: %.3f' % (epoch+1, num_epochs, accuracy, loss))
400 | if write_hist:
401 | file1 = open('/data1/feng/LaTFL/history.txt', 'a')
402 | file1.write('\n Epoch: %03d/%03d | Test: %.3f%% | Loss: %.3f' % (epoch+1, num_epochs, accuracy, loss))
403 | file1.close()
404 |
405 |
406 | print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
407 |
408 | print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
409 |
410 | with torch.set_grad_enabled(False): # save memory during inference
411 | accuracy = compute_accuracy(model, test_loader)
412 | print('Test accuracy: %.2f%%' % (accuracy))
413 | if write_hist:
414 | file1 = open('/data1/feng/LaTFL/history.txt', 'a')
415 | file1.write('\n Test accuracy: %.2f%%' % (accuracy))
416 | file1.close()
417 |
418 | # model.save_weights('/data1/feng/LaTFL/cifar10vgg.h5')
419 |
420 |
--------------------------------------------------------------------------------
/main_TMLCN_FLANT5_SAMSUM.py:
--------------------------------------------------------------------------------
1 | # Source: https://www.philschmid.de/fine-tune-flan-t5
2 | # tensorboard: tensorboard --logdir log
3 | # Run the following in terminal before python3 FbFTL_LLM/test.py: huggingface-cli login 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
4 |
5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
6 | from torchinfo import summary
7 | from datasets import load_dataset
8 | from random import randrange
9 | from datasets import concatenate_datasets
10 | import evaluate
11 | import nltk
12 | import numpy as np
13 | from nltk.tokenize import sent_tokenize
14 | nltk.download("punkt")
15 | from huggingface_hub import HfFolder
16 | from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
17 | import sys
18 | from collections import deque
19 |
20 | HfFolder.save_token('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')
21 |
22 | # FL parameters
23 | FL_type = 'FbFTL' # 'FL', 'FTLf'(same as FL), 'FTLc', 'FbFTL'
24 | model_to_train = [False, 4, 0, True] # available for 'FTLc' and 'FbFTL': [False, 4, 0, True]
25 | # shared: (boolean), encoder: (int:0~8, index after which is trained, 10 to not train layer norm), decoder: (same as encoder), lm_head: (boolean)
26 | train_set_denominator = -1 # -1 to use full training dataset, or 0 < int <= 14732 : pick a subset with [int] training samples
27 | # Hyperparameters
28 | NUM_CLASSES = 32128
29 | learning_rate = 2e-4 # TODO: source code 5e-5, full model 2e-4
30 | num_train_epochs = 20 # TODO: source code 5, full model 20
31 | batch_size = 8 # source code 8
32 | # TODO: implement following params
33 | # if FL_type == 'FL':
34 | # transfer, full = False, True # transfer or train model from scratch # train whole model or last few layers
35 | # sigma = 0. # 0.5 relative std for addtive gaussian noise on gradients
36 | # elif FL_type == 'FTLf':
37 | # transfer, full = True, True
38 | # sigma = 0. # 0.3305 relative std for addtive gaussian noise on gradients
39 | # elif FL_type == 'FTLc':
40 | # transfer, full = True, False
41 | # sigma = 0. # 0.285 relative std for addtive gaussian noise on gradients
42 | # elif FL_type == 'FbFTL':
43 | # transfer, full = True, False
44 | # sigma = 0 # 0.8? relative std for addtive gaussian noise on features
45 | # saved_noise = True # save noise at beginning
46 | # else:
47 | # raise ValueError('Unknown FL_type: ' + FL_type)
48 | # relative_noise_type = 'all_std' # 'individual', 'all_std'
49 | # packet_loss_rate = 0. # 0, 0.05, 0.1, 0.15
50 | # quan_digit = 32 # digits kept after feature quantization: None (max:(12~18)(6~8), min=0, std~0.8) or int
51 | # sparse_rate = 0.9 # ratio of uplink elements kept after sparsification: None or (0,1]
52 |
53 | # Load dataset from the hub
54 | dataset_id = "samsum"
55 | dataset = load_dataset(dataset_id)
56 | if train_set_denominator != -1:
57 | dataset['train'] = dataset['train'].select(range(train_set_denominator))
58 | print(f"Train dataset size: {len(dataset['train'])}") # 14732
59 | print(f"Test dataset size: {len(dataset['test'])}") # 819
60 | sample = dataset['train'][randrange(len(dataset["train"]))]
61 | # print(f"dialogue: \n{sample['dialogue']}\n---------------")
62 | # print(f"summary: \n{sample['summary']}\n---------------")
63 | # sys.exit()
64 |
65 | # Loading the Model
66 | model_id="google/flan-t5-small"
67 | tokenizer = AutoTokenizer.from_pretrained(model_id)
68 | model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
69 | # for parameter in model.parameters(): # model: shared, encoder, decoder, lm_head
70 | # parameter.requires_grad = False
71 | if FL_type in ['FTLc', 'FbFTL'] and not model_to_train[0]:
72 | for parameter in model.shared.parameters(): # initial embedding layer, not trained
73 | parameter.requires_grad = False
74 |
75 | if FL_type in ['FTLc', 'FbFTL']:
76 | for i, m in enumerate(model.encoder.block): # whether train encoder blocks
77 | if i < model_to_train[1]:
78 | for parameter in m.parameters():
79 | parameter.requires_grad = False
80 | if model_to_train[1] >= 10:
81 | for parameter in model.encoder.final_layer_norm.parameters(): # whether train encoder layer norm
82 | parameter.requires_grad = False
83 |
84 | if FL_type in ['FTLc', 'FbFTL']:
85 | for i, m in enumerate(model.decoder.block): # whether train decoder blocks
86 | if i < model_to_train[2]:
87 | for parameter in m.parameters():
88 | parameter.requires_grad = False
89 | if model_to_train[2] >= 10:
90 | for parameter in model.decoder.final_layer_norm.parameters(): # whether train decoder layer norm
91 | parameter.requires_grad = False
92 |
93 | if FL_type in ['FTLc', 'FbFTL'] and not model_to_train[3]:
94 | for parameter in model.lm_head.parameters(): # final output layer, always trained
95 | parameter.requires_grad = False
96 | summary(model)
97 | # print(model)
98 | # for param in model.state_dict():
99 | # print(param)
100 | # sys.exit()
101 |
102 | # Preprocess data
103 | tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True),
104 | batched=True, remove_columns=["dialogue", "summary"])
105 | max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
106 | print(f"Max source length: {max_source_length}")
107 | tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["summary"], truncation=True),
108 | batched=True, remove_columns=["dialogue", "summary"])
109 | max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
110 | print(f"Max target length: {max_target_length}")
111 |
112 | def preprocess_function(sample,padding="max_length"):
113 | # add prefix to the input for t5
114 | inputs = ["summarize: " + item for item in sample["dialogue"]]
115 | # tokenize inputs
116 | model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
117 | # Tokenize targets with the `text_target` keyword argument
118 | labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)
119 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
120 | # padding in the loss.
121 | if padding == "max_length":
122 | labels["input_ids"] = [
123 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
124 | ]
125 | model_inputs["labels"] = labels["input_ids"]
126 | return model_inputs
127 |
128 | tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])
129 | print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")
130 |
131 | # Metric
132 | metric = evaluate.load("rouge")
133 | # helper function to postprocess text
134 | def postprocess_text(preds, labels):
135 | preds = [pred.strip() for pred in preds]
136 | labels = [label.strip() for label in labels]
137 | # rougeLSum expects newline after each sentence
138 | preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
139 | labels = ["\n".join(sent_tokenize(label)) for label in labels]
140 | return preds, labels
141 |
142 | def compute_metrics(eval_preds):
143 | preds, labels = eval_preds
144 | if isinstance(preds, tuple):
145 | preds = preds[0]
146 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
147 | # Replace -100 in the labels as we can't decode them.
148 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
149 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
150 | # Some simple post-processing
151 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
152 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
153 | result = {k: round(v * 100, 4) for k, v in result.items()}
154 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
155 | result["gen_len"] = np.mean(prediction_lens)
156 | return result
157 |
158 | # we want to ignore tokenizer pad token in the loss
159 | label_pad_token_id = -100
160 | # Data collator
161 | data_collator = DataCollatorForSeq2Seq(
162 | tokenizer,
163 | model=model,
164 | label_pad_token_id=label_pad_token_id,
165 | pad_to_multiple_of=8
166 | )
167 |
168 | # Hugging Face repository id
169 | repository_id = f"{model_id.split('/')[1]}-{dataset_id}"
170 |
171 | # Define training args
172 | training_args = Seq2SeqTrainingArguments(
173 | output_dir=repository_id,
174 | per_device_train_batch_size=batch_size,
175 | per_device_eval_batch_size=8,
176 | predict_with_generate=True,
177 | fp16=False, # Overflows with fp16
178 | learning_rate=learning_rate,
179 | num_train_epochs=num_train_epochs,
180 | # logging & evaluation strategies
181 | logging_dir=f"{repository_id}/logs",
182 | logging_strategy="steps",
183 | logging_steps=500,
184 | evaluation_strategy="epoch",
185 | save_strategy="epoch",
186 | save_total_limit=2,
187 | load_best_model_at_end=True,
188 | # metric_for_best_model="overall_f1",
189 | # push to hub parameters
190 | report_to="tensorboard",
191 | push_to_hub=False,
192 | hub_strategy="every_save",
193 | hub_model_id=repository_id,
194 | hub_token=HfFolder.get_token(),
195 | )
196 |
197 | # Create Trainer instance
198 | trainer = Seq2SeqTrainer(
199 | model=model,
200 | args=training_args,
201 | data_collator=data_collator,
202 | train_dataset=tokenized_dataset["train"],
203 | eval_dataset=tokenized_dataset["test"],
204 | compute_metrics=compute_metrics,
205 | )
206 |
207 | # Start training
208 | trainer.train()
209 |
210 | trainer.evaluate()
211 |
212 | # Save our tokenizer and create model card
213 | tokenizer.save_pretrained(repository_id)
214 | trainer.create_model_card()
215 | # Push the results to the hub
216 | # trainer.push_to_hub()
217 |
218 |
219 | # Run Inference and summarize ChatGPT dialogues
220 | # from transformers import pipeline
221 | # from random import randrange
222 | # # load model and tokenizer from huggingface hub with pipeline
223 | # summarizer = pipeline("summarization", model="philschmid/flan-t5-base-samsum", device=0)
224 | # # select a random test sample
225 | # sample = dataset['test'][randrange(len(dataset["test"]))]
226 | # print(f"dialogue: \n{sample['dialogue']}\n---------------")
227 | # # summarize dialogue
228 | # res = summarizer(sample["dialogue"])
229 | # print(f"flan-t5-base summary:\n{res[0]['summary_text']}")
--------------------------------------------------------------------------------