├── .gitignore ├── 001-keras_overview.ipynb ├── 002-keras_function.ipynb ├── 003-keras_training_and_evaluation.ipynb ├── 004-keras_write_layers_from_scratch.ipynb ├── 005-keras_saving_and_serializing_model.ipynb ├── 006-eager_execution.ipynb ├── 007-Variables.ipynb ├── 008-AutoGraph.ipynb ├── 020-Eager ├── 001-Tensor_and_operations.ipynb ├── 002-custom_layers.ipynb ├── 003-automatic_differentiation.ipynb ├── 004-custom_training_base.ipynb ├── 005-custom_training_walkthrough.ipynb └── 006-tf_function_and_autograph.ipynb ├── 021-MLP ├── .ipynb_checkpoints │ ├── 001-MLP-checkpoint.ipynb │ └── 002-MLP2-checkpoint.ipynb ├── 001-MLP.ipynb └── 002-MLP2.ipynb ├── 022-CNN ├── .ipynb_checkpoints │ ├── 001-cnn-checkpoint.ipynb │ ├── 002-cnn_variants-checkpoint.ipynb │ ├── 003-text_cnn-checkpoint.ipynb │ └── 004-pretrained_cnn-checkpoint.ipynb ├── 001-cnn.ipynb ├── 002-cnn_variants.ipynb ├── 003-text_cnn.ipynb ├── 004-pretrained_cnn.ipynb └── dog.jpg ├── 023-RNN ├── .ipynb_checkpoints │ ├── 002-rnn_variance-checkpoint.ipynb │ └── 003-cnn_rnn-checkpoint.ipynb ├── 002-rnn_variance.ipynb └── 003-cnn_rnn.ipynb ├── 024-AutoEncoder ├── .ipynb_checkpoints │ ├── 001-autoencoder-checkpoint.ipynb │ └── 002-cnn_autoencoder-checkpoint.ipynb ├── 001-autoencoder.ipynb ├── 002-cnn_autoencoder.ipynb └── model.png ├── 025-GAN ├── .ipynb_checkpoints │ └── 002-DCGAN-checkpoint.ipynb └── 002-DCGAN.ipynb ├── 026-Transformer └── 001-Transformer.ipynb ├── 031-Image ├── .ipynb_checkpoints │ └── 001-image_classification-checkpoint.ipynb ├── 001-image_classification.ipynb └── pix2pix.ipynb ├── 032-Text ├── .ipynb_checkpoints │ └── 001-word_embeddings-checkpoint.ipynb ├── 001-word_embeddings.ipynb ├── 002-text_classification_with_RNN.ipynb ├── meta.tsv ├── nmt_with_attention.ipynb ├── text_generation.ipynb └── vecs.tsv ├── 033-Estimators ├── .ipynb_checkpoints │ └── 001-boosted_trees-checkpoint.ipynb └── 001-boosted_trees.ipynb ├── 040-App ├── 002-style_transfer.ipynb └── 403-image_caption_with_attention.ipynb ├── 101-example_image_classification.ipynb ├── 102-example_text_classification.ipynb ├── 103-example_overfitting_and_underfitting.ipynb ├── 104-example_classify_structured_data.ipynb ├── 105-example_regression.ipynb ├── 106-example_save_and_restore_models.ipynb └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints 2 | *.png 3 | *.jpg 4 | -------------------------------------------------------------------------------- /004-keras_write_layers_from_scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 用keras构建自己的网络层\n", 8 | "\n", 9 | "\n", 10 | "## 1.构建一个简单的网络层\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "\n", 20 | "from __future__ import absolute_import, division, print_function\n", 21 | "import tensorflow as tf\n", 22 | "tf.keras.backend.clear_session()\n", 23 | "import tensorflow.keras as keras\n", 24 | "import tensorflow.keras.layers as layers" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "name": "stdout", 34 | "output_type": "stream", 35 | "text": [ 36 | "tf.Tensor(\n", 37 | "[[0.06709253 0.06818779 0.09926171 0.0179923 ]\n", 38 | " [0.06709253 0.06818779 0.09926171 0.0179923 ]\n", 39 | " [0.06709253 0.06818779 0.09926171 0.0179923 ]], shape=(3, 4), dtype=float32)\n" 40 | ] 41 | } 42 | ], 43 | "source": [ 44 | "# 定义网络层就是:设置网络权重和输出到输入的计算过程\n", 45 | "class MyLayer(layers.Layer):\n", 46 | " def __init__(self, input_dim=32, unit=32):\n", 47 | " super(MyLayer, self).__init__()\n", 48 | " \n", 49 | " w_init = tf.random_normal_initializer()\n", 50 | " self.weight = tf.Variable(initial_value=w_init(\n", 51 | " shape=(input_dim, unit), dtype=tf.float32), trainable=True)\n", 52 | " \n", 53 | " b_init = tf.zeros_initializer()\n", 54 | " self.bias = tf.Variable(initial_value=b_init(\n", 55 | " shape=(unit,), dtype=tf.float32), trainable=True)\n", 56 | " \n", 57 | " def call(self, inputs):\n", 58 | " return tf.matmul(inputs, self.weight) + self.bias\n", 59 | " \n", 60 | "x = tf.ones((3,5))\n", 61 | "my_layer = MyLayer(5, 4)\n", 62 | "out = my_layer(x)\n", 63 | "print(out)\n", 64 | " \n", 65 | " " 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "按上面构建网络层,图层会自动跟踪权重w和b,当然我们也可以直接用add_weight的方法构建权重" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "tf.Tensor(\n", 85 | "[[-0.10401802 -0.05459599 -0.08195674 0.13151655]\n", 86 | " [-0.10401802 -0.05459599 -0.08195674 0.13151655]\n", 87 | " [-0.10401802 -0.05459599 -0.08195674 0.13151655]], shape=(3, 4), dtype=float32)\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "class MyLayer(layers.Layer):\n", 93 | " def __init__(self, input_dim=32, unit=32):\n", 94 | " super(MyLayer, self).__init__()\n", 95 | " self.weight = self.add_weight(shape=(input_dim, unit),\n", 96 | " initializer=keras.initializers.RandomNormal(),\n", 97 | " trainable=True)\n", 98 | " self.bias = self.add_weight(shape=(unit,),\n", 99 | " initializer=keras.initializers.Zeros(),\n", 100 | " trainable=True)\n", 101 | " \n", 102 | " def call(self, inputs):\n", 103 | " return tf.matmul(inputs, self.weight) + self.bias\n", 104 | " \n", 105 | "x = tf.ones((3,5))\n", 106 | "my_layer = MyLayer(5, 4)\n", 107 | "out = my_layer(x)\n", 108 | "print(out)\n", 109 | " " 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "也可以设置不可训练的权重" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "[3. 3. 3.]\n", 129 | "[6. 6. 6.]\n", 130 | "weight: []\n", 131 | "non-trainable weight: []\n", 132 | "trainable weight: []\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "class AddLayer(layers.Layer):\n", 138 | " def __init__(self, input_dim=32):\n", 139 | " super(AddLayer, self).__init__()\n", 140 | " self.sum = self.add_weight(shape=(input_dim,),\n", 141 | " initializer=keras.initializers.Zeros(),\n", 142 | " trainable=False)\n", 143 | " \n", 144 | " \n", 145 | " def call(self, inputs):\n", 146 | " self.sum.assign_add(tf.reduce_sum(inputs, axis=0))\n", 147 | " return self.sum\n", 148 | " \n", 149 | "x = tf.ones((3,3))\n", 150 | "my_layer = AddLayer(3)\n", 151 | "out = my_layer(x)\n", 152 | "print(out.numpy())\n", 153 | "out = my_layer(x)\n", 154 | "print(out.numpy())\n", 155 | "print('weight:', my_layer.weights)\n", 156 | "print('non-trainable weight:', my_layer.non_trainable_weights)\n", 157 | "print('trainable weight:', my_layer.trainable_weights)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "当定义网络时不知道网络的维度是可以重写build()函数,用获得的shape构建网络" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 5, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "tf.Tensor(\n", 177 | "[[ 0.00949192 -0.02009935 -0.11726624]\n", 178 | " [ 0.00949192 -0.02009935 -0.11726624]\n", 179 | " [ 0.00949192 -0.02009935 -0.11726624]], shape=(3, 3), dtype=float32)\n", 180 | "tf.Tensor(\n", 181 | "[[-0.00516411 -0.04891593 -0.0181773 ]\n", 182 | " [-0.00516411 -0.04891593 -0.0181773 ]], shape=(2, 3), dtype=float32)\n" 183 | ] 184 | } 185 | ], 186 | "source": [ 187 | "class MyLayer(layers.Layer):\n", 188 | " def __init__(self, unit=32):\n", 189 | " super(MyLayer, self).__init__()\n", 190 | " self.unit = unit\n", 191 | " \n", 192 | " def build(self, input_shape):\n", 193 | " self.weight = self.add_weight(shape=(input_shape[-1], self.unit),\n", 194 | " initializer=keras.initializers.RandomNormal(),\n", 195 | " trainable=True)\n", 196 | " self.bias = self.add_weight(shape=(self.unit,),\n", 197 | " initializer=keras.initializers.Zeros(),\n", 198 | " trainable=True)\n", 199 | " \n", 200 | " def call(self, inputs):\n", 201 | " return tf.matmul(inputs, self.weight) + self.bias\n", 202 | " \n", 203 | "\n", 204 | "my_layer = MyLayer(3)\n", 205 | "x = tf.ones((3,5))\n", 206 | "out = my_layer(x)\n", 207 | "print(out)\n", 208 | "my_layer = MyLayer(3)\n", 209 | "\n", 210 | "x = tf.ones((2,2))\n", 211 | "out = my_layer(x)\n", 212 | "print(out)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "## 2.使用子层递归构建网络层\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 12, 225 | "metadata": {}, 226 | "outputs": [ 227 | { 228 | "name": "stdout", 229 | "output_type": "stream", 230 | "text": [ 231 | "trainable weights: 0\n", 232 | "trainable weights: 6\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "class MyBlock(layers.Layer):\n", 238 | " def __init__(self):\n", 239 | " super(MyBlock, self).__init__()\n", 240 | " self.layer1 = MyLayer(32)\n", 241 | " self.layer2 = MyLayer(16)\n", 242 | " self.layer3 = MyLayer(2)\n", 243 | " def call(self, inputs):\n", 244 | " h1 = self.layer1(inputs)\n", 245 | " h1 = tf.nn.relu(h1)\n", 246 | " h2 = self.layer2(h1)\n", 247 | " h2 = tf.nn.relu(h2)\n", 248 | " return self.layer3(h2)\n", 249 | " \n", 250 | "my_block = MyBlock()\n", 251 | "print('trainable weights:', len(my_block.trainable_weights))\n", 252 | "y = my_block(tf.ones(shape=(3, 64)))\n", 253 | "# 构建网络在build()里面,所以执行了才有网络\n", 254 | "print('trainable weights:', len(my_block.trainable_weights)) " 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "可以通过构建网络层的方法来收集loss" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 18, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "0\n", 274 | "1\n", 275 | "1\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "class LossLayer(layers.Layer):\n", 281 | " \n", 282 | " def __init__(self, rate=1e-2):\n", 283 | " super(LossLayer, self).__init__()\n", 284 | " self.rate = rate\n", 285 | " \n", 286 | " def call(self, inputs):\n", 287 | " self.add_loss(self.rate * tf.reduce_sum(inputs))\n", 288 | " return inputs\n", 289 | "\n", 290 | "class OutLayer(layers.Layer):\n", 291 | " def __init__(self):\n", 292 | " super(OutLayer, self).__init__()\n", 293 | " self.loss_fun=LossLayer(1e-2)\n", 294 | " def call(self, inputs):\n", 295 | " return self.loss_fun(inputs)\n", 296 | " \n", 297 | "my_layer = OutLayer()\n", 298 | "print(len(my_layer.losses)) # 还未call\n", 299 | "y = my_layer(tf.zeros(1,1))\n", 300 | "print(len(my_layer.losses)) # 执行call之后\n", 301 | "y = my_layer(tf.zeros(1,1))\n", 302 | "print(len(my_layer.losses)) # call之前会重新置0\n", 303 | "\n" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "如果中间调用了keras网络层,里面的正则化loss也会被加入进来" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 25, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "[]\n", 323 | "[, ]\n" 334 | ] 335 | } 336 | ], 337 | "source": [ 338 | "class OuterLayer(layers.Layer):\n", 339 | "\n", 340 | " def __init__(self):\n", 341 | " super(OuterLayer, self).__init__()\n", 342 | " self.dense = layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(1e-3))\n", 343 | " \n", 344 | " def call(self, inputs):\n", 345 | " return self.dense(inputs)\n", 346 | "\n", 347 | "\n", 348 | "my_layer = OuterLayer()\n", 349 | "y = my_layer(tf.zeros((1,1)))\n", 350 | "print(my_layer.losses) \n", 351 | "print(my_layer.weights) " 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "# 3.其他网络层配置\n", 359 | "使自己的网络层可以序列化" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": 28, 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "name": "stdout", 369 | "output_type": "stream", 370 | "text": [ 371 | "{'name': 'linear_1', 'trainable': True, 'dtype': None, 'units': 125}\n" 372 | ] 373 | } 374 | ], 375 | "source": [ 376 | "class Linear(layers.Layer):\n", 377 | "\n", 378 | " def __init__(self, units=32, **kwargs):\n", 379 | " super(Linear, self).__init__(**kwargs)\n", 380 | " self.units = units\n", 381 | "\n", 382 | " def build(self, input_shape):\n", 383 | " self.w = self.add_weight(shape=(input_shape[-1], self.units),\n", 384 | " initializer='random_normal',\n", 385 | " trainable=True)\n", 386 | " self.b = self.add_weight(shape=(self.units,),\n", 387 | " initializer='random_normal',\n", 388 | " trainable=True)\n", 389 | " def call(self, inputs):\n", 390 | " return tf.matmul(inputs, self.w) + self.b\n", 391 | " \n", 392 | " def get_config(self):\n", 393 | " config = super(Linear, self).get_config()\n", 394 | " config.update({'units':self.units})\n", 395 | " return config\n", 396 | " \n", 397 | " \n", 398 | "layer = Linear(125)\n", 399 | "config = layer.get_config()\n", 400 | "print(config)\n", 401 | "new_layer = Linear.from_config(config)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "配置只有训练时可以执行的网络层" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 30, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "class MyDropout(layers.Layer):\n", 418 | " def __init__(self, rate, **kwargs):\n", 419 | " super(MyDropout, self).__init__(**kwargs)\n", 420 | " self.rate = rate\n", 421 | " def call(self, inputs, training=None):\n", 422 | " return tf.cond(training, \n", 423 | " lambda: tf.nn.dropout(inputs, rate=self.rate),\n", 424 | " lambda: inputs)\n", 425 | " \n", 426 | " " 427 | ] 428 | }, 429 | { 430 | "cell_type": "markdown", 431 | "metadata": {}, 432 | "source": [ 433 | "# 4.构建自己的模型\n", 434 | "通常,我们使用Layer类来定义内部计算块,并使用Model类来定义外部模型 - 即要训练的对象。\n", 435 | "\n", 436 | "Model类与Layer的区别:\n", 437 | "- 它公开了内置的训练,评估和预测循环(model.fit(),model.evaluate(),model.predict())。 \n", 438 | "- 它通过model.layers属性公开其内层列表。 \n", 439 | "- 它公开了保存和序列化API。\n", 440 | "\n", 441 | "下面通过构建一个变分自编码器(VAE),来介绍如何构建自己的网络。" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 46, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "# 采样网络\n", 451 | "class Sampling(layers.Layer):\n", 452 | " def call(self, inputs):\n", 453 | " z_mean, z_log_var = inputs\n", 454 | " batch = tf.shape(z_mean)[0]\n", 455 | " dim = tf.shape(z_mean)[1]\n", 456 | " epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n", 457 | " return z_mean + tf.exp(0.5*z_log_var) * epsilon\n", 458 | "# 编码器\n", 459 | "class Encoder(layers.Layer):\n", 460 | " def __init__(self, latent_dim=32, \n", 461 | " intermediate_dim=64, name='encoder', **kwargs):\n", 462 | " super(Encoder, self).__init__(name=name, **kwargs)\n", 463 | " self.dense_proj = layers.Dense(intermediate_dim, activation='relu')\n", 464 | " self.dense_mean = layers.Dense(latent_dim)\n", 465 | " self.dense_log_var = layers.Dense(latent_dim)\n", 466 | " self.sampling = Sampling()\n", 467 | " \n", 468 | " def call(self, inputs):\n", 469 | " h1 = self.dense_proj(inputs)\n", 470 | " z_mean = self.dense_mean(h1)\n", 471 | " z_log_var = self.dense_log_var(h1)\n", 472 | " z = self.sampling((z_mean, z_log_var))\n", 473 | " return z_mean, z_log_var, z\n", 474 | " \n", 475 | "# 解码器\n", 476 | "class Decoder(layers.Layer):\n", 477 | " def __init__(self, original_dim, \n", 478 | " intermediate_dim=64, name='decoder', **kwargs):\n", 479 | " super(Decoder, self).__init__(name=name, **kwargs)\n", 480 | " self.dense_proj = layers.Dense(intermediate_dim, activation='relu')\n", 481 | " self.dense_output = layers.Dense(original_dim, activation='sigmoid')\n", 482 | " def call(self, inputs):\n", 483 | " h1 = self.dense_proj(inputs)\n", 484 | " return self.dense_output(h1)\n", 485 | " \n", 486 | "# 变分自编码器\n", 487 | "class VAE(tf.keras.Model):\n", 488 | " def __init__(self, original_dim, latent_dim=32, \n", 489 | " intermediate_dim=64, name='encoder', **kwargs):\n", 490 | " super(VAE, self).__init__(name=name, **kwargs)\n", 491 | " \n", 492 | " self.original_dim = original_dim\n", 493 | " self.encoder = Encoder(latent_dim=latent_dim,\n", 494 | " intermediate_dim=intermediate_dim)\n", 495 | " self.decoder = Decoder(original_dim=original_dim,\n", 496 | " intermediate_dim=intermediate_dim)\n", 497 | " def call(self, inputs):\n", 498 | " z_mean, z_log_var, z = self.encoder(inputs)\n", 499 | " reconstructed = self.decoder(z)\n", 500 | " \n", 501 | " kl_loss = -0.5*tf.reduce_sum(\n", 502 | " z_log_var-tf.square(z_mean)-tf.exp(z_log_var)+1)\n", 503 | " self.add_loss(kl_loss)\n", 504 | " return reconstructed" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 47, 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "name": "stdout", 514 | "output_type": "stream", 515 | "text": [ 516 | "Epoch 1/3\n", 517 | "60000/60000 [==============================] - 3s 44us/sample - loss: 0.7352\n", 518 | "Epoch 2/3\n", 519 | "60000/60000 [==============================] - 2s 33us/sample - loss: 0.0691\n", 520 | "Epoch 3/3\n", 521 | "60000/60000 [==============================] - 2s 33us/sample - loss: 0.0679\n" 522 | ] 523 | }, 524 | { 525 | "data": { 526 | "text/plain": [ 527 | "" 528 | ] 529 | }, 530 | "execution_count": 47, 531 | "metadata": {}, 532 | "output_type": "execute_result" 533 | } 534 | ], 535 | "source": [ 536 | "\n", 537 | "(x_train, _), _ = tf.keras.datasets.mnist.load_data()\n", 538 | "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n", 539 | "vae = VAE(784,32,64)\n", 540 | "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n", 541 | "\n", 542 | "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n", 543 | "vae.fit(x_train, x_train, epochs=3, batch_size=64)" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "自己编写训练方法" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 50, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "name": "stdout", 560 | "output_type": "stream", 561 | "text": [ 562 | "Start of epoch 0\n", 563 | "step 0: mean loss = tf.Tensor(213.26726, shape=(), dtype=float32)\n", 564 | "step 100: mean loss = tf.Tensor(6.5270114, shape=(), dtype=float32)\n", 565 | "step 200: mean loss = tf.Tensor(3.3300452, shape=(), dtype=float32)\n", 566 | "step 300: mean loss = tf.Tensor(2.2522914, shape=(), dtype=float32)\n", 567 | "step 400: mean loss = tf.Tensor(1.7097591, shape=(), dtype=float32)\n", 568 | "step 500: mean loss = tf.Tensor(1.3835965, shape=(), dtype=float32)\n", 569 | "step 600: mean loss = tf.Tensor(1.1659158, shape=(), dtype=float32)\n", 570 | "step 700: mean loss = tf.Tensor(1.0099952, shape=(), dtype=float32)\n", 571 | "step 800: mean loss = tf.Tensor(0.89288116, shape=(), dtype=float32)\n", 572 | "step 900: mean loss = tf.Tensor(0.8014734, shape=(), dtype=float32)\n", 573 | "Start of epoch 1\n", 574 | "step 0: mean loss = tf.Tensor(0.77191323, shape=(), dtype=float32)\n", 575 | "step 100: mean loss = tf.Tensor(0.70431703, shape=(), dtype=float32)\n", 576 | "step 200: mean loss = tf.Tensor(0.6486862, shape=(), dtype=float32)\n", 577 | "step 300: mean loss = tf.Tensor(0.60191154, shape=(), dtype=float32)\n", 578 | "step 400: mean loss = tf.Tensor(0.56213117, shape=(), dtype=float32)\n", 579 | "step 500: mean loss = tf.Tensor(0.52777255, shape=(), dtype=float32)\n", 580 | "step 600: mean loss = tf.Tensor(0.49796674, shape=(), dtype=float32)\n", 581 | "step 700: mean loss = tf.Tensor(0.47174037, shape=(), dtype=float32)\n", 582 | "step 800: mean loss = tf.Tensor(0.4485459, shape=(), dtype=float32)\n", 583 | "step 900: mean loss = tf.Tensor(0.4277973, shape=(), dtype=float32)\n", 584 | "Start of epoch 2\n", 585 | "step 0: mean loss = tf.Tensor(0.42051753, shape=(), dtype=float32)\n", 586 | "step 100: mean loss = tf.Tensor(0.40269083, shape=(), dtype=float32)\n", 587 | "step 200: mean loss = tf.Tensor(0.38661462, shape=(), dtype=float32)\n", 588 | "step 300: mean loss = tf.Tensor(0.3719676, shape=(), dtype=float32)\n", 589 | "step 400: mean loss = tf.Tensor(0.35864368, shape=(), dtype=float32)\n", 590 | "step 500: mean loss = tf.Tensor(0.3463759, shape=(), dtype=float32)\n", 591 | "step 600: mean loss = tf.Tensor(0.33514142, shape=(), dtype=float32)\n", 592 | "step 700: mean loss = tf.Tensor(0.3247494, shape=(), dtype=float32)\n", 593 | "step 800: mean loss = tf.Tensor(0.3151487, shape=(), dtype=float32)\n", 594 | "step 900: mean loss = tf.Tensor(0.3061987, shape=(), dtype=float32)\n" 595 | ] 596 | } 597 | ], 598 | "source": [ 599 | "train_dataset = tf.data.Dataset.from_tensor_slices(x_train)\n", 600 | "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n", 601 | "\n", 602 | "original_dim = 784\n", 603 | "vae = VAE(original_dim, 64, 32)\n", 604 | "\n", 605 | "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n", 606 | "mse_loss_fn = tf.keras.losses.MeanSquaredError()\n", 607 | "\n", 608 | "loss_metric = tf.keras.metrics.Mean()\n", 609 | "\n", 610 | "# Iterate over epochs.\n", 611 | "for epoch in range(3):\n", 612 | " print('Start of epoch %d' % (epoch,))\n", 613 | "\n", 614 | " # Iterate over the batches of the dataset.\n", 615 | " for step, x_batch_train in enumerate(train_dataset):\n", 616 | " with tf.GradientTape() as tape:\n", 617 | " reconstructed = vae(x_batch_train)\n", 618 | " # Compute reconstruction loss\n", 619 | " loss = mse_loss_fn(x_batch_train, reconstructed)\n", 620 | " loss += sum(vae.losses) # Add KLD regularization loss\n", 621 | " \n", 622 | " grads = tape.gradient(loss, vae.trainable_variables)\n", 623 | " optimizer.apply_gradients(zip(grads, vae.trainable_variables))\n", 624 | " \n", 625 | " loss_metric(loss)\n", 626 | " \n", 627 | " if step % 100 == 0:\n", 628 | " print('step %s: mean loss = %s' % (step, loss_metric.result()))" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": null, 634 | "metadata": {}, 635 | "outputs": [], 636 | "source": [] 637 | } 638 | ], 639 | "metadata": { 640 | "kernelspec": { 641 | "display_name": "Python 3", 642 | "language": "python", 643 | "name": "python3" 644 | }, 645 | "language_info": { 646 | "codemirror_mode": { 647 | "name": "ipython", 648 | "version": 3 649 | }, 650 | "file_extension": ".py", 651 | "mimetype": "text/x-python", 652 | "name": "python", 653 | "nbconvert_exporter": "python", 654 | "pygments_lexer": "ipython3", 655 | "version": "3.6.6" 656 | } 657 | }, 658 | "nbformat": 4, 659 | "nbformat_minor": 2 660 | } 661 | -------------------------------------------------------------------------------- /005-keras_saving_and_serializing_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# tensorflow2教程-keras模型保持和序列化\n", 8 | "\n", 9 | "## 1.保持序列模型和函数模型" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "Model: \"3_layer_mlp\"\n", 22 | "_________________________________________________________________\n", 23 | "Layer (type) Output Shape Param # \n", 24 | "=================================================================\n", 25 | "digits (InputLayer) [(None, 784)] 0 \n", 26 | "_________________________________________________________________\n", 27 | "dense_1 (Dense) (None, 64) 50240 \n", 28 | "_________________________________________________________________\n", 29 | "dense_2 (Dense) (None, 64) 4160 \n", 30 | "_________________________________________________________________\n", 31 | "predictions (Dense) (None, 10) 650 \n", 32 | "=================================================================\n", 33 | "Total params: 55,050\n", 34 | "Trainable params: 55,050\n", 35 | "Non-trainable params: 0\n", 36 | "_________________________________________________________________\n", 37 | "60000/60000 [==============================] - 2s 29us/sample - loss: 0.3116\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "# 构建一个简单的模型并训练\n", 43 | "from __future__ import absolute_import, division, print_function\n", 44 | "import tensorflow as tf\n", 45 | "tf.keras.backend.clear_session()\n", 46 | "from tensorflow import keras\n", 47 | "from tensorflow.keras import layers\n", 48 | "\n", 49 | "inputs = keras.Input(shape=(784,), name='digits')\n", 50 | "x = layers.Dense(64, activation='relu', name='dense_1')(inputs)\n", 51 | "x = layers.Dense(64, activation='relu', name='dense_2')(x)\n", 52 | "outputs = layers.Dense(10, activation='softmax', name='predictions')(x)\n", 53 | "\n", 54 | "model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')\n", 55 | "model.summary()\n", 56 | "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", 57 | "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n", 58 | "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n", 59 | "\n", 60 | "model.compile(loss='sparse_categorical_crossentropy',\n", 61 | " optimizer=keras.optimizers.RMSprop())\n", 62 | "history = model.fit(x_train, y_train,\n", 63 | " batch_size=64,\n", 64 | " epochs=1)\n", 65 | "\n", 66 | "predictions = model.predict(x_test)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "### 1.1保持全模型\n", 74 | "可以对整个模型进行保存,其保持的内容包括:\n", 75 | "- 该模型的架构\n", 76 | "- 模型的权重(在训练期间学到的)\n", 77 | "- 模型的训练配置(你传递给编译的),如果有的话\n", 78 | "- 优化器及其状态(如果有的话)(这使您可以从中断的地方重新启动训练)\n" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 2, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "import numpy as np\n", 88 | "model.save('the_save_model.h5')\n", 89 | "new_model = keras.models.load_model('the_save_model.h5')\n", 90 | "new_prediction = new_model.predict(x_test)\n", 91 | "np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### 1.2 保持为SavedModel文件" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stderr", 108 | "output_type": "stream", 109 | "text": [ 110 | "WARNING: Logging before flag parsing goes to stderr.\n", 111 | "W0311 23:50:02.399847 139796031059712 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.\n", 112 | "Instructions for updating:\n", 113 | "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.\n", 114 | "W0311 23:50:02.400944 139796031059712 tf_logging.py:161] Export includes no default signature!\n", 115 | "W0311 23:50:02.721019 139796031059712 tf_logging.py:161] Export includes no default signature!\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "keras.experimental.export_saved_model(model, 'saved_model')\n", 121 | "new_model = keras.experimental.load_from_saved_model('saved_model')\n", 122 | "new_prediction = new_model.predict(x_test)\n", 123 | "np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样\n" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "### 1.3仅保存网络结构\n", 131 | "仅保持网络结构,这样导出的模型并未包含训练好的参数" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 6, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "config = model.get_config()\n", 141 | "reinitialized_model = keras.Model.from_config(config)\n", 142 | "new_prediction = reinitialized_model.predict(x_test)\n", 143 | "assert abs(np.sum(predictions-new_prediction)) >0" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "也可以使用json保存网络结构" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 7, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "json_config = model.to_json()\n", 160 | "reinitialized_model = keras.models.model_from_json(json_config)\n", 161 | "new_prediction = reinitialized_model.predict(x_test)\n", 162 | "assert abs(np.sum(predictions-new_prediction)) >0" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "### 1.4仅保存网络参数" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 8, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "weights = model.get_weights()\n", 179 | "model.set_weights(weights)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 12, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# 可以把结构和参数保存结合起来\n", 189 | "config = model.get_config()\n", 190 | "weights = model.get_weights()\n", 191 | "new_model = keras.Model.from_config(config) # config只能用keras.Model的这个api\n", 192 | "new_model.set_weights(weights)\n", 193 | "new_predictions = new_model.predict(x_test)\n", 194 | "np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "### 1.5完整的模型保持方法" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 15, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "json_config = model.to_json()\n", 211 | "with open('model_config.json', 'w') as json_file:\n", 212 | " json_file.write(json_config)\n", 213 | "\n", 214 | "model.save_weights('path_to_my_weights.h5')\n", 215 | "\n", 216 | "with open('model_config.json') as json_file:\n", 217 | " json_config = json_file.read()\n", 218 | "new_model = keras.models.model_from_json(json_config)\n", 219 | "new_model.load_weights('path_to_my_weights.h5')\n", 220 | "\n", 221 | "new_predictions = new_model.predict(x_test)\n", 222 | "np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 16, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "# 当然也可以一步到位\n", 232 | "model.save('path_to_my_model.h5')\n", 233 | "del model\n", 234 | "model = keras.models.load_model('path_to_my_model.h5')" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "### 1.6保存网络权重为SavedModel格式" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 18, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "model.save_weights('weight_tf_savedmodel')\n", 251 | "model.save_weights('weight_tf_savedmodel_h5', save_format='h5')" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "### 1.7子类模型参数保存\n", 259 | "子类模型的结构无法保存和序列化,只能保持参数" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 19, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "# 构建模型\n", 269 | "class ThreeLayerMLP(keras.Model):\n", 270 | " \n", 271 | " def __init__(self, name=None):\n", 272 | " super(ThreeLayerMLP, self).__init__(name=name)\n", 273 | " self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')\n", 274 | " self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')\n", 275 | " self.pred_layer = layers.Dense(10, activation='softmax', name='predictions')\n", 276 | "\n", 277 | " def call(self, inputs):\n", 278 | " x = self.dense_1(inputs)\n", 279 | " x = self.dense_2(x)\n", 280 | " return self.pred_layer(x)\n", 281 | "\n", 282 | "def get_model():\n", 283 | " return ThreeLayerMLP(name='3_layer_mlp')\n", 284 | "\n", 285 | "model = get_model()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 20, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "60000/60000 [==============================] - 2s 28us/sample - loss: 0.3217\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "# 训练模型\n", 303 | "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", 304 | "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n", 305 | "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n", 306 | "\n", 307 | "model.compile(loss='sparse_categorical_crossentropy',\n", 308 | " optimizer=keras.optimizers.RMSprop())\n", 309 | "history = model.fit(x_train, y_train,\n", 310 | " batch_size=64,\n", 311 | " epochs=1)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 21, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "# 保持权重参数\n", 321 | "model.save_weights('my_model_weights', save_format='tf')\n", 322 | "\n", 323 | "# 输出结果,供后面对比\n", 324 | "\n", 325 | "predictions = model.predict(x_test)\n", 326 | "first_batch_loss = model.train_on_batch(x_train[:64], y_train[:64])\n" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 24, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "# 读取保存的模型参数\n", 336 | "new_model = get_model()\n", 337 | "new_model.compile(loss='sparse_categorical_crossentropy',\n", 338 | " optimizer=keras.optimizers.RMSprop())\n", 339 | "\n", 340 | "#new_model.train_on_batch(x_train[:1], y_train[:1])\n", 341 | "\n", 342 | "new_model.load_weights('my_model_weights')\n", 343 | "\n", 344 | "new_predictions = new_model.predict(x_test)\n", 345 | "np.testing.assert_allclose(predictions, new_predictions, atol=1e-6)\n", 346 | "\n", 347 | "\n", 348 | "new_first_batch_loss = new_model.train_on_batch(x_train[:64], y_train[:64])\n", 349 | "assert first_batch_loss == new_first_batch_loss" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [] 358 | } 359 | ], 360 | "metadata": { 361 | "kernelspec": { 362 | "display_name": "Python 3", 363 | "language": "python", 364 | "name": "python3" 365 | }, 366 | "language_info": { 367 | "codemirror_mode": { 368 | "name": "ipython", 369 | "version": 3 370 | }, 371 | "file_extension": ".py", 372 | "mimetype": "text/x-python", 373 | "name": "python", 374 | "nbconvert_exporter": "python", 375 | "pygments_lexer": "ipython3", 376 | "version": "3.6.6" 377 | } 378 | }, 379 | "nbformat": 4, 380 | "nbformat_minor": 2 381 | } 382 | -------------------------------------------------------------------------------- /006-eager_execution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow2教程-Eager Execution" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "> 最全TensorFlow 2.0 入门教程持续更新:https://zhuanlan.zhihu.com/p/59507137\n", 15 | "> \n", 16 | "> 完整TensorFlow2.0教程代码请看https://github.com/czy36mengfei/tensorflow2_tutorials_chinese (欢迎star)\n", 17 | "> \n", 18 | "> 最新TensorFlow 2教程和相关资源,请关注微信公众号:DoitNLP, 后面我会在DoitNLP上,持续更新深度学习、NLP、Tensorflow的相关教程和前沿资讯,它将成为我们一起学习TensorFlow2的大本营。\n", 19 | ">\n", 20 | "> 本教程主要由tensorflow2.0官方教程的个人学习复现笔记整理而来,中文讲解,方便喜欢阅读中文教程的朋友,tensorflow官方教程:https://www.tensorflow.org" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "TensorFlow 的 Eager Execution 是一种命令式编程环境,可立即评估操作,无需构建图:操作会返回具体的值,而不是构建以后再运行的计算图。这样可以轻松地使用 TensorFlow 和调试模型,并且还减少了样板代码。\n", 28 | "\n", 29 | "Eager Execution 是一个灵活的机器学习平台,用于研究和实验,可提供:\n", 30 | "\n", 31 | "- 直观的界面 - 自然地组织代码结构并使用 Python 数据结构。快速迭代小模型和小型数据集。\n", 32 | "- 更轻松的调试功能 - 直接调用操作以检查正在运行的模型并测试更改。使用标准 Python 调试工具进行即时错误报告。\n", 33 | "- 自然控制流程 - 使用 Python 控制流程而不是图控制流程,简化了动态模型的规范。" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "2.0.0-beta1\n" 46 | ] 47 | }, 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "True" 52 | ] 53 | }, 54 | "execution_count": 1, 55 | "metadata": {}, 56 | "output_type": "execute_result" 57 | } 58 | ], 59 | "source": [ 60 | "from __future__ import absolute_import, division, print_function\n", 61 | "\n", 62 | "import tensorflow as tf\n", 63 | "print(tf.__version__)\n", 64 | "# 在tensorflow2中默认使用Eager Execution\n", 65 | "tf.executing_eagerly()" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## 1.Eager Execution下运算" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "在Eager Execution下可以直接进行运算,结果会立即返回" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 2, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "tf.Tensor([[9.]], shape=(1, 1), dtype=float32)\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "x = [[3.]]\n", 97 | "m = tf.matmul(x, x)\n", 98 | "print(m)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "启用Eager Execution会改变TensorFlow操作的行为 - 现在他们会立即评估并将其值返回给Python。tf.Tensor对象引用具体值,而不再是指向计算图中节点的符号句柄。由于在会话中没有构建和运行的计算图,因此使用print()或调试程序很容易检查结果。评估,打印和检查张量值不会破坏计算梯度的流程。\n", 106 | "\n", 107 | "Eager Execution可以与NumPy很好地协作。NumPy操作接受tf.Tensor参数。TensorFlow 数学运算将Python对象和NumPy数组转换为tf.Tensor对象。tf.Tensor.numpy方法将对象的值作为NumPy的ndarray类型返回。" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "tf.Tensor(\n", 120 | "[[1 9]\n", 121 | " [3 6]], shape=(2, 2), dtype=int32)\n", 122 | "tf.Tensor(\n", 123 | "[[ 3 11]\n", 124 | " [ 5 8]], shape=(2, 2), dtype=int32)\n", 125 | "tf.Tensor(\n", 126 | "[[ 3 99]\n", 127 | " [15 48]], shape=(2, 2), dtype=int32)\n", 128 | "[[ 3 99]\n", 129 | " [15 48]]\n", 130 | "[[1 9]\n", 131 | " [3 6]]\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "# tf.Tensor对象引用具体值\n", 137 | "a = tf.constant([[1,9],[3,6]])\n", 138 | "print(a)\n", 139 | "\n", 140 | "# 支持broadcasting(广播:不同shape的数据进行数学运算)\n", 141 | "b = tf.add(a, 2)\n", 142 | "print(b)\n", 143 | "\n", 144 | "# 支持运算符重载\n", 145 | "print(a*b)\n", 146 | "\n", 147 | "# 可以当做numpy数据使用\n", 148 | "import numpy as np\n", 149 | "s = np.multiply(a,b)\n", 150 | "print(s)\n", 151 | "\n", 152 | "# 转换为numpy类型\n", 153 | "print(a.numpy())" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "## 2.动态控制流" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Eager Execution的一个主要好处是在执行模型时可以使用宿主语言(Python)的所有功能。所以,例如,写fizzbuzz很容易:" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 4, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "1\n", 180 | "2\n", 181 | "Fizz\n", 182 | "4\n", 183 | "Buzz\n", 184 | "Fizz\n", 185 | "7\n", 186 | "8\n", 187 | "Fizz\n", 188 | "Buzz\n", 189 | "11\n", 190 | "Fizz\n", 191 | "13\n", 192 | "14\n", 193 | "FizzBuzz\n", 194 | "16\n" 195 | ] 196 | } 197 | ], 198 | "source": [ 199 | "def fizzbuzz(max_num):\n", 200 | " counter = tf.constant(0)\n", 201 | " max_num = tf.convert_to_tensor(max_num)\n", 202 | " # 使用range遍历\n", 203 | " for num in range(1, max_num.numpy()+1):\n", 204 | " # 重新转为tensor类型\n", 205 | " num = tf.constant(num)\n", 206 | " # 使用if-elif 做判断\n", 207 | " if int(num % 3) == 0 and int(num % 5) == 0:\n", 208 | " print('FizzBuzz')\n", 209 | " elif int(num % 3) == 0:\n", 210 | " print('Fizz')\n", 211 | " elif int(num % 5) == 0:\n", 212 | " print('Buzz')\n", 213 | " else:\n", 214 | " print(num.numpy())\n", 215 | " counter += 1 # 自加运算\n", 216 | "fizzbuzz(16)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "## 3.在Eager Execution下训练\n" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "### 计算梯度\n", 231 | "自动微分对于实现机器学习算法(例如用于训练神经网络的反向传播)来说是很有用的。在 Eager Execution中,使用 tf.GradientTape 来跟踪操作以便稍后计算梯度。\n", 232 | "\n", 233 | "可以用tf.GradientTape来训练和/或计算梯度。它对复杂的训练循环特别有用。\n", 234 | "\n", 235 | "由于在每次调用期间可能发生不同的操作,所有前向传递操作都被记录到“磁带”中。要计算梯度,请向反向播放磁带,然后丢弃。特定的tf.GradientTape只能计算一个梯度; 后续调用会引发运行时错误。" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 5, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "tf.Tensor([[2.]], shape=(1, 1), dtype=float32)\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "w = tf.Variable([[1.0]])\n", 253 | "# 用tf.GradientTape()记录梯度\n", 254 | "with tf.GradientTape() as tape:\n", 255 | " loss = w*w\n", 256 | "grad = tape.gradient(loss, w) # 计算梯度\n", 257 | "print(grad)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "### 训练模型" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 6, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "Logits: [[-0.0330125 0.00557926 0.01437856 -0.07978292 -0.05083258 -0.02875553\n", 277 | " -0.02977117 -0.02533271 0.0575003 0.03532691]]\n", 278 | "........................................" 279 | ] 280 | }, 281 | { 282 | "data": { 283 | "text/plain": [ 284 | "Text(0, 0.5, 'Loss [entropy]')" 285 | ] 286 | }, 287 | "execution_count": 6, 288 | "metadata": {}, 289 | "output_type": "execute_result" 290 | } 291 | ], 292 | "source": [ 293 | "# 导入mnist数据\n", 294 | "(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()\n", 295 | "# 数据转换\n", 296 | "dataset = tf.data.Dataset.from_tensor_slices(\n", 297 | " (tf.cast(mnist_images[...,tf.newaxis]/255, tf.float32),\n", 298 | " tf.cast(mnist_labels,tf.int64)))\n", 299 | "# 数据打乱与分批次\n", 300 | "dataset = dataset.shuffle(1000).batch(32)\n", 301 | "# 使用Sequential构建一个卷积网络\n", 302 | "mnist_model = tf.keras.Sequential([\n", 303 | " tf.keras.layers.Conv2D(16,[3,3], activation='relu', \n", 304 | " input_shape=(None, None, 1)),\n", 305 | " tf.keras.layers.Conv2D(16,[3,3], activation='relu'),\n", 306 | " tf.keras.layers.GlobalAveragePooling2D(),\n", 307 | " tf.keras.layers.Dense(10)\n", 308 | "])\n", 309 | "# 展示数据\n", 310 | "# 即使没有经过培训,也可以调用模型并在Eager Execution中检查输出\n", 311 | "for images,labels in dataset.take(1):\n", 312 | " print(\"Logits: \", mnist_model(images[0:1]).numpy())\n", 313 | " \n", 314 | "# 优化器与损失函数\n", 315 | "optimizer = tf.keras.optimizers.Adam()\n", 316 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", 317 | "\n", 318 | "# 按批次训练\n", 319 | "# 虽然 keras 模型具有内置训练循环(fit 方法),但有时需要更多自定义设置。下面是一个用 eager 实现的训练循环示例:\n", 320 | "loss_history = []\n", 321 | "for (batch, (images, labels)) in enumerate(dataset.take(400)):\n", 322 | " if batch % 10 == 0:\n", 323 | " print('.', end='')\n", 324 | " with tf.GradientTape() as tape:\n", 325 | " # 获取预测结果\n", 326 | " logits = mnist_model(images, training=True)\n", 327 | " # 获取损失\n", 328 | " loss_value = loss_object(labels, logits)\n", 329 | "\n", 330 | " loss_history.append(loss_value.numpy().mean())\n", 331 | " # 获取本批数据梯度\n", 332 | " grads = tape.gradient(loss_value, mnist_model.trainable_variables)\n", 333 | " # 反向传播优化\n", 334 | " optimizer.apply_gradients(zip(grads, mnist_model.trainable_variables))\n", 335 | " \n", 336 | "# 绘图展示loss变化\n", 337 | "import matplotlib.pyplot as plt\n", 338 | "plt.plot(loss_history)\n", 339 | "plt.xlabel('Batch #')\n", 340 | "plt.ylabel('Loss [entropy]')" 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": {}, 346 | "source": [ 347 | "## 4.变量求导优化" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": {}, 353 | "source": [ 354 | "tf.Variable对象存储在训练期间访问的可变tf.Tensor值,以使自动微分更容易。模型的参数可以作为变量封装在类中。\n", 355 | "\n", 356 | "将tf.Variable 和tf.GradientTape 结合,可以更好地封装模型参数。例如,可以重写上面的自动微分示例为:" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 7, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "Initial loss: 69.425\n", 369 | "Loss at step 000: 66.713\n", 370 | "Loss at step 020: 30.274\n", 371 | "Loss at step 040: 14.048\n", 372 | "Loss at step 060: 6.822\n", 373 | "Loss at step 080: 3.604\n", 374 | "Loss at step 100: 2.171\n", 375 | "Loss at step 120: 1.533\n", 376 | "Loss at step 140: 1.249\n", 377 | "Loss at step 160: 1.122\n", 378 | "Loss at step 180: 1.066\n", 379 | "Loss at step 200: 1.040\n", 380 | "Loss at step 220: 1.029\n", 381 | "Loss at step 240: 1.024\n", 382 | "Loss at step 260: 1.022\n", 383 | "Loss at step 280: 1.021\n", 384 | "Final loss: 1.021\n", 385 | "W = 2.9795801639556885, B = 2.0041041374206543\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "class MyModel(tf.keras.Model):\n", 391 | " def __init__(self):\n", 392 | " super(MyModel, self).__init__()\n", 393 | " self.W = tf.Variable(5., name='weight')\n", 394 | " self.B = tf.Variable(10., name='bias')\n", 395 | " def call(self, inputs):\n", 396 | " return inputs * self.W + self.B\n", 397 | "\n", 398 | "# 满足函数3 * x + 2的数据\n", 399 | "NUM_EXAMPLES = 2000\n", 400 | "training_inputs = tf.random.normal([NUM_EXAMPLES])\n", 401 | "noise = tf.random.normal([NUM_EXAMPLES])\n", 402 | "training_outputs = training_inputs * 3 + 2 + noise\n", 403 | "\n", 404 | "# 损失函数\n", 405 | "def loss(model, inputs, targets):\n", 406 | " error = model(inputs) - targets\n", 407 | " return tf.reduce_mean(tf.square(error))\n", 408 | "\n", 409 | "# 梯度函数\n", 410 | "def grad(model, inputs, targets):\n", 411 | " with tf.GradientTape() as tape:\n", 412 | " loss_value = loss(model, inputs, targets)\n", 413 | " return tape.gradient(loss_value, [model.W, model.B])\n", 414 | "\n", 415 | "# 模型与优化器\n", 416 | "model = MyModel()\n", 417 | "optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n", 418 | "\n", 419 | "print(\"Initial loss: {:.3f}\".format(loss(model, training_inputs, training_outputs)))\n", 420 | "\n", 421 | "# 训练循环, 反向传播优化\n", 422 | "for i in range(300):\n", 423 | " grads = grad(model, training_inputs, training_outputs)\n", 424 | " optimizer.apply_gradients(zip(grads, [model.W, model.B]))\n", 425 | " if i % 20 == 0:\n", 426 | " print(\"Loss at step {:03d}: {:.3f}\".format(i, loss(model, training_inputs, training_outputs)))\n", 427 | "\n", 428 | "print(\"Final loss: {:.3f}\".format(loss(model, training_inputs, training_outputs)))\n", 429 | "print(\"W = {}, B = {}\".format(model.W.numpy(), model.B.numpy()))" 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "metadata": {}, 435 | "source": [ 436 | "## 5.Eager Execution中的对象" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "使用 Graph Execution 时,程序状态(如变量)存储在全局集合中,它们的生命周期由 tf.Session 对象管理。相反,在 Eager Execution 期间,状态对象的生命周期由其对应的 Python 对象的生命周期决定。" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "metadata": {}, 449 | "source": [ 450 | "### 变量对象\n", 451 | "变量将持续存在,直到删除对象的最后一个引用,然后变量被删除。" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 8, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "if tf.test.is_gpu_available():\n", 461 | " with tf.device(\"gpu:0\"):\n", 462 | " v = tf.Variable(tf.random.normal([1000, 1000]))\n", 463 | " v = None # v no longer takes up GPU memory" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "### 基于对象的保存\n", 471 | "tf.train.Checkpoint 可以将 tf.Variable 保存到检查点并从中恢复:" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 9, 477 | "metadata": {}, 478 | "outputs": [ 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "\n" 484 | ] 485 | } 486 | ], 487 | "source": [ 488 | "# 使用检测点保存变量\n", 489 | "x = tf.Variable(6.0)\n", 490 | "checkpoint = tf.train.Checkpoint(x=x)\n", 491 | "# 变量的改变会同步到检测点\n", 492 | "x.assign(1.0)\n", 493 | "checkpoint.save('./ckpt/')\n", 494 | "# 检测点保存后,变量的改变对检测点无影响\n", 495 | "x.assign(8.0)\n", 496 | "checkpoint.restore(tf.train.latest_checkpoint('./ckpt/'))\n", 497 | "print(x)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "要保存和加载模型,tf.train.Checkpoint 会存储对象的内部状态,而不需要隐藏变量。要记录 model、optimizer 和全局步的状态,可以将它们传递到 tf.train.Checkpoint:" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 10, 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "data": { 514 | "text/plain": [ 515 | "" 516 | ] 517 | }, 518 | "execution_count": 10, 519 | "metadata": {}, 520 | "output_type": "execute_result" 521 | } 522 | ], 523 | "source": [ 524 | "# 模型保持\n", 525 | "import os\n", 526 | "model = tf.keras.Sequential([\n", 527 | " tf.keras.layers.Conv2D(16,[3,3], activation='relu'),\n", 528 | " tf.keras.layers.GlobalAveragePooling2D(),\n", 529 | " tf.keras.layers.Dense(10)\n", 530 | "])\n", 531 | "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n", 532 | "checkpoint_dir = './ck_model_dir'\n", 533 | "if not os.path.exists(checkpoint_dir):\n", 534 | " os.makedirs(checkpoint_dir)\n", 535 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 536 | "# 将优化器和模型记录至检测点\n", 537 | "root = tf.train.Checkpoint(optimizer=optimizer,\n", 538 | " model=model)\n", 539 | "# 保存检测点\n", 540 | "root.save(checkpoint_prefix)\n", 541 | "# 读取检测点\n", 542 | "root.restore(tf.train.latest_checkpoint(checkpoint_dir))" 543 | ] 544 | }, 545 | { 546 | "cell_type": "markdown", 547 | "metadata": {}, 548 | "source": [ 549 | "### 面向对象的指标\n", 550 | "tf.keras.metrics存储为对象。通过将新数据传递给callable来更新度量标准,并使用tf.keras.metrics.result方法检索结果,例如:" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 11, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "name": "stdout", 560 | "output_type": "stream", 561 | "text": [ 562 | "tf.Tensor(2.5, shape=(), dtype=float32)\n", 563 | "tf.Tensor(5.5, shape=(), dtype=float32)\n" 564 | ] 565 | } 566 | ], 567 | "source": [ 568 | "m = tf.keras.metrics.Mean('loss')\n", 569 | "m(0)\n", 570 | "m(5)\n", 571 | "print(m.result()) # => 2.5\n", 572 | "m([8, 9])\n", 573 | "print(m.result()) # => 5.5" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": {}, 579 | "source": [ 580 | "## 6.自动微分高级内容" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": {}, 586 | "source": [ 587 | "### 动态模型\n", 588 | "tf.GradientTape 也可用于动态模型。这个回溯线搜索算法示例看起来像普通的 NumPy 代码,除了存在梯度并且可微分,尽管控制流比较复杂:" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": 12, 594 | "metadata": {}, 595 | "outputs": [], 596 | "source": [ 597 | "def line_search_step(fn, init_x, rate=1.0):\n", 598 | " with tf.GradientTape() as tape:\n", 599 | " # 变量会自动记录,但需要手动观察张量\n", 600 | " tape.watch(init_x)\n", 601 | " value = fn(init_x)\n", 602 | " grad = tape.gradient(value, init_x)\n", 603 | " grad_norm = tf.reduce_sum(grad * grad)\n", 604 | " init_value = value\n", 605 | " while value > init_value - rate * grad_norm:\n", 606 | " x = init_x - rate * grad\n", 607 | " value = fn(x)\n", 608 | " rate /= 2.0\n", 609 | " return x, value" 610 | ] 611 | }, 612 | { 613 | "cell_type": "markdown", 614 | "metadata": {}, 615 | "source": [ 616 | "### 自定义梯度\n", 617 | "自定义梯度是在 Eager Execution 和 Graph Execution 中覆盖梯度的一种简单方式。在正向函数中,定义相对于输入、输出或中间结果的梯度。例如,下面是在反向传播中截断梯度范数的一种简单方式:" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 13, 623 | "metadata": {}, 624 | "outputs": [], 625 | "source": [ 626 | "@tf.custom_gradient\n", 627 | "def clip_gradient_by_norm(x, norm):\n", 628 | " y = tf.identity(x)\n", 629 | " def grad_fn(dresult):\n", 630 | " return [tf.clip_by_norm(dresult, norm), None]\n", 631 | " return y, grad_fn" 632 | ] 633 | }, 634 | { 635 | "cell_type": "markdown", 636 | "metadata": {}, 637 | "source": [ 638 | "自定义梯度可以提供数值稳定的梯度" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 14, 644 | "metadata": {}, 645 | "outputs": [ 646 | { 647 | "name": "stdout", 648 | "output_type": "stream", 649 | "text": [ 650 | "0.5\n", 651 | "nan\n" 652 | ] 653 | } 654 | ], 655 | "source": [ 656 | "def log1pexp(x):\n", 657 | " return tf.math.log(1 + tf.exp(x))\n", 658 | "\n", 659 | "def grad_log1pexp(x):\n", 660 | " with tf.GradientTape() as tape:\n", 661 | " tape.watch(x)\n", 662 | " value = log1pexp(x)\n", 663 | " return tape.gradient(value, x)\n", 664 | "# 梯度计算在x = 0时工作正常。\n", 665 | "print(grad_log1pexp(tf.constant(0.)).numpy())\n", 666 | "# 但是,由于数值不稳定,x = 100失败。\n", 667 | "print(grad_log1pexp(tf.constant(100.)).numpy())" 668 | ] 669 | }, 670 | { 671 | "cell_type": "markdown", 672 | "metadata": {}, 673 | "source": [ 674 | "这里,log1pexp函数可以使用自定义梯度求导进行分析简化。 下面的实现重用了在前向传递期间计算的tf.exp(x)的值 - 通过消除冗余计算使其更有效:" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 15, 680 | "metadata": {}, 681 | "outputs": [ 682 | { 683 | "name": "stdout", 684 | "output_type": "stream", 685 | "text": [ 686 | "0.5\n", 687 | "1.0\n" 688 | ] 689 | } 690 | ], 691 | "source": [ 692 | "@tf.custom_gradient\n", 693 | "def log1pexp(x):\n", 694 | " e = tf.exp(x)\n", 695 | " def grad(dy):\n", 696 | " return dy * (1 - 1 / (1 + e))\n", 697 | " return tf.math.log(1 + e), grad\n", 698 | "\n", 699 | "def grad_log1pexp(x):\n", 700 | " with tf.GradientTape() as tape:\n", 701 | " tape.watch(x)\n", 702 | " value = log1pexp(x)\n", 703 | " return tape.gradient(value, x)\n", 704 | "# 和以前一样,梯度计算在x = 0时工作正常。\n", 705 | "print(grad_log1pexp(tf.constant(0.)).numpy())\n", 706 | "# 并且梯度计算也适用于x = 100\n", 707 | "print(grad_log1pexp(tf.constant(100.)).numpy())" 708 | ] 709 | }, 710 | { 711 | "cell_type": "markdown", 712 | "metadata": {}, 713 | "source": [ 714 | "## 7.使用gpu提升性能" 715 | ] 716 | }, 717 | { 718 | "cell_type": "markdown", 719 | "metadata": {}, 720 | "source": [ 721 | "在 Eager Execution 期间,计算会自动分流到 GPU。如果要控制计算运行的位置,可以将其放在 tf.device('/gpu:0') 块(或 CPU 等效块)中:" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 16, 727 | "metadata": {}, 728 | "outputs": [ 729 | { 730 | "name": "stdout", 731 | "output_type": "stream", 732 | "text": [ 733 | "Time to multiply a (1000, 1000) matrix by itself 200 times:\n", 734 | "CPU: 1.698425054550171 secs\n", 735 | "GPU: 0.13264727592468262 secs\n" 736 | ] 737 | } 738 | ], 739 | "source": [ 740 | "import time\n", 741 | "\n", 742 | "def measure(x, steps):\n", 743 | " # TensorFlow在第一次使用时初始化GPU,其不计入时间。\n", 744 | " tf.matmul(x, x)\n", 745 | " start = time.time()\n", 746 | " for i in range(steps):\n", 747 | " x = tf.matmul(x, x)\n", 748 | " # tf.matmul可以在完成矩阵乘法之前返回(例如,\n", 749 | " # 可以在对CUDA流进行操作之后返回)。 \n", 750 | " # 下面的x.numpy()调用将确保所有已排队\n", 751 | " # 的操作都已完成(并且还将结果复制到主机内存,\n", 752 | " # 因此我们只包括一些matmul操作时间)\n", 753 | " _ = x.numpy()\n", 754 | " end = time.time()\n", 755 | " return end - start\n", 756 | "\n", 757 | "shape = (1000, 1000)\n", 758 | "steps = 200\n", 759 | "print(\"Time to multiply a {} matrix by itself {} times:\".format(shape, steps))\n", 760 | "\n", 761 | "# 在CPU上运行:\n", 762 | "with tf.device(\"/cpu:0\"):\n", 763 | " print(\"CPU: {} secs\".format(measure(tf.random.normal(shape), steps)))\n", 764 | "\n", 765 | "# 在GPU上运行,如果可以的话:\n", 766 | "if tf.test.is_gpu_available():\n", 767 | " with tf.device(\"/gpu:0\"):\n", 768 | " print(\"GPU: {} secs\".format(measure(tf.random.normal(shape), steps)))\n", 769 | "else:\n", 770 | " print(\"GPU: not found\")" 771 | ] 772 | }, 773 | { 774 | "cell_type": "markdown", 775 | "metadata": {}, 776 | "source": [ 777 | "tf.Tensor对象可以被复制到不同的设备来执行其操作" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": 17, 783 | "metadata": {}, 784 | "outputs": [], 785 | "source": [ 786 | "if tf.test.is_gpu_available():\n", 787 | " x = tf.random.normal([10, 10])\n", 788 | " # 将tensor对象复制到gpu上\n", 789 | " x_gpu0 = x.gpu()\n", 790 | " x_cpu = x.cpu()\n", 791 | "\n", 792 | " _ = tf.matmul(x_cpu, x_cpu) # Runs on CPU\n", 793 | " _ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": null, 799 | "metadata": {}, 800 | "outputs": [], 801 | "source": [] 802 | } 803 | ], 804 | "metadata": { 805 | "kernelspec": { 806 | "display_name": "Python 3", 807 | "language": "python", 808 | "name": "python3" 809 | }, 810 | "language_info": { 811 | "codemirror_mode": { 812 | "name": "ipython", 813 | "version": 3 814 | }, 815 | "file_extension": ".py", 816 | "mimetype": "text/x-python", 817 | "name": "python", 818 | "nbconvert_exporter": "python", 819 | "pygments_lexer": "ipython3", 820 | "version": "3.6.6" 821 | } 822 | }, 823 | "nbformat": 4, 824 | "nbformat_minor": 2 825 | } 826 | -------------------------------------------------------------------------------- /007-Variables.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow教程-Variables" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "创建一个变量" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 7, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "\n", 29 | "no gpu\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "import tensorflow as tf\n", 35 | "my_var = tf.Variable(tf.ones([2,3]))\n", 36 | "print(my_var)\n", 37 | "try:\n", 38 | " with tf.device(\"/device:GPU:0\"):\n", 39 | " v = tf.Variable(tf.zeros([10, 10]))\n", 40 | " print(v)\n", 41 | "except:\n", 42 | " print('no gpu')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "使用变量\n" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 8, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "tf.Tensor(9.0, shape=(), dtype=float32)\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "a = tf.Variable(1.0)\n", 67 | "b = (a+2) *3\n", 68 | "print(b)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 9, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "tf.Tensor(9.0, shape=(), dtype=float32)\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "a = tf.Variable(1.0)\n", 86 | "b = (a.assign_add(2)) *3\n", 87 | "print(b)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "变量跟踪\n" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 11, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "(, , , , , , , , , , , )\n" 107 | ] 108 | }, 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "12" 113 | ] 114 | }, 115 | "execution_count": 11, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "class MyModuleOne(tf.Module):\n", 122 | " def __init__(self):\n", 123 | " self.v0 = tf.Variable(1.0)\n", 124 | " self.vs = [tf.Variable(x) for x in range(10)]\n", 125 | " \n", 126 | "class MyOtherModule(tf.Module):\n", 127 | " def __init__(self):\n", 128 | " self.m = MyModuleOne()\n", 129 | " self.v = tf.Variable(10.0)\n", 130 | " \n", 131 | "m = MyOtherModule()\n", 132 | "print(m.variables)\n", 133 | "len(m.variables) " 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.6.8" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /008-AutoGraph.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow2.0教程-AutoGraph" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "\n", 15 | "tf.function的一个很酷的新功能是AutoGraph,它允许使用自然的Python语法编写图形代码。" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from __future__ import absolute_import, division, print_function\n", 25 | "import numpy as np\n", 26 | "import tensorflow as tf\n", 27 | "from tensorflow.python.ops import control_flow_util\n", 28 | "control_flow_util.ENABLE_CONTROL_FLOW_V2 = True" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## 1.tf.function装饰器\n", 36 | "当使用tf.function注释函数时,可以像调用任何其他函数一样调用它。 \n", 37 | "它将被编译成图,这意味着可以获得更快执行,更好地在GPU或TPU上运行或导出到SavedModel。" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/plain": [ 48 | "" 52 | ] 53 | }, 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "output_type": "execute_result" 57 | } 58 | ], 59 | "source": [ 60 | "@tf.function\n", 61 | "def simple_nn_layer(x, y):\n", 62 | " return tf.nn.relu(tf.matmul(x, y))\n", 63 | "\n", 64 | "\n", 65 | "x = tf.random.uniform((3, 3))\n", 66 | "y = tf.random.uniform((3, 3))\n", 67 | "\n", 68 | "simple_nn_layer(x, y)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "如果我们检查注释的结果,我们可以看到它是一个特殊的可调用函数,它处理与TensorFlow运行时的所有交互。\n", 76 | "\n" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "" 88 | ] 89 | }, 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "simple_nn_layer" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "如果代码使用多个函数,则无需对它们进行全部注释 \n", 104 | "- 从带注释函数调用的任何函数也将以图形模式运行。" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "" 116 | ] 117 | }, 118 | "execution_count": 4, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | } 122 | ], 123 | "source": [ 124 | "def linear_layer(x):\n", 125 | " return 2 * x + 1\n", 126 | "\n", 127 | "\n", 128 | "@tf.function\n", 129 | "def deep_net(x):\n", 130 | " return tf.nn.relu(linear_layer(x))\n", 131 | "\n", 132 | "\n", 133 | "deep_net(tf.constant((1, 2, 3)))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "## 2.使用Python控制流程\n", 141 | "在tf.function中使用依赖于数据的控制流时,可以使用Python控制流语句,AutoGraph会将它们转换为适当的TensorFlow操作。 例如,如果语句依赖于Tensor,则语句将转换为tf.cond()。" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "name": "stdout", 151 | "output_type": "stream", 152 | "text": [ 153 | "square_if_positive(2) = 4\n", 154 | "square_if_positive(-2) = 0\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "@tf.function\n", 160 | "def square_if_positive(x):\n", 161 | " if x > 0:\n", 162 | " x = x * x\n", 163 | " else:\n", 164 | " x = 0\n", 165 | " return x\n", 166 | "\n", 167 | "\n", 168 | "print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))\n", 169 | "print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))\n", 170 | "\n" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "AutoGraph支持常见的Python语句,例如while,if,break,continue和return,支持嵌套。 这意味着可以在while和if语句的条件下使用Tensor表达式,或者在for循环中迭代Tensor。" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 6, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "data": { 187 | "text/plain": [ 188 | "" 189 | ] 190 | }, 191 | "execution_count": 6, 192 | "metadata": {}, 193 | "output_type": "execute_result" 194 | } 195 | ], 196 | "source": [ 197 | "@tf.function\n", 198 | "def sum_even(items):\n", 199 | " s = 0\n", 200 | " for c in items:\n", 201 | " if c % 2 > 0:\n", 202 | " continue\n", 203 | " s += c\n", 204 | " return s\n", 205 | "\n", 206 | "\n", 207 | "sum_even(tf.constant([10, 12, 15, 20]))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "AutoGraph还为高级用户提供了低级API。 例如,我们可以使用它来查看生成的代码。" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 7, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "from __future__ import print_function\n", 227 | "\n", 228 | "def tf__sum_even(items):\n", 229 | " do_return = False\n", 230 | " retval_ = None\n", 231 | " s = 0\n", 232 | "\n", 233 | " def loop_body(loop_vars, s_2):\n", 234 | " c = loop_vars\n", 235 | " continue_ = False\n", 236 | " cond = c % 2 > 0\n", 237 | "\n", 238 | " def if_true():\n", 239 | " continue_ = True\n", 240 | " return continue_\n", 241 | "\n", 242 | " def if_false():\n", 243 | " return continue_\n", 244 | " continue_ = ag__.if_stmt(cond, if_true, if_false)\n", 245 | " cond_1 = ag__.not_(continue_)\n", 246 | "\n", 247 | " def if_true_1():\n", 248 | " s_1, = s_2,\n", 249 | " s_1 += c\n", 250 | " return s_1\n", 251 | "\n", 252 | " def if_false_1():\n", 253 | " return s_2\n", 254 | " s_2 = ag__.if_stmt(cond_1, if_true_1, if_false_1)\n", 255 | " return s_2,\n", 256 | " s, = ag__.for_stmt(items, None, loop_body, (s,))\n", 257 | " do_return = True\n", 258 | " retval_ = s\n", 259 | " return retval_\n", 260 | "\n", 261 | "\n", 262 | "\n", 263 | "tf__sum_even.autograph_info__ = {}\n", 264 | "\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "print(tf.autograph.to_code(sum_even.python_function, experimental_optional_features=None))" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | "一个更复杂的控制流程的例子:" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 8, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "Fizz\n", 289 | "1\n", 290 | "2\n", 291 | "Fizz\n", 292 | "4\n", 293 | "Buzz\n", 294 | "Fizz\n", 295 | "7\n", 296 | "8\n", 297 | "Fizz\n", 298 | "Buzz\n", 299 | "11\n", 300 | "Fizz\n", 301 | "13\n", 302 | "14\n", 303 | "\n" 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "@tf.function\n", 309 | "def fizzbuzz(n):\n", 310 | " msg = tf.constant('')\n", 311 | " for i in tf.range(n):\n", 312 | " if tf.equal(i % 3, 0):\n", 313 | " msg += 'Fizz'\n", 314 | " elif tf.equal(i % 5, 0):\n", 315 | " msg += 'Buzz'\n", 316 | " else:\n", 317 | " msg += tf.as_string(i)\n", 318 | " msg += '\\n'\n", 319 | " return msg\n", 320 | "\n", 321 | "\n", 322 | "print(fizzbuzz(tf.constant(15)).numpy().decode())" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "## 3.Keras和AutoGraph\n", 330 | "也可以将tf.function与对象方法一起使用。 例如,可以通过注释模型的调用函数来装饰自定义Keras模型。" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 13, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "data": { 340 | "text/plain": [ 341 | "" 342 | ] 343 | }, 344 | "execution_count": 13, 345 | "metadata": {}, 346 | "output_type": "execute_result" 347 | } 348 | ], 349 | "source": [ 350 | "class CustomModel(tf.keras.models.Model):\n", 351 | "\n", 352 | " @tf.function\n", 353 | " def call(self, input_data):\n", 354 | " if tf.reduce_mean(input_data) > 0:\n", 355 | " return input_data\n", 356 | " else:\n", 357 | " return input_data // 2\n", 358 | "\n", 359 | "\n", 360 | "model = CustomModel()\n", 361 | "\n", 362 | "model(tf.constant([-2, -4]))" 363 | ] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": {}, 368 | "source": [ 369 | "副作用\n", 370 | "就像在eager模式下一样,你可以使用带有副作用的操作,比如通常在tf.function中的tf.assign或tf.print,它会插入必要的控件依赖项以确保它们按顺序执行。" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 14, 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "" 382 | ] 383 | }, 384 | "execution_count": 14, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "v = tf.Variable(5)\n", 391 | "\n", 392 | "@tf.function\n", 393 | "def find_next_odd():\n", 394 | " v.assign(v + 1)\n", 395 | " if tf.equal(v % 2, 0):\n", 396 | " v.assign(v + 1)\n", 397 | "\n", 398 | "\n", 399 | "find_next_odd()\n", 400 | "v" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "## 4.用AutoGraph训练一个简单模型" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 15, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stdout", 417 | "output_type": "stream", 418 | "text": [ 419 | "Step 10 : loss 1.85892391 ; accuracy 0.37\n", 420 | "Step 20 : loss 1.25817835 ; accuracy 0.5125\n", 421 | "Step 30 : loss 0.871798933 ; accuracy 0.605666637\n", 422 | "Step 40 : loss 0.669722676 ; accuracy 0.66\n", 423 | "Step 50 : loss 0.413253099 ; accuracy 0.6976\n", 424 | "Step 60 : loss 0.340167761 ; accuracy 0.728\n", 425 | "Step 70 : loss 0.438654482 ; accuracy 0.748428583\n", 426 | "Step 80 : loss 0.346158206 ; accuracy 0.766875\n", 427 | "Step 90 : loss 0.241747677 ; accuracy 0.780555546\n", 428 | "Step 100 : loss 0.229299903 ; accuracy 0.7935\n", 429 | "Step 110 : loss 0.300931275 ; accuracy 0.803727269\n", 430 | "Step 120 : loss 0.369899929 ; accuracy 0.812583327\n", 431 | "Step 130 : loss 0.305111647 ; accuracy 0.82\n", 432 | "Step 140 : loss 0.396656752 ; accuracy 0.825857162\n", 433 | "Step 150 : loss 0.308267832 ; accuracy 0.831266642\n", 434 | "Step 160 : loss 0.323994666 ; accuracy 0.836312473\n", 435 | "Step 170 : loss 0.264144391 ; accuracy 0.840941191\n", 436 | "Step 180 : loss 0.450227708 ; accuracy 0.844333351\n", 437 | "Step 190 : loss 0.213473886 ; accuracy 0.848105252\n", 438 | "Step 200 : loss 0.224886 ; accuracy 0.85145\n", 439 | "Final step tf.Tensor(200, shape=(), dtype=int32) : loss tf.Tensor(0.224886, shape=(), dtype=float32) ; accuracy tf.Tensor(0.85145, shape=(), dtype=float32)\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "def prepare_mnist_features_and_labels(x, y):\n", 445 | " x = tf.cast(x, tf.float32) / 255.0\n", 446 | " y = tf.cast(y, tf.int64)\n", 447 | " return x, y\n", 448 | "\n", 449 | "def mnist_dataset():\n", 450 | " (x, y), _ = tf.keras.datasets.mnist.load_data()\n", 451 | " ds = tf.data.Dataset.from_tensor_slices((x, y))\n", 452 | " ds = ds.map(prepare_mnist_features_and_labels)\n", 453 | " ds = ds.take(20000).shuffle(20000).batch(100)\n", 454 | " return ds\n", 455 | "\n", 456 | "train_dataset = mnist_dataset()\n", 457 | "model = tf.keras.Sequential((\n", 458 | " tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),\n", 459 | " tf.keras.layers.Dense(100, activation='relu'),\n", 460 | " tf.keras.layers.Dense(100, activation='relu'),\n", 461 | " tf.keras.layers.Dense(10)))\n", 462 | "model.build()\n", 463 | "optimizer = tf.keras.optimizers.Adam()\n", 464 | "compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", 465 | "\n", 466 | "compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()\n", 467 | "\n", 468 | "\n", 469 | "def train_one_step(model, optimizer, x, y):\n", 470 | " with tf.GradientTape() as tape:\n", 471 | " logits = model(x)\n", 472 | " loss = compute_loss(y, logits)\n", 473 | "\n", 474 | " grads = tape.gradient(loss, model.trainable_variables)\n", 475 | " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", 476 | "\n", 477 | " compute_accuracy(y, logits)\n", 478 | " return loss\n", 479 | "\n", 480 | "\n", 481 | "@tf.function\n", 482 | "def train(model, optimizer):\n", 483 | " train_ds = mnist_dataset()\n", 484 | " step = 0\n", 485 | " loss = 0.0\n", 486 | " accuracy = 0.0\n", 487 | " for x, y in train_ds:\n", 488 | " step += 1\n", 489 | " loss = train_one_step(model, optimizer, x, y)\n", 490 | " if tf.equal(step % 10, 0):\n", 491 | " tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())\n", 492 | " return step, loss, accuracy\n", 493 | "\n", 494 | "step, loss, accuracy = train(model, optimizer)\n", 495 | "print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())" 496 | ] 497 | }, 498 | { 499 | "cell_type": "markdown", 500 | "metadata": {}, 501 | "source": [ 502 | "## 5.关于批处理的说明\n", 503 | "在实际应用中,批处理对性能至关重要。 转换为AutoGraph的最佳代码是在批处理级别决定控制流的代码。 如果在单个示例级别做出决策,请尝试使用批处理API来维护性能。" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 16, 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]" 515 | ] 516 | }, 517 | "execution_count": 16, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "def square_if_positive(x):\n", 524 | " return [i ** 2 if i > 0 else i for i in x]\n", 525 | "\n", 526 | "\n", 527 | "square_if_positive(range(-5, 5))" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 17, 533 | "metadata": {}, 534 | "outputs": [ 535 | { 536 | "data": { 537 | "text/plain": [ 538 | "" 539 | ] 540 | }, 541 | "execution_count": 17, 542 | "metadata": {}, 543 | "output_type": "execute_result" 544 | } 545 | ], 546 | "source": [ 547 | "# 在tensorflow中上面的代码应该改成下面所示\n", 548 | "@tf.function\n", 549 | "def square_if_positive_naive(x):\n", 550 | " result = tf.TensorArray(tf.int32, size=x.shape[0])\n", 551 | " for i in tf.range(x.shape[0]):\n", 552 | " if x[i] > 0:\n", 553 | " result = result.write(i, x[i] ** 2)\n", 554 | " else:\n", 555 | " result = result.write(i, x[i])\n", 556 | " return result.stack()\n", 557 | "\n", 558 | "\n", 559 | "square_if_positive_naive(tf.range(-5, 5))" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 18, 565 | "metadata": {}, 566 | "outputs": [ 567 | { 568 | "data": { 569 | "text/plain": [ 570 | "" 571 | ] 572 | }, 573 | "execution_count": 18, 574 | "metadata": {}, 575 | "output_type": "execute_result" 576 | } 577 | ], 578 | "source": [ 579 | "# 也可以怎么写\n", 580 | "def square_if_positive_vectorized(x):\n", 581 | " return tf.where(x > 0, x ** 2, x)\n", 582 | "\n", 583 | "\n", 584 | "square_if_positive_vectorized(tf.range(-5, 5))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [] 593 | } 594 | ], 595 | "metadata": { 596 | "kernelspec": { 597 | "display_name": "Python 3", 598 | "language": "python", 599 | "name": "python3" 600 | }, 601 | "language_info": { 602 | "codemirror_mode": { 603 | "name": "ipython", 604 | "version": 3 605 | }, 606 | "file_extension": ".py", 607 | "mimetype": "text/x-python", 608 | "name": "python", 609 | "nbconvert_exporter": "python", 610 | "pygments_lexer": "ipython3", 611 | "version": "3.6.6" 612 | } 613 | }, 614 | "nbformat": 4, 615 | "nbformat_minor": 2 616 | } 617 | -------------------------------------------------------------------------------- /020-Eager/001-Tensor_and_operations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow2.0教程-张量极其操作" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## 导入TensorFlow\n", 15 | "运行tensorflow程序,需要导入tensorflow模块。\n", 16 | "从TensorFlow 2.0开始,默认情况下会启用急切执行。 这为TensorFlow提供了一个更加互动的前端节。" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from __future__ import absolute_import, division, print_function\n", 26 | "import tensorflow as tf" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## 1 Tensors\n", 34 | "张量是一个多维数组。 与NumPy ndarray对象类似,tf.Tensor对象具有数据类型和形状。 此外,tf.Tensors可以驻留在加速器内存中(如GPU)。 TensorFlow提供了丰富的操作库(tf.add,tf.matmul,tf.linalg.inv等),它们使用和生成tf.Tensors。 这些操作会自动转换原生Python类型,例如:" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "tf.Tensor(3, shape=(), dtype=int32)\n", 47 | "tf.Tensor([ 5 13], shape=(2,), dtype=int32)\n", 48 | "tf.Tensor(36, shape=(), dtype=int32)\n", 49 | "tf.Tensor(24, shape=(), dtype=int32)\n", 50 | "tf.Tensor(25, shape=(), dtype=int32)\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "print(tf.add(1,2))\n", 56 | "print(tf.add([3,8], [2,5]))\n", 57 | "print(tf.square(6))\n", 58 | "print(tf.reduce_sum([7,8,9]))\n", 59 | "print(tf.square(3)+tf.square(4))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "每个Tensor都有形状和类型" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 5, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "tf.Tensor(\n", 79 | "[[ 6]\n", 80 | " [12]], shape=(2, 1), dtype=int32)\n", 81 | "(2, 1)\n", 82 | "\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "x = tf.matmul([[3], [6]], [[2]])\n", 88 | "print(x)\n", 89 | "print(x.shape)\n", 90 | "print(x.dtype)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "NumPy数组和tf.Tensors之间最明显的区别是:\n", 98 | "\n", 99 | "张量可以由加速器内存(如GPU,TPU)支持。\n", 100 | "张量是不可变的。" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "NumPy兼容性\n", 108 | "在TensorFlow tf.Tensors和NumPy ndarray之间转换很容易:\n", 109 | "\n", 110 | "TensorFlow操作自动将NumPy ndarrays转换为Tensors。\n", 111 | "NumPy操作自动将Tensors转换为NumPy ndarrays。\n", 112 | "使用.numpy()方法将张量显式转换为NumPy ndarrays。 这些转换通常很容易的,因为如果可能,array和tf.Tensor共享底层内存表示。 但是,共享底层表示并不总是可行的,因为tf.Tensor可以托管在GPU内存中,而NumPy阵列总是由主机内存支持,并且转换涉及从GPU到主机内存的复制。" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 7, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "tf.Tensor(\n", 125 | "[[36. 36.]\n", 126 | " [36. 36.]], shape=(2, 2), dtype=float64)\n", 127 | "[[37. 37.]\n", 128 | " [37. 37.]]\n", 129 | "[[36. 36.]\n", 130 | " [36. 36.]]\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "import numpy as np\n", 136 | "ndarray = np.ones([2,2])\n", 137 | "tensor = tf.multiply(ndarray, 36)\n", 138 | "print(tensor)\n", 139 | "# 用np.add对tensorflow进行加运算\n", 140 | "print(np.add(tensor, 1))\n", 141 | "# 转换为numpy类型\n", 142 | "print(tensor.numpy())" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## 2 GPU加速\n", 150 | "使用GPU进行计算可以加速许多TensorFlow操作。 如果没有任何注释,TensorFlow会自动决定是使用GPU还是CPU进行操作 - 如有必要,可以复制CPU和GPU内存之间的张量。 由操作产生的张量通常由执行操作的设备的存储器支持,例如:" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 8, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "Is GPU availabel:\n", 163 | "False\n", 164 | "Is the Tensor on gpu #0:\n", 165 | "False\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "x = tf.random.uniform([3, 3])\n", 171 | "print('Is GPU availabel:')\n", 172 | "print(tf.test.is_gpu_available())\n", 173 | "print('Is the Tensor on gpu #0:')\n", 174 | "print(x.device.endswith('GPU:0'))" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": {}, 180 | "source": [ 181 | "**设备名称**\n", 182 | "\n", 183 | "Tensor.device属性提供托管张量内容的设备的完全限定字符串名称。 此名称编码许多详细信息,例如正在执行此程序的主机的网络地址的标识符以及该主机中的设备。 这是分布式执行TensorFlow程序所必需的。 如果张量位于主机上的第N个GPU上,则字符串以GPU结尾:。" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "**显式设备放置(Placement)**\n", 191 | "\n", 192 | "在TensorFlow中,放置指的是如何分配(放置)设备以执行各个操作。 如上所述,如果没有提供明确的指导,TensorFlow会自动决定执行操作的设备,并在需要时将张量复制到该设备。 但是,可以使用tf.device上下文管理器将TensorFlow操作显式放置在特定设备上,例如:" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 9, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "On CPU:\n", 205 | "10 loops: 1.2e+02ms\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "import time\n", 211 | "def time_matmul(x):\n", 212 | " start = time.time()\n", 213 | " for loop in range(10):\n", 214 | " tf.matmul(x, x)\n", 215 | " result = time.time() - start\n", 216 | " print('10 loops: {:0.2}ms'.format(1000*result))\n", 217 | " \n", 218 | "# 强制使用CPU\n", 219 | "print('On CPU:')\n", 220 | "with tf.device('CPU:0'):\n", 221 | " x = tf.random.uniform([1000, 1000])\n", 222 | " # 使用断言验证当前是否为CPU0\n", 223 | " assert x.device.endswith('CPU:0')\n", 224 | " time_matmul(x) \n", 225 | "\n", 226 | "# 如果存在GPU,强制使用GPU\n", 227 | "if tf.test.is_gpu_available():\n", 228 | " print('On GPU:')\n", 229 | " with tf.device.endswith('GPU:0'):\n", 230 | " x = tf.random.uniform([1000, 1000])\n", 231 | " # 使用断言验证当前是否为GPU0\n", 232 | " assert x.device.endswith('GPU:0')\n", 233 | " time_matmul(x) " 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "## 3 数据集\n", 241 | "本节使用tf.data.Dataset API构建管道,以便为模型提供数据。 tf.data.Dataset API用于从简单,可重复使用的部分构建高性能,复杂的输入管道,这些部分将为模型的培训或评估循环提供支持。" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "metadata": {}, 247 | "source": [ 248 | "**创建源数据集**\n", 249 | "使用其中一个工厂函数(如Dataset.from_tensors,Dataset.from_tensor_slices)或使用从TextLineDataset或TFRecordDataset等文件读取的对象创建源数据集。 有关详细信息,请参阅TensorFlow数据集指南。" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 13, 255 | "metadata": {}, 256 | "outputs": [ 257 | { 258 | "name": "stdout", 259 | "output_type": "stream", 260 | "text": [ 261 | "/tmp/tmpvl0kyn0w\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "# 从列表中获取tensor\n", 267 | "ds_tensors = tf.data.Dataset.from_tensor_slices([6,5,4,3,2,1])\n", 268 | "# 创建csv文件\n", 269 | "import tempfile\n", 270 | "_, filename = tempfile.mkstemp()\n", 271 | "print(filename)\n", 272 | "\n", 273 | "with open(filename, 'w') as f:\n", 274 | " f.write(\"\"\"Line 1\n", 275 | "Line 2\n", 276 | "Line 3\"\"\")\n", 277 | "# 获取TextLineDataset数据集实例\n", 278 | "ds_file = tf.data.TextLineDataset(filename)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": {}, 284 | "source": [ 285 | "**应用转换**\n", 286 | "\n", 287 | "使用map,batch和shuffle等转换函数将转换应用于数据集记录。\n" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 14, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", 297 | "ds_file = ds_file.batch(2)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "**迭代**\n", 305 | "\n", 306 | "tf.data.Dataset对象支持迭代循环记录:" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 15, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "ds_tensors中的元素:\n", 319 | "tf.Tensor([36 25], shape=(2,), dtype=int32)\n", 320 | "tf.Tensor([16 9], shape=(2,), dtype=int32)\n", 321 | "tf.Tensor([4 1], shape=(2,), dtype=int32)\n", 322 | "\n", 323 | "ds_file中的元素:\n", 324 | "tf.Tensor([b'Line 1' b'Line 2'], shape=(2,), dtype=string)\n", 325 | "tf.Tensor([b'Line 3'], shape=(1,), dtype=string)\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "print('ds_tensors中的元素:')\n", 331 | "for x in ds_tensors:\n", 332 | " print(x)\n", 333 | "# 从文件中读取的对象创建的数据源\n", 334 | "print('\\nds_file中的元素:')\n", 335 | "for x in ds_file:\n", 336 | " print(x)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [] 345 | } 346 | ], 347 | "metadata": { 348 | "kernelspec": { 349 | "display_name": "Python 3", 350 | "language": "python", 351 | "name": "python3" 352 | }, 353 | "language_info": { 354 | "codemirror_mode": { 355 | "name": "ipython", 356 | "version": 3 357 | }, 358 | "file_extension": ".py", 359 | "mimetype": "text/x-python", 360 | "name": "python", 361 | "nbconvert_exporter": "python", 362 | "pygments_lexer": "ipython3", 363 | "version": "3.6.6" 364 | } 365 | }, 366 | "nbformat": 4, 367 | "nbformat_minor": 2 368 | } 369 | -------------------------------------------------------------------------------- /020-Eager/003-automatic_differentiation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "K5de6EUfWWUb" 8 | }, 9 | "source": [ 10 | "# TensorFlow2.0教程-自动求导\n", 11 | "\n", 12 | "这节我们会介绍使用tensorflow2自动求导的方法。" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": { 19 | "colab": { 20 | "base_uri": "https://localhost:8080/", 21 | "height": 35 22 | }, 23 | "colab_type": "code", 24 | "executionInfo": { 25 | "elapsed": 2148, 26 | "status": "ok", 27 | "timestamp": 1559219648589, 28 | "user": { 29 | "displayName": "Will Chen", 30 | "photoUrl": "", 31 | "userId": "01179718990779759737" 32 | }, 33 | "user_tz": -480 34 | }, 35 | "id": "49BpOoOXWa8C", 36 | "outputId": "6aa76be9-e27c-4b1c-e751-7217f394efbc" 37 | }, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "2.0.0-alpha0\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "from __future__ import absolute_import, division, print_function, unicode_literals\n", 49 | "# !pip uninstall tensorflow\n", 50 | "#!pip install tensorflow==2.0.0-alpha\n", 51 | "import tensorflow as tf\n", 52 | "print(tf.__version__)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 0, 58 | "metadata": { 59 | "colab": {}, 60 | "colab_type": "code", 61 | "id": "qpi0MAurTzfy" 62 | }, 63 | "outputs": [], 64 | "source": [] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "colab_type": "text", 70 | "id": "YqfVrbeTQupu" 71 | }, 72 | "source": [ 73 | "## 一、Gradient tapes\n", 74 | "tensorflow 提供tf.GradientTape api来实现自动求导功能。只要在tf.GradientTape()上下文中执行的操作,都会被记录与“tape”中,然后tensorflow使用反向自动微分来计算相关操作的梯度。\n", 75 | "\n" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": { 82 | "colab": { 83 | "base_uri": "https://localhost:8080/", 84 | "height": 71 85 | }, 86 | "colab_type": "code", 87 | "executionInfo": { 88 | "elapsed": 760, 89 | "status": "ok", 90 | "timestamp": 1559219653881, 91 | "user": { 92 | "displayName": "Will Chen", 93 | "photoUrl": "", 94 | "userId": "01179718990779759737" 95 | }, 96 | "user_tz": -480 97 | }, 98 | "id": "S2dV3uiKQZsE", 99 | "outputId": "1dcd5f41-ee73-4eed-837d-b45638fc63e3" 100 | }, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "tf.Tensor(\n", 107 | "[[8. 8.]\n", 108 | " [8. 8.]], shape=(2, 2), dtype=float32)\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "x = tf.ones((2,2))\n", 114 | "\n", 115 | "# 需要计算梯度的操作\n", 116 | "with tf.GradientTape() as t:\n", 117 | " t.watch(x)\n", 118 | " y = tf.reduce_sum(x)\n", 119 | " z = tf.multiply(y,y)\n", 120 | "# 计算z关于x的梯度\n", 121 | "dz_dx = t.gradient(z, x)\n", 122 | "print(dz_dx)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": { 128 | "colab_type": "text", 129 | "id": "fPn28EA3Uqp6" 130 | }, 131 | "source": [ 132 | "也可以输出对中间变量的导数" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 3, 138 | "metadata": { 139 | "colab": { 140 | "base_uri": "https://localhost:8080/", 141 | "height": 35 142 | }, 143 | "colab_type": "code", 144 | "executionInfo": { 145 | "elapsed": 942, 146 | "status": "ok", 147 | "timestamp": 1559219655992, 148 | "user": { 149 | "displayName": "Will Chen", 150 | "photoUrl": "", 151 | "userId": "01179718990779759737" 152 | }, 153 | "user_tz": -480 154 | }, 155 | "id": "C-D2Lf06TLgc", 156 | "outputId": "6b6e57d1-3540-4839-d877-0818733da407" 157 | }, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "tf.Tensor(8.0, shape=(), dtype=float32)\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "# 梯度求导只能每个tape一次\n", 169 | "with tf.GradientTape() as t:\n", 170 | " t.watch(x)\n", 171 | " y = tf.reduce_sum(x)\n", 172 | " z = tf.multiply(y,y)\n", 173 | " \n", 174 | "dz_dy = t.gradient(z, y)\n", 175 | "print(dz_dy)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": { 181 | "colab_type": "text", 182 | "id": "5Ntrsuv1tewy" 183 | }, 184 | "source": [ 185 | "默认情况下GradientTape的资源会在执行tf.GradientTape()后被释放。如果想多次计算梯度,需要创建一个持久的GradientTape。" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 4, 191 | "metadata": { 192 | "colab": { 193 | "base_uri": "https://localhost:8080/", 194 | "height": 89 195 | }, 196 | "colab_type": "code", 197 | "executionInfo": { 198 | "elapsed": 766, 199 | "status": "ok", 200 | "timestamp": 1559219848611, 201 | "user": { 202 | "displayName": "Will Chen", 203 | "photoUrl": "", 204 | "userId": "01179718990779759737" 205 | }, 206 | "user_tz": -480 207 | }, 208 | "id": "5kZoUqbrVAO5", 209 | "outputId": "65806f79-3d13-492f-fa1b-29e9c86a1566" 210 | }, 211 | "outputs": [ 212 | { 213 | "name": "stdout", 214 | "output_type": "stream", 215 | "text": [ 216 | "tf.Tensor(\n", 217 | "[[8. 8.]\n", 218 | " [8. 8.]], shape=(2, 2), dtype=float32)\n", 219 | "tf.Tensor(8.0, shape=(), dtype=float32)\n" 220 | ] 221 | } 222 | ], 223 | "source": [ 224 | "with tf.GradientTape(persistent=True) as t:\n", 225 | " t.watch(x)\n", 226 | " y = tf.reduce_sum(x)\n", 227 | " z = tf.multiply(y, y)\n", 228 | " \n", 229 | "dz_dx = t.gradient(z,x)\n", 230 | "print(dz_dx)\n", 231 | "dz_dy = t.gradient(z, y)\n", 232 | "print(dz_dy)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": { 238 | "colab_type": "text", 239 | "id": "meFCHe37wBc8" 240 | }, 241 | "source": [ 242 | "## 二、记录控制流\n", 243 | "因为tapes记录了整个操作,所以即使过程中存在python控制流(如if, while),梯度求导也能正常处理。" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 6, 249 | "metadata": { 250 | "colab": { 251 | "base_uri": "https://localhost:8080/", 252 | "height": 71 253 | }, 254 | "colab_type": "code", 255 | "executionInfo": { 256 | "elapsed": 1517, 257 | "status": "ok", 258 | "timestamp": 1559220443136, 259 | "user": { 260 | "displayName": "Will Chen", 261 | "photoUrl": "", 262 | "userId": "01179718990779759737" 263 | }, 264 | "user_tz": -480 265 | }, 266 | "id": "OP5XoSJovsJs", 267 | "outputId": "b20ae483-321c-4a15-9e76-7a4773137ed3" 268 | }, 269 | "outputs": [ 270 | { 271 | "name": "stdout", 272 | "output_type": "stream", 273 | "text": [ 274 | "tf.Tensor(12.0, shape=(), dtype=float32)\n", 275 | "tf.Tensor(12.0, shape=(), dtype=float32)\n", 276 | "tf.Tensor(4.0, shape=(), dtype=float32)\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "def f(x, y):\n", 282 | " output = 1.0\n", 283 | " # 根据y的循环\n", 284 | " for i in range(y):\n", 285 | " # 根据每一项进行判断\n", 286 | " if i> 1 and i<5:\n", 287 | " output = tf.multiply(output, x)\n", 288 | " return output\n", 289 | "\n", 290 | "def grad(x, y):\n", 291 | " with tf.GradientTape() as t:\n", 292 | " t.watch(x)\n", 293 | " out = f(x, y)\n", 294 | " # 返回梯度\n", 295 | " return t.gradient(out, x)\n", 296 | "# x为固定值\n", 297 | "x = tf.convert_to_tensor(2.0)\n", 298 | "\n", 299 | "print(grad(x, 6))\n", 300 | "print(grad(x, 5))\n", 301 | "print(grad(x, 4))" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": { 307 | "colab_type": "text", 308 | "id": "EihwzFwIyXKX" 309 | }, 310 | "source": [ 311 | "## 三、高阶梯度\n", 312 | "GradientTape上下文管理器在计算梯度的同时也会保持梯度,所以GradientTape也可以实现高阶梯度计算," 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 9, 318 | "metadata": { 319 | "colab": { 320 | "base_uri": "https://localhost:8080/", 321 | "height": 53 322 | }, 323 | "colab_type": "code", 324 | "executionInfo": { 325 | "elapsed": 1144, 326 | "status": "ok", 327 | "timestamp": 1559221001653, 328 | "user": { 329 | "displayName": "Will Chen", 330 | "photoUrl": "", 331 | "userId": "01179718990779759737" 332 | }, 333 | "user_tz": -480 334 | }, 335 | "id": "CCgxtUbNx5s_", 336 | "outputId": "3b68826e-dfa1-40f5-c78d-fee9cb862257" 337 | }, 338 | "outputs": [ 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "tf.Tensor(3.0, shape=(), dtype=float32)\n", 344 | "tf.Tensor(6.0, shape=(), dtype=float32)\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "x = tf.Variable(1.0)\n", 350 | "\n", 351 | "with tf.GradientTape() as t1:\n", 352 | " with tf.GradientTape() as t2:\n", 353 | " y = x * x * x\n", 354 | " dy_dx = t2.gradient(y, x)\n", 355 | " print(dy_dx)\n", 356 | "d2y_d2x = t1.gradient(dy_dx, x)\n", 357 | "print(d2y_d2x)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 0, 363 | "metadata": { 364 | "colab": {}, 365 | "colab_type": "code", 366 | "id": "qBxk9ly4z60f" 367 | }, 368 | "outputs": [], 369 | "source": [] 370 | } 371 | ], 372 | "metadata": { 373 | "colab": { 374 | "name": "003-automatic_differentiation.ipynb", 375 | "provenance": [], 376 | "version": "0.3.2" 377 | }, 378 | "kernelspec": { 379 | "display_name": "Python 3", 380 | "language": "python", 381 | "name": "python3" 382 | }, 383 | "language_info": { 384 | "codemirror_mode": { 385 | "name": "ipython", 386 | "version": 3 387 | }, 388 | "file_extension": ".py", 389 | "mimetype": "text/x-python", 390 | "name": "python", 391 | "nbconvert_exporter": "python", 392 | "pygments_lexer": "ipython3", 393 | "version": "3.6.6" 394 | } 395 | }, 396 | "nbformat": 4, 397 | "nbformat_minor": 1 398 | } 399 | -------------------------------------------------------------------------------- /021-MLP/.ipynb_checkpoints/001-MLP-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /021-MLP/.ipynb_checkpoints/002-MLP2-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /021-MLP/001-MLP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# tensorflow2教程-基础MLP网络" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "2.0.0-alpha0\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "import tensorflow as tf\n", 25 | "import tensorflow.keras as keras\n", 26 | "import tensorflow.keras.layers as layers\n", 27 | "print(tf.__version__)" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## 1.回归任务" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 5, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "(404, 13) (404,)\n", 47 | "(102, 13) (102,)\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "# 导入数据\n", 53 | "(x_train, y_train), (x_test, y_test) = keras.datasets.boston_housing.load_data()\n", 54 | "print(x_train.shape, ' ', y_train.shape)\n", 55 | "print(x_test.shape, ' ', y_test.shape)\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 28, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "Model: \"sequential_10\"\n", 68 | "_________________________________________________________________\n", 69 | "Layer (type) Output Shape Param # \n", 70 | "=================================================================\n", 71 | "dense_33 (Dense) (None, 32) 448 \n", 72 | "_________________________________________________________________\n", 73 | "dense_34 (Dense) (None, 32) 1056 \n", 74 | "_________________________________________________________________\n", 75 | "dense_35 (Dense) (None, 32) 1056 \n", 76 | "_________________________________________________________________\n", 77 | "dense_36 (Dense) (None, 1) 33 \n", 78 | "=================================================================\n", 79 | "Total params: 2,593\n", 80 | "Trainable params: 2,593\n", 81 | "Non-trainable params: 0\n", 82 | "_________________________________________________________________\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "# 构建模型\n", 88 | "\n", 89 | "model = keras.Sequential([\n", 90 | " layers.Dense(32, activation='sigmoid', input_shape=(13,)),\n", 91 | " layers.Dense(32, activation='sigmoid'),\n", 92 | " layers.Dense(32, activation='sigmoid'),\n", 93 | " layers.Dense(1)\n", 94 | "])\n", 95 | "\n", 96 | "# 配置模型\n", 97 | "model.compile(optimizer=keras.optimizers.SGD(0.1),\n", 98 | " loss='mean_squared_error', # keras.losses.mean_squared_error\n", 99 | " metrics=['mse'])\n", 100 | "model.summary()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 29, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "Train on 363 samples, validate on 41 samples\n", 113 | "Epoch 1/50\n", 114 | "363/363 [==============================] - 0s 430us/sample - loss: 371.0176 - mse: 371.0175 - val_loss: 50.0381 - val_mse: 50.0381\n", 115 | "Epoch 2/50\n", 116 | "363/363 [==============================] - 0s 28us/sample - loss: 96.3428 - mse: 96.3428 - val_loss: 51.2528 - val_mse: 51.2528\n", 117 | "Epoch 3/50\n", 118 | "363/363 [==============================] - 0s 44us/sample - loss: 93.5432 - mse: 93.5432 - val_loss: 60.6331 - val_mse: 60.6331\n", 119 | "Epoch 4/50\n", 120 | "363/363 [==============================] - 0s 45us/sample - loss: 96.2177 - mse: 96.2177 - val_loss: 88.8679 - val_mse: 88.8679\n", 121 | "Epoch 5/50\n", 122 | "363/363 [==============================] - 0s 42us/sample - loss: 93.1511 - mse: 93.1511 - val_loss: 100.1161 - val_mse: 100.1161\n", 123 | "Epoch 6/50\n", 124 | "363/363 [==============================] - 0s 35us/sample - loss: 94.3889 - mse: 94.3889 - val_loss: 43.3363 - val_mse: 43.3363\n", 125 | "Epoch 7/50\n", 126 | "363/363 [==============================] - 0s 38us/sample - loss: 90.5968 - mse: 90.5968 - val_loss: 75.2889 - val_mse: 75.2889\n", 127 | "Epoch 8/50\n", 128 | "363/363 [==============================] - 0s 34us/sample - loss: 96.9396 - mse: 96.9396 - val_loss: 42.4718 - val_mse: 42.4718\n", 129 | "Epoch 9/50\n", 130 | "363/363 [==============================] - 0s 44us/sample - loss: 90.4116 - mse: 90.4116 - val_loss: 46.8739 - val_mse: 46.8739\n", 131 | "Epoch 10/50\n", 132 | "363/363 [==============================] - 0s 32us/sample - loss: 89.8437 - mse: 89.8437 - val_loss: 68.0937 - val_mse: 68.0937\n", 133 | "Epoch 11/50\n", 134 | "363/363 [==============================] - 0s 46us/sample - loss: 92.9359 - mse: 92.9359 - val_loss: 42.6696 - val_mse: 42.6696\n", 135 | "Epoch 12/50\n", 136 | "363/363 [==============================] - 0s 35us/sample - loss: 89.6814 - mse: 89.6814 - val_loss: 46.7261 - val_mse: 46.7261\n", 137 | "Epoch 13/50\n", 138 | "363/363 [==============================] - 0s 34us/sample - loss: 90.5774 - mse: 90.5774 - val_loss: 41.6673 - val_mse: 41.6673\n", 139 | "Epoch 14/50\n", 140 | "363/363 [==============================] - 0s 44us/sample - loss: 88.5909 - mse: 88.5909 - val_loss: 55.3597 - val_mse: 55.3597\n", 141 | "Epoch 15/50\n", 142 | "363/363 [==============================] - 0s 435us/sample - loss: 90.5146 - mse: 90.5146 - val_loss: 51.8985 - val_mse: 51.8985\n", 143 | "Epoch 16/50\n", 144 | "363/363 [==============================] - 0s 43us/sample - loss: 91.7516 - mse: 91.7516 - val_loss: 43.2147 - val_mse: 43.2147\n", 145 | "Epoch 17/50\n", 146 | "363/363 [==============================] - 0s 32us/sample - loss: 91.9017 - mse: 91.9017 - val_loss: 58.9149 - val_mse: 58.9149\n", 147 | "Epoch 18/50\n", 148 | "363/363 [==============================] - 0s 36us/sample - loss: 96.3520 - mse: 96.3520 - val_loss: 39.5248 - val_mse: 39.5248\n", 149 | "Epoch 19/50\n", 150 | "363/363 [==============================] - 0s 39us/sample - loss: 84.0867 - mse: 84.0867 - val_loss: 54.0607 - val_mse: 54.0607\n", 151 | "Epoch 20/50\n", 152 | "363/363 [==============================] - 0s 38us/sample - loss: 83.8994 - mse: 83.8994 - val_loss: 100.1479 - val_mse: 100.1479\n", 153 | "Epoch 21/50\n", 154 | "363/363 [==============================] - 0s 45us/sample - loss: 101.9472 - mse: 101.9472 - val_loss: 42.8121 - val_mse: 42.8121\n", 155 | "Epoch 22/50\n", 156 | "363/363 [==============================] - 0s 41us/sample - loss: 90.9167 - mse: 90.9167 - val_loss: 53.8228 - val_mse: 53.8228\n", 157 | "Epoch 23/50\n", 158 | "363/363 [==============================] - 0s 48us/sample - loss: 91.6642 - mse: 91.6642 - val_loss: 42.9756 - val_mse: 42.9756\n", 159 | "Epoch 24/50\n", 160 | "363/363 [==============================] - 0s 45us/sample - loss: 92.6146 - mse: 92.6146 - val_loss: 76.5047 - val_mse: 76.5047\n", 161 | "Epoch 25/50\n", 162 | "363/363 [==============================] - 0s 38us/sample - loss: 94.8085 - mse: 94.8085 - val_loss: 51.9597 - val_mse: 51.9597\n", 163 | "Epoch 26/50\n", 164 | "363/363 [==============================] - 0s 33us/sample - loss: 92.2917 - mse: 92.2917 - val_loss: 42.8546 - val_mse: 42.8546\n", 165 | "Epoch 27/50\n", 166 | "363/363 [==============================] - 0s 38us/sample - loss: 90.6413 - mse: 90.6413 - val_loss: 42.8628 - val_mse: 42.8628\n", 167 | "Epoch 28/50\n", 168 | "363/363 [==============================] - 0s 38us/sample - loss: 91.5466 - mse: 91.5466 - val_loss: 42.7437 - val_mse: 42.7437\n", 169 | "Epoch 29/50\n", 170 | "363/363 [==============================] - 0s 32us/sample - loss: 93.5539 - mse: 93.5539 - val_loss: 48.5895 - val_mse: 48.5895\n", 171 | "Epoch 30/50\n", 172 | "363/363 [==============================] - 0s 33us/sample - loss: 95.7540 - mse: 95.7540 - val_loss: 43.6151 - val_mse: 43.6151\n", 173 | "Epoch 31/50\n", 174 | "363/363 [==============================] - 0s 32us/sample - loss: 89.4952 - mse: 89.4952 - val_loss: 71.6649 - val_mse: 71.6649\n", 175 | "Epoch 32/50\n", 176 | "363/363 [==============================] - 0s 48us/sample - loss: 89.5014 - mse: 89.5014 - val_loss: 44.1873 - val_mse: 44.1873\n", 177 | "Epoch 33/50\n", 178 | "363/363 [==============================] - 0s 30us/sample - loss: 93.1029 - mse: 93.1029 - val_loss: 93.6874 - val_mse: 93.6874\n", 179 | "Epoch 34/50\n", 180 | "363/363 [==============================] - 0s 53us/sample - loss: 95.2734 - mse: 95.2734 - val_loss: 45.0203 - val_mse: 45.0203\n", 181 | "Epoch 35/50\n", 182 | "363/363 [==============================] - 0s 45us/sample - loss: 90.2837 - mse: 90.2837 - val_loss: 98.1514 - val_mse: 98.1514\n", 183 | "Epoch 36/50\n", 184 | "363/363 [==============================] - 0s 40us/sample - loss: 95.9373 - mse: 95.9373 - val_loss: 44.1146 - val_mse: 44.1146\n", 185 | "Epoch 37/50\n", 186 | "363/363 [==============================] - 0s 39us/sample - loss: 91.6945 - mse: 91.6945 - val_loss: 42.3970 - val_mse: 42.3970\n", 187 | "Epoch 38/50\n", 188 | "363/363 [==============================] - 0s 45us/sample - loss: 88.5509 - mse: 88.5509 - val_loss: 45.7873 - val_mse: 45.7873\n", 189 | "Epoch 39/50\n", 190 | "363/363 [==============================] - 0s 27us/sample - loss: 93.6409 - mse: 93.6409 - val_loss: 74.0321 - val_mse: 74.0321\n", 191 | "Epoch 40/50\n", 192 | "363/363 [==============================] - 0s 23us/sample - loss: 87.0454 - mse: 87.0454 - val_loss: 37.7646 - val_mse: 37.7646\n", 193 | "Epoch 41/50\n", 194 | "363/363 [==============================] - 0s 26us/sample - loss: 92.6228 - mse: 92.6228 - val_loss: 44.1021 - val_mse: 44.1021\n", 195 | "Epoch 42/50\n", 196 | "363/363 [==============================] - 0s 40us/sample - loss: 88.6639 - mse: 88.6639 - val_loss: 32.3741 - val_mse: 32.3741\n", 197 | "Epoch 43/50\n", 198 | "363/363 [==============================] - 0s 42us/sample - loss: 92.7604 - mse: 92.7604 - val_loss: 42.6321 - val_mse: 42.6321\n", 199 | "Epoch 44/50\n", 200 | "363/363 [==============================] - 0s 43us/sample - loss: 93.3373 - mse: 93.3373 - val_loss: 60.7727 - val_mse: 60.7727\n", 201 | "Epoch 45/50\n", 202 | "363/363 [==============================] - 0s 26us/sample - loss: 94.0312 - mse: 94.0312 - val_loss: 43.1645 - val_mse: 43.1645\n", 203 | "Epoch 46/50\n", 204 | "363/363 [==============================] - 0s 31us/sample - loss: 89.4547 - mse: 89.4548 - val_loss: 42.7366 - val_mse: 42.7366\n", 205 | "Epoch 47/50\n", 206 | "363/363 [==============================] - 0s 31us/sample - loss: 89.8723 - mse: 89.8723 - val_loss: 44.8266 - val_mse: 44.8266\n", 207 | "Epoch 48/50\n", 208 | "363/363 [==============================] - 0s 34us/sample - loss: 90.3171 - mse: 90.3171 - val_loss: 47.8453 - val_mse: 47.8453\n", 209 | "Epoch 49/50\n", 210 | "363/363 [==============================] - 0s 31us/sample - loss: 86.2326 - mse: 86.2326 - val_loss: 29.5082 - val_mse: 29.5082\n", 211 | "Epoch 50/50\n", 212 | "363/363 [==============================] - 0s 28us/sample - loss: 80.1490 - mse: 80.1490 - val_loss: 30.6706 - val_mse: 30.6706\n" 213 | ] 214 | }, 215 | { 216 | "data": { 217 | "text/plain": [ 218 | "" 219 | ] 220 | }, 221 | "execution_count": 29, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "# 训练\n", 228 | "model.fit(x_train, y_train, batch_size=50, epochs=50, validation_split=0.1, verbose=1)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 31, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "102/102 [==============================] - 0s 116us/sample - loss: 75.0492 - mse: 75.0492\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "result = model.evaluate(x_test, y_test)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 32, 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "['loss', 'mse']\n", 258 | "[75.04923741957721, 75.04924]\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "print(model.metrics_names)\n", 264 | "print(result)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": {}, 270 | "source": [ 271 | "## 2.分类任务" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 34, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "name": "stdout", 281 | "output_type": "stream", 282 | "text": [ 283 | "(398, 30) (398,)\n", 284 | "(171, 30) (171,)\n" 285 | ] 286 | } 287 | ], 288 | "source": [ 289 | "from sklearn.datasets import load_breast_cancer\n", 290 | "from sklearn.model_selection import train_test_split\n", 291 | "\n", 292 | "whole_data = load_breast_cancer()\n", 293 | "x_data = whole_data.data\n", 294 | "y_data = whole_data.target\n", 295 | "\n", 296 | "x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.3, random_state=7)\n", 297 | "\n", 298 | "print(x_train.shape, ' ', y_train.shape)\n", 299 | "print(x_test.shape, ' ', y_test.shape)\n" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 41, 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "name": "stdout", 309 | "output_type": "stream", 310 | "text": [ 311 | "Model: \"sequential_14\"\n", 312 | "_________________________________________________________________\n", 313 | "Layer (type) Output Shape Param # \n", 314 | "=================================================================\n", 315 | "dense_46 (Dense) (None, 32) 992 \n", 316 | "_________________________________________________________________\n", 317 | "dense_47 (Dense) (None, 32) 1056 \n", 318 | "_________________________________________________________________\n", 319 | "dense_48 (Dense) (None, 1) 33 \n", 320 | "=================================================================\n", 321 | "Total params: 2,081\n", 322 | "Trainable params: 2,081\n", 323 | "Non-trainable params: 0\n", 324 | "_________________________________________________________________\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "# 构建模型\n", 330 | "model = keras.Sequential([\n", 331 | " layers.Dense(32, activation='relu', input_shape=(30,)),\n", 332 | " layers.Dense(32, activation='relu'),\n", 333 | " layers.Dense(1, activation='sigmoid')\n", 334 | "])\n", 335 | "\n", 336 | "model.compile(optimizer=keras.optimizers.Adam(),\n", 337 | " loss=keras.losses.binary_crossentropy,\n", 338 | " metrics=['accuracy'])\n", 339 | "model.summary()\n" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 43, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "Epoch 1/10\n", 352 | "398/398 [==============================] - 0s 44us/sample - loss: 0.1982 - accuracy: 0.9196\n", 353 | "Epoch 2/10\n", 354 | "398/398 [==============================] - 0s 28us/sample - loss: 0.2094 - accuracy: 0.9221\n", 355 | "Epoch 3/10\n", 356 | "398/398 [==============================] - 0s 40us/sample - loss: 0.2128 - accuracy: 0.9246\n", 357 | "Epoch 4/10\n", 358 | "398/398 [==============================] - 0s 28us/sample - loss: 0.2101 - accuracy: 0.9271\n", 359 | "Epoch 5/10\n", 360 | "398/398 [==============================] - 0s 21us/sample - loss: 0.2175 - accuracy: 0.9146\n", 361 | "Epoch 6/10\n", 362 | "398/398 [==============================] - 0s 31us/sample - loss: 0.2925 - accuracy: 0.8945\n", 363 | "Epoch 7/10\n", 364 | "398/398 [==============================] - 0s 37us/sample - loss: 0.4531 - accuracy: 0.8618\n", 365 | "Epoch 8/10\n", 366 | "398/398 [==============================] - 0s 26us/sample - loss: 0.3105 - accuracy: 0.8920\n", 367 | "Epoch 9/10\n", 368 | "398/398 [==============================] - 0s 34us/sample - loss: 0.2934 - accuracy: 0.8794\n", 369 | "Epoch 10/10\n", 370 | "398/398 [==============================] - 0s 27us/sample - loss: 0.2597 - accuracy: 0.9045\n" 371 | ] 372 | }, 373 | { 374 | "data": { 375 | "text/plain": [ 376 | "" 377 | ] 378 | }, 379 | "execution_count": 43, 380 | "metadata": {}, 381 | "output_type": "execute_result" 382 | } 383 | ], 384 | "source": [ 385 | "model.fit(x_train, y_train, batch_size=64, epochs=10, verbose=1)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 44, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "name": "stdout", 395 | "output_type": "stream", 396 | "text": [ 397 | "171/171 [==============================] - 0s 463us/sample - loss: 0.3877 - accuracy: 0.8772\n" 398 | ] 399 | }, 400 | { 401 | "data": { 402 | "text/plain": [ 403 | "[0.38765583248340596, 0.877193]" 404 | ] 405 | }, 406 | "execution_count": 44, 407 | "metadata": {}, 408 | "output_type": "execute_result" 409 | } 410 | ], 411 | "source": [ 412 | "model.evaluate(x_test, y_test)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 45, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "name": "stdout", 422 | "output_type": "stream", 423 | "text": [ 424 | "['loss', 'accuracy']\n" 425 | ] 426 | } 427 | ], 428 | "source": [ 429 | "print(model.metrics_names)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [] 438 | } 439 | ], 440 | "metadata": { 441 | "kernelspec": { 442 | "display_name": "Python 3", 443 | "language": "python", 444 | "name": "python3" 445 | }, 446 | "language_info": { 447 | "codemirror_mode": { 448 | "name": "ipython", 449 | "version": 3 450 | }, 451 | "file_extension": ".py", 452 | "mimetype": "text/x-python", 453 | "name": "python", 454 | "nbconvert_exporter": "python", 455 | "pygments_lexer": "ipython3", 456 | "version": "3.6.8" 457 | } 458 | }, 459 | "nbformat": 4, 460 | "nbformat_minor": 2 461 | } 462 | -------------------------------------------------------------------------------- /022-CNN/.ipynb_checkpoints/001-cnn-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /022-CNN/.ipynb_checkpoints/002-cnn_variants-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /022-CNN/.ipynb_checkpoints/004-pretrained_cnn-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /022-CNN/001-cnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# tensorflow2-基础CNN网络\n", 8 | "![](https://adeshpande3.github.io/assets/Cover.png)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [ 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "2.0.0-alpha0\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "import tensorflow as tf\n", 26 | "from tensorflow import keras\n", 27 | "from tensorflow.keras import layers\n", 28 | "print(tf.__version__)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## 1.构造数据" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "(60000, 28, 28) (60000,)\n", 48 | "(10000, 28, 28) (10000,)\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "\n", 54 | "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", 55 | "print(x_train.shape, ' ', y_train.shape)\n", 56 | "print(x_test.shape, ' ', y_test.shape)\n" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 4, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADoBJREFUeJzt3X2MXOV1x/HfyXq9jo1JvHHYboiLHeMEiGlMOjIgLKCiuA5CMiiKiRVFDiFxmuCktK4EdavGrWjlVgmRQynS0ri2I95CAsJ/0CR0FUGiwpbFMeYtvJlNY7PsYjZgQ4i9Xp/+sdfRBnaeWc/cmTu75/uRVjtzz71zj6792zszz8x9zN0FIJ53Fd0AgGIQfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQU1r5M6mW5vP0KxG7hII5bd6U4f9kE1k3ZrCb2YrJG2W1CLpP9x9U2r9GZqls+2iWnYJIKHHuye8btVP+82sRdJNkj4h6QxJq83sjGofD0Bj1fKaf6mk5919j7sflnSHpJX5tAWg3moJ/8mSfjXm/t5s2e8xs7Vm1mtmvcM6VMPuAOSp7u/2u3uXu5fcvdSqtnrvDsAE1RL+fZLmjbn/wWwZgEmglvA/ImmRmS0ws+mSPi1pRz5tAai3qof63P2Ima2T9CONDvVtcfcnc+sMQF3VNM7v7vdJui+nXgA0EB/vBYIi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+IKiaZuk1sz5JByWNSDri7qU8mkJ+bFr6n7jl/XPruv9n/np+2drIzKPJbU9ZOJisz/yKJesv3zC9bG1n6c7ktvtH3kzWz75rfbJ+6l89nKw3g5rCn/kTd9+fw+MAaCCe9gNB1Rp+l/RjM3vUzNbm0RCAxqj1af8yd99nZidJut/MfuHuD45dIfujsFaSZmhmjbsDkJeazvzuvi/7PSjpHklLx1mny91L7l5qVVstuwOQo6rDb2azzGz2sduSlkt6Iq/GANRXLU/7OyTdY2bHHuc2d/9hLl0BqLuqw+/ueyR9LMdepqyW0xcl697Wmqy/dMF7k/W3zik/Jt3+nvR49U8/lh7vLtJ//WZ2sv4v/7YiWe8587aytReH30puu2ng4mT9Az/1ZH0yYKgPCIrwA0ERfiAowg8ERfiBoAg/EFQe3+oLb+TCjyfrN2y9KVn/cGv5r55OZcM+kqz//Y2fS9anvZkebjv3rnVla7P3HUlu27Y/PRQ4s7cnWZ8MOPMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM8+eg7ZmXkvVHfzsvWf9w60Ce7eRqff85yfqeN9KX/t668Ptla68fTY/Td3z7f5L1epr8X9itjDM/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRl7o0b0TzR2v1su6hh+2sWQ1eem6wfWJG+vHbL7hOS9ce+cuNx93TM9fv/KFl/5IL0OP7Ia68n635u+au7930tuakWrH4svQLeoce7dcCH0nOXZzjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQFcf5zWyLpEslDbr74mxZu6Q7Jc2X1Cdplbv/utLOoo7zV9Iy933J+sirQ8n6i7eVH6t/8vwtyW2X/vNXk/WTbiruO/U4fnmP82+V9PaJ0K+T1O3uiyR1Z/cBTCIVw+/uD0p6+6lnpaRt2e1tki7LuS8AdVbta/4Od+/Pbr8sqSOnfgA0SM1v+PnomwZl3zgws7Vm1mtmvcM6VOvuAOSk2vAPmFmnJGW/B8ut6O5d7l5y91Kr2qrcHYC8VRv+HZLWZLfXSLo3n3YANErF8JvZ7ZIekvQRM9trZldJ2iTpYjN7TtKfZvcBTCIVr9vv7qvLlBiwz8nI/ldr2n74wPSqt/3oZ55K1l+5uSX9AEdHqt43isUn/ICgCD8QFOEHgiL8QFCEHwiK8ANBMUX3FHD6tc+WrV15ZnpE9j9P6U7WL/jU1cn67DsfTtbRvDjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNPAalpsl/98unJbf9vx1vJ+nXXb0/W/2bV5cm6//w9ZWvz/umh5LZq4PTxEXHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKk7RnSem6G4+Q58/N1m/9evfSNYXTJtR9b4/un1dsr7olv5k/cievqr3PVXlPUU3gCmI8ANBEX4gKMIPBEX4gaAIPxAU4QeCqjjOb2ZbJF0qadDdF2fLNkr6oqRXstU2uPt9lXbGOP/k4+ctSdZP3LQ3Wb/9Qz+qet+n/eQLyfpH/qH8dQwkaeS5PVXve7LKe5x/q6QV4yz/lrsvyX4qBh9Ac6kYfnd/UNJQA3oB0EC1vOZfZ2a7zWyLmc3JrSMADVFt+G+WtFDSEkn9kr5ZbkUzW2tmvWbWO6xDVe4OQN6qCr+7D7j7iLsflXSLpKWJdbvcveTupVa1VdsngJxVFX4z6xxz93JJT+TTDoBGqXjpbjO7XdKFkuaa2V5JX5d0oZktkeSS+iR9qY49AqgDvs+PmrR0nJSsv3TFqWVrPdduTm77rgpPTD/z4vJk/fVlrybrUxHf5wdQEeEHgiL8QFCEHwiK8ANBEX4gKIb6UJjv7U1P0T3Tpifrv/HDyfqlX72m/GPf05PcdrJiqA9ARYQfCIrwA0ERfiAowg8ERfiBoAg/EFTF7/MjtqPL0pfufuFT6Sm6Fy/pK1urNI5fyY1DZyXrM+/trenxpzrO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8U5yVFifrz34tPdZ+y3nbkvXzZ6S/U1+LQz6crD88tCD9AEf7c+xm6uHMDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBVRznN7N5krZL6pDkkrrcfbOZtUu6U9J8SX2SVrn7r+vXalzTFpySrL9w5QfK1jZecUdy20+esL+qnvKwYaCUrD+w+Zxkfc629HX/kTaRM/8RSevd/QxJ50i62szOkHSdpG53XySpO7sPYJKoGH5373f3ndntg5KelnSypJWSjn38a5uky+rVJID8HddrfjObL+ksST2SOtz92OcnX9boywIAk8SEw29mJ0j6gaRr3P3A2JqPTvg37qR/ZrbWzHrNrHdYh2pqFkB+JhR+M2vVaPBvdfe7s8UDZtaZ1TslDY63rbt3uXvJ3UutasujZwA5qBh+MzNJ35H0tLvfMKa0Q9Ka7PYaSffm3x6AepnIV3rPk/RZSY+b2a5s2QZJmyR9z8yukvRLSavq0+LkN23+Hybrr/9xZ7J+xT/+MFn/8/fenazX0/r+9HDcQ/9efjivfev/Jredc5ShvHqqGH53/5mkcvN9X5RvOwAahU/4AUERfiAowg8ERfiBoAg/EBThB4Li0t0TNK3zD8rWhrbMSm775QUPJOurZw9U1VMe1u1blqzvvDk9Rffc7z+RrLcfZKy+WXHmB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgwozzH/6z9GWiD//lULK+4dT7ytaWv/vNqnrKy8DIW2Vr5+9Yn9z2tL/7RbLe/lp6nP5osopmxpkfCIrwA0ERfiAowg8ERfiBoAg/EBThB4IKM87fd1n679yzZ95Vt33f9NrCZH3zA8uTdRspd+X0Uadd/2LZ2qKBnuS2I8kqpjLO/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl7egWzeZK2S+qQ5JK63H2zmW2U9EVJr2SrbnD38l96l3SitfvZxqzeQL30eLcO+FD6gyGZiXzI54ik9e6+08xmS3rUzO7Pat9y929U2yiA4lQMv7v3S+rPbh80s6clnVzvxgDU13G95jez+ZLOknTsM6PrzGy3mW0xszlltllrZr1m1jusQzU1CyA/Ew6/mZ0g6QeSrnH3A5JulrRQ0hKNPjP45njbuXuXu5fcvdSqthxaBpCHCYXfzFo1Gvxb3f1uSXL3AXcfcfejkm6RtLR+bQLIW8Xwm5lJ+o6kp939hjHLO8esdrmk9HStAJrKRN7tP0/SZyU9bma7smUbJK02syUaHf7rk/SlunQIoC4m8m7/zySNN26YHNMH0Nz4hB8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiCoipfuznVnZq9I+uWYRXMl7W9YA8enWXtr1r4keqtWnr2d4u7vn8iKDQ3/O3Zu1uvupcIaSGjW3pq1L4neqlVUbzztB4Ii/EBQRYe/q+D9pzRrb83al0Rv1Sqkt0Jf8wMoTtFnfgAFKST8ZrbCzJ4xs+fN7LoieijHzPrM7HEz22VmvQX3ssXMBs3siTHL2s3sfjN7Lvs97jRpBfW20cz2Zcdul5ldUlBv88zsJ2b2lJk9aWZ/kS0v9Ngl+irkuDX8ab+ZtUh6VtLFkvZKekTSand/qqGNlGFmfZJK7l74mLCZnS/pDUnb3X1xtuxfJQ25+6bsD+ccd7+2SXrbKOmNomduziaU6Rw7s7SkyyR9TgUeu0Rfq1TAcSvizL9U0vPuvsfdD0u6Q9LKAvpoeu7+oKShty1eKWlbdnubRv/zNFyZ3pqCu/e7+87s9kFJx2aWLvTYJfoqRBHhP1nSr8bc36vmmvLbJf3YzB41s7VFNzOOjmzadEl6WVJHkc2Mo+LMzY30tpmlm+bYVTPjdd54w++dlrn7xyV9QtLV2dPbpuSjr9maabhmQjM3N8o4M0v/TpHHrtoZr/NWRPj3SZo35v4Hs2VNwd33Zb8HJd2j5pt9eODYJKnZ78GC+/mdZpq5ebyZpdUEx66ZZrwuIvyPSFpkZgvMbLqkT0vaUUAf72Bms7I3YmRmsyQtV/PNPrxD0prs9hpJ9xbYy+9plpmby80srYKPXdPNeO3uDf+RdIlG3/F/QdLfFtFDmb4+JOmx7OfJonuTdLtGnwYOa/S9kaskvU9St6TnJP23pPYm6u27kh6XtFujQessqLdlGn1Kv1vSruznkqKPXaKvQo4bn/ADguINPyAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQf0/sEWOix6VKakAAAAASUVORK5CYII=\n", 67 | "text/plain": [ 68 | "
" 69 | ] 70 | }, 71 | "metadata": { 72 | "needs_background": "light" 73 | }, 74 | "output_type": "display_data" 75 | } 76 | ], 77 | "source": [ 78 | "import matplotlib.pyplot as plt\n", 79 | "plt.imshow(x_train[0])\n", 80 | "plt.show()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 5, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "x_train = x_train.reshape((-1,28,28,1))\n", 90 | "x_test = x_test.reshape((-1,28,28,1))" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "## 2.构造网络" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "model = keras.Sequential()\n" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### 卷积层\n", 114 | "![](http://cs231n.github.io/assets/cnn/depthcol.jpeg)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 7, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "model.add(layers.Conv2D(input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3]),\n", 124 | " filters=32, kernel_size=(3,3), strides=(1,1), padding='valid',\n", 125 | " activation='relu'))\n" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "### 池化层\n", 133 | "![](http://cs231n.github.io/assets/cnn/maxpool.jpeg)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 8, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "model.add(layers.MaxPool2D(pool_size=(2,2)))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "### 全连接层" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 9, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "model.add(layers.Flatten())\n", 159 | "model.add(layers.Dense(32, activation='relu'))\n", 160 | "# 分类层\n", 161 | "model.add(layers.Dense(10, activation='softmax'))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## 3.模型配置" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 10, 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "Model: \"sequential\"\n", 181 | "_________________________________________________________________\n", 182 | "Layer (type) Output Shape Param # \n", 183 | "=================================================================\n", 184 | "conv2d (Conv2D) (None, 26, 26, 32) 320 \n", 185 | "_________________________________________________________________\n", 186 | "max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 \n", 187 | "_________________________________________________________________\n", 188 | "flatten (Flatten) (None, 5408) 0 \n", 189 | "_________________________________________________________________\n", 190 | "dense (Dense) (None, 32) 173088 \n", 191 | "_________________________________________________________________\n", 192 | "dense_1 (Dense) (None, 10) 330 \n", 193 | "=================================================================\n", 194 | "Total params: 173,738\n", 195 | "Trainable params: 173,738\n", 196 | "Non-trainable params: 0\n", 197 | "_________________________________________________________________\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "model.compile(optimizer=keras.optimizers.Adam(),\n", 203 | " # loss=keras.losses.CategoricalCrossentropy(), # 需要使用to_categorical\n", 204 | " loss=keras.losses.SparseCategoricalCrossentropy(),\n", 205 | " metrics=['accuracy'])\n", 206 | "model.summary()" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## 4.模型训练" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 11, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "Train on 54000 samples, validate on 6000 samples\n", 226 | "Epoch 1/5\n", 227 | "54000/54000 [==============================] - 8s 148us/sample - loss: 1.5187 - accuracy: 0.5294 - val_loss: 0.7187 - val_accuracy: 0.7482\n", 228 | "Epoch 2/5\n", 229 | "54000/54000 [==============================] - 8s 150us/sample - loss: 0.4631 - accuracy: 0.8745 - val_loss: 0.2158 - val_accuracy: 0.9463\n", 230 | "Epoch 3/5\n", 231 | "54000/54000 [==============================] - 8s 154us/sample - loss: 0.1684 - accuracy: 0.9540 - val_loss: 0.1314 - val_accuracy: 0.9642\n", 232 | "Epoch 4/5\n", 233 | "54000/54000 [==============================] - 8s 150us/sample - loss: 0.1067 - accuracy: 0.9699 - val_loss: 0.1097 - val_accuracy: 0.9722\n", 234 | "Epoch 5/5\n", 235 | "54000/54000 [==============================] - 8s 149us/sample - loss: 0.0799 - accuracy: 0.9768 - val_loss: 0.1175 - val_accuracy: 0.9712\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 12, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "data": { 250 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD9CAYAAABHnDf0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xl8VPW9//HXN3sgYUvCYgImIMjiAoqIoojrRWrBKmVxKWor/XUTe+/9WfR3q8Vf+7her797b90KrlXLYsQNFattRZCKlqCIiAsQWUKCJgECIevMfH9/nEkYQpYJzMyZmbyfj0cec87MyZwPh8w7n3znO+cYay0iIhJfEtwuQEREQk/hLiIShxTuIiJxSOEuIhKHFO4iInFI4S4iEoc6DHdjzFPGmG+NMZvbeNwYYx40xmwzxmwyxpwV+jJFRKQzgunc/whMbufxK4Gh/q+5wB9OvCwRETkRHYa7tXYNsK+dTaYBz1rHB0AvY8yAUBUoIiKdF4ox91xgd8B6if8+ERFxSVIkd2aMmYszdEP37t3PHj58eCR3LyIS8zZs2FBhrc3paLtQhPseYGDAep7/vmNYax8DHgMYO3asLSoqCsHuRUS6DmPMzmC2C8WwzArgB/5ZM+OBKmttWQieV0REjlOHnbsxZikwCcg2xpQA9wDJANbahcBKYAqwDagBbg5XsSIiEpwOw91aO7uDxy3ws5BVJCIiJyyib6h2pLGxkZKSEurq6twuJS6kpaWRl5dHcnKy26WISIRFVbiXlJSQmZlJfn4+xhi3y4lp1loqKyspKSmhoKDA7XJEJMKi6twydXV1ZGVlKdhDwBhDVlaW/goS6aKiKtwBBXsI6ViKdF1RNSwjIhJrPF4fNY1e6hq81DZ6qfHf1jUcWa5tuvU/funwvpw5sFdY61K4Bzhw4ABLlizhpz/9aae+b8qUKSxZsoRevdr+z7r77ruZOHEil1122YmWKSJB8vksdR5/yDZ4qWtsJXAbvMeEs7Odh9pGn387j//WR22D56jvb/TaTtfVNzNV4R5JBw4c4NFHHz0m3D0eD0lJbR+qlStXdvjc99577wnXJxJR1oK3ARoOQ0O1/7blVzU01vi/wYBJAOO/hRbr5qhlawyNPmj0QoPX0uCDeo+PBh80eCz1Xv+tz1LvsdR7LfUeH/UeQ73XR53HWa/zQn2jjzqvpa7RUuex1Hl81DZaGrwWi8GHOerWtlgPvD8pIYGU5CRSk5Oc26REMlKSyE5OJCXNuT8tOYnU5FTSUpzltBTnKzU5ifSm9eQk0lOTSU9JJDU5iW6pKaT7tzMm/CPiCvcA8+fPZ/v27YwePZrk5GTS0tLo3bs3X3zxBV999RVXX301u3fvpq6ujnnz5jF37lwA8vPzKSoqorq6miuvvJILLriA999/n9zcXF599VXS09O56aabuOqqq5g+fTr5+fnMmTOH1157jcbGRl544QWGDx9OeXk51113HaWlpZx33nn85S9/YcOGDWRnZ7t8ZCTq+bxOyAaG7jHLNW3c3yKoAx/zecJWsgFS/F/dw7GDJE4s4Rr9X+Hwnf+Cc34Ypid3RG24L3jtM7aUHgzpc448qQf3fHdUm4/fd999bN68mY0bN/Luu+/yne98h82bNzdPJXzqqafo06cPtbW1nHPOOVx77bVkZWUd9Rxbt25l6dKlPP7448yYMYMXX3yRG2644Zh9ZWdn89FHH/Hoo4/ywAMP8MQTT7BgwQIuueQS7rzzTv785z/z5JNPhvTfL1HAWvDUHx2gje2FbrU/lNsI46ZlT23wNZgESMmElO6Q0g1SuuNN6k5DSha1qbnUkMYhm8pBbyoHPMnsb0ymoiGZb+uS2FuXyAFPCodJo4Y0Dts06kjBYjD4SE9OoHtyAunJCaQnGdJTEuiW5KynJSeQnmyOPJacQFqSIS3JkJ5kSEtOILVpOcmQmpRAWiKkJieQlmhITTIkJzi/FLA+51haH2DbWbf+9WC29fn/j4J97mC3b7ls4aQxofl5akfUhns0GDdu3FFzxB988EFefvllAHbv3s3WrVuPCfeCggJGjx4NwNlnn82OHTtafe5rrrmmeZuXXnoJgLVr1zY//+TJk+ndu3dI/z3SST5v22HaWpfb7vBFwOPWG3wNSen+EO4OKRlHljP6Btzf4jH/ckNCGge8qexvSKa8IYlv6pP4pjaRvYct5dUNVFTXU36onoqKBqrrW+/Q+3RPIScjleyeKeTkpnJyZipnZ6SSk5lKtv82KyOFjNQk0pISSUjQDK1oEbXh3l6HHSndux/5Y/Hdd9/lr3/9K+vWraNbt25MmjSp1TnkqampzcuJiYnU1rbeUTVtl5iYiMcTvj99pQWfDw6Xw6GyI18HA5YP7YXqb6C+upPdcCKkZkByi8DN6NcidLu1GcbNy8ndjqwnJB61G4/XR+XhBsoP1VPeFM7+25bLB+vqgGN/RnukJZGT6QTzabk9jwrqnMxUcvzLfbqnkJwYdbOlJUhRG+5uyMzM5NChQ60+VlVVRe/evenWrRtffPEFH3zwQcj3P2HCBAoLC/nVr37F22+/zf79+0O+j7hlLdQfdML5YKlze6i0xXqZE9wtx5FNAnTvC5n9odcgyBsLqT1aCeB2gjkxxXmz8Dh4fZb9NU5gV+xrCudyKqpL/IF9JMz31zQ4Iw0tZKQmkZ2RQk5mKsP6ZTLhlGyn4w4I6+zMVLIzUkhNSjz2CSTuKNwDZGVlMWHCBE477TTS09Pp169f82OTJ09m4cKFjBgxglNPPZXx48eHfP/33HMPs2fP5rnnnuO8886jf//+ZGZmhnw/McdTfyScW+u0m8K78fCx35vWEzIHOF/Zw6DHgCPrmQOc9e59ITG0LwVrLVW1jUd12C2DusJ/u+9wA17fsYmdlpzQ3FWfnNWNs/N7HwnqjCNddnZmCt1S9FKWoxnbWhsQAa1drOPzzz9nxIgRrtQTDerr60lMTCQpKYl169bxk5/8hI0bN57Qc0b1MfX5oKbi2MAO7LQPlUFN5bHfm5jqdNo9TnJuM09qsT7AuU0J3TwMay3V9Z4WQV13dGAHDJW0Nv85OdEcE9CBQyJH7nPGsfUpY2nJGLPBWju2o+306z6K7Nq1ixkzZuDz+UhJSeHxxx93u6TjV3ewjaGRpvUyqN7bylQ747xZmDkAeg6EgeOO7bQzB0B67+MeBmnPvsMNvPXZXsqq6o4K6qZAr/f4jvmexARDVveU5nAe1i/zmKDum5lKTkYaPdIV2BIZCvcoMnToUD7++GO3y2ifp8EJ5aOGRloZKmmoPvZ7U3v6u+sBUHDhsYGdOcAJ9sTIn6J4974annivmOeLdlPX6MMYyOqe0hzQ+VndjxoGyclI89+m0rtbimaJSNRRuIvD53OGP9rrtA+VOcMoLSWmHBka6XcaDL2i9aGSEA6RhMrmPVUsWlPMG5tKSUwwTBudyy0TChjWL4MkzRSRGKZw7wrqq1uMZbcyi+TQXvC1/Diege45TmfdM9eZRdKy084cAN36hGWIJFystfx9WyWL1mznva0VZKQm8aMLB3PzhHwG9Ex3uzyRkFC4x7v6Q3DfoGM/OJPa48ibjvkXtN5pZ/RzZYgkXDxeHys372XR6u18VnqQnMxUfjV5ONedO4ie6fHz7xQBhXt8qz0Atfth8EVw5uyjAzw1w+3qIqamwcMLRSU8/l4xJftrGZzTnf+49nSuHpOrOd8StxTuJyAjI4Pq6mpKS0u57bbbWL58eaefY+PGjZSWljJlyhQAVqxYwZYtW5g/f/6JFddwGPbvdKYMzvxTVI53h9u+ww088/4Onl23g/01jZw1qBd3XzWSy0b00xugEvcU7iFw0kknHVewgxPuRUVFzeE+depUpk6demIFeephX7HzwZzu2V0u2HdV1vDE2mIK/TNfLhvRlx9fNIRz8vu4XZpIxCjcA8yfP5+BAwfys5/9DIDf/OY3JCUlsWrVKvbv309jYyO//e1vmTZt2lHft2PHDq666io2b97M+PHjefLJJxk1yjk3zqRJk3jggQfw+XzMmzePuro60tPTefrppykoKODuu++mtraWtWvXcuedd1JbW0tRUREPP/wwO3bs4JZbbqGiooKcnByefvppBg0axE033USPHj0oKipi79693H///UyfPt0pxudxgt1ayBoC+76O6DF00+Y9VSxcvZ2Vn5aRmGC4enQucycOZmg/fcpXup7oDfc358PeT0P7nP1Phyvva/PhmTNncvvttzeHe2FhIW+99Ra33XYbPXr0oKKigvHjxzN16tQ2P4gyc+ZMCgsLWbBgAWVlZZSVlTF27FgOHjzIe++9R1JSEn/961+56667ePHFF7n33nubwxzgj3/8Y/Nz/eIXv2DOnDnMmTOHp556ittuu41XXnkFgLKyMtauXcsXX3zB1KlTnXC3PifMPfVOsCenhejARS9rLWu3VbBodTFrtzkzX269cDA3Tyigf8/4//eLtCV6w90FY8aM4dtvv6W0tJTy8nJ69+5N//79+eUvf8maNWtISEhgz549fPPNN/Tv37/V55gxYwZXXHEFCxYsoLCwsLmjrqqqYs6cOWzduhVjDI2NHV8FYN26dc2nA77xxhu54447mh+7+uqrSUhIYOTIkXzzzTdOp35gl/PhoV4nQ2p8d6ser483Pi1j0epitpQdpG9mKvOvdGa+9EjTzBeR6A33djrscPr+97/P8uXL2bt3LzNnzmTx4sWUl5ezYcMGkpOTyc/Pb/VUv01yc3PJyspi06ZNPP/88yxcuBCAX//611x88cW8/PLL7Nixg0mTJp1QnYGnFrbWOvPUa/cfmXcep2oaPBSu380Ta7+mZH8tQ3K6c/+1ZzBtzEma+SISIHrD3SUzZ87k1ltvpaKigtWrV1NYWEjfvn1JTk5m1apV7Ny5M6jnuP/++6mqquKMM84AnM49NzcXOHropb3TDJ9//vksW7aMG2+8kcWLF3PhhRe2vdPqvU6oZ/Rre5sYVlldzzPrdvLsuh0cqGlk7Mm9uee7o7h0eF/NfBFphT5f3cKoUaM4dOgQubm5DBgwgOuvv56ioiJOP/10nn32WYYPH97hc0yfPp1ly5YxY8aM5vvuuOMO7rzzTsaMGXPUxTkuvvhitmzZwujRo3n++eePep6HHnqIp59+mjPOOIPnnnuO3//+98furO6gM9aekumcaCuGPikajF2VNfz6lc2cf987PPi3rZyT34fl/+s8lv/kfC4fqSmNIm3RKX9jWWMtVGx1PkWaPRQSjv1DLFaP6aclVSxac2TmyzVj8rh1YgGn9I3v9xJEOqJT/sY7bwNUbneuItRnSKvBHmustby3tYJFa7bz922VZKYmcevEwdwyoYB+PTTzRaQzYj8RuiKfFyqLnfPFZA2FpBS3KzohTTNfFq4u5vOyg/TrkcqdVw5ntma+iBy3qAt3a60uZtAea2H/DufizX2GONf1bHNTd4bcglXT4OH59bt54r2v2XOgllP6ZnD/9DOYNlozX0ROVFSFe1paGpWVlWRlZSngW2MtVO12LgTdcyCk9WhnU0tlZSVpadE3nFFRXc+z7+/g2Q92cqCmkXPye7Ng6igu0cwXkZCJqnDPy8ujpKSE8vJyt0uJTnUHoe6Ac7reqnKg/eOUlpZGXl5eZGoLws7Kwzz+XjEvFJVQ7/Fxxch+/PiiwZx9cvzOyxdxS1SFe3JyMgUFBW6XEZ02vwiv3gKjroFrn4SE2JnFuqnkAIvWFPPmp2UkJSRwzVm5/OjCwZzSt+ucdlgk0oIKd2PMZOD3QCLwhLX2vhaPnww8BeQA+4AbrLUlIa6169q5Dl7+CQw6D67+Q0wEu7WWNVsrWLR6O+9vd2a+zJ04hJsn5Gvmi0gEdBjuxphE4BHgcqAEWG+MWWGt3RKw2QPAs9baZ4wxlwD/DtwYjoK7nIptsGw29MyDWUui/mRgjV4fb2wqY+Hq7Xyx9xD9eqRy15ThzB43iEzNfBGJmGA693HANmttMYAxZhkwDQgM95HAP/uXVwGvhLLILutwBSye7sxlv2F5VJ8z5nC9M/PlybXOzJehfTP4z+lnMG10LilJ0f+Xhki8CSbcc4HdAeslwLkttvkEuAZn6OZ7QKYxJstaWxmSKruixlpYOsu5ePWc16DPYLcralVFdb3/akc7qaptZFx+H+6dNoqLT9XMFxE3heoN1X8FHjbG3ASsAfYA3pYbGWPmAnMBBg0aFKJdxyGfD16aCyVFMOMZGDjO7YqOsaPiME+sdWa+NHidmS9zJw7h7JN7u12aiBBcuO8BBgas5/nva2atLcXp3DHGZADXWmsPtHwia+1jwGPgnFvmOGuOf3/5NXy+Aq74HYyc1vH2EfTJ7gM8tqaYNzcfmfly68TBDMnRzBeRaBJMuK8HhhpjCnBCfRZwXeAGxphsYJ+11gfciTNzRo7HPx6HdQ/DuLlw3s/crgZwZr6s/qqcRauLWVdcSWZaEj++aAg3n59PX818EYlKHYa7tdZjjPk58BbOVMinrLWfGWPuBYqstSuAScC/G2MszrBMdKRSrPnyTXjzDhh2JUy+z/XT9zZ6fby+qZRFq4v5Yu8h+vdI4/9MGcGscQM180UkykXVKX+7tNKP4ekpkD0Mbl4JKd1dK+VwvYdl63fz5HvFlFbVMaxfBnMnDmHqmSdp5ouIy3TK31hyYBcsmQndsuG6QteCvaK6nj/+fQfPfeCf+VLQh99+7zQmDdPMF5FYo3B3W+0BWPx9aKyDH6yAzMhfJm9Hhf+cLxtKaPT6+KeR/Zl70WDOGqSZLyKxSuHuJk8DPH+Dc9GNG1+Cvh1fwi+UPtl9gEVrtvPm5r0kJyZw7Vl53HphAYM180Uk5inc3WItvHYb7HgPvrcICiZGaLeWd78qZ9Hq7XxQvI8eaUn8dNIQ5pyfT99MzXwRiRcKd7e8ex98shQm3QVnzgr77hq9Pl77pJTH1jgzXwb0TOPfvjOCWeMGkZGqHwOReKNXtRs2LoHV98Ho6+GiO8K6K4/XxzPrdh418+X/ff9MvquZLyJxTeEeacXvwopfQMFFcNX/hH0u+8LV23ng7a84t6APv/ve6Uw6NUdXuRLpAhTukfTNFnj+Ruei1jOfC/uFrQ/VNfL4e19z2Yi+PDHnnLDuS0Sii/4uj5RDe2HJDEjuBte/AGk9w77LZ97fQVVtI7dfNizs+xKR6KLOPRLqq51gr9nnfPq018COv+cEBXbtp+WG/xeJiEQXhXu4eT2w/BbY+ynMfh5OGh2R3TZ17fMuVdcu0hUp3MPJWudEYFvfgu/8Fwy7IiK7PVTXyBNrna799Dx17SJdkcbcw+n9h6DoSTj/NjjnhxHb7bPrdnKgRl27SFemcA+Xz152Lrox8mq4bEHEduuMtRdz6XB17SJdmcI9HHZ9CC/9GAae65xaICFyh7m5a79saMT2KSLRR+EeapXbnQtb98yFWUshOXLna6mu9zR37Wfk9YrYfkUk+ijcQ+lwJSye7ixfvxy6Z0V098+8v0Ndu4gAmi0TOo11sGw2VO2BOa9B1pCI7r6pa79EXbuIoHAPDZ8PXv4x7P4Qvv8MDDo34iU0d+2XqmsXEQ3LhMbffgNbXoHL/y+Mujriu6+u9/CEv2s/c6C6dhFRuJ+49U/C338P5/wIzv+FKyU8u24H+9W1i0gAhfuJ+OptWPmvMPSfYPJ/hP30va2prvfw+JpiLj41R127iDRTuB+v0o3wwk3Q/3SY/hQkuvP2RXPXrjM/ikgAhfvxOLAblsyEbn3gukJIdeeC0ocDuvbR6tpFJIBmy3RWXZVz+t7GGvjB25DZ37VSnl23U127iLRK4d4ZngbnSkoVX8ENL0LfEa6Vcrjew2NrtjNJXbuItELhHixr4fXb4evVcPUfYPAkV8tp7to1Q0ZEWqEx92Ct+U/YuBgumg+jr3O1lMP+T6NOOjWHMYN6u1qLiEQnhXswPlkGq34HZ86GSfPdrobnPtjJvsMN6tpFpE0K944Ur4ZXfw4FE+G7D7oylz2QM9ZezEXD1LWLSNsU7u359gvnDdSsITDjOUhKcbuiI127zvwoIu1QuLfl0Dew+PvO+divfwHS3Z+RUtNwpGs/S127iLRDs2Va03DYmcteUwE3r4Reg9yuCIDn1qlrF5HgKNxb8nlh+Q9h7ybnSkonjXG7IsDp2hetKWaiunYRCYLCPZC18Oav4Ks3YcoDcOpktytq1ty1a4aMiARBY+6B1j0C6x+H834O4251u5pmTWPtE4flcPbJ6tpFpGNBhbsxZrIx5ktjzDZjzDETvY0xg4wxq4wxHxtjNhljpoS+1DDb8iq8/W8wYqpz0Y0o8qcPdlKprl1EOqHDcDfGJAKPAFcCI4HZxpiRLTb7N6DQWjsGmAU8GupCw2r3P+CluZB3DlzzGCREzx80NQ0eFq0u5sKh2eraRSRowaTYOGCbtbbYWtsALAOmtdjGAj38yz2B0tCVGGb7imHpLMgcALOXQnK62xUdpalrv10zZESkE4IJ91xgd8B6if++QL8BbjDGlAArgVavN2eMmWuMKTLGFJWXlx9HuSFWsw/+NN15I/WGF6F7ttsVHaVprN3p2vu4XY6IxJBQjT/MBv5orc0DpgDPGWOOeW5r7WPW2rHW2rE5OTkh2vVxaqyDZddBVYnTsWcNcbeeViz+YBcV1eraRaTzggn3PcDAgPU8/32BfggUAlhr1wFpQHS1wYF8PnjlJ7BrHXxvIQwa73ZFx3DmtW9X1y4ixyWYcF8PDDXGFBhjUnDeMF3RYptdwKUAxpgROOEeBeMubXjnXvjsJbhsAZx2jdvVtKqpa9cMGRE5Hh2Gu7XWA/wceAv4HGdWzGfGmHuNMVP9m/0LcKsx5hNgKXCTtdaGq+gTUvQ0rP1vGHsLTJjndjWtqm3wNnftY/PVtYtI5wX1CVVr7UqcN0oD77s7YHkLMCG0pYXB1r/AG/8CQ6+AK//T9dP3tmXxhzvVtYvICYmeCd3hVrYJXrgJ+o2C6U9DYnSeeaG2wcvC1du54BR17SJy/LpGuFeVOGd5TOsF1xVCaobbFbWpuWvXDBkROQHR2b6GUl0VLJ7hnMb3lj9DjwFuV9Qmp2sv5oJTsjlHXbuInID47ty9jVA4Byq+hBnPOkMyUczp2uvVtYvICYvfzt1aeP12KF4F0x6BIRe7XVG7mrr2CadkqWsXkRMWv537ew/Ax3+CiXfAmBvcrqZDzV37pcPcLkVE4kB8hvumQnjnt3DGLLj4Lrer6VBg1z6uQF27iJy4+Av3HWvhlZ9C/oUw9aGoncseaMk/dqlrF5GQiq9wL//SORlYn8Ew8zlISnG7og7VNTrz2s8foq5dREInfsK9+ltYPB0SU+H6FyA9Ni5ssfjDXZQfqtenUUUkpOJjtkzDYVgyEw5XwE1vQO+T3a4oKIFd+7mDs9wuR0TiSOx37j4vvHgrlG2Ea5+E3LPcrihoS9S1i0iYxH7n/tZd8OUbzonAhsfOdbnrGr38YfV2zhusrl1EQi+2O/cP/gAfLoTxP4Nz57pdTac0d+36NKqIhEHshvvnr8Gf74QR34Urfut2NZ0S2LWPV9cuImEQm+FeUuSMs+eNhWseh4TY+mcs/Ye6dhEJr9hKRYB9XzszYzL7waylkJzudkWdUtfo5Q/vqmsXkfCKvXDf8gpYL1z/ImTkuF1Npy39xy6+VdcuImEWe7NlLvilc86YKD4ve1uauvbxg/uoaxeRsIq9zh1iMtgBljV17TqHjIiEWWyGewyqa/TyqL9rP2+IunYRCS+Fe4SoaxeRSFK4R0DTvPZzC9S1i0hkKNwj4Pn1u/nmYD23X6auXUQiQ+EeZs5Y+zZ17SISUQr3MGvq2jWvXUQiSeEeRk1d+7iCPpynee0iEkEK9zAqLGoaax+KiYFruYpI/FC4h0ldo5dHV21X1y4irlC4h0lh0W72Hqzj9kvVtYtI5Cncw6De4+/a8zVDRkTcoXAPg8L1/q5dY+0i4hKFe4jVe7w8oq5dRFymcA+xpq59nrp2EXGRwj2E6j3OmR/Pye/N+eraRcRFQYW7MWayMeZLY8w2Y8z8Vh7/b2PMRv/XV8aYA6EvNfoVFpVQVlXH7ZcNU9cuIq7q8EpMxphE4BHgcqAEWG+MWWGt3dK0jbX2lwHb/wIYE4Zao5ozQ2abunYRiQrBdO7jgG3W2mJrbQOwDJjWzvazgaWhKC6WNHXt8y5V1y4i7gsm3HOB3QHrJf77jmGMORkoAN458dJiR1PXPvbk3kw4RV27iLgv1G+ozgKWW2u9rT1ojJlrjCkyxhSVl5eHeNfueUFj7SISZYIJ9z3AwID1PP99rZlFO0My1trHrLVjrbVjc3Jygq8yiqlrF5FoFEy4rweGGmMKjDEpOAG+ouVGxpjhQG9gXWhLjG4vFJVQWqV57SISXToMd2utB/g58BbwOVBorf3MGHOvMWZqwKazgGXWWhueUqNPU9d+9sm9ueCUbLfLERFp1uFUSABr7UpgZYv77m6x/pvQlRUblm9wuvb/mH6GunYRiSr6hOpxavD4eOQdde0iEp0U7sfphQ27nbF2na9dRKKQwv04NHh8PLpqO2cN6sWFQ9W1i0j0Ubgfh+UbSthzoFbz2kUkaincO6nB4+ORVdvUtYtIVFO4d1JT1z5PXbuIRDGFeyc0de1jBvViorp2EYliCvdOePEjjbWLSGxQuAepwePj4XfUtYtIbFC4B6mpa9e8dhGJBQr3IDR17aMH9uKiYfFxNksRiW8K9yAcGWtX1y4isUHh3oGmGTLq2kUklijcO/DSRyWU7K/V+dpFJKYo3NvR4PHx8KptnDmwF5PUtYtIDFG4t6Opa9dYu4jEGoV7Gxq96tpFJHYp3NvQ3LVrXruIxCCFeysavT4eemcbZ+b1ZNKp6tpFJPYo3FtxZKxd55ARkdikcG+heaxdXbuIxDCFewsvf7SH3fs0r11EYpvCPUCj18dDq7ZyRl5PLj61r9vliIgcN4V7gKauXfPaRSTWKdz9msba1bWLSDxQuPu9/PEedu2rUdcuInFB4Y6/a39HXbuIxA+FO0e6dl1lSUTiRZcPd4/XOV/76bkChNOWAAAIF0lEQVQ9uWS4unYRiQ9dPtxf/ngPOys11i4i8aVLh7vHP0NGXbuIxJsuHe5NXbvG2kUk3nTZcG/q2k/L7cGlI9S1i0h86bLh/srGUmes/VKd+VFE4k+XDHeP18dD72xV1y4icSuocDfGTDbGfGmM2WaMmd/GNjOMMVuMMZ8ZY5aEtszQaura56lrF5E4ldTRBsaYROAR4HKgBFhvjFlhrd0SsM1Q4E5ggrV2vzEmatthj9fHw+9sZdRJPbhMXbuIxKlgOvdxwDZrbbG1tgFYBkxrsc2twCPW2v0A1tpvQ1tm6Ly6sZQdlTW6ypKIxLVgwj0X2B2wXuK/L9AwYJgx5u/GmA+MMZNDVWAoNY21q2sXkXjX4bBMJ55nKDAJyAPWGGNOt9YeCNzIGDMXmAswaNCgEO06eE1d+2M3nq2uXUTiWjCd+x5gYMB6nv++QCXACmtto7X2a+ArnLA/irX2MWvtWGvt2JycyF6ftGle+8gBPbh8ZL+I7ltEJNKCCff1wFBjTIExJgWYBaxosc0rOF07xphsnGGa4hDWecJWfFLK1xWHdQ4ZEekSOgx3a60H+DnwFvA5UGit/cwYc68xZqp/s7eASmPMFmAV8L+ttZXhKrqznLF2de0i0nUENeZurV0JrGxx390Byxb4Z/9X1Gnq2hdprF1Euoi4/4Sqx3+VpZEDenCFunYR6SLiPtxf21RKccVh5mmsXUS6kLgOd4/Xx0N/28YIde0i0sXEdbg3d+06X7uIdDFxG+5en1XXLiJdVtyG+2ufHOnaExLUtYtI1xKX4e71WR7821aG989U1y4iXVJchntT1377ZeraRaRrirtw9/osD77T1LX3d7scERFXxF24v76plOJyde0i0rXFVbh7fZbf/01du4hIXIV7U9euGTIi0tXFTbgHzpD5p1Hq2kWka4ubcH99Uynb1bWLiABxEu5NXfup/dS1i4hAnIR7c9euGTIiIkAchHtg1z5ZXbuICBAH4f7Gp2Xq2kVEWojpcFfXLiLSupgO9zc+LWPbt9XcphkyIiJHidlwb+rah/XL4MrT1LWLiASK2XBf6e/a5106TF27iEgLMRnu6tpFRNoXk+G+8tMytmqsXUSkTTEX7j5/1z60bwZTThvgdjkiIlEp5sJ95Wana9e8dhGRtsVcuHdLSeSKkf3UtYuItCPJ7QI665Lh/bhkuC56LSLSnpjr3EVEpGMKdxGROKRwFxGJQwp3EZE4pHAXEYlDCncRkTikcBcRiUMKdxGROGSste7s2JhyYOdxfns2UBHCckJFdXWO6uq8aK1NdXXOidR1srU2p6ONXAv3E2GMKbLWjnW7jpZUV+eors6L1tpUV+dEoi4Ny4iIxCGFu4hIHIrVcH/M7QLaoLo6R3V1XrTWpro6J+x1xeSYu4iItC9WO3cREWlHVIe7MWayMeZLY8w2Y8z8Vh5PNcY873/8Q2NMfpTUdZMxptwYs9H/9aMI1fWUMeZbY8zmNh43xpgH/XVvMsacFSV1TTLGVAUcr7sjUNNAY8wqY8wWY8xnxph5rWwT8eMVZF1uHK80Y8w/jDGf+Ota0Mo2EX89BlmXK69H/74TjTEfG2Neb+Wx8B4va21UfgGJwHZgMJACfAKMbLHNT4GF/uVZwPNRUtdNwMMuHLOJwFnA5jYenwK8CRhgPPBhlNQ1CXg9wsdqAHCWfzkT+KqV/8eIH68g63LjeBkgw7+cDHwIjG+xjRuvx2DqcuX16N/3PwNLWvv/CvfxiubOfRywzVpbbK1tAJYB01psMw14xr+8HLjUGBPuC6sGU5crrLVrgH3tbDINeNY6PgB6GWPCfr3CIOqKOGttmbX2I//yIeBzILfFZhE/XkHWFXH+Y1DtX032f7V8wy7ir8cg63KFMSYP+A7wRBubhPV4RXO45wK7A9ZLOPaHvHkba60HqAKyoqAugGv9f8ovN8YMDHNNwQq2djec5//T+k1jzKhI7tj/5/AYnK4vkKvHq526wIXj5R9i2Ah8C/zFWtvm8Yrg6zGYusCd1+P/AHcAvjYeD+vxiuZwj2WvAfnW2jOAv3Dkt7O07iOcj1SfCTwEvBKpHRtjMoAXgduttQcjtd+OdFCXK8fLWuu11o4G8oBxxpjTIrHfjgRRV8Rfj8aYq4BvrbUbwr2vtkRzuO8BAn/D5vnva3UbY0wS0BOodLsua22ltbbev/oEcHaYawpWMMc04qy1B5v+tLbWrgSSjTHZ4d6vMSYZJ0AXW2tfamUTV45XR3W5dbwC9n8AWAVMbvGQG6/HDuty6fU4AZhqjNmBM3R7iTHmTy22CevxiuZwXw8MNcYUGGNScN5wWNFimxXAHP/ydOAd6393ws26WozLTsUZN40GK4Af+GeBjAeqrLVlbhdljOnfNNZojBmH83MZ1lDw7+9J4HNr7X+1sVnEj1cwdbl0vHKMMb38y+nA5cAXLTaL+OsxmLrceD1aa++01uZZa/NxMuIda+0NLTYL6/FKCtUThZq11mOM+TnwFs4MlaestZ8ZY+4Fiqy1K3BeBM8ZY7bhvGE3K0rqus0YMxXw+Ou6Kdx1ARhjluLMpMg2xpQA9+C8wYS1diGwEmcGyDagBrg5SuqaDvzEGOMBaoFZEfglPQG4EfjUP14LcBcwKKAuN45XMHW5cbwGAM8YYxJxfpkUWmtfd/v1GGRdrrweWxPJ46VPqIqIxKFoHpYREZHjpHAXEYlDCncRkTikcBcRiUMKdxGROKRwFxGJQwp3EZE4pHAXEYlD/x/9Q05z0ldWoAAAAABJRU5ErkJggg==\n", 251 | "text/plain": [ 252 | "
" 253 | ] 254 | }, 255 | "metadata": { 256 | "needs_background": "light" 257 | }, 258 | "output_type": "display_data" 259 | } 260 | ], 261 | "source": [ 262 | "plt.plot(history.history['accuracy'])\n", 263 | "plt.plot(history.history['val_accuracy'])\n", 264 | "plt.legend(['training', 'valivation'], loc='upper left')\n", 265 | "plt.show()" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 13, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "10000/10000 [==============================] - 1s 62us/sample - loss: 0.1191 - accuracy: 0.9667\n" 278 | ] 279 | } 280 | ], 281 | "source": [ 282 | "res = model.evaluate(x_test, y_test)\n" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.6.6" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | -------------------------------------------------------------------------------- /022-CNN/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIspeakeryhl/tensorflow2_tutorials_chinese/0e090c42d8c3363fc96e8e8b93a8be19dd4af49e/022-CNN/dog.jpg -------------------------------------------------------------------------------- /023-RNN/.ipynb_checkpoints/003-cnn_rnn-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /024-AutoEncoder/.ipynb_checkpoints/001-autoencoder-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /024-AutoEncoder/.ipynb_checkpoints/002-cnn_autoencoder-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /024-AutoEncoder/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIspeakeryhl/tensorflow2_tutorials_chinese/0e090c42d8c3363fc96e8e8b93a8be19dd4af49e/024-AutoEncoder/model.png -------------------------------------------------------------------------------- /025-GAN/.ipynb_checkpoints/002-DCGAN-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow2教程-DCGAN" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from __future__ import absolute_import, division, print_function\n", 17 | "import tensorflow as tf\n", 18 | "import glob\n", 19 | "import imageio\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import numpy as np\n", 22 | "import os\n", 23 | "import PIL\n", 24 | "import tensorflow.keras.layers as layers\n", 25 | "import time\n", 26 | "\n", 27 | "from IPython import display" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## 1.数据导入和预处理" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()\n", 44 | "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", 45 | "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 3, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "BUFFER_SIZE = 60000\n", 55 | "BATCH_SIZE = 256\n", 56 | "# Batch and shuffle the data\n", 57 | "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## 2.构建模型\n", 65 | "\n", 66 | "### 构建生成器" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "def make_generator_model():\n", 76 | " model = tf.keras.Sequential()\n", 77 | " model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", 78 | " model.add(layers.BatchNormalization())\n", 79 | " model.add(layers.LeakyReLU())\n", 80 | " \n", 81 | " model.add(layers.Reshape((7, 7, 256)))\n", 82 | " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", 83 | " \n", 84 | " model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", 85 | " assert model.output_shape == (None, 7, 7, 128) \n", 86 | " model.add(layers.BatchNormalization())\n", 87 | " model.add(layers.LeakyReLU())\n", 88 | "\n", 89 | " model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", 90 | " assert model.output_shape == (None, 14, 14, 64) \n", 91 | " model.add(layers.BatchNormalization())\n", 92 | " model.add(layers.LeakyReLU())\n", 93 | "\n", 94 | " model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", 95 | " assert model.output_shape == (None, 28, 28, 1)\n", 96 | " \n", 97 | " return model" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "生成器生成图片" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 5, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "data": { 114 | "text/plain": [ 115 | "" 116 | ] 117 | }, 118 | "execution_count": 5, 119 | "metadata": {}, 120 | "output_type": "execute_result" 121 | }, 122 | { 123 | "data": { 124 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGHRJREFUeJzt3XuM1eWZB/DvA8Md5OLUYUAEpGgFrUgHMGitF2iFqmg1VNq0LLHSNG2yNLXZBtOuxjRtt2uNtpumsJCCZQu2omJLXRBpgXIRRBy8AiLKHbkP9xnm2T/msJkqv+8zzgznHPt+Pwlh5nznPeedH/Nwzpz3Zu4OEUlPi0J3QEQKQ8UvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKJU/CKJUvGLJKoknw/Wvn1779y5c2ZeW1vb6PuO2rZowf+fO336NM1LShp/qaJZlC1btqR51Dd2/9H3HV236Puurq5uUnumqdeNMTOa19TU0Dy6bk35N43asr4fPHgQx44d499cTpOK38xuBvAogJYA/tvdf8q+vnPnzpgwYUJmfuTIEfp47AepqqqKtu3UqRPNDx48SPPzzz+f5sypU6do3q1bN5ofOnSo0fffrl072vb48eM0j/q2c+dOmpeVlWVmUQFF/7GwJ5JIVGD79u2j+dGjR2netWtXmrOf1+hntXXr1pnZlClTaNv6Gv2y38xaAvgvAKMADAAwzswGNPb+RCS/mvI7/1AAm9x9s7ufAjAbwJjm6ZaInGtNKf6eALbW+3xb7rZ/YGYTzWyNma05duxYEx5ORJrTOX+3392nuHuFu1e0b9/+XD+ciDRQU4p/O4Be9T6/MHebiHwMNKX4VwPob2Z9zaw1gLsBzGuebonIudbooT53rzGz7wD4X9QN9U1399eidmzcuUuXLrRthw4dMrNo2IgNjwBAx44daf7pT3+60Y/91ltv0Tz6vqOhn02bNmVm0RBlNKTF7huIrysbz37llVdo2yFDhtD84osvpvnLL7+cmUVDlMOHD6f5jh07aB69v7V9e/aL5Oj73rVrV2YWzQmpr0nj/O4+H8D8ptyHiBSGpveKJErFL5IoFb9IolT8IolS8YskSsUvkqi8ruevra2l45/RWDsb716/fj1tG40Jt2nThuZbt27NzD71qU/RttGa+m3bttE8Wj565ZVXNvq+oyW9AwbwhZqtWrWiORsPv+GGG2jbaCw+Wip94MCBzGzQoEG07fvvv0/z6N+kd+/eNGfj8dEcAjZ346PscaBnfpFEqfhFEqXiF0mUil8kUSp+kUSp+EUSldehPoBvxxztwFtZWZmZlZeX07bRbqpz586l+a233pqZLViwgLZt27Ytzf/0pz/R/DOf+QzNX3rppcws2uE22hk4GsqLhmfZkFp0zYcOHUrzDRs20Jz1PRpO27t3L82jXY23bNlCc7bMu7S0tEmP3VB65hdJlIpfJFEqfpFEqfhFEqXiF0mUil8kUSp+kUTlfZyfHS/MtuYG+PLSaDyajYUDwFe+8hWab9y4MTMbNmwYbfvCCy/QfOrUqTSPTl5lY+1NGYcH4uu6cuVKmrMlpqNGjaJtozkI0fbYbEv1aOlrNDcj6lt03QcPHpyZRcuJm3KUfX165hdJlIpfJFEqfpFEqfhFEqXiF0mUil8kUSp+kUQ1aZzfzLYAqAJwGkCNu1fQByspoWuR161bRx+PrWuP1oZH22svWbKE5hMnTszM5s/nBxX369eP5rNmzaJ5tMV19+7dM7PomrKjxwF+lDQAtG/fnubnnXdeZvbEE0/QthUV9McpfOz7778/M3vooYdo22jb8AjbtwLg812OHDlC2/bt2zczi45Mr685Jvnc4O585wMRKTp62S+SqKYWvwNYYGYvmVn262IRKTpNfdl/rbtvN7MLACw0szfd/R9+ec79pzARiPeTE5H8adIzv7tvz/29B8BTAD6046K7T3H3CneviBbuiEj+NLr4zayDmXU68zGAzwN4tbk6JiLnVlNe9pcBeCo3ZFEC4H/c/blm6ZWInHONLn533wwg+2zoszAzOg553XXX0fbsWOThw4fTtvv27aN5NA/g8ccfz8yi47+j9dcDBw6k+auv8hdUF1xwQWbGjoIGgN27d9N806ZNNI+u+7Rp0zKzcePG0bbLly+n+Re+8AWas7kZ/fv3p20vu+yyJuV//etfad6jR4/MrLq6mrZ98803M7MTJ07QtvVpqE8kUSp+kUSp+EUSpeIXSZSKXyRRKn6RROV16+5Tp07Ro4uj7ZS7dOmSmbGlowBf9grEW3v36dMnM7vySj7iOW/ePJr37t2b5jU1NTR/+OGHM7OSEv5PHPUtOtr8qquuovmXvvSlzGzXrl20LbvmALB+/Xqas5+JaHvsaMvy6Aju48eP05wtGY6WA7dokf2czZYKf+h+GvyVIvJPRcUvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKLyOs7fsmVLunV3mzZtaHu2tLVdu3a0bZTffvvtNGfzAP7whz/Qtmx+AhDPb4i2qGZLghctWkTb3nrrrTSPlhP/7ne/oznDtmIHgL17+abQnTp1ojm7rtEy7KqqKprv2LGD5idPnmz0/UdzDD7KWD6jZ36RRKn4RRKl4hdJlIpfJFEqfpFEqfhFEqXiF0lUXsf5a2pq6DrqSy65hLZn694XLlxI244aNYrm0Vj7mDFjMrPnnuPHFbz33ns0j45kjvrGsPX0QLx1d3Td2DbSAJ+jMGfOHNp27NixNH/77bdpzvoWbWleXl5O83vvvZfmCxYsoPmGDRsys2hbcTYfJtq/oT4984skSsUvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKLCQUEzmw7gFgB73P3y3G3dAMwB0AfAFgBj3f1AA+6LjllH65jZPuzRePaxY8do3qtXL5qz/e2jsdUDB/iliY4Hj/YD2L9/f2YWzTGI+h6dhxAdkz1z5szMbPLkybRttBdBdF4COxfghhtuoG0rKytp/thjj9E8+nm76KKLaM6wn6do/kJ9DXnm/y2Amz9w2w8ALHL3/gAW5T4XkY+RsPjdfQmADz61jAEwI/fxDAB8GxwRKTqN/Z2/zN3PnDe0C0BZM/VHRPKkyW/4ed3BYpmHi5nZRDNbY2ZrovPLRCR/Glv8u82sHAByf+/J+kJ3n+LuFe5eEW2iKSL509jinwdgfO7j8QCeaZ7uiEi+hMVvZr8HsALApWa2zczuAfBTACPNbCOAEbnPReRjJBznd/dxGdFNH/XBzAytW7fOzNeuXUvbHzp0KDObMGECbRudcT937lyas3Xp0Zhu27ZtaR6tS2fj+AD/3l9++WXalu35DwA/+tGPaP7000/TfPDgwZnZ0qVLadtofsPo0aNpzvaO+OEPf0jb3nfffTSPzpg4ePAgzadNm5aZ3XXXXbTtFVdckZl9lL0fNMNPJFEqfpFEqfhFEqXiF0mUil8kUSp+kUTl/Yjuzp07Z+bRkcxsGeT8+fNp29LSUpofPXqU5idOnMjMou2rhw0bRvMePXrQfNWqVTRnx4ez4S4AuPPOO2keLdl94oknaM6WFF9++eW0bTT82rNnT5pv27YtM3vooYdo2+ho8ui6RsOUP/vZzzKzFStW0LbseO/mXtIrIv+EVPwiiVLxiyRKxS+SKBW/SKJU/CKJUvGLJCqv4/ynT5+mY5TRuC0ba1+8eDFte/vtfI/RaNy2U6dOmVmfPn1oWzbeDMTLP6Nlmuxo85tu4iuvn332WZqfOnWK5n/84x9pftlll2Vmmzdvpm2/9rWv0by6uprmbPn4pk2baNto3keHDh1oHm0r/sgjj2Rm0bHoZtao7IP0zC+SKBW/SKJU/CKJUvGLJErFL5IoFb9IolT8IonK6zg/wLfQjo4tZttQf//736dt2Zp3ABg+fDjNt27dmpn179+ftmXHewPAyJEjaf7zn/+c5uzxozkGI0aMoPmyZctoft1119F89uzZmdmkSZNo2wceeIDm0fHhbH5EbW0tbVt3Cl22aF5ItL/EPffck5lt2LCBtmVbwUf9rk/P/CKJUvGLJErFL5IoFb9IolT8IolS8YskSsUvkqhwnN/MpgO4BcAed788d9sDAO4FcGawc7K784FNAK1atUKvXr0y82hdOzuiOxpX7d69O82jI7yXLFmSmUVjxsePH6f5L3/5S5pHY8psfsTGjRtp25MnT9I8UlLCf4TY2vTx48fTtrfccgvNL730Upq/+OKLmdlPfvIT2vZb3/oWzSNsHwOA780f7aHQu3fvzKxFi4Y/nzfkK38L4Oaz3P6Iuw/K/QkLX0SKS1j87r4EwP489EVE8qgpv/N/x8wqzWy6mXVtth6JSF40tvh/DaAfgEEAdgJ4OOsLzWyima0xszXRvmgikj+NKn533+3up929FsBUAEPJ105x9wp3r4g2PRSR/GlU8ZtZeb1P7wDAjzQVkaLTkKG+3wO4HkCpmW0D8O8ArjezQQAcwBYA3zyHfRSRcyAsfncfd5abpzXmwU6fPo19+/Zl5tH+9507d87MovHN6P2GVatW0bxr1+z3NDt27Ejb9u3bl+alpaU0Z2cGRHl01nuPHj1ofvjwYZp369aN5mxN/Xe/+13advr06TQ/duwYzd9+++3M7Mc//jFty/ZvAIChQzN/0wUQnwvA6uATn/gEbXv69OnMTOv5RSSk4hdJlIpfJFEqfpFEqfhFEqXiF0lUXrfurq2tpctb33nnHdqeDe1EW1RHw2XsOGcAePDBBzOzv/3tb7Tt7t27aR6Jhn5Wr16dmUVbcz/55JM0Hzt2LM2jrb3Zkt9oKfQdd9xBczaUBwCDBg3KzKKhuMGDB9M8Otr8s5/9LM3Z9tvR8CurIQ31iUhIxS+SKBW/SKJU/CKJUvGLJErFL5IoFb9IovJ+RDcbh4zGKNlR1Ndccw1tW1lZSfMuXbrQfPHixZnZFVdcQdtGx4OPGTOG5kuXLqX55z73ucxswYIFtO2QIUNo/ve//53ml1xyCc3ffffdzCw6Fv2FF16gec+ePWnOtoKPtmqP5mZ88YtfpPnKlStpzuadRFu9s6PJ2RLqD9Izv0iiVPwiiVLxiyRKxS+SKBW/SKJU/CKJUvGLJCqv4/wlJSV0m+roiG52THZFRQVtG40p//nPf6Y521Z84cKFtG15eTnNZ82aRXO2VTPAx4XZmnYA2Lx5M82jvkf3f8EFF2Rm0b93VVVVkx6bzTGI5k5E4/zR/Iirr76a5vPmzcvMomPVH3vsscws2u68Pj3ziyRKxS+SKBW/SKJU/CKJUvGLJErFL5IoFb9IoixaQ29mvQDMBFAGwAFMcfdHzawbgDkA+gDYAmCsux9g91VWVuZf/vKXM/Nof/oDB7LvPtp3/8UXX6R5NGa8d+/ezCw6Hpz1GwCeeeYZmt911100v+mmmzKzaDz7G9/4Bs0fffRRmt944400P3nyZGZ24YUX0rbR8eLRPIFLL700Mxs2bBhtO2PGDJpHe+uvXbuW5tdff31mxo7vBvheBHPmzMGePXuM3kFOQ575awB8z90HALgawLfNbACAHwBY5O79ASzKfS4iHxNh8bv7Tndfm/u4CsAbAHoCGAPgzH+PMwDcfq46KSLN7yP9zm9mfQBcBWAVgDJ335mLdqHu1wIR+ZhocPGbWUcATwKY5O6H62de98bBWd88MLOJZrbGzNZEe5OJSP40qPjNrBXqCn+Wu8/N3bzbzMpzeTmAPWdr6+5T3L3C3SvatWvXHH0WkWYQFr+ZGYBpAN5w91/Ui+YBGJ/7eDwA/pa1iBSVhgz1XQtgKYD1AM6cqTwZdb/3PwHgIgDvom6obz+7r9LSUr/ttttYTvvChvOeeuop2nbAgAE0b9OmDc3ZsFLfvn1pWzbkBMTDQh07dqQ528p5/376T4JTp07RPBrSYkt2AWDq1KmZGRvuAvhW7QDw9NNP05wdsx0d6f71r3+9SY8dbWm+fPnyzKy6upq2ZdtzL168GAcOHGjQUF+4nt/dlwHIurPsAWYRKWqa4SeSKBW/SKJU/CKJUvGLJErFL5IoFb9IovK6dXeLFi3oeDrbHhvgY7MPPvggbTtz5kyaR8uJ2Xj4nXfeSdtOmjSJ5tGy2uioajbW/t5779G248aNo3l0vHi0lPruu+/OzAYOHEjbzp49u0mPXVKS/eM9YsQI2nbKlCk0v/nmm2kebed+7bXXZmbRvxmbe7F69Wratj4984skSsUvkigVv0iiVPwiiVLxiyRKxS+SKBW/SKLyfkR39+7dM/NXX32VtmdbGkdbhPXu3Zvm0fprdtzztGnTaNtom+io7x06dKB5v379MrMuXbrQtq+99hrNo/0eou2zjxw5kpl16tSJtmXHWAPxPAA2L+TZZ5+lbd9//32aV1ZW0nzHjh00f+655zKz9u3b07ZDhw7NzFq1akXb1qdnfpFEqfhFEqXiF0mUil8kUSp+kUSp+EUSpeIXSVRex/lra2vpuG+0Tzsbwzx27Bhte/ToUZqfOHGC5ps3b87Mzj//fNp23bp1NI+Oqt6+fTvN2RHdrN8A6LwLAFi5ciXNozFpdv/R9x2ddzB37lyas/ufPHkybfurX/2K5tH8iWi9P7su0RwE9m/KjkT/ID3ziyRKxS+SKBW/SKJU/CKJUvGLJErFL5IoFb9Ioixar21mvQDMBFAGwAFMcfdHzewBAPcCOLPwebK7z2f3VVZW5l/96lcz8127dtG+sLHV6Lz1T37ykzRfsWIFzW+77bbMrKqqirYtLS2l+V/+8heaT5gwgebsunXu3Jm2XbZsGc2jvm/YsIHmbC+CaA+F/fv303z9+vU0HzVqVGYWzV8YPXo0zaOfpzfeeIPmbB+ErVu30rZsjsBvfvMbbN++3egd5DRkkk8NgO+5+1oz6wTgJTM7cyLBI+7+nw15IBEpLmHxu/tOADtzH1eZ2RsAep7rjonIufWRfuc3sz4ArgKwKnfTd8ys0symm1nXjDYTzWyNma2JtqsSkfxpcPGbWUcATwKY5O6HAfwaQD8Ag1D3yuDhs7Vz9ynuXuHuFe3atWuGLotIc2hQ8ZtZK9QV/ix3nwsA7r7b3U+7ey2AqQCydxUUkaITFr+ZGYBpAN5w91/Uu7283pfdAYBvvSsiRaUhQ33XAlgKYD2A2tzNkwGMQ91LfgewBcA3c28OZoqG+qIlns8//3xmFi2rjY7gjpamlpWVZWavv/46bXvo0CGaR0eTl5eX05xtK/7mm2/Sttdccw3No1/VouXK7HuLtre+6KKLaF5TU0Pzli1bNioDgPPOO4/mvXr1ovl9991H85EjR2Zm0ffFjvBesWIFDh061DxDfe6+DMDZ7oyO6YtIcdMMP5FEqfhFEqXiF0mUil8kUSp+kUSp+EUSldetu80MJSXZD/nOO+/Q9sOHD8/MoqWr0Th+tLS1bdu2mdmNN95I2+7du5fmr7zyCs1bt25Nc7Yt+ZAhQ2jbaNnszp106kbYNzaPpGPHjrTt8uXLaT5ixAiaHz58ODMbOHAgbbtp0yaaV1dX07yiooLm7Fj1aGk7+3mLjrmvT8/8IolS8YskSsUvkigVv0iiVPwiiVLxiyRKxS+SqHA9f7M+mNn7AOovPi8FwAfBC6dY+1as/QLUt8Zqzr71dne+eUVOXov/Qw9utsbd+WyIAinWvhVrvwD1rbEK1Te97BdJlIpfJFGFLv4pBX58plj7Vqz9AtS3xipI3wr6O7+IFE6hn/lFpEAKUvxmdrOZvWVmm8zsB4XoQxYz22Jm681snZmtKXBfppvZHjN7td5t3cxsoZltzP191mPSCtS3B8xse+7arTMzftTtuetbLzNbbGavm9lrZvavudsLeu1Ivwpy3fL+st/MWgLYAGAkgG0AVgMY5+588/s8MbMtACrcveBjwmZ2HYAjAGa6++W52/4DwH53/2nuP86u7v5vRdK3BwAcKfTJzbkDZcrrnywN4HYA/4ICXjvSr7EowHUrxDP/UACb3H2zu58CMBvAmAL0o+i5+xIAH9xtYwyAGbmPZ6DuhyfvMvpWFNx9p7uvzX1cBeDMydIFvXakXwVRiOLvCWBrvc+3obiO/HYAC8zsJTObWOjOnEVZvZORdgHIPkqoMMKTm/PpAydLF821a8yJ181Nb/h92LXuPhjAKADfzr28LUpe9ztbMQ3XNOjk5nw5y8nS/6+Q166xJ143t0IU/3YA9Q86uzB3W1Fw9+25v/cAeArFd/rw7jOHpOb+3lPg/vy/Yjq5+WwnS6MIrl0xnXhdiOJfDaC/mfU1s9YA7gYwrwD9+BAz65B7IwZm1gHA51F8pw/PAzA+9/F4AM8UsC//oFhObs46WRoFvnZFd+K1u+f9D4DRqHvH/20A9xeiDxn9uhjAK7k/rxW6bwB+j7qXgdWoe2/kHgDnA1gEYCOA5wF0K6K+PY6605wrUVdo5QXq27Woe0lfCWBd7s/oQl870q+CXDfN8BNJlN7wE0mUil8kUSp+kUSp+EUSpeIXSZSKXyRRKn6RRKn4RRL1fwp5YkSSP4AUAAAAAElFTkSuQmCC\n", 125 | "text/plain": [ 126 | "
" 127 | ] 128 | }, 129 | "metadata": { 130 | "needs_background": "light" 131 | }, 132 | "output_type": "display_data" 133 | } 134 | ], 135 | "source": [ 136 | "generator = make_generator_model()\n", 137 | "\n", 138 | "noise = tf.random.normal([1, 100])\n", 139 | "generated_image = generator(noise, training=False)\n", 140 | "\n", 141 | "plt.imshow(generated_image[0, :, :, 0], cmap='gray')" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "### 构造判别器" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 6, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "def make_discriminator_model():\n", 158 | " model = tf.keras.Sequential()\n", 159 | " model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', \n", 160 | " input_shape=[28, 28, 1]))\n", 161 | " model.add(layers.LeakyReLU())\n", 162 | " model.add(layers.Dropout(0.3))\n", 163 | " \n", 164 | " model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", 165 | " model.add(layers.LeakyReLU())\n", 166 | " model.add(layers.Dropout(0.3))\n", 167 | " \n", 168 | " model.add(layers.Flatten())\n", 169 | " model.add(layers.Dense(1))\n", 170 | " \n", 171 | " return model" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": {}, 177 | "source": [ 178 | "判别器判别" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 7, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "tf.Tensor([[-0.00016926]], shape=(1, 1), dtype=float32)\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "discriminator = make_discriminator_model()\n", 196 | "decision = discriminator(generated_image)\n", 197 | "print (decision)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "## 3.定义损失函数" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 19, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "# This method returns a helper function to compute cross entropy loss\n", 214 | "cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 20, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "# 判别器损失\n", 224 | "def discriminator_loss(real_output, fake_output):\n", 225 | " real_loss = cross_entropy(tf.ones_like(real_output), real_output)\n", 226 | " fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)\n", 227 | " total_loss = real_loss + fake_loss\n", 228 | " return total_loss" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 21, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "# 生成器损失\n", 238 | "def generator_loss(fake_output):\n", 239 | " return cross_entropy(tf.ones_like(fake_output), fake_output)\n", 240 | "\n" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 22, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "generator_optimizer = tf.keras.optimizers.Adam(1e-4)\n", 250 | "discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "checkpoint保持" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 23, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "checkpoint_dir = './training_checkpoints'\n", 267 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 268 | "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n", 269 | " discriminator_optimizer=discriminator_optimizer,\n", 270 | " generator=generator,\n", 271 | " discriminator=discriminator)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "## 4.训练函数" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 24, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "EPOCHS = 50\n", 288 | "noise_dim = 100\n", 289 | "num_examples_to_generate = 16\n", 290 | "\n", 291 | "# We will reuse this seed overtime (so it's easier)\n", 292 | "# to visualize progress in the animated GIF)\n", 293 | "seed = tf.random.normal([num_examples_to_generate, noise_dim])" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "训练迭代函数" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 25, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "# Notice the use of `tf.function`\n", 310 | "# This annotation causes the function to be \"compiled\".\n", 311 | "@tf.function\n", 312 | "def train_step(images):\n", 313 | " noise = tf.random.normal([BATCH_SIZE, noise_dim])\n", 314 | "\n", 315 | " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", 316 | " generated_images = generator(noise, training=True)\n", 317 | "\n", 318 | " real_output = discriminator(images, training=True)\n", 319 | " fake_output = discriminator(generated_images, training=True)\n", 320 | "\n", 321 | " gen_loss = generator_loss(fake_output)\n", 322 | " disc_loss = discriminator_loss(real_output, fake_output)\n", 323 | "\n", 324 | " gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)\n", 325 | " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)\n", 326 | "\n", 327 | " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))\n", 328 | " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "训练函数" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 26, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "def train(dataset, epochs): \n", 345 | " for epoch in range(epochs):\n", 346 | " start = time.time()\n", 347 | " \n", 348 | " for image_batch in dataset:\n", 349 | " train_step(image_batch)\n", 350 | "\n", 351 | " # Produce images for the GIF as we go\n", 352 | " display.clear_output(wait=True)\n", 353 | " generate_and_save_images(generator,\n", 354 | " epoch + 1,\n", 355 | " seed)\n", 356 | " \n", 357 | " # Save the model every 15 epochs\n", 358 | " if (epoch + 1) % 15 == 0:\n", 359 | " checkpoint.save(file_prefix = checkpoint_prefix)\n", 360 | " \n", 361 | " print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))\n", 362 | " \n", 363 | " # Generate after the final epoch\n", 364 | " display.clear_output(wait=True)\n", 365 | " generate_and_save_images(generator,\n", 366 | " epochs,\n", 367 | " seed)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "生成和保存图像" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 27, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "def generate_and_save_images(model, epoch, test_input):\n", 384 | " # Notice `training` is set to False. \n", 385 | " # This is so all layers run in inference mode (batchnorm).\n", 386 | " predictions = model(test_input, training=False)\n", 387 | "\n", 388 | " fig = plt.figure(figsize=(4,4))\n", 389 | " \n", 390 | " for i in range(predictions.shape[0]):\n", 391 | " plt.subplot(4, 4, i+1)\n", 392 | " plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')\n", 393 | " plt.axis('off')\n", 394 | " \n", 395 | " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", 396 | " plt.show()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "## 5.模型训练" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "%%time\n", 413 | "train(train_dataset, EPOCHS)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "# 生成一张动图\n", 423 | "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n", 424 | "def display_image(epoch_no):\n", 425 | " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": null, 431 | "metadata": {}, 432 | "outputs": [], 433 | "source": [ 434 | "display_image(EPOCHS)" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "## 6.训练过程的动图" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", 453 | " filenames = glob.glob('image*.png')\n", 454 | " filenames = sorted(filenames)\n", 455 | " last = -1\n", 456 | " for i,filename in enumerate(filenames):\n", 457 | " frame = 2*(i**0.5)\n", 458 | " if round(frame) > round(last):\n", 459 | " last = frame\n", 460 | " else:\n", 461 | " continue\n", 462 | " image = imageio.imread(filename)\n", 463 | " writer.append_data(image)\n", 464 | " image = imageio.imread(filename)\n", 465 | " writer.append_data(image)\n", 466 | " \n", 467 | "# A hack to display the GIF inside this notebook\n", 468 | "os.rename('dcgan.gif', 'dcgan.gif.png')" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "display.Image(filename=\"dcgan.gif.png\")" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [] 486 | } 487 | ], 488 | "metadata": { 489 | "kernelspec": { 490 | "display_name": "Python 3", 491 | "language": "python", 492 | "name": "python3" 493 | }, 494 | "language_info": { 495 | "codemirror_mode": { 496 | "name": "ipython", 497 | "version": 3 498 | }, 499 | "file_extension": ".py", 500 | "mimetype": "text/x-python", 501 | "name": "python", 502 | "nbconvert_exporter": "python", 503 | "pygments_lexer": "ipython3", 504 | "version": "3.6.8" 505 | } 506 | }, 507 | "nbformat": 4, 508 | "nbformat_minor": 2 509 | } 510 | -------------------------------------------------------------------------------- /031-Image/.ipynb_checkpoints/001-image_classification-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /104-example_classify_structured_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow2.0教程-结构化数据分类\n", 8 | "\n", 9 | "tensorflow2教程知乎专栏:https://zhuanlan.zhihu.com/c_1091021863043624960\n", 10 | "\n", 11 | "本教程展示了如何对结构化数据进行分类(例如CSV中的表格数据)。我们使用Keras定义模型,并将csv中各列的特征转化为训练的输入。 本教程包含一下功能代码:\n", 12 | "\n", 13 | "- 使用Pandas加载CSV文件。\n", 14 | "- 构建一个输入的pipeline,使用tf.data批处理和打乱数据。\n", 15 | "- 从CSV中的列映射到用于训练模型的输入要素。\n", 16 | "- 使用Keras构建,训练和评估模型。" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "2.0.0-alpha0\n" 29 | ] 30 | } 31 | ], 32 | "source": [ 33 | "from __future__ import absolute_import, division, print_function\n", 34 | "\n", 35 | "import numpy as np\n", 36 | "import pandas as pd\n", 37 | "\n", 38 | "import tensorflow as tf\n", 39 | "\n", 40 | "from tensorflow import feature_column\n", 41 | "from tensorflow.keras import layers\n", 42 | "from sklearn.model_selection import train_test_split\n", 43 | "print(tf.__version__)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## 1.数据集\n", 51 | "我们将使用克利夫兰诊所心脏病基金会提供的一个小数据集。 CSV中有几百行。 每行描述一个患者,每列描述一个属性。 我们将使用此信息来预测患者是否患有心脏病,该疾病在该数据集中是二元分类任务。\n", 52 | "\n", 53 | ">Column| Description| Feature Type | Data Type\n", 54 | ">------------|--------------------|----------------------|-----------------\n", 55 | ">Age | Age in years | Numerical | integer\n", 56 | ">Sex | (1 = male; 0 = female) | Categorical | integer\n", 57 | ">CP | Chest pain type (0, 1, 2, 3, 4) | Categorical | integer\n", 58 | ">Trestbpd | Resting blood pressure (in mm Hg on admission to the hospital) | Numerical | integer\n", 59 | ">Chol | Serum cholestoral in mg/dl | Numerical | integer\n", 60 | ">FBS | (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false) | Categorical | integer\n", 61 | ">RestECG | Resting electrocardiographic results (0, 1, 2) | Categorical | integer\n", 62 | ">Thalach | Maximum heart rate achieved | Numerical | integer\n", 63 | ">Exang | Exercise induced angina (1 = yes; 0 = no) | Categorical | integer\n", 64 | ">Oldpeak | ST depression induced by exercise relative to rest | Numerical | integer\n", 65 | ">Slope | The slope of the peak exercise ST segment | Numerical | float\n", 66 | ">CA | Number of major vessels (0-3) colored by flourosopy | Numerical | integer\n", 67 | ">Thal | 3 = normal; 6 = fixed defect; 7 = reversable defect | Categorical | string\n", 68 | ">Target | Diagnosis of heart disease (1 = true; 0 = false) | Classification | integer\n" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## 2.准备数据\n", 76 | "使用pandas读取数据" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 3, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/html": [ 87 | "
\n", 88 | "\n", 101 | "\n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063111452331215002.330fixed0
167141602860210811.523normal1
267141202290212912.622reversible0
337131302500018703.530normal0
441021302040217201.410normal0
\n", 209 | "
" 210 | ], 211 | "text/plain": [ 212 | " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\\n", 213 | "0 63 1 1 145 233 1 2 150 0 2.3 3 \n", 214 | "1 67 1 4 160 286 0 2 108 1 1.5 2 \n", 215 | "2 67 1 4 120 229 0 2 129 1 2.6 2 \n", 216 | "3 37 1 3 130 250 0 0 187 0 3.5 3 \n", 217 | "4 41 0 2 130 204 0 2 172 0 1.4 1 \n", 218 | "\n", 219 | " ca thal target \n", 220 | "0 0 fixed 0 \n", 221 | "1 3 normal 1 \n", 222 | "2 2 reversible 0 \n", 223 | "3 0 normal 0 \n", 224 | "4 0 normal 0 " 225 | ] 226 | }, 227 | "execution_count": 3, 228 | "metadata": {}, 229 | "output_type": "execute_result" 230 | } 231 | ], 232 | "source": [ 233 | "URL = 'https://storage.googleapis.com/applied-dl/heart.csv'\n", 234 | "dataframe = pd.read_csv(URL)\n", 235 | "dataframe.head()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "划分训练集验证集和测试集" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 4, 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "name": "stdout", 252 | "output_type": "stream", 253 | "text": [ 254 | "193 train examples\n", 255 | "49 validation examples\n", 256 | "61 test examples\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "train, test = train_test_split(dataframe, test_size=0.2)\n", 262 | "train, val = train_test_split(train, test_size=0.2)\n", 263 | "print(len(train), 'train examples')\n", 264 | "print(len(val), 'validation examples')\n", 265 | "print(len(test), 'test examples')" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "使用tf.data构造输入pipeline" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 7, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "def df_to_dataset(dataframe, shuffle=True, batch_size=32):\n", 282 | " dataframe = dataframe.copy()\n", 283 | " labels = dataframe.pop('target')\n", 284 | " ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))\n", 285 | " if shuffle:\n", 286 | " ds = ds.shuffle(buffer_size=len(dataframe))\n", 287 | " ds = ds.batch(batch_size)\n", 288 | " return ds\n", 289 | "\n", 290 | "batch_size = 5\n", 291 | "train_ds = df_to_dataset(train, batch_size=batch_size)\n", 292 | "val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\n", 293 | "test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)\n", 294 | "\n", 295 | "\n" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "Every feature: ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal']\n", 308 | "A batch of ages: tf.Tensor([61 51 57 51 44], shape=(5,), dtype=int32)\n", 309 | "A batch of targets: tf.Tensor([0 0 0 1 0], shape=(5,), dtype=int32)\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "for feature_batch, label_batch in train_ds.take(1):\n", 315 | " print('Every feature:', list(feature_batch.keys()))\n", 316 | " print('A batch of ages:', feature_batch['age'])\n", 317 | " print('A batch of targets:', label_batch )" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "## 3.tensorflow的feature column" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 9, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "example_batch = next(iter(train_ds))[0]" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 10, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "def demo(feature_column):\n", 343 | " feature_layer = layers.DenseFeatures(feature_column)\n", 344 | " print(feature_layer(example_batch).numpy())" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "metadata": {}, 350 | "source": [ 351 | "### 数字列\n", 352 | "特征列的输出成为模型的输入。 数字列是最简单的列类型。 它用于表示真正有价值的特征。 使用此列时,模型将从数据框中接收未更改的列值。" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 11, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "name": "stderr", 362 | "output_type": "stream", 363 | "text": [ 364 | "WARNING: Logging before flag parsing goes to stderr.\n", 365 | "W0324 13:43:10.728773 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2758: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 366 | "Instructions for updating:\n", 367 | "Use `tf.cast` instead.\n" 368 | ] 369 | }, 370 | { 371 | "name": "stdout", 372 | "output_type": "stream", 373 | "text": [ 374 | "[[61.]\n", 375 | " [51.]\n", 376 | " [57.]\n", 377 | " [51.]\n", 378 | " [44.]]\n" 379 | ] 380 | } 381 | ], 382 | "source": [ 383 | "age = feature_column.numeric_column(\"age\")\n", 384 | "demo(age)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "### Bucketized列(桶列)\n", 392 | "通常,您不希望将数字直接输入模型,而是根据数值范围将其值分成不同的类别。 考虑代表一个人年龄的原始数据。 我们可以使用bucketized列将年龄分成几个桶,而不是将年龄表示为数字列。 请注意,下面的one-hot描述了每行匹配的年龄范围。" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 12, 398 | "metadata": {}, 399 | "outputs": [ 400 | { 401 | "name": "stderr", 402 | "output_type": "stream", 403 | "text": [ 404 | "W0324 13:48:31.327955 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:2902: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 405 | "Instructions for updating:\n", 406 | "Use `tf.cast` instead.\n" 407 | ] 408 | }, 409 | { 410 | "name": "stdout", 411 | "output_type": "stream", 412 | "text": [ 413 | "[[0. 0. 0. 0. 0. 0. 1.]\n", 414 | " [0. 0. 0. 0. 0. 0. 1.]\n", 415 | " [0. 0. 0. 0. 0. 0. 1.]\n", 416 | " [0. 0. 0. 0. 0. 0. 1.]\n", 417 | " [0. 0. 0. 0. 0. 1. 0.]]\n" 418 | ] 419 | } 420 | ], 421 | "source": [ 422 | "age_buckets = feature_column.bucketized_column(age, boundaries=[\n", 423 | " 18, 25, 30, 35, 40, 50\n", 424 | "])\n", 425 | "demo(age_buckets)" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "metadata": {}, 431 | "source": [ 432 | "### 类别列\n", 433 | "在该数据集中,thal表示为字符串(例如“固定”,“正常”或“可逆”)。 我们无法直接将字符串提供给模型。 相反,我们必须首先将它们映射到数值。 类别列提供了一种将字符串表示为单热矢量的方法(就像上面用年龄段看到的那样)。 类别表可以使用categorical_column_with_vocabulary_list作为列表传递,或者使用categorical_column_with_vocabulary_file从文件加载。" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 13, 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "name": "stderr", 443 | "output_type": "stream", 444 | "text": [ 445 | "W0324 13:55:01.628555 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4307: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", 446 | "Instructions for updating:\n", 447 | "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n", 448 | "W0324 13:55:01.629235 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4362: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", 449 | "Instructions for updating:\n", 450 | "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n" 451 | ] 452 | }, 453 | { 454 | "name": "stdout", 455 | "output_type": "stream", 456 | "text": [ 457 | "[[0. 0. 1.]\n", 458 | " [0. 1. 0.]\n", 459 | " [0. 0. 1.]\n", 460 | " [0. 0. 1.]\n", 461 | " [0. 1. 0.]]\n" 462 | ] 463 | } 464 | ], 465 | "source": [ 466 | "thal = feature_column.categorical_column_with_vocabulary_list('thal', ['fixed', 'normal', 'reversible'])\n", 467 | "thal_one_hot = feature_column.indicator_column(thal)\n", 468 | "demo(thal_one_hot)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "### 嵌入列\n", 476 | "假设我们不是只有几个可能的字符串,而是每个类别有数千(或更多)值。 由于多种原因,随着类别数量的增加,使用单热编码训练神经网络变得不可行。 我们可以使用嵌入列来克服此限制。 嵌入列不是将数据表示为多维度的单热矢量,而是将数据表示为低维密集向量,其中每个单元格可以包含任意数字,而不仅仅是0或1.嵌入的大小是必须训练调整的参数。\n", 477 | "\n", 478 | "注:当分类列具有许多可能的值时,最好使用嵌入列。" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 16, 484 | "metadata": {}, 485 | "outputs": [ 486 | { 487 | "name": "stdout", 488 | "output_type": "stream", 489 | "text": [ 490 | "[[ 0.21029451 0.28502795 0.27186757 -0.13927 0.44176006 0.18506278\n", 491 | " -0.14189719 0.2901029 ]\n", 492 | " [-0.02674027 -0.21359333 -0.26675928 0.6544374 0.12530805 -0.5243998\n", 493 | " -0.23030454 -0.10796055]\n", 494 | " [ 0.21029451 0.28502795 0.27186757 -0.13927 0.44176006 0.18506278\n", 495 | " -0.14189719 0.2901029 ]\n", 496 | " [ 0.21029451 0.28502795 0.27186757 -0.13927 0.44176006 0.18506278\n", 497 | " -0.14189719 0.2901029 ]\n", 498 | " [-0.02674027 -0.21359333 -0.26675928 0.6544374 0.12530805 -0.5243998\n", 499 | " -0.23030454 -0.10796055]]\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "thal_embedding = feature_column.embedding_column(thal, dimension=8)\n", 505 | "demo(thal_embedding)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": {}, 511 | "source": [ 512 | "### 哈希特征列\n", 513 | "表示具有大量值的分类列的另一种方法是使用categorical_column_with_hash_bucket。 此功能列计算输入的哈希值,然后选择一个hash_bucket_size存储桶来编码字符串。 使用此列时,您不需要提供词汇表,并且可以选择使hash_buckets的数量远远小于实际类别的数量以节省空间。\n", 514 | "\n", 515 | "注:该技术的一个重要缺点是可能存在冲突,其中不同的字符串被映射到同一个桶。" 516 | ] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": 17, 521 | "metadata": {}, 522 | "outputs": [ 523 | { 524 | "name": "stderr", 525 | "output_type": "stream", 526 | "text": [ 527 | "W0324 14:03:23.451644 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4362: HashedCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", 528 | "Instructions for updating:\n", 529 | "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n" 530 | ] 531 | }, 532 | { 533 | "name": "stdout", 534 | "output_type": "stream", 535 | "text": [ 536 | "[[0. 0. 0. ... 0. 0. 0.]\n", 537 | " [0. 0. 0. ... 0. 0. 0.]\n", 538 | " [0. 0. 0. ... 0. 0. 0.]\n", 539 | " [0. 0. 0. ... 0. 0. 0.]\n", 540 | " [0. 0. 0. ... 0. 0. 0.]]\n" 541 | ] 542 | } 543 | ], 544 | "source": [ 545 | "thal_hashed = feature_column.categorical_column_with_hash_bucket('thal', hash_bucket_size=1000)\n", 546 | "demo(feature_column.indicator_column(thal_hashed))" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "metadata": {}, 552 | "source": [ 553 | "### 交叉功能列\n", 554 | "将特征组合成单个特征(更好地称为特征交叉),使模型能够为每个特征组合学习单独的权重。 在这里,我们将创建一个与age和thal交叉的新功能。 请注意,crossed_column不会构建所有可能组合的完整表(可能非常大)。 相反,它由hashed_column支持,因此您可以选择表的大小。" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 18, 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "name": "stderr", 564 | "output_type": "stream", 565 | "text": [ 566 | "W0324 14:09:05.265740 140513109178112 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4362: CrossedColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n", 567 | "Instructions for updating:\n", 568 | "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n" 569 | ] 570 | }, 571 | { 572 | "name": "stdout", 573 | "output_type": "stream", 574 | "text": [ 575 | "[[0. 0. 0. ... 0. 0. 0.]\n", 576 | " [0. 0. 0. ... 0. 0. 0.]\n", 577 | " [0. 0. 0. ... 0. 0. 0.]\n", 578 | " [0. 0. 0. ... 0. 0. 0.]\n", 579 | " [0. 0. 0. ... 0. 0. 0.]]\n" 580 | ] 581 | } 582 | ], 583 | "source": [ 584 | "crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=1000)\n", 585 | "demo(feature_column.indicator_column(crossed_feature))" 586 | ] 587 | }, 588 | { 589 | "cell_type": "markdown", 590 | "metadata": {}, 591 | "source": [ 592 | "## 4.选择使用feature column" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 19, 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [ 601 | "feature_columns = []\n", 602 | "\n", 603 | "# numeric cols\n", 604 | "for header in ['age', 'trestbps', 'chol', 'thalach', 'oldpeak', 'slope', 'ca']:\n", 605 | " feature_columns.append(feature_column.numeric_column(header))\n", 606 | "\n", 607 | "# bucketized cols\n", 608 | "age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n", 609 | "feature_columns.append(age_buckets)\n", 610 | "\n", 611 | "# indicator cols\n", 612 | "thal = feature_column.categorical_column_with_vocabulary_list(\n", 613 | " 'thal', ['fixed', 'normal', 'reversible'])\n", 614 | "thal_one_hot = feature_column.indicator_column(thal)\n", 615 | "feature_columns.append(thal_one_hot)\n", 616 | "\n", 617 | "# embedding cols\n", 618 | "thal_embedding = feature_column.embedding_column(thal, dimension=8)\n", 619 | "feature_columns.append(thal_embedding)\n", 620 | "\n", 621 | "# crossed cols\n", 622 | "crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=1000)\n", 623 | "crossed_feature = feature_column.indicator_column(crossed_feature)\n", 624 | "feature_columns.append(crossed_feature)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": {}, 630 | "source": [ 631 | "构建特征层" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 20, 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "feature_layer = tf.keras.layers.DenseFeatures(feature_columns)" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": 21, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "batch_size = 32\n", 650 | "train_ds = df_to_dataset(train, batch_size=batch_size)\n", 651 | "val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)\n", 652 | "test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)" 653 | ] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "metadata": {}, 658 | "source": [ 659 | "## 5.构建模型并训练" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 22, 665 | "metadata": {}, 666 | "outputs": [ 667 | { 668 | "name": "stdout", 669 | "output_type": "stream", 670 | "text": [ 671 | "Epoch 1/5\n", 672 | "7/7 [==============================] - 1s 133ms/step - loss: 1.1864 - accuracy: 0.6357 - val_loss: 0.6905 - val_accuracy: 0.5714\n", 673 | "Epoch 2/5\n", 674 | "7/7 [==============================] - 0s 24ms/step - loss: 0.9603 - accuracy: 0.6804 - val_loss: 0.4047 - val_accuracy: 0.8163\n", 675 | "Epoch 3/5\n", 676 | "7/7 [==============================] - 0s 24ms/step - loss: 0.5744 - accuracy: 0.7389 - val_loss: 0.6673 - val_accuracy: 0.7755\n", 677 | "Epoch 4/5\n", 678 | "7/7 [==============================] - 0s 24ms/step - loss: 0.4890 - accuracy: 0.8092 - val_loss: 0.6298 - val_accuracy: 0.6122\n", 679 | "Epoch 5/5\n", 680 | "7/7 [==============================] - 0s 24ms/step - loss: 0.5618 - accuracy: 0.6795 - val_loss: 0.3861 - val_accuracy: 0.8367\n" 681 | ] 682 | }, 683 | { 684 | "data": { 685 | "text/plain": [ 686 | "" 687 | ] 688 | }, 689 | "execution_count": 22, 690 | "metadata": {}, 691 | "output_type": "execute_result" 692 | } 693 | ], 694 | "source": [ 695 | "model = tf.keras.Sequential([\n", 696 | " feature_layer,\n", 697 | " layers.Dense(128, activation='relu'),\n", 698 | " layers.Dense(128, activation='relu'),\n", 699 | " layers.Dense(1, activation='sigmoid')\n", 700 | "])\n", 701 | "\n", 702 | "model.compile(optimizer='adam',\n", 703 | " loss='binary_crossentropy',\n", 704 | " metrics=['accuracy'])\n", 705 | "model.fit(train_ds, validation_data=val_ds,epochs=5)" 706 | ] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "metadata": {}, 711 | "source": [ 712 | "测试" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": 23, 718 | "metadata": {}, 719 | "outputs": [ 720 | { 721 | "name": "stdout", 722 | "output_type": "stream", 723 | "text": [ 724 | "2/2 [==============================] - 0s 16ms/step - loss: 0.8278 - accuracy: 0.6066\n", 725 | "Accuracy 0.60655737\n" 726 | ] 727 | } 728 | ], 729 | "source": [ 730 | "loss, accuracy = model.evaluate(test_ds)\n", 731 | "print(\"Accuracy\", accuracy)" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": null, 737 | "metadata": {}, 738 | "outputs": [], 739 | "source": [] 740 | } 741 | ], 742 | "metadata": { 743 | "kernelspec": { 744 | "display_name": "Python 3", 745 | "language": "python", 746 | "name": "python3" 747 | }, 748 | "language_info": { 749 | "codemirror_mode": { 750 | "name": "ipython", 751 | "version": 3 752 | }, 753 | "file_extension": ".py", 754 | "mimetype": "text/x-python", 755 | "name": "python", 756 | "nbconvert_exporter": "python", 757 | "pygments_lexer": "ipython3", 758 | "version": "3.6.6" 759 | } 760 | }, 761 | "nbformat": 4, 762 | "nbformat_minor": 2 763 | } 764 | -------------------------------------------------------------------------------- /106-example_save_and_restore_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow2.0教程-保持和读取模型" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "data": { 17 | "text/plain": [ 18 | "'2.0.0-alpha0'" 19 | ] 20 | }, 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "output_type": "execute_result" 24 | } 25 | ], 26 | "source": [ 27 | "from __future__ import absolute_import, division, print_function\n", 28 | "\n", 29 | "import os\n", 30 | "\n", 31 | "import tensorflow as tf\n", 32 | "from tensorflow import keras\n", 33 | "\n", 34 | "tf.__version__" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "导入数据" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", 51 | "\n", 52 | "train_labels = train_labels[:1000]\n", 53 | "test_labels = test_labels[:1000]\n", 54 | "\n", 55 | "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n", 56 | "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## 1.定义一个模型" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 5, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Model: \"sequential_2\"\n", 76 | "_________________________________________________________________\n", 77 | "Layer (type) Output Shape Param # \n", 78 | "=================================================================\n", 79 | "dense_4 (Dense) (None, 128) 100480 \n", 80 | "_________________________________________________________________\n", 81 | "dropout_2 (Dropout) (None, 128) 0 \n", 82 | "_________________________________________________________________\n", 83 | "dense_5 (Dense) (None, 10) 1290 \n", 84 | "=================================================================\n", 85 | "Total params: 101,770\n", 86 | "Trainable params: 101,770\n", 87 | "Non-trainable params: 0\n", 88 | "_________________________________________________________________\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "def create_model():\n", 94 | " model = keras.Sequential([\n", 95 | " keras.layers.Dense(128, activation='relu', input_shape=(784,)),\n", 96 | " keras.layers.Dropout(0.5),\n", 97 | " keras.layers.Dense(10, activation='softmax')\n", 98 | " ])\n", 99 | "\n", 100 | " model.compile(optimizer='adam',\n", 101 | " loss=keras.losses.sparse_categorical_crossentropy,\n", 102 | " metrics=['accuracy'])\n", 103 | " return model\n", 104 | "model = create_model()\n", 105 | "model.summary()\n", 106 | " " 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## 2.checkpoint回调" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Train on 1000 samples, validate on 1000 samples\n", 126 | "Epoch 1/10\n", 127 | " 544/1000 [===============>..............] - ETA: 0s - loss: 2.0658 - accuracy: 0.2831 \n", 128 | "Epoch 00001: saving model to 106save/model.ckpt\n", 129 | "1000/1000 [==============================] - 1s 855us/sample - loss: 1.8036 - accuracy: 0.4190 - val_loss: 1.3101 - val_accuracy: 0.6700\n", 130 | "Epoch 2/10\n", 131 | " 800/1000 [=======================>......] - ETA: 0s - loss: 1.0327 - accuracy: 0.7125\n", 132 | "Epoch 00002: saving model to 106save/model.ckpt\n", 133 | "1000/1000 [==============================] - 0s 132us/sample - loss: 1.0101 - accuracy: 0.7190 - val_loss: 0.8742 - val_accuracy: 0.7650\n", 134 | "Epoch 3/10\n", 135 | " 768/1000 [======================>.......] - ETA: 0s - loss: 0.7168 - accuracy: 0.7865\n", 136 | "Epoch 00003: saving model to 106save/model.ckpt\n", 137 | "1000/1000 [==============================] - 0s 113us/sample - loss: 0.7214 - accuracy: 0.7900 - val_loss: 0.7212 - val_accuracy: 0.7950\n", 138 | "Epoch 4/10\n", 139 | " 992/1000 [============================>.] - ETA: 0s - loss: 0.5918 - accuracy: 0.8367\n", 140 | "Epoch 00004: saving model to 106save/model.ckpt\n", 141 | "1000/1000 [==============================] - 0s 90us/sample - loss: 0.5904 - accuracy: 0.8380 - val_loss: 0.6292 - val_accuracy: 0.8140\n", 142 | "Epoch 5/10\n", 143 | " 864/1000 [========================>.....] - ETA: 0s - loss: 0.4970 - accuracy: 0.8600\n", 144 | "Epoch 00005: saving model to 106save/model.ckpt\n", 145 | "1000/1000 [==============================] - 0s 105us/sample - loss: 0.4997 - accuracy: 0.8600 - val_loss: 0.5710 - val_accuracy: 0.8410\n", 146 | "Epoch 6/10\n", 147 | " 896/1000 [=========================>....] - ETA: 0s - loss: 0.4247 - accuracy: 0.8839\n", 148 | "Epoch 00006: saving model to 106save/model.ckpt\n", 149 | "1000/1000 [==============================] - 0s 97us/sample - loss: 0.4316 - accuracy: 0.8810 - val_loss: 0.5430 - val_accuracy: 0.8420\n", 150 | "Epoch 7/10\n", 151 | " 32/1000 [..............................] - ETA: 0s - loss: 0.2628 - accuracy: 0.9688\n", 152 | "Epoch 00007: saving model to 106save/model.ckpt\n", 153 | "1000/1000 [==============================] - 0s 81us/sample - loss: 0.3724 - accuracy: 0.8930 - val_loss: 0.5041 - val_accuracy: 0.8480\n", 154 | "Epoch 8/10\n", 155 | " 32/1000 [..............................] - ETA: 0s - loss: 0.2136 - accuracy: 0.9375\n", 156 | "Epoch 00008: saving model to 106save/model.ckpt\n", 157 | "1000/1000 [==============================] - 0s 75us/sample - loss: 0.3221 - accuracy: 0.9030 - val_loss: 0.4861 - val_accuracy: 0.8510\n", 158 | "Epoch 9/10\n", 159 | " 960/1000 [===========================>..] - ETA: 0s - loss: 0.3195 - accuracy: 0.9177\n", 160 | "Epoch 00009: saving model to 106save/model.ckpt\n", 161 | "1000/1000 [==============================] - 0s 108us/sample - loss: 0.3230 - accuracy: 0.9150 - val_loss: 0.4580 - val_accuracy: 0.8600\n", 162 | "Epoch 10/10\n", 163 | " 704/1000 [====================>.........] - ETA: 0s - loss: 0.2577 - accuracy: 0.9219\n", 164 | "Epoch 00010: saving model to 106save/model.ckpt\n", 165 | "1000/1000 [==============================] - 0s 128us/sample - loss: 0.2701 - accuracy: 0.9170 - val_loss: 0.4465 - val_accuracy: 0.8620\n" 166 | ] 167 | }, 168 | { 169 | "data": { 170 | "text/plain": [ 171 | "" 172 | ] 173 | }, 174 | "execution_count": 7, 175 | "metadata": {}, 176 | "output_type": "execute_result" 177 | } 178 | ], 179 | "source": [ 180 | "check_path = '106save/model.ckpt'\n", 181 | "check_dir = os.path.dirname(check_path)\n", 182 | "\n", 183 | "cp_callback = tf.keras.callbacks.ModelCheckpoint(check_path, \n", 184 | " save_weights_only=True, verbose=1)\n", 185 | "model = create_model()\n", 186 | "model.fit(train_images, train_labels, epochs=10,\n", 187 | " validation_data=(test_images, test_labels),\n", 188 | " callbacks=[cp_callback])" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 9, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "checkpoint model.ckpt.data-00000-of-00001 model.ckpt.index\r\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "!ls {check_dir}" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 10, 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "name": "stdout", 215 | "output_type": "stream", 216 | "text": [ 217 | "1000/1000 [==============================] - 0s 69us/sample - loss: 2.4036 - accuracy: 0.0890\n", 218 | "Untrained model, accuracy: 8.90%\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "model = create_model()\n", 224 | "\n", 225 | "loss, acc = model.evaluate(test_images, test_labels)\n", 226 | "print(\"Untrained model, accuracy: {:5.2f}%\".format(100*acc))" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 11, 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "1000/1000 [==============================] - 0s 47us/sample - loss: 0.4465 - accuracy: 0.8620\n", 239 | "Untrained model, accuracy: 86.20%\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "model.load_weights(check_path)\n", 245 | "loss, acc = model.evaluate(test_images, test_labels)\n", 246 | "print(\"Untrained model, accuracy: {:5.2f}%\".format(100*acc))" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "## 3.设置checkpoint回调" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 12, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "Train on 1000 samples, validate on 1000 samples\n", 266 | "Epoch 1/10\n", 267 | "1000/1000 [==============================] - 1s 1ms/sample - loss: 1.7242 - accuracy: 0.4490 - val_loss: 1.2205 - val_accuracy: 0.6890\n", 268 | "Epoch 2/10\n", 269 | "1000/1000 [==============================] - 0s 102us/sample - loss: 0.9133 - accuracy: 0.7450 - val_loss: 0.8194 - val_accuracy: 0.7800\n", 270 | "Epoch 3/10\n", 271 | "1000/1000 [==============================] - 0s 88us/sample - loss: 0.6489 - accuracy: 0.8360 - val_loss: 0.6748 - val_accuracy: 0.8050\n", 272 | "Epoch 4/10\n", 273 | "1000/1000 [==============================] - 0s 78us/sample - loss: 0.5492 - accuracy: 0.8360 - val_loss: 0.6144 - val_accuracy: 0.8150\n", 274 | "Epoch 5/10\n", 275 | " 32/1000 [..............................] - ETA: 0s - loss: 0.4468 - accuracy: 0.9062\n", 276 | "Epoch 00005: saving model to 106save02/cp-0005.ckpt\n", 277 | "1000/1000 [==============================] - 0s 130us/sample - loss: 0.4755 - accuracy: 0.8750 - val_loss: 0.5483 - val_accuracy: 0.8330\n", 278 | "Epoch 6/10\n", 279 | "1000/1000 [==============================] - 0s 94us/sample - loss: 0.4191 - accuracy: 0.8790 - val_loss: 0.5164 - val_accuracy: 0.8500\n", 280 | "Epoch 7/10\n", 281 | "1000/1000 [==============================] - 0s 107us/sample - loss: 0.3699 - accuracy: 0.8980 - val_loss: 0.4935 - val_accuracy: 0.8420\n", 282 | "Epoch 8/10\n", 283 | "1000/1000 [==============================] - 0s 87us/sample - loss: 0.3404 - accuracy: 0.9070 - val_loss: 0.4559 - val_accuracy: 0.8600\n", 284 | "Epoch 9/10\n", 285 | "1000/1000 [==============================] - 0s 85us/sample - loss: 0.3060 - accuracy: 0.9250 - val_loss: 0.4513 - val_accuracy: 0.8630\n", 286 | "Epoch 10/10\n", 287 | " 800/1000 [=======================>......] - ETA: 0s - loss: 0.3016 - accuracy: 0.9150\n", 288 | "Epoch 00010: saving model to 106save02/cp-0010.ckpt\n", 289 | "1000/1000 [==============================] - 0s 120us/sample - loss: 0.2845 - accuracy: 0.9220 - val_loss: 0.4402 - val_accuracy: 0.8580\n" 290 | ] 291 | }, 292 | { 293 | "data": { 294 | "text/plain": [ 295 | "" 296 | ] 297 | }, 298 | "execution_count": 12, 299 | "metadata": {}, 300 | "output_type": "execute_result" 301 | } 302 | ], 303 | "source": [ 304 | "check_path = '106save02/cp-{epoch:04d}.ckpt'\n", 305 | "check_dir = os.path.dirname(check_path)\n", 306 | "\n", 307 | "cp_callback = tf.keras.callbacks.ModelCheckpoint(check_path,save_weights_only=True, \n", 308 | " verbose=1, period=5) # 每5\n", 309 | "model = create_model()\n", 310 | "model.fit(train_images, train_labels, epochs=10,\n", 311 | " validation_data=(test_images, test_labels),\n", 312 | " callbacks=[cp_callback])" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 14, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "name": "stdout", 322 | "output_type": "stream", 323 | "text": [ 324 | "checkpoint\t\t\t cp-0010.ckpt.data-00000-of-00001\r\n", 325 | "cp-0005.ckpt.data-00000-of-00001 cp-0010.ckpt.index\r\n", 326 | "cp-0005.ckpt.index\r\n" 327 | ] 328 | } 329 | ], 330 | "source": [ 331 | "!ls {check_dir}" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "载入最新版模型" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 16, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "106save02/cp-0010.ckpt\n" 351 | ] 352 | } 353 | ], 354 | "source": [ 355 | "latest = tf.train.latest_checkpoint(check_dir)\n", 356 | "print(latest)\n" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 18, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "1000/1000 [==============================] - 0s 78us/sample - loss: 0.4402 - accuracy: 0.8580\n", 369 | "restored model accuracy: 85.80%\n" 370 | ] 371 | } 372 | ], 373 | "source": [ 374 | "model = create_model()\n", 375 | "model.load_weights(latest)\n", 376 | "loss, acc = model.evaluate(test_images, test_labels)\n", 377 | "print('restored model accuracy: {:5.2f}%'.format(acc*100))" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "## 5.手动保持权重" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 20, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "name": "stdout", 394 | "output_type": "stream", 395 | "text": [ 396 | "1000/1000 [==============================] - 0s 69us/sample - loss: 0.4402 - accuracy: 0.8580\n", 397 | "restored model accuracy: 85.80%\n" 398 | ] 399 | } 400 | ], 401 | "source": [ 402 | "model.save_weights('106save03/manually_model.ckpt')\n", 403 | "model = create_model()\n", 404 | "model.load_weights('106save03/manually_model.ckpt')\n", 405 | "loss, acc = model.evaluate(test_images, test_labels)\n", 406 | "print('restored model accuracy: {:5.2f}%'.format(acc*100))" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "## 6.保持整个模型" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 22, 419 | "metadata": {}, 420 | "outputs": [ 421 | { 422 | "name": "stdout", 423 | "output_type": "stream", 424 | "text": [ 425 | "Train on 1000 samples, validate on 1000 samples\n", 426 | "Epoch 1/10\n", 427 | "1000/1000 [==============================] - 0s 240us/sample - loss: 1.7636 - accuracy: 0.4460 - val_loss: 1.2041 - val_accuracy: 0.7230\n", 428 | "Epoch 2/10\n", 429 | "1000/1000 [==============================] - 0s 82us/sample - loss: 0.9278 - accuracy: 0.7410 - val_loss: 0.7989 - val_accuracy: 0.7880\n", 430 | "Epoch 3/10\n", 431 | "1000/1000 [==============================] - 0s 97us/sample - loss: 0.6722 - accuracy: 0.7970 - val_loss: 0.6739 - val_accuracy: 0.8110\n", 432 | "Epoch 4/10\n", 433 | "1000/1000 [==============================] - 0s 110us/sample - loss: 0.5326 - accuracy: 0.8530 - val_loss: 0.6027 - val_accuracy: 0.8170\n", 434 | "Epoch 5/10\n", 435 | "1000/1000 [==============================] - 0s 88us/sample - loss: 0.4674 - accuracy: 0.8640 - val_loss: 0.5623 - val_accuracy: 0.8270\n", 436 | "Epoch 6/10\n", 437 | "1000/1000 [==============================] - 0s 91us/sample - loss: 0.3986 - accuracy: 0.8900 - val_loss: 0.5429 - val_accuracy: 0.8370\n", 438 | "Epoch 7/10\n", 439 | "1000/1000 [==============================] - 0s 87us/sample - loss: 0.3717 - accuracy: 0.8830 - val_loss: 0.5205 - val_accuracy: 0.8340\n", 440 | "Epoch 8/10\n", 441 | "1000/1000 [==============================] - 0s 100us/sample - loss: 0.3492 - accuracy: 0.8980 - val_loss: 0.4844 - val_accuracy: 0.8480\n", 442 | "Epoch 9/10\n", 443 | "1000/1000 [==============================] - 0s 90us/sample - loss: 0.3048 - accuracy: 0.9200 - val_loss: 0.4603 - val_accuracy: 0.8550\n", 444 | "Epoch 10/10\n", 445 | "1000/1000 [==============================] - 0s 90us/sample - loss: 0.2574 - accuracy: 0.9290 - val_loss: 0.4674 - val_accuracy: 0.8540\n" 446 | ] 447 | } 448 | ], 449 | "source": [ 450 | "model = create_model()\n", 451 | "model.fit(train_images, train_labels, epochs=10,\n", 452 | " validation_data=(test_images, test_labels),\n", 453 | " )\n", 454 | "model.save('106save03.h5')" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 23, 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "name": "stdout", 464 | "output_type": "stream", 465 | "text": [ 466 | "Model: \"sequential_11\"\n", 467 | "_________________________________________________________________\n", 468 | "Layer (type) Output Shape Param # \n", 469 | "=================================================================\n", 470 | "dense_22 (Dense) (None, 128) 100480 \n", 471 | "_________________________________________________________________\n", 472 | "dropout_11 (Dropout) (None, 128) 0 \n", 473 | "_________________________________________________________________\n", 474 | "dense_23 (Dense) (None, 10) 1290 \n", 475 | "=================================================================\n", 476 | "Total params: 101,770\n", 477 | "Trainable params: 101,770\n", 478 | "Non-trainable params: 0\n", 479 | "_________________________________________________________________\n" 480 | ] 481 | } 482 | ], 483 | "source": [ 484 | "new_model = keras.models.load_model('106save03.h5')\n", 485 | "new_model.summary()" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 24, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "name": "stdout", 495 | "output_type": "stream", 496 | "text": [ 497 | "1000/1000 [==============================] - 1s 810us/sample - loss: 0.4674 - accuracy: 0.8540\n", 498 | "restored model accuracy: 85.40%\n" 499 | ] 500 | } 501 | ], 502 | "source": [ 503 | "loss, acc = model.evaluate(test_images, test_labels)\n", 504 | "print('restored model accuracy: {:5.2f}%'.format(acc*100))" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "## 7.其他导出模型的方法" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 26, 517 | "metadata": {}, 518 | "outputs": [ 519 | { 520 | "name": "stderr", 521 | "output_type": "stream", 522 | "text": [ 523 | "WARNING: Logging before flag parsing goes to stderr.\n", 524 | "W0326 20:00:41.243743 140450529666816 deprecation.py:323] From /home/czy/anaconda3/envs/tf2_0/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.\n", 525 | "Instructions for updating:\n", 526 | "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.\n", 527 | "W0326 20:00:41.244926 140450529666816 tf_logging.py:161] Export includes no default signature!\n", 528 | "W0326 20:00:41.639915 140450529666816 tf_logging.py:161] Export includes no default signature!\n" 529 | ] 530 | }, 531 | { 532 | "data": { 533 | "text/plain": [ 534 | "'./saved_models/1553601639'" 535 | ] 536 | }, 537 | "execution_count": 26, 538 | "metadata": {}, 539 | "output_type": "execute_result" 540 | } 541 | ], 542 | "source": [ 543 | "import time\n", 544 | "saved_model_path = \"./saved_models/{}\".format(int(time.time()))\n", 545 | "\n", 546 | "tf.keras.experimental.export_saved_model(model, saved_model_path)\n", 547 | "saved_model_path" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 27, 553 | "metadata": {}, 554 | "outputs": [ 555 | { 556 | "name": "stdout", 557 | "output_type": "stream", 558 | "text": [ 559 | "Model: \"sequential_11\"\n", 560 | "_________________________________________________________________\n", 561 | "Layer (type) Output Shape Param # \n", 562 | "=================================================================\n", 563 | "dense_22 (Dense) (None, 128) 100480 \n", 564 | "_________________________________________________________________\n", 565 | "dropout_11 (Dropout) (None, 128) 0 \n", 566 | "_________________________________________________________________\n", 567 | "dense_23 (Dense) (None, 10) 1290 \n", 568 | "=================================================================\n", 569 | "Total params: 101,770\n", 570 | "Trainable params: 101,770\n", 571 | "Non-trainable params: 0\n", 572 | "_________________________________________________________________\n" 573 | ] 574 | } 575 | ], 576 | "source": [ 577 | "new_model = tf.keras.experimental.load_from_saved_model(saved_model_path)\n", 578 | "new_model.summary()" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 28, 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "1000/1000 [==============================] - 0s 131us/sample - loss: 0.4674 - accuracy: 0.8540\n", 591 | "Restored model, accuracy: 85.40%\n" 592 | ] 593 | } 594 | ], 595 | "source": [ 596 | "# 该方法必须先运行compile函数\n", 597 | "new_model.compile(optimizer=model.optimizer, # keep the optimizer that was loaded\n", 598 | " loss='sparse_categorical_crossentropy',\n", 599 | " metrics=['accuracy'])\n", 600 | "\n", 601 | "# Evaluate the restored model.\n", 602 | "loss, acc = new_model.evaluate(test_images, test_labels)\n", 603 | "print(\"Restored model, accuracy: {:5.2f}%\".format(100*acc))" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [] 612 | } 613 | ], 614 | "metadata": { 615 | "kernelspec": { 616 | "display_name": "Python 3", 617 | "language": "python", 618 | "name": "python3" 619 | }, 620 | "language_info": { 621 | "codemirror_mode": { 622 | "name": "ipython", 623 | "version": 3 624 | }, 625 | "file_extension": ".py", 626 | "mimetype": "text/x-python", 627 | "name": "python", 628 | "nbconvert_exporter": "python", 629 | "pygments_lexer": "ipython3", 630 | "version": "3.6.8" 631 | } 632 | }, 633 | "nbformat": 4, 634 | "nbformat_minor": 2 635 | } 636 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow2_tutorials_chinese 2 | 3 | tensorflow2中文教程,持续更新(不定期更新) 4 | 5 | 6 | 7 | tensorflow 2.0 正式版已上线, 后面将持续根据TensorFlow2的相关教程和学习资料。 8 | 9 | 最新tensorflow教程和相关资源,请关注微信公众号:DoitNLP, 10 | 后面我会在DoitNLP上,持续更新深度学习、NLP、Tensorflow的相关教程和前沿资讯,它将成为我们一起学习tensorflow的大本营。 11 | 12 | 13 | 当前tensorflow版本:tensorflow2.0 14 | 15 | 16 | 17 | 18 | **最全Tensorflow 2.0 教程持续更新:** 19 | https://zhuanlan.zhihu.com/p/59507137 20 | 21 | 22 | 本教程主要由tensorflow2.0官方教程的个人学习复现笔记整理而来,并借鉴了一些keras构造神经网络的方法,中文讲解,方便喜欢阅读中文教程的朋友,tensorflow官方教程:https://www.tensorflow.org 23 | 24 | 25 | [TensorFlow 2.0 教程- Keras 快速入门](https://zhuanlan.zhihu.com/p/58825020) 26 | 27 | [TensorFlow 2.0 教程-keras 函数api](https://zhuanlan.zhihu.com/p/58825710) 28 | 29 | [TensorFlow 2.0 教程-使用keras训练模型](https://zhuanlan.zhihu.com/p/58826227) 30 | 31 | [TensorFlow 2.0 教程-用keras构建自己的网络层](https://zhuanlan.zhihu.com/p/59481536) 32 | 33 | [TensorFlow 2.0 教程-keras模型保存和序列化](https://zhuanlan.zhihu.com/p/59481985) 34 | 35 | [TensorFlow 2.0 教程-eager模式](https://zhuanlan.zhihu.com/p/59482373) 36 | 37 | [TensorFlow 2.0 教程-Variables](https://zhuanlan.zhihu.com/p/59482589) 38 | 39 | [TensorFlow 2.0 教程--AutoGraph](https://zhuanlan.zhihu.com/p/59482934) 40 | 41 | TensorFlow 2.0 深度学习实践 42 | 43 | [TensorFlow2.0 教程-图像分类](https://zhuanlan.zhihu.com/p/59506238) 44 | 45 | [TensorFlow2.0 教程-文本分类](https://zhuanlan.zhihu.com/p/59506402) 46 | 47 | [TensorFlow2.0 教程-过拟合和欠拟合](https://zhuanlan.zhihu.com/p/59506543) 48 | 49 | [TensorFlow2.0教程-结构化数据分类](https://zhuanlan.zhihu.com/p/60232704) 50 | 51 | [TensorFlow2.0教程-回归](https://zhuanlan.zhihu.com/p/60238056) 52 | 53 | [TensorFlow2.0教程-保持和读取模型](https://zhuanlan.zhihu.com/p/60485936) 54 | 55 | TensorFlow 2.0 基础网络结构 56 | 57 | [TensorFlow2教程-基础MLP网络](https://zhuanlan.zhihu.com/p/60899040) 58 | 59 | [TensorFlow2教程-MLP及深度学习常见技巧](https://zhuanlan.zhihu.com/p/60900318) 60 | 61 | [TensorFlow2教程-基础CNN网络](https://zhuanlan.zhihu.com/p/60900649) 62 | 63 | [TensorFlow2教程-CNN变体网络](https://zhuanlan.zhihu.com/p/60900902) 64 | 65 | [TensorFlow2教程-文本卷积](https://zhuanlan.zhihu.com/p/60901179) 66 | 67 | [TensorFlow2教程-LSTM和GRU](https://zhuanlan.zhihu.com/p/60966714) 68 | 69 | [TensorFlow2教程-自编码器](https://zhuanlan.zhihu.com/p/61077346) 70 | 71 | [TensorFlow2教程-卷积自编码器](https://zhuanlan.zhihu.com/p/61080045) 72 | 73 | [TensorFlow2教程-词嵌入](https://zhuanlan.zhihu.com/p/61224215) 74 | 75 | [TensorFlow2教程-DCGAN](https://zhuanlan.zhihu.com/p/61280722) 76 | 77 | [TensorFlow2教程-使用Estimator构建Boosted trees](https://zhuanlan.zhihu.com/p/61400276) 78 | 79 | TensorFlow 2.0 安装 80 | 81 | [TensorFlow2教程-Ubuntu安装TensorFlow 2.0](https://zhuanlan.zhihu.com/p/61472293) 82 | 83 | [TensorFlow2教程-Windows安装tensorflow2.0](https://zhuanlan.zhihu.com/p/62036280) 84 | 85 | 86 | 完整tensorflow2.0教程代码请看[tensorflow2.0:中文教程tensorflow2_tutorials_chinese(欢迎star)](https://github.com/czy36mengfei/tensorflow2_tutorials_chinese) 87 | 88 | 更多TensorFlow 2.0 入门教程请持续关注专栏:[Tensorflow2教程](https://zhuanlan.zhihu.com/c_1091021863043624960) 89 | 90 | 深度学习入门书籍和资源推荐:https://zhuanlan.zhihu.com/p/65371424 --------------------------------------------------------------------------------