├── README.md
├── Tensorflow_federated_Federated_Core.ipynb
└── fl.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | tensorflow federated 简单实现fedavg
2 |
--------------------------------------------------------------------------------
/Tensorflow_federated_Federated_Core.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Tensorflow federated:Federated Core.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyOHba4/2VwWpEJA0ihY4bCU",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "view-in-github",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "
"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "0gkQwKdMqbed"
32 | },
33 | "source": [
34 | "## 前言\n",
35 | "本文来源于Tensorflow federated(TFF)[官网上的教程](https://www.tensorflow.org/federated/tutorials/custom_federated_algorithms_1),力求用简洁的文字阐述TFF框架的使用方法,同时也记录自己的实验过程。本文是系列教程一,讲解TFF底层架构Federated Core(FC)的一些概念。\n",
36 | "\n",
37 | "首先应该明白,TFF中数据是第一公民:编写代码时,不需要指明某段代码是运行在server还是client,但一定要指明某变量、常量是放在server还是client,是不是全局唯一的\n",
38 | "\n",
39 | "## 准备实验环境"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "metadata": {
45 | "id": "H4GcktbZiBPt"
46 | },
47 | "source": [
48 | "!pip install --quiet --upgrade tensorflow_federated_nightly\n",
49 | "!pip install --quiet --upgrade nest_asyncio\n",
50 | "import nest_asyncio\n",
51 | "nest_asyncio.apply()\n",
52 | "import collections\n",
53 | "import numpy as np\n",
54 | "import tensorflow as tf\n",
55 | "import tensorflow_federated as tff"
56 | ],
57 | "execution_count": null,
58 | "outputs": []
59 | },
60 | {
61 | "cell_type": "code",
62 | "metadata": {
63 | "id": "j5OuCYA8ibhm"
64 | },
65 | "source": [
66 | "@tff.federated_computation\n",
67 | "def hello_world():\n",
68 | " return 'Hello, World!'\n",
69 | "\n",
70 | "hello_world()"
71 | ],
72 | "execution_count": null,
73 | "outputs": []
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {
78 | "id": "92izVg-9uLI7"
79 | },
80 | "source": [
81 | "## 联邦数据 Federated data\n",
82 | "联邦数据抽象地指代分布在不同设备上的本地数据。\n",
83 | "\n",
84 | "例如,一组温度传感器中的数据认为是一个federated value,有数据类型、数据存放位置:\n",
85 | "```python\n",
86 | "federated_float_on_clients = tff.type_at_clients(tf.float32)\n",
87 | "```\n",
88 | "1. 它有member和placement属性,分别代表成员的数据类型(Type,T)和存放位置(Group,G)\n",
89 | "```\n",
90 | ">>> str(federated_float_on_clients.member)\n",
91 | "'float32'\n",
92 | ">>> str(federated_float_on_clients.placement)\n",
93 | "'CLIENTS'\n",
94 | "```\n",
95 | "定义一个federated type由T和G组成,记作`{T}@G`。这里由于每个设备中的值不相同,我们需要用花括号将T括起来\n",
96 | "```\n",
97 | ">>> str(federated_float_on_clients)\n",
98 | "'{float32}@CLIENTS'\n",
99 | "```\n",
100 | "2. `all_equal`属性表示它们是否相同,默认为false\n",
101 | "```\n",
102 | ">>> federated_float_on_clients.all_equal\n",
103 | "False\n",
104 | "```\n",
105 | "此时,由于每个设备中的值相同,我们就去掉花括号,记作`T@G`\n",
106 | "```\n",
107 | ">>> str(tff.type_at_clients(tf.float32, all_equal=True))\n",
108 | "'float32@CLIENTS'\n",
109 | "```\n",
110 | "3. 例子\n",
111 | "\n",
112 | " 例如一个线性回归模型有a和b两个参数:\n",
113 | " ```\n",
114 | " simple_regression_model_type = (\n",
115 | " tff.StructType([('a', tf.float32), ('b', tf.float32)]))\n",
116 | " >>> str(simple_regression_model_type)\n",
117 | " '@CLIENTS'\n",
118 | " ```\n",
119 | "\n",
120 | " 注意,这里的`tf.float32`是`tff.TensorType(dtype=tf.float32, shape=[])`的缩写。tff.TensorType方法创建一个TFF中的tensor类型。(?)\n",
121 | " ```\n",
122 | " str(tff.type_at_clients(\n",
123 | " simple_regression_model_type, all_equal=True))\n",
124 | " '@CLIENTS'\n",
125 | " ```\n",
126 | " 表示所有设备中都有a和b这两个参数,且都相等。\n",
127 | "\n",
128 | "\n"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {
134 | "id": "VI0-Xo5N20er"
135 | },
136 | "source": [
137 | "## 位置 Placement\n",
138 | "our goal is for TFF to enable **writing code that you could deploy for execution on groups of physical devices in a distributed system, potentially including mobile or embedded devices running Android**. Each of of those devices would receive a separate set of instructions to execute locally, depending on the role it plays in the system (an end-user device, a centralized coordinator, an intermediate layer in a multi-tier architecture, etc.). **It is important to be able to reason about which subsets of devices execute what code, and where different portions of the data might physically materialize.**\n",
139 | "\n",
140 | "以上原文说明了数据为什么要搞的这么复杂。TFF代码写好后,是可以生成一整套前后端代码的,此时就会用到数据存放位置这个概念。比如服务器端的参数和设备端的参数,在TFF中写在一起,但分割后就会存放在不同的设备上。\n",
141 | "\n",
142 | "\n",
143 | "大多数设备不能运行python,因此,TFF不关心操作符,而关心数据。前者因编程语言而异,在不同的编程环境下有不同的实现方式(安卓、ios、web);后者则不变。TFF中很多函数是抽象的,是跨网络、跨设备的,比如`broadcast`函数,将参数分发给部分设备。\n",
144 | "\n",
145 | "Within the body of TFF code, by design, **there's no way to enumerate the devices that constitute the group represented by tff.CLIENTS, or to probe for the existence of a specific device in the group**. There's no concept of a device or client identity anywhere in the Federated Core API, the underlying set of architectural abstractions, or the core runtime infrastructure we provide to support simulations. All the computation logic you write will be expressed as operations on the entire client group.\n",
146 | "\n",
147 | "\n",
148 | "以上原文说明了所有设备只能被看作整体,无法从中探测到某个具体的设备。因为TFF是以联邦的视角来设计的,如果能针对某一个具体设备来操作,就无法顾全大局。\n",
149 | "\n",
150 | "事实上,联邦数据集足以代表所有设备(如果不考虑设备异构)。"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {
156 | "id": "zeIiaXfn8okq"
157 | },
158 | "source": [
159 | "## 联邦计算 Federated computations\n",
160 | "接收federated value,输出federated value。\n",
161 | "```\n",
162 | "@tff.federated_computation(tff.type_at_clients(tf.float32))\n",
163 | "def get_average_temperature(sensor_readings):\n",
164 | " return tff.federated_mean(sensor_readings)\n",
165 | "```\n",
166 | "看到这里,你可能有疑问,用tf的现成的方法不是一步就能做出来吗?但是我们这里写的`get_average_temperature`不是tf代码,也不是python代码,是一种分布式系统的语言(it's a specification of a distributed system in an internal platform-independent glue language)。\n",
167 | "\n",
168 | "1. 我们先来看看联邦计算中的函数签名\n",
169 | "```\n",
170 | ">>> str(get_average_temperature.type_signature)\n",
171 | "'({float32}@CLIENTS -> float32@SERVER)'\n",
172 | "```\n",
173 | "这个输出说明,此函数的参数是各个设备上的浮点数据,输出服务器上的一个浮点数据。这告诉我们,不应该将一个联邦计算过程想象是在服务器或者某个机器上执行的过程,而应该想,它完成了一个多方协作的任务。\n",
174 | "\n",
175 | "2. 然后,我们看看联邦计算的调用方法:\n",
176 | "使用python语言即可调用\n",
177 | "\n",
178 | "```\n",
179 | ">>> get_average_temperature([68.5, 70.3, 69.8])\n",
180 | "69.53334\n",
181 | "```\n",
182 | "在执行以上计算时,你就像一个外部观察者,带着全局视野,完成分布式任务的一步操作。\n",
183 | "\n",
184 | "另外,在函数体内的语句一定会被执行:\n",
185 | "```\n",
186 | "@tff.federated_computation(tff.type_at_clients(tf.float32))\n",
187 | "def get_average_temperature(sensor_readings):\n",
188 | "\n",
189 | " print ('Getting traced, the argument is \"{}\".'.format(\n",
190 | " type(sensor_readings).__name__))\n",
191 | "\n",
192 | " return tff.federated_mean(sensor_readings)\n",
193 | "#以下是输出\n",
194 | "Getting traced, the argument is \"ValueImpl\".\n",
195 | "```\n",
196 | "\n",
197 | "3. 接着,我们看一个抽象的联邦计算的例子。记住,tf的函数需要被包装后才能用。\n",
198 | "\n",
199 | "```\n",
200 | "@tff.tf_computation(tf.float32)\n",
201 | "def add_half(x):\n",
202 | " return tf.add(x, 0.5)\n",
203 | " \n",
204 | "@tff.federated_computation(tff.type_at_clients(tf.float32))\n",
205 | "def add_half_on_clients(x):\n",
206 | " return tff.federated_map(add_half, x)\n",
207 | " \n",
208 | "add_half_on_clients([1.0, 3.0, 2.0])\n",
209 | "#以下是输出\n",
210 | "[,\n",
211 | " ,\n",
212 | " ]\n",
213 | "```\n",
214 | "如下代码会出错,因为tf.constant()函数是在@tff.federated_computation包装外使用的。可以理解为,外部环境是tff环境,包装后是tf环境。\n",
215 | "```\n",
216 | "try:\n",
217 | "\n",
218 | " # Eager mode\n",
219 | " constant_10 = tf.constant(10.)\n",
220 | "\n",
221 | " @tff.tf_computation(tf.float32)\n",
222 | " def add_ten(x):\n",
223 | " return x + constant_10\n",
224 | "\n",
225 | "except Exception as err:\n",
226 | " print (err)\n",
227 | "```\n",
228 | "注意,从tf环境中调用的函数,也还是运行在tf环境中,于是如下代码是正确的:\n",
229 | "```\n",
230 | "def get_constant_10():\n",
231 | " return tf.constant(10.)\n",
232 | "\n",
233 | "@tff.tf_computation(tf.float32)\n",
234 | "def add_ten(x):\n",
235 | " return x + get_constant_10()\n",
236 | "\n",
237 | "add_ten(5.0)\n",
238 | "#输出 15.0\n",
239 | "```\n",
240 | "\n",
241 | "4. 最后,我们看一个具体的温度传感器的例子\n",
242 | "\n",
243 | "```\n",
244 | "@tff.tf_computation(tff.SequenceType(tf.float32))\n",
245 | "def get_local_temperature_average(local_temperatures):\n",
246 | " sum_and_count = (\n",
247 | " local_temperatures.reduce((0.0, 0), lambda x, y: (x[0] + y, x[1] + 1)))\n",
248 | " return sum_and_count[0] / tf.cast(sum_and_count[1], tf.float32)\n",
249 | "\n",
250 | "get_local_temperature_average([68.5, 70.3, 69.8])\n",
251 | "```\n"
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "metadata": {
257 | "id": "ir8ylaSVM_RG"
258 | },
259 | "source": [
260 | "## 例子\n",
261 | "\n",
262 | "下面实现传感器内数据平均,再在服务器上实现数据平均的功能"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "metadata": {
268 | "id": "rICVD0ICXulE"
269 | },
270 | "source": [
271 | "@tff.federated_computation(\n",
272 | " tff.type_at_clients(tff.SequenceType(tf.float32)))\n",
273 | "def get_global_temperature_average(sensor_readings):\n",
274 | " return tff.federated_mean(\n",
275 | " tff.federated_map(get_local_temperature_average, sensor_readings))\n",
276 | " "
277 | ],
278 | "execution_count": null,
279 | "outputs": []
280 | },
281 | {
282 | "cell_type": "code",
283 | "metadata": {
284 | "id": "-NfPtmpIX1TK"
285 | },
286 | "source": [
287 | "get_global_temperature_average([[68.0, 70.0], [71.0], [68.0, 72.0, 70.0]])\n",
288 | "#输出70.0"
289 | ],
290 | "execution_count": null,
291 | "outputs": []
292 | }
293 | ]
294 | }
--------------------------------------------------------------------------------
/fl.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "federated learning.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyP0GJS5rzNK2kqdENXBBGia",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "TPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "sotZxUgvOL3B",
33 | "outputId": "45c9c78d-f972-4f77-c10b-035eb3c9c8b8",
34 | "colab": {
35 | "base_uri": "https://localhost:8080/",
36 | "height": 379
37 | }
38 | },
39 | "source": [
40 | "!pip install --quiet --upgrade tensorflow_federated_nightly\n",
41 | "!pip install --quiet --upgrade nest_asyncio\n",
42 | "import nest_asyncio\n",
43 | "nest_asyncio.apply()\n",
44 | "import collections\n",
45 | "import numpy as np\n",
46 | "import tensorflow as tf\n",
47 | "import tensorflow_federated as tff\n",
48 | "\n",
49 | "# TODO(b/148678573,b/148685415): must use the reference context because it\n",
50 | "# supports unbounded references and tff.sequence_* intrinsics.\n",
51 | "tff.backends.reference.set_reference_context()"
52 | ],
53 | "execution_count": null,
54 | "outputs": [
55 | {
56 | "output_type": "stream",
57 | "text": [
58 | "\u001b[K |████████████████████████████████| 512kB 3.4MB/s \n",
59 | "\u001b[K |████████████████████████████████| 393.5MB 40kB/s \n",
60 | "\u001b[K |████████████████████████████████| 112kB 51.8MB/s \n",
61 | "\u001b[K |████████████████████████████████| 3.0MB 44.8MB/s \n",
62 | "\u001b[K |████████████████████████████████| 1.1MB 45.0MB/s \n",
63 | "\u001b[K |████████████████████████████████| 174kB 44.7MB/s \n",
64 | "\u001b[K |████████████████████████████████| 153kB 51.0MB/s \n",
65 | "\u001b[K |████████████████████████████████| 1.3MB 45.9MB/s \n",
66 | "\u001b[K |████████████████████████████████| 10.6MB 45.7MB/s \n",
67 | "\u001b[K |████████████████████████████████| 471kB 46.2MB/s \n",
68 | "\u001b[?25h Building wheel for absl-py (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
69 | "\u001b[31mERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.\u001b[0m\n",
70 | "\u001b[31mERROR: tf-nightly 2.4.0.dev20201020 has requirement absl-py~=0.10, but you'll have absl-py 0.9.0 which is incompatible.\u001b[0m\n",
71 | "\u001b[31mERROR: tf-nightly 2.4.0.dev20201020 has requirement grpcio~=1.32.0, but you'll have grpcio 1.29.0 which is incompatible.\u001b[0m\n",
72 | "\u001b[31mERROR: tf-nightly 2.4.0.dev20201020 has requirement numpy~=1.19.2, but you'll have numpy 1.18.5 which is incompatible.\u001b[0m\n"
73 | ],
74 | "name": "stdout"
75 | },
76 | {
77 | "output_type": "stream",
78 | "text": [
79 | "/usr/local/lib/python3.6/dist-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.4.0-dev20201020). \n",
80 | "TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. \n",
81 | "If you encounter a bug, do not file an issue on GitHub.\n",
82 | " UserWarning,\n"
83 | ],
84 | "name": "stderr"
85 | }
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "metadata": {
91 | "id": "_Uok1cUSshkN",
92 | "outputId": "8a7bd4e0-4bfe-4ee5-bf8b-5574feb7a5b7",
93 | "colab": {
94 | "base_uri": "https://localhost:8080/",
95 | "height": 35
96 | }
97 | },
98 | "source": [
99 | "import tensorflow as tf\n",
100 | "tf.__version__\n",
101 | "# %load_ext tensorboard\n",
102 | "# logdir = \"/tmp/logs/scalars/training/\"\n",
103 | "# summary_writer = tf.summary.create_file_writer(logdir)\n",
104 | "\n",
105 | "#with summary_writer.as_default():\n",
106 | " #for round_num in range(1, NUM_ROUNDS):\n",
107 | " #tf.summary.scalar(name, value, step=round_num)\n",
108 | "\n",
109 | "#%tensorboard --logdir /tmp/logs/scalars/ --port=0"
110 | ],
111 | "execution_count": null,
112 | "outputs": [
113 | {
114 | "output_type": "execute_result",
115 | "data": {
116 | "application/vnd.google.colaboratory.intrinsic+json": {
117 | "type": "string"
118 | },
119 | "text/plain": [
120 | "'2.4.0-dev20201020'"
121 | ]
122 | },
123 | "metadata": {
124 | "tags": []
125 | },
126 | "execution_count": 3
127 | }
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "metadata": {
133 | "id": "usbrIMi0nNPp",
134 | "outputId": "2571109e-5042-437b-e3b6-a1d61f637845",
135 | "colab": {
136 | "base_uri": "https://localhost:8080/",
137 | "height": 55
138 | }
139 | },
140 | "source": [
141 | "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
142 | "import copy\n",
143 | "#深拷贝\n",
144 | "ordered_mnist_train=copy.deepcopy(mnist_train)\n",
145 | "#给mnist_train按标签排序\n",
146 | "f=0\n",
147 | "for digit in range(10):\n",
148 | " index=[i for i,d in enumerate(mnist_train[1]) if d==digit]\n",
149 | " for i in range(5420):\n",
150 | " ordered_mnist_train[0][f+i]=mnist_train[0][index[i]]\n",
151 | " ordered_mnist_train[1][f+i]=mnist_train[1][index[i]]\n",
152 | " f+=5420"
153 | ],
154 | "execution_count": null,
155 | "outputs": [
156 | {
157 | "output_type": "stream",
158 | "text": [
159 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
160 | "11493376/11490434 [==============================] - 0s 0us/step\n"
161 | ],
162 | "name": "stdout"
163 | }
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "metadata": {
169 | "id": "AS4T32eau-E2"
170 | },
171 | "source": [
172 | "#检测排序结果\n",
173 | "from matplotlib import pyplot as plt\n",
174 | "for i in range(5420*3-1,5420*3+1):\n",
175 | " plt.imshow(np.array(ordered_mnist_train[0][i]), cmap='gray', aspect='equal')\n",
176 | " plt.grid(False)\n",
177 | " print(ordered_mnist_train[1][i])\n",
178 | " _ = plt.show()"
179 | ],
180 | "execution_count": null,
181 | "outputs": []
182 | },
183 | {
184 | "cell_type": "code",
185 | "metadata": {
186 | "id": "p3Zb9QczOTRV"
187 | },
188 | "source": [
189 | "BATCH_SIZE = 100\n",
190 | "# 定义:20个shard 每个shard包含271张图片\n",
191 | "SHARD_SIZE = 271\n",
192 | "def get_data(source,i,j):\n",
193 | " device=[]\n",
194 | " #根据普通索引ij计算对应shard数目,继而计算对应图片索引\n",
195 | " shard1=(20*j+i)*SHARD_SIZE\n",
196 | " shard2=(j+20*i+10)*SHARD_SIZE\n",
197 | " #600张图片索引存放在all_samples中\n",
198 | " all_samples=list(range(shard1,shard1+SHARD_SIZE))+list(range(shard2,shard2+SHARD_SIZE))\n",
199 | "\n",
200 | " batch_nums=int(2*SHARD_SIZE/BATCH_SIZE)\n",
201 | " #分成batch_nums个batch\n",
202 | " for i in range(int(batch_nums)):\n",
203 | " device_index=all_samples[BATCH_SIZE*i:BATCH_SIZE*i + BATCH_SIZE]\n",
204 | " device.append({\n",
205 | " 'x':\n",
206 | " np.array([source[0][i].flatten() / 255.0 for i in device_index],\n",
207 | " dtype=np.float32),\n",
208 | " 'y':\n",
209 | " np.array([source[1][i] for i in device_index], dtype=np.int32)\n",
210 | " })\n",
211 | " return device\n",
212 | "#构造联邦数据集\n",
213 | "federated_train_data = [get_data(ordered_mnist_train,i,j) for i in range(10) for j in range(10)]"
214 | ],
215 | "execution_count": null,
216 | "outputs": []
217 | },
218 | {
219 | "cell_type": "code",
220 | "metadata": {
221 | "id": "4IX3j3NpOWPV",
222 | "outputId": "2ee082fe-2239-4539-80fb-93533a0d77cb",
223 | "colab": {
224 | "base_uri": "https://localhost:8080/",
225 | "height": 169
226 | }
227 | },
228 | "source": [
229 | "#定义模型\n",
230 | "BATCH_SPEC = collections.OrderedDict(\n",
231 | " x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),\n",
232 | " y=tf.TensorSpec(shape=[None], dtype=tf.int32))\n",
233 | "BATCH_TYPE = tff.to_type(BATCH_SPEC)\n",
234 | "MODEL_SPEC = collections.OrderedDict(\n",
235 | " weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),\n",
236 | " bias=tf.TensorSpec(shape=[10], dtype=tf.float32))\n",
237 | "MODEL_TYPE = tff.to_type(MODEL_SPEC)\n",
238 | "\n",
239 | "import random\n",
240 | "@tf.function\n",
241 | "def forward_pass(model, batch):\n",
242 | " #L2正则项\n",
243 | " loss=10e-4*tf.reduce_sum(model['weights']*model['weights'])+tf.reduce_sum(model['bias']*model['bias'])/(2*BATCH_SIZE)\n",
244 | " #(100, 784) *(784, 10)+ (10,)=(100,10) 即100个预测值\n",
245 | " predicted_y = tf.nn.softmax(\n",
246 | " tf.matmul(batch['x'], model['weights']) + model['bias'])\n",
247 | " #返回这一整个batch的损失值\n",
248 | " return tf.reduce_mean(\n",
249 | " #记录100个图片的损失值\n",
250 | " tf.nn.softmax_cross_entropy_with_logits(\n",
251 | " #tf.math.log(predicted_y)是100*10矩阵\n",
252 | " #对应元素相乘\n",
253 | " tf.one_hot(batch['y'], 10),tf.math.log(predicted_y)))+loss\n",
254 | "\n",
255 | "#批训练\n",
256 | "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)\n",
257 | "def batch_loss(model, batch):\n",
258 | " return forward_pass(model, batch)\n",
259 | "\n",
260 | "@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)\n",
261 | "def batch_train(initial_model, batch, learning_rate):\n",
262 | " # Define a group of model variables and set them to `initial_model`. Must\n",
263 | " # be defined outside the @tf.function.\n",
264 | " model_vars = collections.OrderedDict([\n",
265 | " (name, tf.Variable(name=name, initial_value=value))\n",
266 | " for name, value in initial_model.items()\n",
267 | " ])\n",
268 | " optimizer = tf.keras.optimizers.SGD(learning_rate)\n",
269 | "\n",
270 | " @tf.function\n",
271 | " def _train_on_batch(model_vars, batch):\n",
272 | " # Perform one step of gradient descent using loss from `batch_loss`.\n",
273 | " with tf.GradientTape() as tape:\n",
274 | " loss = forward_pass(model_vars, batch)\n",
275 | " grads = tape.gradient(loss, model_vars)\n",
276 | " optimizer.apply_gradients(\n",
277 | " zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))\n",
278 | " return model_vars\n",
279 | "\n",
280 | " return _train_on_batch(model_vars, batch)\n",
281 | "\n",
282 | "#设备训练\n",
283 | "LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)\n",
284 | "E=1\n",
285 | "@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)\n",
286 | "def local_train(initial_model, learning_rate, all_batches):\n",
287 | "\n",
288 | " # Mapping function to apply to each batch.\n",
289 | " @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)\n",
290 | " def batch_fn(model, batch):\n",
291 | " for _ in range(E):\n",
292 | " model=batch_train(model, batch, learning_rate)\n",
293 | " return model\n",
294 | "\n",
295 | " return tff.sequence_reduce(all_batches, initial_model, batch_fn)\n",
296 | "\n",
297 | "#设备评估\n",
298 | "@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)\n",
299 | "def local_eval(model, all_batches):\n",
300 | " # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.\n",
301 | " return tff.sequence_sum(\n",
302 | " tff.sequence_map(\n",
303 | " tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),\n",
304 | " all_batches))\n",
305 | "\n",
306 | "#联邦评估\n",
307 | "SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)\n",
308 | "CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)\n",
309 | "@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)\n",
310 | "def federated_eval(model, data):\n",
311 | " return tff.federated_mean(\n",
312 | " tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))\n",
313 | " \n",
314 | "\n",
315 | "#联邦训练\n",
316 | "SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)\n",
317 | "\n",
318 | "@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,\n",
319 | " CLIENT_DATA_TYPE,tff.type_at_clients(tf.float32))\n",
320 | "def federated_train(model, learning_rate, data,w):\n",
321 | " return tff.federated_mean(\n",
322 | " tff.federated_map(local_train, [\n",
323 | " tff.federated_broadcast(model),\n",
324 | " tff.federated_broadcast(learning_rate), data\n",
325 | " ]),w)\n",
326 | " \n",
327 | " \n",
328 | "#计算准确率\n",
329 | "def accur(model):\n",
330 | " device_index=range(1000)\n",
331 | " device=[]\n",
332 | " device.append({\n",
333 | " 'x':\n",
334 | " np.array([mnist_test[0][i].flatten() / 255.0 for i in device_index],\n",
335 | " dtype=np.float32),\n",
336 | " 'y':\n",
337 | " np.array([mnist_test[1][i] for i in device_index], dtype=np.int32)\n",
338 | " })\n",
339 | " predicted_y = tf.nn.softmax(\n",
340 | " tf.matmul(device[0]['x'], model['weights']) + model['bias'])\n",
341 | " p=[]\n",
342 | " for i in predicted_y:\n",
343 | " max=-1\n",
344 | " flag=0\n",
345 | " for j in range(1,10):\n",
346 | " if i[j]>max:\n",
347 | " max=i[j]\n",
348 | " flag=j\n",
349 | " p.append(flag)\n",
350 | " cnt=0\n",
351 | " for i in range(1000):\n",
352 | " if p[i]==device[0]['y'][i]:\n",
353 | " cnt+=1\n",
354 | " return cnt/1000"
355 | ],
356 | "execution_count": null,
357 | "outputs": [
358 | {
359 | "output_type": "stream",
360 | "text": [
361 | "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n",
362 | "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
363 | "Cause: module 'gast' has no attribute 'Index'\n",
364 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
365 | "WARNING: AutoGraph could not transform and will run it as-is.\n",
366 | "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
367 | "Cause: module 'gast' has no attribute 'Index'\n",
368 | "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n"
369 | ],
370 | "name": "stdout"
371 | }
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "metadata": {
377 | "id": "Dbqs99OIOsFE",
378 | "outputId": "9f8be68c-2e77-45f2-e358-5f1a474e908b",
379 | "colab": {
380 | "base_uri": "https://localhost:8080/",
381 | "height": 1000
382 | }
383 | },
384 | "source": [
385 | "initial_model = collections.OrderedDict(\n",
386 | " weights=np.zeros([784, 10], dtype=np.float32),\n",
387 | " bias=np.zeros([10], dtype=np.float32))\n",
388 | "#设置权值\n",
389 | "weight_list = [1.9,0.1]\n",
390 | "for _ in range(98):\n",
391 | " weight_list.append(1)\n",
392 | "weight_list = [x/100 for x in weight_list]\n",
393 | "model = initial_model\n",
394 | "for round_num in range(300):\n",
395 | " learning_rate = 0.1 /(1+round_num)\n",
396 | " model = federated_train(model, learning_rate, federated_train_data,weight_list)\n",
397 | " loss = federated_eval(model, federated_train_data)\n",
398 | " acc = accur(model)\n",
399 | " print('round {}, loss={}, acc={}%'.format(round_num, loss, acc*100))\n"
400 | ],
401 | "execution_count": null,
402 | "outputs": [
403 | {
404 | "output_type": "stream",
405 | "text": [
406 | "round 0, loss=10.372981071472168, acc=70.39999999999999%\n",
407 | "round 1, loss=9.729572296142578, acc=71.8%\n",
408 | "round 2, loss=9.276639938354492, acc=71.8%\n",
409 | "round 3, loss=8.937175750732422, acc=71.8%\n",
410 | "round 4, loss=8.671344757080078, acc=71.8%\n",
411 | "round 5, loss=8.455917358398438, acc=71.8%\n",
412 | "round 6, loss=8.27657699584961, acc=71.89999999999999%\n",
413 | "round 7, loss=8.124074935913086, acc=71.8%\n",
414 | "round 8, loss=7.992163181304932, acc=71.8%\n",
415 | "round 9, loss=7.876453399658203, acc=71.8%\n",
416 | "round 10, loss=7.773778915405273, acc=71.8%\n",
417 | "round 11, loss=7.681778430938721, acc=71.89999999999999%\n",
418 | "round 12, loss=7.598653793334961, acc=71.89999999999999%\n",
419 | "round 13, loss=7.52301025390625, acc=72.0%\n",
420 | "round 14, loss=7.453742027282715, acc=72.0%\n",
421 | "round 15, loss=7.389968395233154, acc=72.0%\n",
422 | "round 16, loss=7.330965995788574, acc=72.0%\n",
423 | "round 17, loss=7.276141166687012, acc=72.0%\n",
424 | "round 18, loss=7.224999904632568, acc=72.0%\n",
425 | "round 19, loss=7.1771321296691895, acc=71.89999999999999%\n",
426 | "round 20, loss=7.132178783416748, acc=71.8%\n",
427 | "round 21, loss=7.0898518562316895, acc=71.89999999999999%\n",
428 | "round 22, loss=7.049884796142578, acc=72.1%\n",
429 | "round 23, loss=7.012059211730957, acc=72.1%\n",
430 | "round 24, loss=6.976179122924805, acc=72.1%\n",
431 | "round 25, loss=6.942073822021484, acc=72.1%\n",
432 | "round 26, loss=6.9095964431762695, acc=72.1%\n",
433 | "round 27, loss=6.878612518310547, acc=72.1%\n",
434 | "round 28, loss=6.849008560180664, acc=72.1%\n",
435 | "round 29, loss=6.820674419403076, acc=72.1%\n",
436 | "round 30, loss=6.793519973754883, acc=72.1%\n",
437 | "round 31, loss=6.767460823059082, acc=72.1%\n",
438 | "round 32, loss=6.742420673370361, acc=72.1%\n",
439 | "round 33, loss=6.718331336975098, acc=72.1%\n",
440 | "round 34, loss=6.695131301879883, acc=72.1%\n",
441 | "round 35, loss=6.672764778137207, acc=72.2%\n",
442 | "round 36, loss=6.651175498962402, acc=72.2%\n",
443 | "round 37, loss=6.6303229331970215, acc=72.2%\n",
444 | "round 38, loss=6.610158920288086, acc=72.3%\n",
445 | "round 39, loss=6.590648651123047, acc=72.3%\n",
446 | "round 40, loss=6.571752548217773, acc=72.3%\n",
447 | "round 41, loss=6.553436279296875, acc=72.3%\n",
448 | "round 42, loss=6.535670757293701, acc=72.3%\n",
449 | "round 43, loss=6.518428325653076, acc=72.3%\n",
450 | "round 44, loss=6.50167989730835, acc=72.39999999999999%\n",
451 | "round 45, loss=6.485401153564453, acc=72.39999999999999%\n",
452 | "round 46, loss=6.469570159912109, acc=72.39999999999999%\n",
453 | "round 47, loss=6.454165458679199, acc=72.39999999999999%\n",
454 | "round 48, loss=6.439167022705078, acc=72.39999999999999%\n",
455 | "round 49, loss=6.424554347991943, acc=72.3%\n",
456 | "round 50, loss=6.410314559936523, acc=72.3%\n",
457 | "round 51, loss=6.396425724029541, acc=72.3%\n",
458 | "round 52, loss=6.382877349853516, acc=72.3%\n",
459 | "round 53, loss=6.369650840759277, acc=72.39999999999999%\n",
460 | "round 54, loss=6.356736660003662, acc=72.39999999999999%\n",
461 | "round 55, loss=6.344120025634766, acc=72.39999999999999%\n",
462 | "round 56, loss=6.331789493560791, acc=72.39999999999999%\n",
463 | "round 57, loss=6.319730758666992, acc=72.39999999999999%\n",
464 | "round 58, loss=6.307937145233154, acc=72.39999999999999%\n",
465 | "round 59, loss=6.296399116516113, acc=72.39999999999999%\n",
466 | "round 60, loss=6.285103797912598, acc=72.39999999999999%\n",
467 | "round 61, loss=6.274044036865234, acc=72.5%\n",
468 | "round 62, loss=6.26320743560791, acc=72.5%\n",
469 | "round 63, loss=6.252591609954834, acc=72.5%\n",
470 | "round 64, loss=6.242187023162842, acc=72.5%\n",
471 | "round 65, loss=6.231985569000244, acc=72.5%\n",
472 | "round 66, loss=6.221979141235352, acc=72.5%\n",
473 | "round 67, loss=6.212162971496582, acc=72.5%\n",
474 | "round 68, loss=6.2025299072265625, acc=72.6%\n",
475 | "round 69, loss=6.193075656890869, acc=72.6%\n",
476 | "round 70, loss=6.1837897300720215, acc=72.6%\n",
477 | "round 71, loss=6.174671649932861, acc=72.5%\n",
478 | "round 72, loss=6.165714740753174, acc=72.5%\n",
479 | "round 73, loss=6.156914710998535, acc=72.39999999999999%\n",
480 | "round 74, loss=6.1482648849487305, acc=72.5%\n",
481 | "round 75, loss=6.1397600173950195, acc=72.5%\n",
482 | "round 76, loss=6.1313982009887695, acc=72.5%\n",
483 | "round 77, loss=6.123175621032715, acc=72.6%\n",
484 | "round 78, loss=6.115087509155273, acc=72.6%\n",
485 | "round 79, loss=6.107127666473389, acc=72.6%\n",
486 | "round 80, loss=6.099295139312744, acc=72.6%\n",
487 | "round 81, loss=6.0915846824646, acc=72.6%\n",
488 | "round 82, loss=6.083992958068848, acc=72.6%\n",
489 | "round 83, loss=6.076519012451172, acc=72.6%\n",
490 | "round 84, loss=6.069157123565674, acc=72.6%\n",
491 | "round 85, loss=6.061906337738037, acc=72.6%\n",
492 | "round 86, loss=6.0547614097595215, acc=72.6%\n",
493 | "round 87, loss=6.0477213859558105, acc=72.6%\n",
494 | "round 88, loss=6.040783882141113, acc=72.6%\n",
495 | "round 89, loss=6.033944129943848, acc=72.6%\n",
496 | "round 90, loss=6.027200698852539, acc=72.6%\n",
497 | "round 91, loss=6.020553112030029, acc=72.6%\n",
498 | "round 92, loss=6.013996124267578, acc=72.6%\n",
499 | "round 93, loss=6.007532119750977, acc=72.6%\n",
500 | "round 94, loss=6.001150608062744, acc=72.6%\n",
501 | "round 95, loss=5.9948554039001465, acc=72.6%\n",
502 | "round 96, loss=5.988643646240234, acc=72.6%\n",
503 | "round 97, loss=5.982513427734375, acc=72.6%\n",
504 | "round 98, loss=5.976461887359619, acc=72.6%\n",
505 | "round 99, loss=5.970491886138916, acc=72.6%\n",
506 | "round 100, loss=5.964592456817627, acc=72.6%\n",
507 | "round 101, loss=5.958770751953125, acc=72.6%\n",
508 | "round 102, loss=5.953021049499512, acc=72.6%\n",
509 | "round 103, loss=5.947341442108154, acc=72.6%\n",
510 | "round 104, loss=5.941730976104736, acc=72.6%\n",
511 | "round 105, loss=5.936187267303467, acc=72.6%\n",
512 | "round 106, loss=5.930712699890137, acc=72.6%\n",
513 | "round 107, loss=5.9253010749816895, acc=72.6%\n",
514 | "round 108, loss=5.919955253601074, acc=72.6%\n",
515 | "round 109, loss=5.914668560028076, acc=72.6%\n",
516 | "round 110, loss=5.909443378448486, acc=72.6%\n",
517 | "round 111, loss=5.904280185699463, acc=72.6%\n",
518 | "round 112, loss=5.899175643920898, acc=72.6%\n",
519 | "round 113, loss=5.894125461578369, acc=72.7%\n",
520 | "round 114, loss=5.889133453369141, acc=72.7%\n",
521 | "round 115, loss=5.884196758270264, acc=72.7%\n",
522 | "round 116, loss=5.8793134689331055, acc=72.7%\n",
523 | "round 117, loss=5.874481678009033, acc=72.7%\n",
524 | "round 118, loss=5.869705677032471, acc=72.7%\n",
525 | "round 119, loss=5.8649773597717285, acc=72.7%\n",
526 | "round 120, loss=5.860297203063965, acc=72.7%\n",
527 | "round 121, loss=5.855671405792236, acc=72.7%\n",
528 | "round 122, loss=5.851091384887695, acc=72.7%\n",
529 | "round 123, loss=5.846558094024658, acc=72.7%\n",
530 | "round 124, loss=5.842071056365967, acc=72.7%\n",
531 | "round 125, loss=5.837630748748779, acc=72.7%\n",
532 | "round 126, loss=5.833232879638672, acc=72.7%\n",
533 | "round 127, loss=5.8288798332214355, acc=72.7%\n",
534 | "round 128, loss=5.824571132659912, acc=72.7%\n",
535 | "round 129, loss=5.820303916931152, acc=72.7%\n",
536 | "round 130, loss=5.816077709197998, acc=72.7%\n",
537 | "round 131, loss=5.811893939971924, acc=72.7%\n",
538 | "round 132, loss=5.807748794555664, acc=72.7%\n",
539 | "round 133, loss=5.803645610809326, acc=72.7%\n",
540 | "round 134, loss=5.799578666687012, acc=72.7%\n",
541 | "round 135, loss=5.7955498695373535, acc=72.7%\n",
542 | "round 136, loss=5.791559219360352, acc=72.7%\n",
543 | "round 137, loss=5.787606239318848, acc=72.7%\n",
544 | "round 138, loss=5.783689022064209, acc=72.7%\n",
545 | "round 139, loss=5.779806613922119, acc=72.7%\n",
546 | "round 140, loss=5.775959491729736, acc=72.7%\n",
547 | "round 141, loss=5.7721476554870605, acc=72.8%\n",
548 | "round 142, loss=5.768371105194092, acc=72.8%\n",
549 | "round 143, loss=5.764625072479248, acc=72.8%\n",
550 | "round 144, loss=5.760915756225586, acc=72.8%\n",
551 | "round 145, loss=5.7572340965271, acc=72.8%\n",
552 | "round 146, loss=5.753586292266846, acc=72.8%\n",
553 | "round 147, loss=5.749969959259033, acc=72.8%\n",
554 | "round 148, loss=5.746387958526611, acc=72.8%\n",
555 | "round 149, loss=5.742830753326416, acc=72.8%\n",
556 | "round 150, loss=5.739307880401611, acc=72.8%\n",
557 | "round 151, loss=5.73581075668335, acc=72.8%\n",
558 | "round 152, loss=5.7323479652404785, acc=72.8%\n",
559 | "round 153, loss=5.728910446166992, acc=72.8%\n",
560 | "round 154, loss=5.7255024909973145, acc=72.8%\n",
561 | "round 155, loss=5.72212028503418, acc=72.8%\n",
562 | "round 156, loss=5.718766689300537, acc=72.8%\n",
563 | "round 157, loss=5.715438365936279, acc=72.8%\n",
564 | "round 158, loss=5.712139129638672, acc=72.8%\n",
565 | "round 159, loss=5.708866596221924, acc=72.8%\n",
566 | "round 160, loss=5.705617904663086, acc=72.8%\n"
567 | ],
568 | "name": "stdout"
569 | }
570 | ]
571 | },
572 | {
573 | "cell_type": "markdown",
574 | "metadata": {
575 | "id": "muC6lh09Vluy"
576 | },
577 | "source": [
578 | "round 0, loss=21.60552406311035\n",
579 | "round 1, loss=20.365678787231445\n",
580 | "round 2, loss=19.27480125427246\n",
581 | "round 3, loss=18.31110954284668\n",
582 | "round 4, loss=17.457256317138672\n",
583 | "\n",
584 | "不改权值的结果:\n",
585 | "```\n",
586 | "round 0, loss=10.36648941040039, acc=70.3%\n",
587 | "round 1, loss=9.720335960388184, acc=71.39999999999999%\n",
588 | "round 2, loss=9.265982627868652, acc=71.5%\n",
589 | "round 3, loss=8.925678253173828, acc=71.5%\n",
590 | "round 4, loss=8.659313201904297, acc=71.6%\n",
591 | "round 5, loss=8.443524360656738, acc=71.5%\n",
592 | "round 6, loss=8.263936042785645, acc=71.7%\n",
593 | "round 7, loss=8.1112642288208, acc=71.6%\n",
594 | "round 8, loss=7.979235649108887, acc=71.6%\n",
595 | "round 9, loss=7.8634490966796875, acc=71.6%\n",
596 | "round 10, loss=7.760721206665039, acc=71.7%\n",
597 | "round 11, loss=7.668695449829102, acc=71.7%\n",
598 | "round 12, loss=7.585556507110596, acc=71.8%\n",
599 | "round 13, loss=7.509913921356201, acc=71.8%\n",
600 | "round 14, loss=7.4406561851501465, acc=71.8%\n",
601 | "round 15, loss=7.376895904541016, acc=71.8%\n",
602 | "round 16, loss=7.317915439605713, acc=71.8%\n",
603 | "round 17, loss=7.263117790222168, acc=71.8%\n",
604 | "round 18, loss=7.212005615234375, acc=71.8%\n",
605 | "round 19, loss=7.164167404174805, acc=71.8%\n",
606 | "round 20, loss=7.119253635406494, acc=71.7%\n",
607 | "round 21, loss=7.076962471008301, acc=71.89999999999999%\n",
608 | "round 22, loss=7.037031650543213, acc=72.0%\n",
609 | "round 23, loss=6.999242782592773, acc=72.0%\n",
610 | "round 24, loss=6.963403701782227, acc=72.0%\n",
611 | "round 25, loss=6.929336071014404, acc=72.0%\n",
612 | "round 26, loss=6.896897792816162, acc=72.0%\n",
613 | "round 27, loss=6.865954399108887, acc=72.0%\n",
614 | "round 28, loss=6.836390495300293, acc=72.0%\n",
615 | "round 29, loss=6.808094501495361, acc=72.0%\n",
616 | "round 30, loss=6.780981063842773, acc=72.0%\n",
617 | "round 31, loss=6.754961013793945, acc=72.0%\n",
618 | "round 32, loss=6.729955673217773, acc=72.0%\n",
619 | "round 33, loss=6.705905914306641, acc=72.0%\n",
620 | "round 34, loss=6.682745933532715, acc=72.0%\n",
621 | "round 35, loss=6.660414218902588, acc=72.0%\n",
622 | "round 36, loss=6.638864040374756, acc=72.0%\n",
623 | "round 37, loss=6.618048667907715, acc=72.0%\n",
624 | "round 38, loss=6.5979228019714355, acc=72.1%\n",
625 | "round 39, loss=6.578448295593262, acc=72.1%\n",
626 | "round 40, loss=6.559586048126221, acc=72.1%\n",
627 | "round 41, loss=6.541306972503662, acc=72.1%\n",
628 | "round 42, loss=6.523574352264404, acc=72.1%\n",
629 | "round 43, loss=6.5063652992248535, acc=72.1%\n",
630 | "round 44, loss=6.489650726318359, acc=72.2%\n",
631 | "round 45, loss=6.473405838012695, acc=72.2%\n",
632 | "round 46, loss=6.457608699798584, acc=72.2%\n",
633 | "round 47, loss=6.442233085632324, acc=72.2%\n",
634 | "round 48, loss=6.427267551422119, acc=72.2%\n",
635 | "round 49, loss=6.412689208984375, acc=72.1%\n",
636 | "round 50, loss=6.398477077484131, acc=72.1%\n",
637 | "round 51, loss=6.384622573852539, acc=72.1%\n",
638 | "round 52, loss=6.371103286743164, acc=72.1%\n",
639 | "round 53, loss=6.357907772064209, acc=72.1%\n",
640 | "round 54, loss=6.345022201538086, acc=72.2%\n",
641 | "round 55, loss=6.332433700561523, acc=72.2%\n",
642 | "round 56, loss=6.320131301879883, acc=72.2%\n",
643 | "round 57, loss=6.308100700378418, acc=72.2%\n",
644 | "round 58, loss=6.296338081359863, acc=72.2%\n",
645 | "round 59, loss=6.284825325012207, acc=72.3%\n",
646 | "round 60, loss=6.273555755615234, acc=72.3%\n",
647 | "round 61, loss=6.262521266937256, acc=72.3%\n",
648 | "round 62, loss=6.251714706420898, acc=72.3%\n",
649 | "round 63, loss=6.241124153137207, acc=72.3%\n",
650 | "round 64, loss=6.230745315551758, acc=72.3%\n",
651 | "round 65, loss=6.220568656921387, acc=72.3%\n",
652 | "round 66, loss=6.210587978363037, acc=72.3%\n",
653 | "round 67, loss=6.200796127319336, acc=72.3%\n",
654 | "round 68, loss=6.191188335418701, acc=72.3%\n",
655 | "round 69, loss=6.181758403778076, acc=72.3%\n",
656 | "round 70, loss=6.172497749328613, acc=72.1%\n",
657 | "round 71, loss=6.16340446472168, acc=72.1%\n",
658 | "round 72, loss=6.154470920562744, acc=72.1%\n",
659 | "round 73, loss=6.145693778991699, acc=72.2%\n",
660 | "round 74, loss=6.137065887451172, acc=72.2%\n",
661 | "round 75, loss=6.128585338592529, acc=72.3%\n",
662 | "round 76, loss=6.120245456695557, acc=72.3%\n",
663 | "round 77, loss=6.1120452880859375, acc=72.3%\n",
664 | "round 78, loss=6.103979110717773, acc=72.3%\n",
665 | "round 79, loss=6.09604024887085, acc=72.3%\n",
666 | "round 80, loss=6.088228702545166, acc=72.3%\n",
667 | "round 81, loss=6.080538749694824, acc=72.3%\n",
668 | "round 82, loss=6.072969913482666, acc=72.3%\n",
669 | "round 83, loss=6.065515518188477, acc=72.3%\n",
670 | "round 84, loss=6.058176040649414, acc=72.3%\n",
671 | "round 85, loss=6.050944805145264, acc=72.3%\n",
672 | "round 86, loss=6.043820858001709, acc=72.3%\n",
673 | "round 87, loss=6.036800384521484, acc=72.3%\n",
674 | "round 88, loss=6.029882907867432, acc=72.3%\n",
675 | "round 89, loss=6.023061752319336, acc=72.3%\n",
676 | "round 90, loss=6.016338348388672, acc=72.3%\n",
677 | "round 91, loss=6.009708881378174, acc=72.3%\n",
678 | "round 92, loss=6.003171920776367, acc=72.3%\n",
679 | "round 93, loss=5.996725082397461, acc=72.3%\n",
680 | "round 94, loss=5.9903645515441895, acc=72.3%\n",
681 | "round 95, loss=5.984086513519287, acc=72.3%\n",
682 | "round 96, loss=5.977894306182861, acc=72.3%\n",
683 | "round 97, loss=5.971782684326172, acc=72.3%\n",
684 | "round 98, loss=5.965747833251953, acc=72.3%\n",
685 | "round 99, loss=5.959793567657471, acc=72.3%\n",
686 | "round 100, loss=5.953916072845459, acc=72.3%\n",
687 | "round 101, loss=5.948107719421387, acc=72.3%\n",
688 | "round 102, loss=5.942374229431152, acc=72.3%\n",
689 | "round 103, loss=5.936712741851807, acc=72.3%\n",
690 | "round 104, loss=5.9311203956604, acc=72.3%\n",
691 | "round 105, loss=5.925593852996826, acc=72.3%\n",
692 | "round 106, loss=5.920135974884033, acc=72.3%\n",
693 | "round 107, loss=5.914740085601807, acc=72.3%\n",
694 | "round 108, loss=5.9094109535217285, acc=72.3%\n",
695 | "round 109, loss=5.904142379760742, acc=72.3%\n",
696 | "round 110, loss=5.898932933807373, acc=72.3%\n",
697 | "round 111, loss=5.8937859535217285, acc=72.3%\n",
698 | "round 112, loss=5.888695240020752, acc=72.39999999999999%\n",
699 | "round 113, loss=5.883661270141602, acc=72.39999999999999%\n",
700 | "round 114, loss=5.8786845207214355, acc=72.39999999999999%\n",
701 | "round 115, loss=5.873762607574463, acc=72.39999999999999%\n",
702 | "round 116, loss=5.868896007537842, acc=72.5%\n",
703 | "round 117, loss=5.864081382751465, acc=72.5%\n",
704 | "round 118, loss=5.859317779541016, acc=72.5%\n",
705 | "round 119, loss=5.8546037673950195, acc=72.5%\n",
706 | "round 120, loss=5.849942684173584, acc=72.5%\n",
707 | "round 121, loss=5.845328330993652, acc=72.5%\n",
708 | "round 122, loss=5.840763568878174, acc=72.5%\n",
709 | "round 123, loss=5.836244106292725, acc=72.5%\n",
710 | "round 124, loss=5.831772327423096, acc=72.5%\n",
711 | "round 125, loss=5.827345371246338, acc=72.5%\n",
712 | "round 126, loss=5.822963237762451, acc=72.5%\n",
713 | "round 127, loss=5.818624496459961, acc=72.5%\n",
714 | "round 128, loss=5.8143310546875, acc=72.5%\n",
715 | "round 129, loss=5.8100762367248535, acc=72.5%\n",
716 | "round 130, loss=5.805863857269287, acc=72.5%\n",
717 | "round 131, loss=5.801693916320801, acc=72.5%\n",
718 | "round 132, loss=5.7975640296936035, acc=72.5%\n",
719 | "round 133, loss=5.79347038269043, acc=72.5%\n",
720 | "round 134, loss=5.789417743682861, acc=72.5%\n",
721 | "round 135, loss=5.785403251647949, acc=72.5%\n",
722 | "round 136, loss=5.781425952911377, acc=72.5%\n",
723 | "round 137, loss=5.7774858474731445, acc=72.5%\n",
724 | "round 138, loss=5.7735819816589355, acc=72.5%\n",
725 | "round 139, loss=5.769712448120117, acc=72.5%\n",
726 | "round 140, loss=5.765878200531006, acc=72.6%\n",
727 | "round 141, loss=5.762081146240234, acc=72.6%\n",
728 | "round 142, loss=5.7583136558532715, acc=72.6%\n",
729 | "round 143, loss=5.754583358764648, acc=72.6%\n",
730 | "round 144, loss=5.750883102416992, acc=72.6%\n",
731 | "round 145, loss=5.747217178344727, acc=72.6%\n",
732 | "round 146, loss=5.743582248687744, acc=72.6%\n",
733 | "round 147, loss=5.7399773597717285, acc=72.6%\n",
734 | "round 148, loss=5.736404895782471, acc=72.6%\n",
735 | "round 149, loss=5.732864856719971, acc=72.6%\n",
736 | "round 150, loss=5.729351043701172, acc=72.6%\n",
737 | "round 151, loss=5.7258687019348145, acc=72.6%\n",
738 | "round 152, loss=5.722413539886475, acc=72.6%\n",
739 | "round 153, loss=5.718986988067627, acc=72.6%\n",
740 | "round 154, loss=5.7155914306640625, acc=72.6%\n",
741 | "round 155, loss=5.712220668792725, acc=72.6%\n",
742 | "round 156, loss=5.7088799476623535, acc=72.6%\n",
743 | "round 157, loss=5.705564498901367, acc=72.6%\n",
744 | "round 158, loss=5.702276229858398, acc=72.6%\n",
745 | "round 159, loss=5.699013710021973, acc=72.6%\n",
746 | "round 160, loss=5.695777416229248, acc=72.5%\n",
747 | "round 161, loss=5.69256591796875, acc=72.5%\n",
748 | "round 162, loss=5.689380645751953, acc=72.5%\n",
749 | "round 163, loss=5.68621826171875, acc=72.5%\n",
750 | "round 164, loss=5.683082103729248, acc=72.5%\n",
751 | "round 165, loss=5.679969310760498, acc=72.5%\n",
752 | "round 166, loss=5.676881313323975, acc=72.5%\n",
753 | "round 167, loss=5.673816680908203, acc=72.5%\n",
754 | "round 168, loss=5.670772075653076, acc=72.5%\n",
755 | "round 169, loss=5.667752742767334, acc=72.5%\n",
756 | "round 170, loss=5.664754867553711, acc=72.5%\n",
757 | "round 171, loss=5.6617817878723145, acc=72.5%\n",
758 | "round 172, loss=5.6588263511657715, acc=72.5%\n",
759 | "round 173, loss=5.655895233154297, acc=72.5%\n",
760 | "round 174, loss=5.65298318862915, acc=72.5%\n",
761 | "round 175, loss=5.650094985961914, acc=72.6%\n",
762 | "round 176, loss=5.647225856781006, acc=72.6%\n",
763 | "round 177, loss=5.6443772315979, acc=72.6%\n",
764 | "round 178, loss=5.6415510177612305, acc=72.6%\n",
765 | "round 179, loss=5.638741493225098, acc=72.6%\n",
766 | "round 180, loss=5.6359543800354, acc=72.6%\n",
767 | "round 181, loss=5.633185386657715, acc=72.6%\n",
768 | "round 182, loss=5.630435943603516, acc=72.6%\n",
769 | "round 183, loss=5.627706527709961, acc=72.6%\n",
770 | "round 184, loss=5.624994277954102, acc=72.6%\n",
771 | "round 185, loss=5.6222991943359375, acc=72.6%\n",
772 | "round 186, loss=5.619626522064209, acc=72.6%\n",
773 | "round 187, loss=5.616968154907227, acc=72.6%\n",
774 | "round 188, loss=5.614329814910889, acc=72.6%\n",
775 | "round 189, loss=5.611709117889404, acc=72.6%\n",
776 | "round 190, loss=5.609104156494141, acc=72.6%\n",
777 | "round 191, loss=5.606518745422363, acc=72.6%\n",
778 | "round 192, loss=5.603947639465332, acc=72.6%\n",
779 | "round 193, loss=5.6013946533203125, acc=72.6%\n",
780 | "round 194, loss=5.598859786987305, acc=72.6%\n",
781 | "round 195, loss=5.596339225769043, acc=72.6%\n",
782 | "round 196, loss=5.593835353851318, acc=72.6%\n",
783 | "round 197, loss=5.591347694396973, acc=72.6%\n",
784 | "round 198, loss=5.588876724243164, acc=72.6%\n",
785 | "round 199, loss=5.586421012878418, acc=72.6%\n",
786 | "round 200, loss=5.583980083465576, acc=72.6%\n",
787 | "round 201, loss=5.581554412841797, acc=72.6%\n",
788 | "round 202, loss=5.579144477844238, acc=72.6%\n",
789 | "round 203, loss=5.576749801635742, acc=72.6%\n",
790 | "round 204, loss=5.57436990737915, acc=72.6%\n",
791 | "round 205, loss=5.5720014572143555, acc=72.6%\n",
792 | "round 206, loss=5.569650173187256, acc=72.6%\n",
793 | "round 207, loss=5.567317008972168, acc=72.6%\n",
794 | "round 208, loss=5.5649919509887695, acc=72.6%\n",
795 | "round 209, loss=5.562684535980225, acc=72.6%\n",
796 | "round 210, loss=5.56038761138916, acc=72.6%\n",
797 | "round 211, loss=5.558109283447266, acc=72.6%\n",
798 | "round 212, loss=5.555840492248535, acc=72.6%\n",
799 | "round 213, loss=5.553586483001709, acc=72.6%\n",
800 | "round 214, loss=5.5513458251953125, acc=72.6%\n",
801 | "round 215, loss=5.549118518829346, acc=72.6%\n",
802 | "round 216, loss=5.546905040740967, acc=72.6%\n",
803 | "round 217, loss=5.544701099395752, acc=72.6%\n",
804 | "round 218, loss=5.542513847351074, acc=72.6%\n",
805 | "round 219, loss=5.540337085723877, acc=72.6%\n",
806 | "round 220, loss=5.538171291351318, acc=72.6%\n",
807 | "round 221, loss=5.5360212326049805, acc=72.6%\n",
808 | "round 222, loss=5.533881664276123, acc=72.6%\n",
809 | "round 223, loss=5.531755447387695, acc=72.6%\n",
810 | "round 224, loss=5.529639720916748, acc=72.6%\n",
811 | "round 225, loss=5.527536392211914, acc=72.6%\n",
812 | "round 226, loss=5.525442123413086, acc=72.6%\n",
813 | "round 227, loss=5.52336311340332, acc=72.6%\n",
814 | "round 228, loss=5.521296977996826, acc=72.6%\n",
815 | "round 229, loss=5.519237041473389, acc=72.6%\n",
816 | "round 230, loss=5.517190933227539, acc=72.6%\n",
817 | "round 231, loss=5.515157699584961, acc=72.6%\n",
818 | "round 232, loss=5.513134002685547, acc=72.6%\n",
819 | "round 233, loss=5.511120796203613, acc=72.7%\n",
820 | "round 234, loss=5.509117603302002, acc=72.7%\n",
821 | "round 235, loss=5.507127285003662, acc=72.7%\n",
822 | "round 236, loss=5.505147933959961, acc=72.7%\n",
823 | "round 237, loss=5.503175258636475, acc=72.7%\n",
824 | "round 238, loss=5.501216411590576, acc=72.89999999999999%\n",
825 | "round 239, loss=5.499266147613525, acc=72.89999999999999%\n",
826 | "round 240, loss=5.49732780456543, acc=72.89999999999999%\n",
827 | "round 241, loss=5.49539852142334, acc=72.89999999999999%\n",
828 | "round 242, loss=5.493480205535889, acc=72.89999999999999%\n",
829 | "round 243, loss=5.491570949554443, acc=72.89999999999999%\n",
830 | "round 244, loss=5.489671230316162, acc=72.89999999999999%\n",
831 | "round 245, loss=5.4877824783325195, acc=72.89999999999999%\n",
832 | "round 246, loss=5.485902786254883, acc=72.89999999999999%\n",
833 | "round 247, loss=5.4840312004089355, acc=72.89999999999999%\n",
834 | "round 248, loss=5.482170581817627, acc=72.89999999999999%\n",
835 | "round 249, loss=5.480320453643799, acc=72.89999999999999%\n",
836 | "round 250, loss=5.47847843170166, acc=72.89999999999999%\n",
837 | "round 251, loss=5.476645469665527, acc=72.89999999999999%\n",
838 | "round 252, loss=5.474822521209717, acc=72.89999999999999%\n",
839 | "round 253, loss=5.473005771636963, acc=72.89999999999999%\n",
840 | "round 254, loss=5.4712018966674805, acc=72.89999999999999%\n",
841 | "round 255, loss=5.469404220581055, acc=72.89999999999999%\n",
842 | "round 256, loss=5.467617988586426, acc=72.89999999999999%\n",
843 | "round 257, loss=5.4658379554748535, acc=72.89999999999999%\n",
844 | "round 258, loss=5.4640679359436035, acc=72.89999999999999%\n",
845 | "round 259, loss=5.462305068969727, acc=72.89999999999999%\n",
846 | "round 260, loss=5.460551738739014, acc=72.89999999999999%\n",
847 | "round 261, loss=5.458808422088623, acc=72.89999999999999%\n",
848 | "round 262, loss=5.457070827484131, acc=72.89999999999999%\n",
849 | "round 263, loss=5.4553422927856445, acc=72.89999999999999%\n",
850 | "round 264, loss=5.453622341156006, acc=72.89999999999999%\n",
851 | "round 265, loss=5.451910495758057, acc=72.89999999999999%\n",
852 | "round 266, loss=5.4502081871032715, acc=72.89999999999999%\n",
853 | "round 267, loss=5.448512554168701, acc=72.89999999999999%\n",
854 | "round 268, loss=5.446822643280029, acc=72.89999999999999%\n",
855 | "round 269, loss=5.445144176483154, acc=72.89999999999999%\n",
856 | "round 270, loss=5.443472385406494, acc=72.89999999999999%\n",
857 | "round 271, loss=5.441807746887207, acc=72.89999999999999%\n",
858 | "round 272, loss=5.440150737762451, acc=72.89999999999999%\n",
859 | "round 273, loss=5.438502311706543, acc=72.89999999999999%\n",
860 | "round 274, loss=5.436861515045166, acc=72.89999999999999%\n",
861 | "round 275, loss=5.435227870941162, acc=72.89999999999999%\n",
862 | "round 276, loss=5.433599948883057, acc=72.89999999999999%\n",
863 | "round 277, loss=5.431982040405273, acc=72.89999999999999%\n",
864 | "round 278, loss=5.430369853973389, acc=72.89999999999999%\n",
865 | "round 279, loss=5.428764820098877, acc=72.89999999999999%\n",
866 | "round 280, loss=5.427168369293213, acc=73.0%\n",
867 | "round 281, loss=5.425579071044922, acc=73.0%\n",
868 | "round 282, loss=5.423994064331055, acc=73.0%\n",
869 | "round 283, loss=5.422417163848877, acc=73.0%\n",
870 | "round 284, loss=5.420848846435547, acc=73.0%\n",
871 | "round 285, loss=5.419287204742432, acc=73.0%\n",
872 | "round 286, loss=5.417731761932373, acc=73.0%\n",
873 | "round 287, loss=5.4161834716796875, acc=73.0%\n",
874 | "round 288, loss=5.4146409034729, acc=73.2%\n",
875 | "round 289, loss=5.4131059646606445, acc=73.2%\n",
876 | "round 290, loss=5.411578178405762, acc=73.2%\n",
877 | "round 291, loss=5.410055637359619, acc=73.2%\n",
878 | "round 292, loss=5.40854024887085, acc=73.2%\n",
879 | "round 293, loss=5.40703010559082, acc=73.2%\n",
880 | "round 294, loss=5.405528545379639, acc=73.2%\n",
881 | "round 295, loss=5.404033184051514, acc=73.2%\n",
882 | "round 296, loss=5.402544021606445, acc=73.2%\n",
883 | "round 297, loss=5.401060104370117, acc=73.2%\n",
884 | "round 298, loss=5.399583339691162, acc=73.2%\n",
885 | "round 299, loss=5.398112773895264, acc=73.2%\n",
886 | "```"
887 | ]
888 | },
889 | {
890 | "cell_type": "markdown",
891 | "metadata": {
892 | "id": "fUg7lHHAXHap"
893 | },
894 | "source": [
895 | "改变权值(0.19,0.01)后的结果\n",
896 | "```\n",
897 | "round 0, loss=10.372981071472168, acc=70.39999999999999%\n",
898 | "round 1, loss=9.729572296142578, acc=71.8%\n",
899 | "round 2, loss=9.276639938354492, acc=71.8%\n",
900 | "round 3, loss=8.937175750732422, acc=71.8%\n",
901 | "round 4, loss=8.671344757080078, acc=71.8%\n",
902 | "round 5, loss=8.455917358398438, acc=71.8%\n",
903 | "round 6, loss=8.27657699584961, acc=71.89999999999999%\n",
904 | "round 7, loss=8.124074935913086, acc=71.8%\n",
905 | "round 8, loss=7.992163181304932, acc=71.8%\n",
906 | "round 9, loss=7.876453399658203, acc=71.8%\n",
907 | "round 10, loss=7.773778915405273, acc=71.8%\n",
908 | "round 11, loss=7.681778430938721, acc=71.89999999999999%\n",
909 | "round 12, loss=7.598653793334961, acc=71.89999999999999%\n",
910 | "round 13, loss=7.52301025390625, acc=72.0%\n",
911 | "round 14, loss=7.453742027282715, acc=72.0%\n",
912 | "round 15, loss=7.389968395233154, acc=72.0%\n",
913 | "```"
914 | ]
915 | }
916 | ]
917 | }
--------------------------------------------------------------------------------