├── .gitattributes ├── README.md └── tensorflow_estimator_learn ├── CNNClassifier.ipynb ├── CNNClassifier_dataset.ipynb ├── CNN_raw.ipynb ├── DNNClassifier.ipynb ├── DNNClassifier_dataset.ipynb ├── data_csv ├── mnist_test.csv ├── mnist_train.csv └── mnist_val.csv ├── images ├── 02_convolution.png ├── 02_convolution.svg ├── 02_network_flowchart.png ├── 02_network_flowchart.svg ├── 0_TF_HELLO.png ├── dataset_classes.png ├── estimator_types.png ├── feed_tf.png ├── feed_tf_out.png ├── inputs_to_model_bridge.jpg ├── pt_sum_code.png ├── pt_sum_output.png ├── tensorflow_programming_environment.png ├── tensors_flowing.gif ├── tf_feed_out_wrong2.png ├── tf_feed_wrong.png ├── tf_feed_wrong_out_1.png ├── tf_graph.png ├── tf_sess_code.png ├── tf_sess_output.png ├── tf_sum_graph.png ├── tf_sum_output.png ├── tf_sum_sess.png ├── tf_sum_sess_code.png ├── tf_sum_sess_out.png ├── tfe_sum_code.png └── tfe_sum_output.png └── tmp ├── basic_pt.py ├── basic_tf.py ├── basic_tfe.py ├── feed_tf.py └── feed_tf_wrong.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv filter=lfs diff=lfs merge=lfs -text 2 | *.tfrecords filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow_estimator_tutorial 2 | - The tensorflow version is out of date, please pay attention to the version problem. 3 | **Enjoy tf.estimator** 4 | 5 | ## 代码结构 6 | ``` 7 | |--tensorflow_estimator_learn 8 | |--data_csv 9 | |--mnist_test.csv 10 | |--mnist_train.csv 11 | |--mnist_val.csv 12 | 13 | |--images 14 | |--ZJUAI_2018_AUT 15 | |--ZJUAI_2018_AUT 16 | 17 | |--tmp 18 | |--ZJUAI_2018_AUT 19 | |--ZJUAI_2018_AUT 20 | 21 | |--CNNClassifier.jpynb 22 | 23 | |--CNNClassifier_dataset.jpynb 24 | 25 | |--CNN_raw.jpynb 26 | 27 | |--DNNClassifier.jpynb 28 | 29 | |--DNNClassifier_dataset.jpynb 30 | ``` 31 | ## 文件说明 32 | ### data_csv 33 | data_csv文件中存放了**MNSIT**原始csv文件,分为验证、训练、测试三个部分 34 | ### images 35 | images文件中存放了**jupyter notebook**中所涉及的一些图片 36 | ### tmp 37 | tmp 文件中存放了一些临时代码 38 | ### CNNClassifier.jpynb 39 | 未采用`tf.data`API的自定义estimator实现 40 | ### CNNClassifier_dataset.jpynb 41 | 采用`tf.data`API的自定义estimator实现 42 | ### CNN_raw.jpynb 43 | 未采用高阶API的 **搭建CNN实现MNIST分类** 44 | ### DNNClassifier.jpynb 45 | 未采用`tf.data`API的预制sestimator实现 46 | ### DNNClassifier_dataset.jpynb 47 | 采用`tf.data`API的预制estimator实现 48 | 49 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/CNNClassifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n", 8 | "\n", 9 | "- 基于MNIST数据集,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n", 10 | "\n", 11 | "- TensorBoard的简单使用\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## 导入各个库" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "'1.8.0'" 30 | ] 31 | }, 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "%matplotlib inline\n", 39 | "import tensorflow as tf\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import numpy as np\n", 42 | "import pandas as pd\n", 43 | "import multiprocessing\n", 44 | "\n", 45 | "\n", 46 | "from tensorflow import data\n", 47 | "from tensorflow.python.feature_column import feature_column\n", 48 | "\n", 49 | "tf.__version__" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## MNIST数据集载入" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "### 看看MNIST数据长什么样子的\n", 64 | "\n", 65 | "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n", 66 | "\n", 67 | "More info: http://yann.lecun.com/exdb/mnist/" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "- MNIST数据集包含70000张图像和对应的标签(图像的分类)。数据集被划为3个子集:训练集,验证集和测试集。\n", 75 | "\n", 76 | "- 定义**MNIST**数据的基本信息" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 2, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n", 86 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n", 87 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n", 88 | "\n", 89 | "MULTI_THREADING = True\n", 90 | "RESUME_TRAINING = False\n", 91 | "\n", 92 | "NUM_CLASS = 10\n", 93 | "IMG_SHAPE = [28,28]\n", 94 | "\n", 95 | "IMG_WIDTH = 28\n", 96 | "IMG_HEIGHT = 28\n", 97 | "IMG_FLAT = 784\n", 98 | "NUM_CHANNEL = 1\n", 99 | "\n", 100 | "BATCH_SIZE = 128\n", 101 | "NUM_TRAIN = 55000\n", 102 | "NUM_VAL = 5000\n", 103 | "NUM_TEST = 10000" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "### 读取csv文件并查看数据信息" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 3, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "test_data (10000, 784)\n", 123 | "test_label (10000,)\n", 124 | "val_data (5000, 784)\n", 125 | "val_label (5000,)\n", 126 | "train_data (55000, 784)\n", 127 | "train_label (55000,)\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n", 133 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n", 134 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n", 135 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n", 136 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n", 137 | "\n", 138 | "train_values = train_data.values\n", 139 | "train_data = train_values[:,1:]/255.0\n", 140 | "train_label = train_values[:,0:1].squeeze()\n", 141 | "\n", 142 | "val_values = val_data.values\n", 143 | "val_data = val_values[:,1:]/255.0\n", 144 | "val_label = val_values[:,0:1].squeeze()\n", 145 | "\n", 146 | "test_values = test_data.values\n", 147 | "test_data = test_values[:,1:]/255.0\n", 148 | "test_label = test_values[:,0:1].squeeze()\n", 149 | "\n", 150 | "print('test_data',np.shape(test_data))\n", 151 | "print('test_label',np.shape(test_label))\n", 152 | "\n", 153 | "print('val_data',np.shape(val_data))\n", 154 | "print('val_label',np.shape(val_label))\n", 155 | "\n", 156 | "print('train_data',np.shape(train_data))\n", 157 | "print('train_label',np.shape(train_label))\n", 158 | "\n", 159 | "# train_data.head(10)\n", 160 | "# test_data.head(10)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "## 试试自己写一个estimator\n", 168 | "\n", 169 | "- 基于**MNIST数据集**,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n", 170 | "\n", 171 | "- [官网API](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator)\n", 172 | "\n", 173 | "- 看看有哪些参数\n", 174 | "\n", 175 | "```python\n", 176 | "__init__(\n", 177 | " model_fn,\n", 178 | " model_dir=None,\n", 179 | " config=None,\n", 180 | " params=None,\n", 181 | " warm_start_from=None\n", 182 | ")\n", 183 | "```\n", 184 | "- 本例中,重点在 **tf.estimator.Estimator** 中的 `model_fn`\n" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "### 先简单看看数据流\n", 192 | "\n", 193 | "下面的图表直接显示了本次MNIST例子的数据流向,共有**2个卷积层**,每一层卷积之后采用最大池化进行下采样(图中并未画出),最后接**2个全连接层**,实现对MNIST数据集的分类\n", 194 | "\n", 195 | "![Flowchart](images/02_network_flowchart.png)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "### 先看看input_fn之创建输入函数" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 5, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "batch_size = BATCH_SIZE\n", 212 | "\n", 213 | "# Define the input function for training\n", 214 | "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 215 | " x = {'images': np.array(train_data)},\n", 216 | " y = np.array(train_label),\n", 217 | " batch_size=batch_size,\n", 218 | " num_epochs=None, \n", 219 | " shuffle=True)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 6, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "# Evaluate the Model\n", 229 | "# Define the input function for evaluating\n", 230 | "eval_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 231 | " x = {'images': np.array(test_data)},\n", 232 | " y = np.array(test_label),\n", 233 | " batch_size=batch_size, \n", 234 | " shuffle=False)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 7, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "some images (9, 784)\n" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "# Predict some images\n", 252 | "some_images = test_data[0:9]\n", 253 | "print('some images',np.shape(some_images))\n", 254 | "\n", 255 | "# Define the input function for predicting\n", 256 | "test_input_fn = tf.estimator.inputs.numpy_input_fn(\n", 257 | " x={'images': some_images},\n", 258 | " num_epochs=1,\n", 259 | " shuffle=False)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "### 定义feature_columns" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 10, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "feature_x = tf.feature_column.numeric_column('images', shape=IMG_SHAPE)\n", 276 | "\n", 277 | "feature_columns = [feature_x]" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "### 重点在这里——model_fn\n", 285 | "\n", 286 | "\n", 287 | "#### model_fn: Model function. Follows the signature:\n", 288 | "\n", 289 | "* Args:\n", 290 | " * `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same.\n", 291 | " * `labels`: This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same (for multi-head models).If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` must still be able to handle `labels=None`.\n", 292 | " * `mode`: Optional. Specifies if this training, evaluation or prediction. See `tf.estimator.ModeKeys`.\n", 293 | " * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning.\n", 294 | " * `config`: Optional `estimator.RunConfig` object. Will receive what is passed to Estimator as its `config` parameter, or a default value. Allows setting up things in your `model_fn` based on configuration such as `num_ps_replicas`, or `model_dir`.\n", 295 | "* Returns:\n", 296 | " `tf.estimator.EstimatorSpec`\n", 297 | " \n", 298 | "#### 注意model_fn返回的tf.estimator.EstimatorSpec\n", 299 | "\n", 300 | "\n", 301 | "\n" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "### 定义我们自己的model_fn" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 19, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "def model_fn(features, labels, mode, params):\n", 318 | " # Args:\n", 319 | " #\n", 320 | " # features: This is the x-arg from the input_fn.\n", 321 | " # labels: This is the y-arg from the input_fn,\n", 322 | " # see e.g. train_input_fn for these two.\n", 323 | " # mode: Either TRAIN, EVAL, or PREDICT\n", 324 | " # params: User-defined hyper-parameters, e.g. learning-rate.\n", 325 | " \n", 326 | " # Reference to the tensor named \"x\" in the input-function.\n", 327 | "# x = features[\"images\"]\n", 328 | " x = tf.feature_column.input_layer(features, params['feature_columns'])\n", 329 | " # The convolutional layers expect 4-rank tensors\n", 330 | " # but x is a 2-rank tensor, so reshape it.\n", 331 | " net = tf.reshape(x, [-1, IMG_HEIGHT, IMG_WIDTH, NUM_CHANNEL]) \n", 332 | "\n", 333 | " # First convolutional layer.\n", 334 | " net = tf.layers.conv2d(inputs=net, name='layer_conv1',\n", 335 | " filters=16, kernel_size=5,\n", 336 | " padding='same', activation=tf.nn.relu)\n", 337 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2)\n", 338 | "\n", 339 | " # Second convolutional layer.\n", 340 | " net = tf.layers.conv2d(inputs=net, name='layer_conv2',\n", 341 | " filters=36, kernel_size=5,\n", 342 | " padding='same', activation=tf.nn.relu)\n", 343 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2) \n", 344 | "\n", 345 | " # Flatten to a 2-rank tensor.\n", 346 | " net = tf.contrib.layers.flatten(net)\n", 347 | " # Eventually this should be replaced with:\n", 348 | " # net = tf.layers.flatten(net)\n", 349 | "\n", 350 | " # First fully-connected / dense layer.\n", 351 | " # This uses the ReLU activation function.\n", 352 | " net = tf.layers.dense(inputs=net, name='layer_fc1',\n", 353 | " units=128, activation=tf.nn.relu) \n", 354 | "\n", 355 | " # Second fully-connected / dense layer.\n", 356 | " # This is the last layer so it does not use an activation function.\n", 357 | " net = tf.layers.dense(inputs=net, name='layer_fc2',\n", 358 | " units=10)\n", 359 | "\n", 360 | " # Logits output of the neural network.\n", 361 | " logits = net\n", 362 | "\n", 363 | " # Softmax output of the neural network.\n", 364 | " y_pred = tf.nn.softmax(logits=logits)\n", 365 | " \n", 366 | " # Classification output of the neural network.\n", 367 | " y_pred_cls = tf.argmax(y_pred, axis=1)\n", 368 | "\n", 369 | " if mode == tf.estimator.ModeKeys.PREDICT:\n", 370 | " # If the estimator is supposed to be in prediction-mode\n", 371 | " # then use the predicted class-number that is output by\n", 372 | " # the neural network. Optimization etc. is not needed.\n", 373 | " spec = tf.estimator.EstimatorSpec(mode=mode,\n", 374 | " predictions=y_pred_cls)\n", 375 | " else:\n", 376 | " # Otherwise the estimator is supposed to be in either\n", 377 | " # training or evaluation-mode. Note that the loss-function\n", 378 | " # is also required in Evaluation mode.\n", 379 | " \n", 380 | " # Define the loss-function to be optimized, by first\n", 381 | " # calculating the cross-entropy between the output of\n", 382 | " # the neural network and the true labels for the input data.\n", 383 | " # This gives the cross-entropy for each image in the batch.\n", 384 | " cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,\n", 385 | " logits=logits)\n", 386 | "\n", 387 | " # Reduce the cross-entropy batch-tensor to a single number\n", 388 | " # which can be used in optimization of the neural network.\n", 389 | " loss = tf.reduce_mean(cross_entropy)\n", 390 | "\n", 391 | " # Define the optimizer for improving the neural network.\n", 392 | " optimizer = tf.train.AdamOptimizer(learning_rate=params[\"learning_rate\"])\n", 393 | "\n", 394 | " # Get the TensorFlow op for doing a single optimization step.\n", 395 | " train_op = optimizer.minimize(\n", 396 | " loss=loss, global_step=tf.train.get_global_step())\n", 397 | "\n", 398 | " # Define the evaluation metrics,\n", 399 | " # in this case the classification accuracy.\n", 400 | " metrics = \\\n", 401 | " {\n", 402 | " \"accuracy\": tf.metrics.accuracy(labels, y_pred_cls)\n", 403 | " }\n", 404 | "\n", 405 | " # Wrap all of this in an EstimatorSpec.\n", 406 | " spec = tf.estimator.EstimatorSpec(\n", 407 | " mode=mode,\n", 408 | " loss=loss,\n", 409 | " train_op=train_op,\n", 410 | " eval_metric_ops=metrics)\n", 411 | " \n", 412 | " return spec" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": {}, 418 | "source": [ 419 | "### 自建的estimator在这里\n", 420 | "\n", 421 | "我们可以指定超参数,例如优化器的学习率。" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 23, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "params = {\"learning_rate\": 1e-4,\n", 431 | " 'feature_columns': feature_columns}" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 24, 437 | "metadata": {}, 438 | "outputs": [ 439 | { 440 | "name": "stdout", 441 | "output_type": "stream", 442 | "text": [ 443 | "INFO:tensorflow:Using default config.\n", 444 | "INFO:tensorflow:Using config: {'_evaluation_master': '', '_session_config': None, '_save_checkpoints_secs': 600, '_log_step_count_steps': 100, '_master': '', '_is_chief': True, '_task_type': 'worker', '_keep_checkpoint_max': 5, '_service': None, '_tf_random_seed': None, '_cluster_spec': , '_model_dir': './cnn_classifer/', '_task_id': 0, '_num_worker_replicas': 1, '_num_ps_replicas': 0, '_save_checkpoints_steps': None, '_global_id_in_cluster': 0, '_train_distribute': None, '_keep_checkpoint_every_n_hours': 10000, '_save_summary_steps': 100}\n" 445 | ] 446 | } 447 | ], 448 | "source": [ 449 | "model = tf.estimator.Estimator(model_fn=model_fn,\n", 450 | " params=params,\n", 451 | " model_dir=\"./cnn_classifer/\")" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "### 训练训练看看" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 26, 464 | "metadata": {}, 465 | "outputs": [ 466 | { 467 | "name": "stdout", 468 | "output_type": "stream", 469 | "text": [ 470 | "INFO:tensorflow:Calling model_fn.\n", 471 | "INFO:tensorflow:Done calling model_fn.\n", 472 | "INFO:tensorflow:Create CheckpointSaverHook.\n", 473 | "INFO:tensorflow:Graph was finalized.\n", 474 | "INFO:tensorflow:Running local_init_op.\n", 475 | "INFO:tensorflow:Done running local_init_op.\n", 476 | "INFO:tensorflow:Saving checkpoints for 1 into ./cnn_classifer/model.ckpt.\n", 477 | "INFO:tensorflow:step = 1, loss = 2.3124514\n", 478 | "INFO:tensorflow:global_step/sec: 5.13683\n", 479 | "INFO:tensorflow:step = 101, loss = 1.004812 (19.469 sec)\n", 480 | "INFO:tensorflow:global_step/sec: 4.44593\n", 481 | "INFO:tensorflow:step = 201, loss = 0.40566427 (22.492 sec)\n", 482 | "INFO:tensorflow:global_step/sec: 5.59063\n", 483 | "INFO:tensorflow:step = 301, loss = 0.28785554 (17.887 sec)\n", 484 | "INFO:tensorflow:global_step/sec: 5.88434\n", 485 | "INFO:tensorflow:step = 401, loss = 0.23790869 (16.994 sec)\n", 486 | "INFO:tensorflow:global_step/sec: 5.28758\n", 487 | "INFO:tensorflow:step = 501, loss = 0.2865603 (18.912 sec)\n", 488 | "INFO:tensorflow:global_step/sec: 5.55467\n", 489 | "INFO:tensorflow:step = 601, loss = 0.27893203 (18.004 sec)\n", 490 | "INFO:tensorflow:global_step/sec: 5.46903\n", 491 | "INFO:tensorflow:step = 701, loss = 0.13836136 (18.286 sec)\n", 492 | "INFO:tensorflow:global_step/sec: 5.41053\n", 493 | "INFO:tensorflow:step = 801, loss = 0.12664635 (18.480 sec)\n", 494 | "INFO:tensorflow:global_step/sec: 5.21324\n", 495 | "INFO:tensorflow:step = 901, loss = 0.22681555 (19.184 sec)\n", 496 | "INFO:tensorflow:global_step/sec: 5.59755\n", 497 | "INFO:tensorflow:step = 1001, loss = 0.19516315 (17.862 sec)\n", 498 | "INFO:tensorflow:global_step/sec: 5.7998\n", 499 | "INFO:tensorflow:step = 1101, loss = 0.15528539 (17.242 sec)\n", 500 | "INFO:tensorflow:global_step/sec: 5.63879\n", 501 | "INFO:tensorflow:step = 1201, loss = 0.07765657 (17.734 sec)\n", 502 | "INFO:tensorflow:global_step/sec: 5.5637\n", 503 | "INFO:tensorflow:step = 1301, loss = 0.11297858 (17.974 sec)\n", 504 | "INFO:tensorflow:global_step/sec: 5.15256\n", 505 | "INFO:tensorflow:step = 1401, loss = 0.13372605 (19.412 sec)\n", 506 | "INFO:tensorflow:global_step/sec: 5.43482\n", 507 | "INFO:tensorflow:step = 1501, loss = 0.13708562 (18.397 sec)\n", 508 | "INFO:tensorflow:global_step/sec: 5.36527\n", 509 | "INFO:tensorflow:step = 1601, loss = 0.050685763 (18.639 sec)\n", 510 | "INFO:tensorflow:global_step/sec: 5.23113\n", 511 | "INFO:tensorflow:step = 1701, loss = 0.06853628 (19.115 sec)\n", 512 | "INFO:tensorflow:global_step/sec: 5.25113\n", 513 | "INFO:tensorflow:step = 1801, loss = 0.11101746 (19.058 sec)\n", 514 | "INFO:tensorflow:global_step/sec: 5.02226\n", 515 | "INFO:tensorflow:step = 1901, loss = 0.091775164 (19.900 sec)\n", 516 | "INFO:tensorflow:Saving checkpoints for 2000 into ./cnn_classifer/model.ckpt.\n", 517 | "INFO:tensorflow:Loss for final step: 0.08684543.\n" 518 | ] 519 | }, 520 | { 521 | "data": { 522 | "text/plain": [ 523 | "" 524 | ] 525 | }, 526 | "execution_count": 26, 527 | "metadata": {}, 528 | "output_type": "execute_result" 529 | } 530 | ], 531 | "source": [ 532 | "model.train(input_fn=train_input_fn, steps=2000)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 27, 538 | "metadata": {}, 539 | "outputs": [ 540 | { 541 | "name": "stdout", 542 | "output_type": "stream", 543 | "text": [ 544 | "INFO:tensorflow:Calling model_fn.\n", 545 | "INFO:tensorflow:Done calling model_fn.\n", 546 | "INFO:tensorflow:Starting evaluation at 2018-10-25-04:44:07\n", 547 | "INFO:tensorflow:Graph was finalized.\n", 548 | "INFO:tensorflow:Restoring parameters from ./cnn_classifer/model.ckpt-2000\n", 549 | "INFO:tensorflow:Running local_init_op.\n", 550 | "INFO:tensorflow:Done running local_init_op.\n", 551 | "INFO:tensorflow:Finished evaluation at 2018-10-25-04:44:14\n", 552 | "INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.9761, global_step = 2000, loss = 0.07788641\n" 553 | ] 554 | }, 555 | { 556 | "data": { 557 | "text/plain": [ 558 | "{'accuracy': 0.9761, 'global_step': 2000, 'loss': 0.07788641}" 559 | ] 560 | }, 561 | "execution_count": 27, 562 | "metadata": {}, 563 | "output_type": "execute_result" 564 | } 565 | ], 566 | "source": [ 567 | "# Use the Estimator 'evaluate' method\n", 568 | "model.evaluate(eval_input_fn)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": {}, 574 | "source": [ 575 | "### 测试一下瞅瞅" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 28, 581 | "metadata": {}, 582 | "outputs": [ 583 | { 584 | "name": "stdout", 585 | "output_type": "stream", 586 | "text": [ 587 | "INFO:tensorflow:Calling model_fn.\n", 588 | "INFO:tensorflow:Done calling model_fn.\n", 589 | "INFO:tensorflow:Graph was finalized.\n", 590 | "INFO:tensorflow:Restoring parameters from ./cnn_classifer/model.ckpt-2000\n", 591 | "INFO:tensorflow:Running local_init_op.\n", 592 | "INFO:tensorflow:Done running local_init_op.\n" 593 | ] 594 | }, 595 | { 596 | "data": { 597 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADO5JREFUeJzt3V2IXfW5x/Hf76QpiOlFYjUMNpqeogerSKKjCMYS9VhyYiEWg9SLkkLJ9CJKCyVU7EVzWaQv1JvAlIbGkmMrpNUoYmNjMQ1qcSJqEmNiElIzMW9lhCaCtNGnF7Nsp3H2f+/st7XH5/uBYfZez3p52Mxv1lp77bX/jggByOe/6m4AQD0IP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpD7Vz43Z5uOEQI9FhFuZr6M9v+1ltvfZPmD7gU7WBaC/3O5n+23PkrRf0h2SxiW9LOneiHijsAx7fqDH+rHnv1HSgYg4FBF/l/RrSSs6WB+APuok/JdKOjLl+Xg17T/YHrE9Znusg20B6LKev+EXEaOSRiUO+4FB0sme/6ikBVOef66aBmAG6CT8L0u6wvbnbX9a0tckbelOWwB6re3D/og4a/s+Sb+XNEvShojY07XOAPRU25f62toY5/xAz/XlQz4AZi7CDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp7iG5Jsn1Y0mlJH0g6GxHD3WgKQO91FP7KrRHx1y6sB0AfcdgPJNVp+EPSVts7bY90oyEA/dHpYf+SiDhq+xJJz9p+MyK2T52h+qfAPwZgwDgiurMie52kMxHxo8I83dkYgIYiwq3M1/Zhv+0LbX/mo8eSvixpd7vrA9BfnRz2z5f0O9sfref/I+KZrnQFoOe6dtjf0sY47Ad6rueH/QBmNsIPJEX4gaQIP5AU4QeSIvxAUt24qy+FlStXNqytXr26uOw777xTrL///vvF+qZNm4r148ePN6wdOHCguCzyYs8PJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxS2+LDh061LC2cOHC/jUyjdOnTzes7dmzp4+dDJbx8fGGtYceeqi47NjYWLfb6Rtu6QVQRPiBpAg/kBThB5Ii/EBShB9IivADSXE/f4tK9+xfe+21xWX37t1brF911VXF+nXXXVesL126tGHtpptuKi575MiRYn3BggXFeifOnj1brJ86dapYHxoaanvbb7/9drE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR+ftsbJH1F0smIuKaaNk/SbyQtlHRY0j0R8W7Tjc3g+/kH2dy5cxvWFi1aVFx2586dxfoNN9zQVk+taDZewf79+4v1Zp+fmDdvXsPamjVrisuuX7++WB9k3byf/5eSlp0z7QFJ2yLiCknbqucAZpCm4Y+I7ZImzpm8QtLG6vFGSXd1uS8APdbuOf/8iDhWPT4uaX6X+gHQJx1/tj8ionQub3tE0kin2wHQXe3u+U/YHpKk6vfJRjNGxGhEDEfEcJvbAtAD7YZ/i6RV1eNVkp7oTjsA+qVp+G0/KulFSf9je9z2NyX9UNIdtt+S9L/VcwAzCN/bj4F19913F+uPPfZYsb579+6GtVtvvbW47MTEuRe4Zg6+tx9AEeEHkiL8QFKEH0iK8ANJEX4gKS71oTaXXHJJsb5r166Oll+5cmXD2ubNm4vLzmRc6gNQRPiBpAg/kBThB5Ii/EBShB9IivADSTFEN2rT7OuzL7744mL93XfL3xa/b9++8+4pE/b8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU9/Ojp26++eaGteeee6647OzZs4v1pUuXFuvbt28v1j+puJ8fQBHhB5Ii/EBShB9IivADSRF+ICnCDyTV9H5+2xskfUXSyYi4ppq2TtJqSaeq2R6MiKd71SRmruXLlzesNbuOv23btmL9xRdfbKsnTGplz/9LScummf7TiFhU/RB8YIZpGv6I2C5pog+9AOijTs7577P9uu0Ntud2rSMAfdFu+NdL+oKkRZKOSfpxoxltj9gesz3W5rYA9EBb4Y+IExHxQUR8KOnnkm4szDsaEcMRMdxukwC6r63w2x6a8vSrknZ3px0A/dLKpb5HJS2V9Fnb45J+IGmp7UWSQtJhSd/qYY8AeoD7+dGRCy64oFjfsWNHw9rVV19dXPa2224r1l944YViPSvu5wdQRPiBpAg/kBThB5Ii/EBShB9IiiG60ZG1a9cW64sXL25Ye+aZZ4rLcimvt9jzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBS3NKLojvvvLNYf/zxx4v19957r2Ft2bLpvhT631566aViHdPjll4ARYQfSIrwA0kRfiApwg8kRfiBpAg/kBT38yd30UUXFesPP/xwsT5r1qxi/emnGw/gzHX8erHnB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkmt7Pb3uBpEckzZcUkkYj4me250n6jaSFkg5Luici3m2yLu7n77Nm1+GbXWu//vrri/WDBw8W66V79psti/Z0837+s5K+GxFflHSTpDW2vyjpAUnbIuIKSduq5wBmiKbhj4hjEfFK9fi0pL2SLpW0QtLGaraNku7qVZMAuu+8zvltL5S0WNKfJc2PiGNV6bgmTwsAzBAtf7bf9hxJmyV9JyL+Zv/7tCIiotH5vO0RSSOdNgqgu1ra89uercngb4qI31aTT9gequpDkk5Ot2xEjEbEcEQMd6NhAN3RNPye3MX/QtLeiPjJlNIWSauqx6skPdH99gD0SiuX+pZI+pOkXZI+rCY/qMnz/sckXSbpL5q81DfRZF1c6uuzK6+8slh/8803O1r/ihUrivUnn3yyo/Xj/LV6qa/pOX9E7JDUaGW3n09TAAYHn/ADkiL8QFKEH0iK8ANJEX4gKcIPJMVXd38CXH755Q1rW7du7Wjda9euLdafeuqpjtaP+rDnB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuM7/CTAy0vhb0i677LKO1v38888X682+DwKDiz0/kBThB5Ii/EBShB9IivADSRF+ICnCDyTFdf4ZYMmSJcX6/fff36dO8EnCnh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp6nd/2AkmPSJovKSSNRsTPbK+TtFrSqWrWByPi6V41mtktt9xSrM+ZM6ftdR88eLBYP3PmTNvrxmBr5UM+ZyV9NyJesf0ZSTttP1vVfhoRP+pdewB6pWn4I+KYpGPV49O290q6tNeNAeit8zrnt71Q0mJJf64m3Wf7ddsbbM9tsMyI7THbYx11CqCrWg6/7TmSNkv6TkT8TdJ6SV+QtEiTRwY/nm65iBiNiOGIGO5CvwC6pKXw256tyeBviojfSlJEnIiIDyLiQ0k/l3Rj79oE0G1Nw2/bkn4haW9E/GTK9KEps31V0u7utwegV1p5t/9mSV+XtMv2q9W0ByXda3uRJi//HZb0rZ50iI689tprxfrtt99erE9MTHSzHQyQVt7t3yHJ05S4pg/MYHzCD0iK8ANJEX4gKcIPJEX4gaQIP5CU+znEsm3GcwZ6LCKmuzT/Mez5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpfg/R/VdJf5ny/LPVtEE0qL0Nal8SvbWrm71d3uqMff2Qz8c2bo8N6nf7DWpvg9qXRG/tqqs3DvuBpAg/kFTd4R+tefslg9rboPYl0Vu7aumt1nN+APWpe88PoCa1hN/2Mtv7bB+w/UAdPTRi+7DtXbZfrXuIsWoYtJO2d0+ZNs/2s7bfqn5PO0xaTb2ts320eu1etb28pt4W2P6j7Tds77H97Wp6ra9doa9aXre+H/bbniVpv6Q7JI1LelnSvRHxRl8bacD2YUnDEVH7NWHbX5J0RtIjEXFNNe0hSRMR8cPqH+fciPjegPS2TtKZukdurgaUGZo6srSkuyR9QzW+doW+7lENr1sde/4bJR2IiEMR8XdJv5a0ooY+Bl5EbJd07qgZKyRtrB5v1OQfT9816G0gRMSxiHilenxa0kcjS9f62hX6qkUd4b9U0pEpz8c1WEN+h6SttnfaHqm7mWnMr4ZNl6TjkubX2cw0mo7c3E/njCw9MK9dOyNedxtv+H3ckoi4TtL/SVpTHd4OpJg8ZxukyzUtjdzcL9OMLP0vdb527Y543W11hP+opAVTnn+umjYQIuJo9fukpN9p8EYfPvHRIKnV75M19/MvgzRy83QjS2sAXrtBGvG6jvC/LOkK25+3/WlJX5O0pYY+Psb2hdUbMbJ9oaQva/BGH94iaVX1eJWkJ2rs5T8MysjNjUaWVs2v3cCNeB0Rff+RtFyT7/gflPT9Onpo0Nd/S3qt+tlTd2+SHtXkYeA/NPneyDclXSRpm6S3JP1B0rwB6u1XknZJel2TQRuqqbclmjykf13Sq9XP8rpfu0JftbxufMIPSIo3/ICkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPVP82g/p9/JjhUAAAAASUVORK5CYII=\n", 598 | "text/plain": [ 599 | "
" 600 | ] 601 | }, 602 | "metadata": {}, 603 | "output_type": "display_data" 604 | }, 605 | { 606 | "name": "stdout", 607 | "output_type": "stream", 608 | "text": [ 609 | "Model prediction: 7\n" 610 | ] 611 | }, 612 | { 613 | "data": { 614 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADXZJREFUeJzt3X+IHPUZx/HPU5uAaFGT0uMwttGohSj+CKcUCaVFjVZiYkA0wT9SWnr9o0LF+ItUUChiKf1B/wpEDCba2jRcjFFL0zZUTSEJOSVGo1ETuWjCJdcQ0QSRmuTpHzvXXvXmu5uZ2Z29PO8XHLc7z+7Mw3Kfm5md3e/X3F0A4vlS3Q0AqAfhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q1Jc7uTEz4+OEQJu5u7XyuFJ7fjO70czeNrPdZvZAmXUB6Cwr+tl+MztN0juSrpe0T9I2SYvc/c3Ec9jzA23WiT3/1ZJ2u/t77v5vSX+UNL/E+gB0UJnwnyvpgzH392XL/o+Z9ZvZoJkNltgWgIq1/Q0/d18uabnEYT/QTcrs+fdLOm/M/WnZMgATQJnwb5N0kZmdb2aTJS2UtL6atgC0W+HDfnc/ZmZ3Stog6TRJK9x9Z2WdAWirwpf6Cm2Mc36g7TryIR8AExfhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBdXTobhRzzz33JOunn356bu2yyy5LPvfWW28t1NOoZcuWJeubN2/OrT355JOlto1y2PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM3tsFVq9enayXvRZfpz179uTWrrvuuuRz33///arbCYHRewEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUKW+z29mQ5KOSDou6Zi791XR1Kmmzuv4u3btStY3bNiQrF9wwQXJ+s0335ysz5gxI7d2xx13JJ/76KOPJusop4rBPL7r7ocqWA+ADuKwHwiqbPhd0l/N7BUz66+iIQCdUfawf7a77zezr0n6m5ntcveXxz4g+6fAPwagy5Ta87v7/uz3iKRnJF09zmOWu3sfbwYC3aVw+M3sDDP7yuhtSXMkvVFVYwDaq8xhf4+kZ8xsdD1/cPe/VNIVgLYrHH53f0/S5RX2MmH19aXPaBYsWFBq/Tt37kzW582bl1s7dCh9Ffbo0aPJ+uTJk5P1LVu2JOuXX57/JzJ16tTkc9FeXOoDgiL8QFCEHwiK8ANBEX4gKMIPBMUU3RXo7e1N1rPPQuRqdinvhhtuSNaHh4eT9TKWLFmSrM+cObPwul944YXCz0V57PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICiu81fgueeeS9YvvPDCZP3IkSPJ+uHDh0+6p6osXLgwWZ80aVKHOkHV2PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBc5++AvXv31t1CrnvvvTdZv/jii0utf+vWrYVqaD/2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl7+gFmKyTNlTTi7pdmy6ZIWi1puqQhSbe5+4dNN2aW3hgqN3fu3GR9zZo1yXqzKbpHRkaS9dR4AC+99FLyuSjG3dMTRWRa2fM/IenGzy17QNJGd79I0sbsPoAJpGn43f1lSZ8fSma+pJXZ7ZWSbqm4LwBtVvScv8fdR+eIOiCpp6J+AHRI6c/2u7unzuXNrF9Sf9ntAKhW0T3/QTPrlaTsd+67Pu6+3N373L2v4LYAtEHR8K+XtDi7vVjSs9W0A6BTmobfzJ6WtFnSN81sn5n9UNIvJF1vZu9Kui67D2ACaXrO7+6LckrXVtwL2qCvL3221ew6fjOrV69O1rmW3734hB8QFOEHgiL8QFCEHwiK8ANBEX4gKIbuPgWsW7cutzZnzpxS6161alWy/uCDD5ZaP+rDnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmo6dHelG2Po7kJ6e3uT9ddeey23NnXq1ORzDx06lKxfc801yfqePXuSdXRelUN3AzgFEX4gKMIPBEX4gaAIPxAU4QeCIvxAUHyffwIYGBhI1ptdy0956qmnknWu45+62PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBNr/Ob2QpJcyWNuPul2bKHJf1I0r+yhy119z+3q8lT3bx585L1WbNmFV73iy++mKw/9NBDhdeNia2VPf8Tkm4cZ/lv3f2K7IfgAxNM0/C7+8uSDnegFwAdVOac/04z22FmK8zsnMo6AtARRcO/TNIMSVdIGpb067wHmlm/mQ2a2WDBbQFog0Lhd/eD7n7c3U9IekzS1YnHLnf3PnfvK9okgOoVCr+ZjR1OdoGkN6ppB0CntHKp72lJ35H0VTPbJ+khSd8xsyskuaQhST9uY48A2qBp+N190TiLH29DL6esZt+3X7p0abI+adKkwtvevn17sn706NHC68bExif8gKAIPxAU4QeCIvxAUIQfCIrwA0ExdHcHLFmyJFm/6qqrSq1/3bp1uTW+sos87PmBoAg/EBThB4Ii/EBQhB8IivADQRF+IChz985tzKxzG+sin376abJe5iu7kjRt2rTc2vDwcKl1Y+Jxd2vlcez5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAovs9/CpgyZUpu7bPPPutgJ1/00Ucf5daa9dbs8w9nnXVWoZ4k6eyzz07W77777sLrbsXx48dza/fff3/yuZ988kklPbDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgml7nN7PzJK2S1CPJJS1399+Z2RRJqyVNlzQk6TZ3/7B9rSLPjh076m4h15o1a3JrzcYa6OnpSdZvv/32Qj11uwMHDiTrjzzySCXbaWXPf0zSEnefKelbkn5iZjMlPSBpo7tfJGljdh/ABNE0/O4+7O6vZrePSHpL0rmS5ktamT1spaRb2tUkgOqd1Dm/mU2XdKWkrZJ63H30uO2AGqcFACaIlj/bb2ZnShqQdJe7f2z2v2HC3N3zxuczs35J/WUbBVCtlvb8ZjZJjeD/3t3XZosPmllvVu+VNDLec919ubv3uXtfFQ0DqEbT8FtjF/+4pLfc/TdjSuslLc5uL5b0bPXtAWiXpkN3m9lsSZskvS7pRLZ4qRrn/X+S9HVJe9W41He4ybpCDt29du3aZH3+/Pkd6iSWY8eO5dZOnDiRW2vF+vXrk/XBwcHC6960aVOyvmXLlmS91aG7m57zu/s/JeWt7NpWNgKg+/AJPyAowg8ERfiBoAg/EBThB4Ii/EBQTNHdBe67775kvewU3imXXHJJst7Or82uWLEiWR8aGiq1/oGBgdzarl27Sq27mzFFN4Akwg8ERfiBoAg/EBThB4Ii/EBQhB8Iiuv8wCmG6/wAkgg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqKbhN7PzzOwfZvamme00s59myx82s/1mtj37uan97QKoStPBPMysV1Kvu79qZl+R9IqkWyTdJumou/+q5Y0xmAfQdq0O5vHlFlY0LGk4u33EzN6SdG659gDU7aTO+c1suqQrJW3NFt1pZjvMbIWZnZPznH4zGzSzwVKdAqhUy2P4mdmZkl6S9Ii7rzWzHkmHJLmkn6txavCDJuvgsB9os1YP+1sKv5lNkvS8pA3u/ptx6tMlPe/ulzZZD+EH2qyyATzNzCQ9LumtscHP3ggctUDSGyfbJID6tPJu/2xJmyS9LulEtnippEWSrlDjsH9I0o+zNwdT62LPD7RZpYf9VSH8QPsxbj+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQTQfwrNghSXvH3P9qtqwbdWtv3dqXRG9FVdnbN1p9YEe/z/+FjZsNuntfbQ0kdGtv3dqXRG9F1dUbh/1AUIQfCKru8C+vefsp3dpbt/Yl0VtRtfRW6zk/gPrUvecHUJNawm9mN5rZ22a228weqKOHPGY2ZGavZzMP1zrFWDYN2oiZvTFm2RQz+5uZvZv9HneatJp664qZmxMzS9f62nXbjNcdP+w3s9MkvSPpekn7JG2TtMjd3+xoIznMbEhSn7vXfk3YzL4t6aikVaOzIZnZLyUddvdfZP84z3H3+7ukt4d1kjM3t6m3vJmlv68aX7sqZ7yuQh17/qsl7Xb399z935L+KGl+DX10PXd/WdLhzy2eL2lldnulGn88HZfTW1dw92F3fzW7fUTS6MzStb52ib5qUUf4z5X0wZj7+9RdU367pL+a2Stm1l93M+PoGTMz0gFJPXU2M46mMzd30udmlu6a167IjNdV4w2/L5rt7rMkfU/ST7LD267kjXO2brpcs0zSDDWmcRuW9Os6m8lmlh6QdJe7fzy2VudrN05ftbxudYR/v6Tzxtyfli3rCu6+P/s9IukZNU5TusnB0UlSs98jNffzX+5+0N2Pu/sJSY+pxtcum1l6QNLv3X1ttrj21268vup63eoI/zZJF5nZ+WY2WdJCSetr6OMLzOyM7I0YmdkZkuao+2YfXi9pcXZ7saRna+zl/3TLzM15M0ur5teu62a8dveO/0i6SY13/PdI+lkdPeT0dYGk17KfnXX3JulpNQ4DP1PjvZEfSpoqaaOkdyX9XdKULurtSTVmc96hRtB6a+ptthqH9Dskbc9+bqr7tUv0Vcvrxif8gKB4ww8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFD/Abw9Wv8QfFP9AAAAAElFTkSuQmCC\n", 615 | "text/plain": [ 616 | "
" 617 | ] 618 | }, 619 | "metadata": {}, 620 | "output_type": "display_data" 621 | }, 622 | { 623 | "name": "stdout", 624 | "output_type": "stream", 625 | "text": [ 626 | "Model prediction: 2\n" 627 | ] 628 | }, 629 | { 630 | "data": { 631 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADCRJREFUeJzt3X/oXfV9x/Hne1n6h2n/MKvGYMV0RaclYjK+iGCYHdXiRND8I1UYkcnSPxqwsD8m7o8JYyCydgz/KKQ0NJXOZkSDWqdtJ8N0MKpRM383OvmWJsREUahVpDN574/viXzV7z33m3vPvecm7+cDLt9zz+eee94c8srn/LrnE5mJpHr+oO8CJPXD8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKuoPp7myiPB2QmnCMjOW87mxev6IuCYifhURr0XE7eN8l6TpilHv7Y+IFcAB4GrgIPAUcFNmvtSyjD2/NGHT6PkvA17LzNcz8/fAj4Hrx/g+SVM0TvjPBX6z6P3BZt7HRMTWiNgXEfvGWJekjk38hF9mbge2g7v90iwZp+c/BJy36P0XmnmSTgHjhP8p4IKI+GJEfAb4OvBQN2VJmrSRd/sz88OI2Ab8FFgB7MjMFzurTNJEjXypb6SVecwvTdxUbvKRdOoy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqmoqQ7RrXouvPDCgW2vvPJK67K33XZba/s999wzUk1aYM8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0WNdZ0/IuaBd4FjwIeZOddFUTp9bNy4cWDb8ePHW5c9ePBg1+VokS5u8vnzzHyrg++RNEXu9ktFjRv+BH4WEU9HxNYuCpI0HePu9m/KzEMRcTbw84h4JTP3Lv5A85+C/zFIM2asnj8zDzV/jwJ7gMuW+Mz2zJzzZKA0W0YOf0SsiojPnZgGvga80FVhkiZrnN3+NcCeiDjxPf+amY91UpWkiRs5/Jn5OnBph7XoNLRhw4aBbe+9917rsnv27Om6HC3ipT6pKMMvFWX4paIMv1SU4ZeKMvxSUT66W2NZv359a/u2bdsGtt17771dl6OTYM8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0V5nV9jueiii1rbV61aNbBt165dXZejk2DPLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFRWZOb2UR01uZpuLJJ59sbT/rrLMGtg17FsCwR3traZkZy/mcPb9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFTX09/wRsQO4DjiameubeauBXcA6YB64MTPfmVyZ6su6deta2+fm5lrbDxw4MLDN6/j9Wk7P/wPgmk/Mux14PDMvAB5v3ks6hQwNf2buBd7+xOzrgZ3N9E7gho7rkjRhox7zr8nMw830G8CajuqRNCVjP8MvM7Ptnv2I2ApsHXc9kro1as9/JCLWAjR/jw76YGZuz8y5zGw/MyRpqkYN/0PAlmZ6C/BgN+VImpah4Y+I+4D/Bv4kIg5GxK3AXcDVEfEqcFXzXtIpZOgxf2beNKDpqx3Xohl05ZVXjrX8m2++2VEl6pp3+ElFGX6pKMMvFWX4paIMv1SU4ZeKcohutbrkkkvGWv7uu+/uqBJ1zZ5fKsrwS0UZfqkowy8VZfilogy/VJThl4pyiO7iLr/88tb2Rx55pLV9fn6+tf2KK64Y2PbBBx+0LqvROES3pFaGXyrK8EtFGX6pKMMvFWX4paIMv1SUv+cv7qqrrmptX716dWv7Y4891trutfzZZc8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UNvc4fETuA64Cjmbm+mXcn8NfAifGX78jMf59UkZqcSy+9tLV92PMedu/e3WU5mqLl9Pw/AK5ZYv4/Z+aG5mXwpVPM0PBn5l7g7SnUImmKxjnm3xYRz0XEjog4s7OKJE3FqOH/LvAlYANwGPj2oA9GxNaI2BcR+0Zcl6QJGCn8mXkkM49l5nHge8BlLZ/dnplzmTk3apGSujdS+CNi7aK3m4EXuilH0rQs51LffcBXgM9HxEHg74GvRMQGIIF54BsTrFHSBPjc/tPcOeec09q+f//+1vZ33nmntf3iiy8+6Zo0WT63X1Irwy8VZfilogy/VJThl4oy/FJRPrr7NHfLLbe0tp999tmt7Y8++miH1WiW2PNLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe5z/NnX/++WMtP+wnvTp12fNLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe5z/NXXfddWMt//DDD3dUiWaNPb9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFTX0On9EnAf8EFgDJLA9M/8lIlYDu4B1wDxwY2b64+8ebNq0aWDbsCG6Vddyev4Pgb/JzC8DlwPfjIgvA7cDj2fmBcDjzXtJp4ih4c/Mw5n5TDP9LvAycC5wPbCz+dhO4IZJFSmpeyd1zB8R64CNwC+BNZl5uGl6g4XDAkmniGXf2x8RnwXuB76Vmb+NiI/aMjMjIgcstxXYOm6hkrq1rJ4/IlayEPwfZeYDzewjEbG2aV8LHF1q2czcnplzmTnXRcGSujE0/LHQxX8feDkzv7Oo6SFgSzO9BXiw+/IkTcpydvuvAP4SeD4i9jfz7gDuAv4tIm4Ffg3cOJkSNczmzZsHtq1YsaJ12Weffba1fe/evSPVpNk3NPyZ+V9ADGj+arflSJoW7/CTijL8UlGGXyrK8EtFGX6pKMMvFeWju08BZ5xxRmv7tddeO/J37969u7X92LFjI3+3Zps9v1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VFZlLPn1rMisb8KgvtVu5cmVr+xNPPDGw7ejRJR+w9JGbb765tf39999vbdfsycxBP8H/GHt+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/zSacbr/JJaGX6pKMMvFWX4paIMv1SU4ZeKMvxSUUPDHxHnRcR/RsRLEfFiRNzWzL8zIg5FxP7mNfrD4yVN3dCbfCJiLbA2M5+JiM8BTwM3ADcCv8vMf1r2yrzJR5q45d7kM3TEnsw8DBxupt+NiJeBc8crT1LfTuqYPyLWARuBXzaztkXEcxGxIyLOHLDM1ojYFxH7xqpUUqeWfW9/RHwWeAL4x8x8ICLWAG8BCfwDC4cGfzXkO9ztlyZsubv9ywp/RKwEfgL8NDO/s0T7OuAnmbl+yPcYfmnCOvthT0QE8H3g5cXBb04EnrAZeOFki5TUn+Wc7d8E/AJ4HjjezL4DuAnYwMJu/zzwjebkYNt32fNLE9bpbn9XDL80ef6eX1Irwy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFDH+DZsbeAXy96//lm3iya1dpmtS6wtlF1Wdv5y/3gVH/P/6mVR+zLzLneCmgxq7XNal1gbaPqqzZ3+6WiDL9UVN/h397z+tvMam2zWhdY26h6qa3XY35J/em755fUk17CHxHXRMSvIuK1iLi9jxoGiYj5iHi+GXm41yHGmmHQjkbEC4vmrY6In0fEq83fJYdJ66m2mRi5uWVk6V633ayNeD313f6IWAEcAK4GDgJPATdl5ktTLWSAiJgH5jKz92vCEfFnwO+AH54YDSki7gbezsy7mv84z8zMv52R2u7kJEdunlBtg0aWvoUet12XI153oY+e/zLgtcx8PTN/D/wYuL6HOmZeZu4F3v7E7OuBnc30Thb+8UzdgNpmQmYezsxnmul3gRMjS/e67Vrq6kUf4T8X+M2i9weZrSG/E/hZRDwdEVv7LmYJaxaNjPQGsKbPYpYwdOTmafrEyNIzs+1GGfG6a57w+7RNmfmnwF8A32x2b2dSLhyzzdLlmu8CX2JhGLfDwLf7LKYZWfp+4FuZ+dvFbX1uuyXq6mW79RH+Q8B5i95/oZk3EzLzUPP3KLCHhcOUWXLkxCCpzd+jPdfzkcw8kpnHMvM48D163HbNyNL3Az/KzAea2b1vu6Xq6mu79RH+p4ALIuKLEfEZ4OvAQz3U8SkRsao5EUNErAK+xuyNPvwQsKWZ3gI82GMtHzMrIzcPGlmanrfdzI14nZlTfwHXsnDG/3+Bv+ujhgF1/THwP83rxb5rA+5jYTfw/1g4N3Ir8EfA48CrwH8Aq2eotntZGM35ORaCtran2jaxsEv/HLC/eV3b97ZrqauX7eYdflJRnvCTijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1TU/wNRj+er2ohshAAAAABJRU5ErkJggg==\n", 632 | "text/plain": [ 633 | "
" 634 | ] 635 | }, 636 | "metadata": {}, 637 | "output_type": "display_data" 638 | }, 639 | { 640 | "name": "stdout", 641 | "output_type": "stream", 642 | "text": [ 643 | "Model prediction: 1\n" 644 | ] 645 | }, 646 | { 647 | "data": { 648 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADbdJREFUeJzt3W+MFPUdx/HPF2qfYB9ouRL8U7DFYIhJpTmxDwi2thowGvCBijGGRtNDg2KTPqiBxGKaJo22NE0kkGskPRtrbYLGCyGVlphSE9J4mPrvrv7NQSEniDQqIaYI3z7YufaU298suzM7c3zfr+Ryu/Pdnf068rmZ3d/M/szdBSCeaVU3AKAahB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBf6OaLmRmnEwIlc3dr5XEd7fnNbKmZvWFmb5vZA52sC0B3Wbvn9pvZdElvSrpW0gFJL0q6zd2HE89hzw+UrBt7/kWS3nb3d939P5L+IGl5B+sD0EWdhP9CSf+acP9AtuwzzKzPzIbMbKiD1wJQsNI/8HP3fkn9Eof9QJ10suc/KOniCfcvypYBmAI6Cf+Lki41s0vM7IuSVkoaLKYtAGVr+7Df3T81s3slPSdpuqSt7v56YZ0BKFXbQ31tvRjv+YHSdeUkHwBTF+EHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQXV1im5034wZM5L1Rx55JFlfvXp1sr53795k/eabb25a27dvX/K5KBd7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IqqNZes1sVNLHkk5K+tTde3Mezyy9XTZv3rxkfWRkpKP1T5uW3n+sXbu2aW3Tpk0dvTYm1+osvUWc5PMddz9SwHoAdBGH/UBQnYbfJe00s71m1ldEQwC6o9PD/sXuftDMviLpz2b2T3ffPfEB2R8F/jAANdPRnt/dD2a/D0t6RtKiSR7T7+69eR8GAuiutsNvZjPM7EvjtyVdJ+m1ohoDUK5ODvtnSXrGzMbX83t3/1MhXQEoXdvhd/d3JX2jwF7Qpp6enqa1gYGBLnaCqYShPiAowg8ERfiBoAg/EBThB4Ii/EBQfHX3FJC6LFaSVqxY0bS2aNFpJ1121ZIlS5rW8i4Hfvnll5P13bt3J+tIY88PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0F19NXdZ/xifHV3W06ePJmsnzp1qkudnC5vrL6T3vKm8L711luT9bzpw89WrX51N3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf4a2LFjR7K+bNmyZL3Kcf4PPvggWT927FjT2pw5c4pu5zOmT59e6vrrinF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU7vf2m9lWSTdIOuzul2fLzpf0lKS5kkYl3eLu/y6vzant6quvTtbnz5+frOeN45c5zr9ly5ZkfefOncn6hx9+2LR2zTXXJJ+7fv36ZD3PPffc07S2efPmjtZ9Nmhlz/9bSUs/t+wBSbvc/VJJu7L7AKaQ3PC7+25JRz+3eLmkgez2gKTmU8YAqKV23/PPcvex7PZ7kmYV1A+ALul4rj5399Q5+2bWJ6mv09cBUKx29/yHzGy2JGW/Dzd7oLv3u3uvu/e2+VoAStBu+Aclrcpur5L0bDHtAOiW3PCb2ZOS9kiab2YHzOwuST+XdK2ZvSXpe9l9AFMI1/MXYO7cucn6nj17kvWZM2cm6518N37ed99v27YtWX/ooYeS9ePHjyfrKXnX8+dtt56enmT9k08+aVp78MEHk8999NFHk/UTJ04k61Xien4ASYQfCIrwA0ERfiAowg8ERfiBoBjqK8C8efOS9ZGRkY7WnzfU9/zzzzetrVy5MvncI0eOtNVTN9x3333J+saNG5P11HbLuwz6sssuS9bfeeedZL1KDPUBSCL8QFCEHwiK8ANBEX4gKMIPBEX4gaA6/hovlG9oaChZv/POO5vW6jyOn2dwcDBZv/3225P1K6+8ssh2zjrs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMb5uyDvevw8V111VUGdTC1m6cvS87ZrJ9t9w4YNyfodd9zR9rrrgj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwSVO85vZlsl3SDpsLtfni3bIOkHkt7PHrbO3XeU1WTd3X333cl63nfEY3I33nhjsr5w4cJkPbXd8/6f5I3znw1a2fP/VtLSSZb/yt2vyH7CBh+YqnLD7+67JR3tQi8AuqiT9/z3mtkrZrbVzM4rrCMAXdFu+DdL+rqkKySNSfplsweaWZ+ZDZlZ+ovoAHRVW+F390PuftLdT0n6jaRFicf2u3uvu/e22ySA4rUVfjObPeHuTZJeK6YdAN3SylDfk5K+LWmmmR2Q9BNJ3zazKyS5pFFJq0vsEUAJcsPv7rdNsvixEnqZsvLGoyPr6elpWluwYEHyuevWrSu6nf95//33k/UTJ06U9tp1wRl+QFCEHwiK8ANBEX4gKMIPBEX4gaD46m6Uav369U1ra9asKfW1R0dHm9ZWrVqVfO7+/fsL7qZ+2PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM86MjO3akv7h5/vz5XerkdMPDw01rL7zwQhc7qSf2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8BTCzZH3atM7+xi5btqzt5/b39yfrF1xwQdvrlvL/26qcnpyvVE9jzw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQeWO85vZxZIelzRLkkvqd/dfm9n5kp6SNFfSqKRb3P3f5bVaX5s3b07WH3744Y7Wv3379mS9k7H0ssfhy1z/li1bSlt3BK3s+T+V9CN3XyDpW5LWmNkCSQ9I2uXul0rald0HMEXkht/dx9z9pez2x5JGJF0oabmkgexhA5JWlNUkgOKd0Xt+M5sraaGkv0ua5e5jWek9Nd4WAJgiWj6338zOlbRN0g/d/aOJ57O7u5uZN3len6S+ThsFUKyW9vxmdo4awX/C3Z/OFh8ys9lZfbakw5M919373b3X3XuLaBhAMXLDb41d/GOSRtx944TSoKTxqU5XSXq2+PYAlMXcJz1a//8DzBZL+pukVyWNj9usU+N9/x8lfVXSPjWG+o7mrCv9YlPUnDlzkvU9e/Yk6z09Pcl6nS+bzevt0KFDTWsjIyPJ5/b1pd8tjo2NJevHjx9P1s9W7p6+xjyT+57f3V+Q1Gxl3z2TpgDUB2f4AUERfiAowg8ERfiBoAg/EBThB4LKHecv9MXO0nH+PEuWLEnWV6xIXxN1//33J+t1Hudfu3Zt09qmTZuKbgdqfZyfPT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4/xSwdOnSZD113XveNNWDg4PJet4U33nTkw8PDzet7d+/P/lctIdxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8wFmGcX4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EFRu+M3sYjN73syGzex1M7s/W77BzA6a2T+yn+vLbxdAUXJP8jGz2ZJmu/tLZvYlSXslrZB0i6Rj7v6Lll+Mk3yA0rV6ks8XWljRmKSx7PbHZjYi6cLO2gNQtTN6z29mcyUtlPT3bNG9ZvaKmW01s/OaPKfPzIbMbKijTgEUquVz+83sXEl/lfQzd3/azGZJOiLJJf1UjbcGd+asg8N+oGStHva3FH4zO0fSdknPufvGSepzJW1398tz1kP4gZIVdmGPNb6e9TFJIxODn30QOO4mSa+daZMAqtPKp/2LJf1N0quSxueCXifpNklXqHHYPyppdfbhYGpd7PmBkhV62F8Uwg+Uj+v5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgsr9As+CHZG0b8L9mdmyOqprb3XtS6K3dhXZ25xWH9jV6/lPe3GzIXfvrayBhLr2Vte+JHprV1W9cdgPBEX4gaCqDn9/xa+fUtfe6tqXRG/tqqS3St/zA6hO1Xt+ABWpJPxmttTM3jCzt83sgSp6aMbMRs3s1Wzm4UqnGMumQTtsZq9NWHa+mf3ZzN7Kfk86TVpFvdVi5ubEzNKVbru6zXjd9cN+M5su6U1J10o6IOlFSbe5+3BXG2nCzEYl9bp75WPCZrZE0jFJj4/PhmRmD0s66u4/z/5wnufuP65Jbxt0hjM3l9Rbs5mlv68Kt12RM14XoYo9/yJJb7v7u+7+H0l/kLS8gj5qz913Szr6ucXLJQ1ktwfU+MfTdU16qwV3H3P3l7LbH0san1m60m2X6KsSVYT/Qkn/mnD/gOo15bdL2mlme82sr+pmJjFrwsxI70maVWUzk8idubmbPjezdG22XTszXheND/xOt9jdvylpmaQ12eFtLXnjPVudhms2S/q6GtO4jUn6ZZXNZDNLb5P0Q3f/aGKtym03SV+VbLcqwn9Q0sUT7l+ULasFdz+Y/T4s6Rk13qbUyaHxSVKz34cr7ud/3P2Qu59091OSfqMKt102s/Q2SU+4+9PZ4sq33WR9VbXdqgj/i5IuNbNLzOyLklZKGqygj9OY2YzsgxiZ2QxJ16l+sw8PSlqV3V4l6dkKe/mMuszc3GxmaVW87Wo347W7d/1H0vVqfOL/jqT1VfTQpK+vSXo5+3m96t4kPanGYeAJNT4buUvSlyXtkvSWpL9IOr9Gvf1OjdmcX1EjaLMr6m2xGof0r0j6R/ZzfdXbLtFXJduNM/yAoPjADwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUP8FAfaK+yOWZZUAAAAASUVORK5CYII=\n", 649 | "text/plain": [ 650 | "
" 651 | ] 652 | }, 653 | "metadata": {}, 654 | "output_type": "display_data" 655 | }, 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "Model prediction: 0\n" 661 | ] 662 | }, 663 | { 664 | "data": { 665 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADXVJREFUeJzt3W+oXPWdx/HPZ00bMQ2Su8FwScPeGmUlBDfViygb1krXmI2VWPxDQliyKr19UGGL+2BFhRV1QWSbpU8MpBgal27aRSOGWvpnQ1xXWEpuJKvRu60xpCQh5o9paCKBau53H9wTuSZ3ztzMnJkzc7/vF1zuzPmeM/PlJJ/7O2fOzPwcEQKQz5/U3QCAehB+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJzermk9nm7YRAh0WEp7NeWyO/7ZW2f2N7n+1H23ksAN3lVt/bb/sySb+VdLukQ5J2SVobEe+VbMPID3RYN0b+myTti4j9EfFHST+WtLqNxwPQRe2Ef6Gkg5PuHyqWfY7tEdujtkfbeC4AFev4C34RsUnSJonDfqCXtDPyH5a0aNL9LxfLAPSBdsK/S9K1tr9i+4uS1kjaXk1bADqt5cP+iPjU9sOSfiHpMkmbI+LdyjoD0FEtX+pr6ck45wc6ritv8gHQvwg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IquUpuiXJ9gFJpyWdk/RpRAxX0RQ+74Ybbiitb9u2rWFtaGio4m56x4oVK0rrY2NjDWsHDx6sup2+01b4C7dFxIkKHgdAF3HYDyTVbvhD0i9t77Y9UkVDALqj3cP+5RFx2PZVkn5l+/8i4o3JKxR/FPjDAPSYtkb+iDhc/D4m6RVJN02xzqaIGObFQKC3tBx+23Nszz1/W9IKSXuragxAZ7Vz2L9A0iu2zz/Ov0fEzyvpCkDHtRz+iNgv6S8q7AUN3HHHHaX12bNnd6mT3nLXXXeV1h988MGGtTVr1lTdTt/hUh+QFOEHkiL8QFKEH0iK8ANJEX4gqSo+1Yc2zZpV/s+watWqLnXSX3bv3l1af+SRRxrW5syZU7rtxx9/3FJP/YSRH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jp/D7jttttK67fccktp/bnnnquynb4xb9680vqSJUsa1q644orSbbnOD2DGIvxAUoQfSIrwA0kRfiApwg8kRfiBpBwR3Xsyu3tP1kOWLl1aWn/99ddL6x999FFp/cYbb2xYO3PmTOm2/azZflu+fHnD2uDgYOm2x48fb6WlnhARns56jPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTTz/Pb3izpG5KORcTSYtmApJ9IGpJ0QNL9EfH7zrXZ35544onSerPvkF+5cmVpfaZeyx8YGCit33rrraX18fHxKtuZcaYz8v9Q0oX/+x6VtCMirpW0o7gPoI80DX9EvCHp5AWLV0vaUtzeIunuivsC0GGtnvMviIgjxe0PJS2oqB8AXdL2d/hFRJS9Z9/2iKSRdp8HQLVaHfmP2h6UpOL3sUYrRsSmiBiOiOEWnwtAB7Qa/u2S1he310t6tZp2AHRL0/Db3irpfyT9ue1Dth+S9Kyk222/L+mvi/sA+kjTc/6IWNug9PWKe+lb9957b2l91apVpfV9+/aV1kdHRy+5p5ng8ccfL603u45f9nn/U6dOtdLSjMI7/ICkCD+QFOEHkiL8QFKEH0iK8ANJMUV3Be67777SerPpoJ9//vkq2+kbQ0NDpfV169aV1s+dO1daf+aZZxrWPvnkk9JtM2DkB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuM4/TVdeeWXD2s0339zWY2/cuLGt7fvVyEj5t7vNnz+/tD42NlZa37lz5yX3lAkjP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kxXX+aZo9e3bD2sKFC0u33bp1a9XtzAiLFy9ua/u9e/dW1ElOjPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTT6/y2N0v6hqRjEbG0WPakpG9JOl6s9lhE/KxTTfaC06dPN6zt2bOndNvrr7++tD4wMFBaP3nyZGm9l1111VUNa82mNm/mzTffbGv77KYz8v9Q0soplv9rRCwrfmZ08IGZqGn4I+INSf079ACYUjvn/A/bftv2ZtvzKusIQFe0Gv6NkhZLWibpiKTvNVrR9ojtUdujLT4XgA5oKfwRcTQizkXEuKQfSLqpZN1NETEcEcOtNgmgei2F3/bgpLvflMTHq4A+M51LfVslfU3SfNuHJP2TpK/ZXiYpJB2Q9O0O9gigA5qGPyLWTrH4hQ700tPOnj3bsPbBBx+UbnvPPfeU1l977bXS+oYNG0rrnbR06dLS+tVXX11aHxoaaliLiFZa+sz4+Hhb22fHO/yApAg/kBThB5Ii/EBShB9IivADSbndyy2X9GR2956si6677rrS+lNPPVVav/POO0vrZV8b3mknTpworTf7/1M2zbbtlno6b+7cuaX1ssuzM1lETGvHMvIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJc5+8By5YtK61fc801XerkYi+99FJb22/ZsqVhbd26dW099qxZzDA/Fa7zAyhF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcaG0BzSb4rtZvZft37+/Y4/d7GvF9+5lLpkyjPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTT6/y2F0l6UdICSSFpU0R83/aApJ9IGpJ0QNL9EfH7zrWKflT23fztfm8/1/HbM52R/1NJ/xARSyTdLOk7tpdIelTSjoi4VtKO4j6APtE0/BFxJCLeKm6fljQmaaGk1ZLOf03LFkl3d6pJANW7pHN+20OSvirp15IWRMSRovShJk4LAPSJab+33/aXJL0s6bsR8YfJ52sREY2+n8/2iKSRdhsFUK1pjfy2v6CJ4P8oIrYVi4/aHizqg5KOTbVtRGyKiOGIGK6iYQDVaBp+TwzxL0gai4gNk0rbJa0vbq+X9Gr17QHolOkc9v+lpL+V9I7t858tfUzSs5L+w/ZDkn4n6f7OtIh+VvbV8N382nhcrGn4I+JNSY0uyH692nYAdAvv8AOSIvxAUoQfSIrwA0kRfiApwg8kxVd3o6Muv/zylrc9e/ZshZ3gQoz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU1/nRUQ888EDD2qlTp0q3ffrpp6tuB5Mw8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUlznR0ft2rWrYW3Dhg0Na5K0c+fOqtvBJIz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5CUm82RbnuRpBclLZAUkjZFxPdtPynpW5KOF6s+FhE/a/JYTMgOdFhEeDrrTSf8g5IGI+It23Ml7ZZ0t6T7JZ2JiH+ZblOEH+i86Ya/6Tv8IuKIpCPF7dO2xyQtbK89AHW7pHN+20OSvirp18Wih22/bXuz7XkNthmxPWp7tK1OAVSq6WH/ZyvaX5L0X5L+OSK22V4g6YQmXgd4WhOnBg82eQwO+4EOq+ycX5Jsf0HSTyX9IiIu+jRGcUTw04hY2uRxCD/QYdMNf9PDftuW9IKkscnBL14IPO+bkvZeapMA6jOdV/uXS/pvSe9IGi8WPyZpraRlmjjsPyDp28WLg2WPxcgPdFilh/1VIfxA51V22A9gZiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1e0puk9I+t2k+/OLZb2oV3vr1b4kemtVlb392XRX7Orn+S96cns0IoZra6BEr/bWq31J9NaqunrjsB9IivADSdUd/k01P3+ZXu2tV/uS6K1VtfRW6zk/gPrUPfIDqEkt4be90vZvbO+z/WgdPTRi+4Dtd2zvqXuKsWIatGO2905aNmD7V7bfL35POU1aTb09aftwse/22F5VU2+LbO+0/Z7td23/fbG81n1X0lct+63rh/22L5P0W0m3SzokaZektRHxXlcbacD2AUnDEVH7NWHbfyXpjKQXz8+GZPs5SScj4tniD+e8iPjHHuntSV3izM0d6q3RzNJ/pxr3XZUzXlehjpH/Jkn7ImJ/RPxR0o8lra6hj54XEW9IOnnB4tWSthS3t2jiP0/XNeitJ0TEkYh4q7h9WtL5maVr3XclfdWijvAvlHRw0v1D6q0pv0PSL23vtj1SdzNTWDBpZqQPJS2os5kpNJ25uZsumFm6Z/ZdKzNeV40X/C62PCJukPQ3kr5THN72pJg4Z+ulyzUbJS3WxDRuRyR9r85mipmlX5b03Yj4w+Ranftuir5q2W91hP+wpEWT7n+5WNYTIuJw8fuYpFc0cZrSS46enyS1+H2s5n4+ExFHI+JcRIxL+oFq3HfFzNIvS/pRRGwrFte+76bqq679Vkf4d0m61vZXbH9R0hpJ22vo4yK25xQvxMj2HEkr1HuzD2+XtL64vV7SqzX28jm9MnNzo5mlVfO+67kZryOi6z+SVmniFf8PJD1eRw8N+rpa0v8WP+/W3ZukrZo4DPxEE6+NPCTpTyXtkPS+pP+UNNBDvf2bJmZzflsTQRusqbflmjikf1vSnuJnVd37rqSvWvYb7/ADkuIFPyApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSf0/fhI1ni26LDgAAAAASUVORK5CYII=\n", 666 | "text/plain": [ 667 | "
" 668 | ] 669 | }, 670 | "metadata": {}, 671 | "output_type": "display_data" 672 | }, 673 | { 674 | "name": "stdout", 675 | "output_type": "stream", 676 | "text": [ 677 | "Model prediction: 4\n" 678 | ] 679 | }, 680 | { 681 | "data": { 682 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADGdJREFUeJzt3X/oXfV9x/HnW5cK2v5hUhaCCUsXZFAU7PiqIwvSsVmdVGJRpP4xMiZN/2hghf0xMX9MGAMZa0f+iqYYGqVLO/BXKGVNFoauMkoSyTTqWrOS2ISYNPijFgwxyXt/fE/ct/q95369v8795v18wJd77/mce86bQ175nB/3nE9kJpLquazrAiR1w/BLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrqdya5sojw54TSmGVmLGS+oXr+iLg9In4WEYcj4oFhliVpsmLQ3/ZHxOXAz4FbgWPAPuC+zHy15Tv2/NKYTaLnvwk4nJm/yMyzwPeB9UMsT9IEDRP+a4Bfzvl8rJn2WyJiY0Tsj4j9Q6xL0oiN/YRfZm4DtoG7/dI0GabnPw6smvN5ZTNN0iIwTPj3AddGxOci4lPAV4FdoylL0rgNvNufmeciYhPwY+ByYHtmvjKyyiSN1cCX+gZamcf80thN5Ec+khYvwy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oaeIhugIg4ArwHnAfOZebMKIrSpePOO+/s2bZr167W727atKm1/ZFHHmltP3/+fGt7dUOFv/EnmXl6BMuRNEHu9ktFDRv+BHZHxIGI2DiKgiRNxrC7/esy83hE/C6wJyL+JzOfnztD85+C/zFIU2aonj8zjzevp4CngZvmmWdbZs54MlCaLgOHPyKuiojPXHwPfAk4NKrCJI3XMLv9y4GnI+Licv4lM/9tJFVJGrvIzMmtLGJyK9NELFu2rLX94MGDPdtWrlw51LqvvPLK1vb3339/qOUvVpkZC5nPS31SUYZfKsrwS0UZfqkowy8VZfilokZxV58Ku+WWW1rbh7mct3Pnztb2M2fODLxs2fNLZRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe51erK664orV98+bNY1v3E0880do+ydvRL0X2/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlI/uVquZmfaBlvbt2zfwss+dO9favmTJkoGXXZmP7pbUyvBLRRl+qSjDLxVl+KWiDL9UlOGXiup7P39EbAe+DJzKzOuaaUuBHwCrgSPAvZn59vjKVFfuvvvusS179+7dY1u2+ltIz/9d4PaPTHsA2JuZ1wJ7m8+SFpG+4c/M54G3PjJ5PbCjeb8DuGvEdUkas0GP+Zdn5onm/ZvA8hHVI2lChn6GX2Zm22/2I2IjsHHY9UgarUF7/pMRsQKgeT3Va8bM3JaZM5nZfoeIpIkaNPy7gA3N+w3As6MpR9Kk9A1/ROwE/gv4g4g4FhH3Aw8Dt0bE68CfNZ8lLSLez69WL7zwQmv72rVrW9vPnj3bs+3mm29u/e7Bgwdb2zU/7+eX1MrwS0UZfqkowy8VZfilogy/VJSX+orrd6mu36W+ft5+u/ed3kuXLh1q2Zqfl/oktTL8UlGGXyrK8EtFGX6pKMMvFWX4paKGfoyXFrcbb7xxrMvfunXrWJevwdnzS0UZfqkowy8VZfilogy/VJThl4oy/FJRXucvbmZmuIGU3nnnndZ2r/NPL3t+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyqq73P7I2I78GXgVGZe10x7CPga8Ktmtgcz80d9V+Zz+ydu3bp1re3PPfdca/tll7X3D0ePHm1tX716dWu7Rm+Uz+3/LnD7PNP/OTNvaP76Bl/SdOkb/sx8HnhrArVImqBhjvk3RcRLEbE9Iq4eWUWSJmLQ8G8F1gA3ACeAb/WaMSI2RsT+iNg/4LokjcFA4c/Mk5l5PjMvAN8BbmqZd1tmzmTmcHeQSBqpgcIfESvmfPwKcGg05UialL639EbETuCLwGcj4hjwd8AXI+IGIIEjwNfHWKOkMegb/sy8b57Jj42hFo3BsmXLWtv7XcfvZ8+ePUN9X93xF35SUYZfKsrwS0UZfqkowy8VZfilonx09yXunnvuGer7/R7N/eijjw61fHXHnl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXiur76O6RrsxHd4/FypUre7b1e7R2v1t6Dx1qf07L9ddf39quyRvlo7slXYIMv1SU4ZeKMvxSUYZfKsrwS0UZfqko7+e/BKxdu7Zn27CP5n7mmWeG+r6mlz2/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxXV9zp/RKwCHgeWAwlsy8wtEbEU+AGwGjgC3JuZb4+vVPXSbxjuNqdPn25t37Jly8DL1nRbSM9/DvibzPw88EfANyLi88ADwN7MvBbY23yWtEj0DX9mnsjMF5v37wGvAdcA64EdzWw7gLvGVaSk0ftEx/wRsRr4AvBTYHlmnmia3mT2sEDSIrHg3/ZHxKeBJ4FvZuavI/7/MWGZmb2ezxcRG4GNwxYqabQW1PNHxBJmg/+9zHyqmXwyIlY07SuAU/N9NzO3ZeZMZs6MomBJo9E3/DHbxT8GvJaZ357TtAvY0LzfADw7+vIkjctCdvv/GPgL4OWIONhMexB4GPjXiLgfOArcO54S1c9tt9028HffeOON1vZ333134GVruvUNf2b+BOj1HPA/HW05kibFX/hJRRl+qSjDLxVl+KWiDL9UlOGXivLR3YvAkiVLWtvXrFkz8LLPnDnT2v7BBx8MvGxNN3t+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/yLwIULF1rb9+/f37Ptuuuua/3u4cOHB6pJi589v1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8V5XX+ReD8+fOt7Zs3b+7ZljnvKGofOnDgwEA1afGz55eKMvxSUYZfKsrwS0UZfqkowy8VZfiloqLfdeCIWAU8DiwHEtiWmVsi4iHga8CvmlkfzMwf9VlW+8okDS0zYyHzLST8K4AVmfliRHwGOADcBdwL/CYz/2mhRRl+afwWGv6+v/DLzBPAieb9exHxGnDNcOVJ6tonOuaPiNXAF4CfNpM2RcRLEbE9Iq7u8Z2NEbE/Ino/a0rSxPXd7f9wxohPA88B/5CZT0XEcuA0s+cB/p7ZQ4O/6rMMd/ulMRvZMT9ARCwBfgj8ODO/PU/7auCHmdn6tEjDL43fQsPfd7c/IgJ4DHhtbvCbE4EXfQU49EmLlNSdhZztXwf8J/AycPEZ0g8C9wE3MLvbfwT4enNysG1Z9vzSmI10t39UDL80fiPb7Zd0aTL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VNekhuk8DR+d8/mwzbRpNa23TWhdY26BGWdvvLXTGid7P/7GVR+zPzJnOCmgxrbVNa11gbYPqqjZ3+6WiDL9UVNfh39bx+ttMa23TWhdY26A6qa3TY35J3em655fUkU7CHxG3R8TPIuJwRDzQRQ29RMSRiHg5Ig52PcRYMwzaqYg4NGfa0ojYExGvN6/zDpPWUW0PRcTxZtsdjIg7OqptVUT8R0S8GhGvRMRfN9M73XYtdXWy3Sa+2x8RlwM/B24FjgH7gPsy89WJFtJDRBwBZjKz82vCEXEL8Bvg8YujIUXEPwJvZebDzX+cV2fm305JbQ/xCUduHlNtvUaW/ks63HajHPF6FLro+W8CDmfmLzLzLPB9YH0HdUy9zHweeOsjk9cDO5r3O5j9xzNxPWqbCpl5IjNfbN6/B1wcWbrTbddSVye6CP81wC/nfD7GdA35ncDuiDgQERu7LmYey+eMjPQmsLzLYubRd+TmSfrIyNJTs+0GGfF61Dzh93HrMvMPgT8HvtHs3k6lnD1mm6bLNVuBNcwO43YC+FaXxTQjSz8JfDMzfz23rcttN09dnWy3LsJ/HFg15/PKZtpUyMzjzesp4GlmD1OmycmLg6Q2r6c6rudDmXkyM89n5gXgO3S47ZqRpZ8EvpeZTzWTO99289XV1XbrIvz7gGsj4nMR8Sngq8CuDur4mIi4qjkRQ0RcBXyJ6Rt9eBewoXm/AXi2w1p+y7SM3NxrZGk63nZTN+J1Zk78D7iD2TP+/wts7qKGHnX9PvDfzd8rXdcG7GR2N/ADZs+N3A8sA/YCrwP/DiydotqeYHY055eYDdqKjmpbx+wu/UvAwebvjq63XUtdnWw3f+EnFeUJP6kowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRf0f7V4JFFPw3M8AAAAASUVORK5CYII=\n", 683 | "text/plain": [ 684 | "
" 685 | ] 686 | }, 687 | "metadata": {}, 688 | "output_type": "display_data" 689 | }, 690 | { 691 | "name": "stdout", 692 | "output_type": "stream", 693 | "text": [ 694 | "Model prediction: 1\n" 695 | ] 696 | }, 697 | { 698 | "data": { 699 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADbxJREFUeJzt3X+MFPUZx/HPU6uJEWKOoicqKWhME38Vy8WYFIXGiqhN0BiNROsZiYfxR6ppDIYaazRNTFNs/EeSMxDOH1X8hRL8hZKmtKExAjnA06onOQU8OVSM518oPP1jh/bE2+8uu7M7ezzvV3K53Xl2Zp4MfG5md2b2a+4uAPH8qOgGABSD8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCOrHzVyZmXE5IdBg7m7VvK6uPb+ZzTGz982s38zurmdZAJrLar2238yOkPSBpIsk7ZD0tqR57v5uYh72/ECDNWPPf66kfnff5u57JT0taW4dywPQRPWE/yRJ20c835FN+x4z6zKzDWa2oY51AchZwz/wc/duSd0Sh/1AK6lnz79T0uQRz0/OpgEYA+oJ/9uSTjOzqWZ2lKRrJK3Kpy0AjVbzYb+7f2dmt0l6XdIRkpa5e19unQFoqJpP9dW0Mt7zAw3XlIt8AIxdhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRV8xDdkmRmA5KGJe2T9J27d+TRVDTHH398sv7MM88k6+vXry9b6+7uTs47MDCQrB+ujj322GT9ggsuSNZfe+21ZP3bb7895J6ara7wZ37l7p/nsBwATcRhPxBUveF3SWvMbKOZdeXREIDmqPewf4a77zSz4yW9YWb/cfd1I1+Q/VHgDwPQYura87v7zuz3kKSVks4d5TXd7t7Bh4FAa6k5/GZ2jJmNP/BY0mxJ7+TVGIDGquewv13SSjM7sJy/uXv6/AeAllFz+N19m6Sf59jLYautrS1Z7+vrS9YrnZPetWtX2VrU8/hSertt3LgxOe9xxx2XrE+fPj1Z7+/vT9ZbAaf6gKAIPxAU4QeCIvxAUIQfCIrwA0HlcVdfeBMnTkzWV6xYkaxPmDAhWX/kkUeS9dtvvz1Zj+qee+4pW5s6dWpy3gULFiTrY+FUXiXs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35q3MrHkra6LZs2cn66+++mpdyz/hhBOS9d27d9e1/LHqjDPOSNa3bt1atrZy5crkvDfccEOyPjw8nKwXyd2tmtex5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoLifv0qpYbSvvPLKupY9f/78ZJ3z+KN78803a152pfP8rXwePy/s+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrn+c1smaTfSBpy9zOzaRMkrZA0RdKApKvdfU/j2ize4sWLy9auu+665LyVhoN+9tlna+rpcHf++ecn6+3t7cn68uXLy9aeeOKJWlo6rFSz518uac5B0+6WtNbdT5O0NnsOYAypGH53Xyfpy4Mmz5XUkz3ukXR5zn0BaLBa3/O3u/tg9vgzSenjLwAtp+5r+93dU9/NZ2ZdkrrqXQ+AfNW6599lZpMkKfs9VO6F7t7t7h3u3lHjugA0QK3hXyWpM3vcKemlfNoB0CwVw29mT0n6t6SfmdkOM5sv6UFJF5nZh5J+nT0HMIZUfM/v7vPKlC7MuZeWlhrfYP/+/cl5P/3002R97969NfU0Fhx99NFla4sWLUrOe8sttyTrlcacuPHGG5P16LjCDwiK8ANBEX4gKMIPBEX4gaAIPxAUX93dBJdddlmyvmbNmmT9q6++StaXLFlyyD3lZebMmcn6rFmzytbOO++8utb93HPP1TV/dOz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoq3RbZK4rS3zdV6ubPn162dqLL76YnPfEE0+sa91mlqw389/wYI3sbdu2bcn6nDkHf6n093300Uc1r3ssc/f0P0qGPT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMX9/FVKDbN99tlnJ+edNm1asl7pfPVdd92VrO/evbtsraenp2wtD48//niyvnnz5pqXvX79+mQ96nn8vLDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKt7Pb2bLJP1G0pC7n5lNu0/STZIOnGBe5O6vVFzZGL6fH6M75ZRTkvX+/v6ytd7e3uS8F198cbKeur4hsjzv518uabSrUP7q7tOyn4rBB9BaKobf3ddJ+rIJvQBoonre899mZlvMbJmZteXWEYCmqDX8SySdKmmapEFJi8u90My6zGyDmW2ocV0AGqCm8Lv7Lnff5+77JT0q6dzEa7vdvcPdO2ptEkD+agq/mU0a8fQKSe/k0w6AZql4S6+ZPSVplqSJZrZD0h8lzTKzaZJc0oCkBQ3sEUADVAy/u88bZfLSBvSCMejee+9N1lPXkSxcuDA5L+fxG4sr/ICgCD8QFOEHgiL8QFCEHwiK8ANB8dXdSLrqqquS9euvvz5ZHx4eLlv74osvauoJ+WDPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZ4fSZdcckld869evbpsbdOmTXUtG/Vhzw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQVUcojvXlTFE95gzODiYrI8bNy5ZnzlzZtka5/kbI88hugEchgg/EBThB4Ii/EBQhB8IivADQRF+IKiK9/Ob2WRJj0lql+SSut39YTObIGmFpCmSBiRd7e57GtcqGuHmm29O1tvb25P1oaGhZJ1z+a2rmj3/d5J+7+6nSzpP0q1mdrqkuyWtdffTJK3NngMYIyqG390H3X1T9nhY0nuSTpI0V1JP9rIeSZc3qkkA+Tuk9/xmNkXSOZLektTu7geu/fxMpbcFAMaIqr/Dz8zGSXpe0h3u/rXZ/y8fdncvd92+mXVJ6qq3UQD5qmrPb2ZHqhT8J939hWzyLjOblNUnSRr1kx9373b3DnfvyKNhAPmoGH4r7eKXSnrP3R8aUVolqTN73CnppfzbA9Ao1Rz2/1LSbyVtNbPebNoiSQ9KesbM5kv6WNLVjWkRjVTpVF+lW75ffvnlmtc9fvz4ZL2trS1Z/+STT2peN6oIv7v/S1K5+4MvzLcdAM3CFX5AUIQfCIrwA0ERfiAowg8ERfiBoBiiG3XZt29fsn7ttdeWrd15553Jefv6+pL1zs7OZB1p7PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICiG6A6ut7c3WT/rrLOS9ZFf5zaa1P+vpUuXJud94IEHkvXt27cn61ExRDeAJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrz/MHNmDEjWb///vuT9XXr1iXrS5YsKVvbsyc9ovvevXuTdYyO8/wAkgg/EBThB4Ii/EBQhB8IivADQRF+IKiK5/nNbLKkxyS1S3JJ3e7+sJndJ+kmSbuzly5y91cqLIvz/ECDVXuev5rwT5I0yd03mdl4SRslXS7paknfuPtfqm2K8AONV234K47Y4+6Dkgazx8Nm9p6kk+prD0DRDuk9v5lNkXSOpLeySbeZ2RYzW2ZmbWXm6TKzDWa2oa5OAeSq6mv7zWycpH9I+pO7v2Bm7ZI+V+lzgAdUemtwY4VlcNgPNFhu7/klycyOlLRa0uvu/tAo9SmSVrv7mRWWQ/iBBsvtxh4rfT3rUknvjQx+9kHgAVdIeudQmwRQnGo+7Z8h6Z+Stkran01eJGmepGkqHfYPSFqQfTiYWhZ7fqDBcj3szwvhBxqP+/kBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqvgFnjn7XNLHI55PzKa1olbtrVX7kuitVnn29tNqX9jU+/l/sHKzDe7eUVgDCa3aW6v2JdFbrYrqjcN+ICjCDwRVdPi7C15/Sqv21qp9SfRWq0J6K/Q9P4DiFL3nB1CQQsJvZnPM7H0z6zezu4vooRwzGzCzrWbWW/QQY9kwaENm9s6IaRPM7A0z+zD7PeowaQX1dp+Z7cy2Xa+ZXVpQb5PN7O9m9q6Z9ZnZ77LphW67RF+FbLemH/ab2RGSPpB0kaQdkt6WNM/d321qI2WY2YCkDncv/JywmV0g6RtJjx0YDcnM/izpS3d/MPvD2ebuC1ukt/t0iCM3N6i3ciNL36ACt12eI17noYg9/7mS+t19m7vvlfS0pLkF9NHy3H2dpC8PmjxXUk/2uEel/zxNV6a3luDug+6+KXs8LOnAyNKFbrtEX4UoIvwnSdo+4vkOtdaQ3y5pjZltNLOuopsZRfuIkZE+k9ReZDOjqDhyczMdNLJ0y2y7Wka8zhsf+P3QDHf/haRLJN2aHd62JC+9Z2ul0zVLJJ2q0jBug5IWF9lMNrL085LucPevR9aK3Haj9FXIdisi/DslTR7x/ORsWktw953Z7yFJK1V6m9JKdh0YJDX7PVRwP//j7rvcfZ+775f0qArcdtnI0s9LetLdX8gmF77tRuurqO1WRPjflnSamU01s6MkXSNpVQF9/ICZHZN9ECMzO0bSbLXe6MOrJHVmjzslvVRgL9/TKiM3lxtZWgVvu5Yb8drdm/4j6VKVPvH/SNIfiuihTF+nSNqc/fQV3Zukp1Q6DPxWpc9G5kv6iaS1kj6U9KakCS3U2+Mqjea8RaWgTSqotxkqHdJvkdSb/Vxa9LZL9FXIduMKPyAoPvADgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUfwEJEYHZ+iI4owAAAABJRU5ErkJggg==\n", 700 | "text/plain": [ 701 | "
" 702 | ] 703 | }, 704 | "metadata": {}, 705 | "output_type": "display_data" 706 | }, 707 | { 708 | "name": "stdout", 709 | "output_type": "stream", 710 | "text": [ 711 | "Model prediction: 4\n" 712 | ] 713 | }, 714 | { 715 | "data": { 716 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADa9JREFUeJzt3XuMVOUZx/Hfo1YxggasRSJ4AUm1wWRpVq0JqTZi4y0iiReQGJoYVhMwNeEPCU0smnhJbYuGP0yWiKLilkZRiGlalDSRmtqIt0WxRWyWCAHWSrUSJXh5+scc2lV33rPMnJlzluf7STY7c545c57M8uOcmXfOec3dBSCeI8puAEA5CD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCOaufGzIyvEwIt5u42lMc1tec3s0vN7B9mts3MFjXzXADayxr9br+ZHSlpq6RLJO2Q9Iqk2e6+JbEOe36gxdqx5z9P0jZ3/6e7H5D0O0kzmng+AG3UTPhPkfT+gPs7smVfY2ZdZrbJzDY1sS0ABWv5B37u3i2pW+KwH6iSZvb8OyVNGHB/fLYMwDDQTPhfkTTZzM4ws6MlzZK0rpi2ALRaw4f97v6FmS2Q9CdJR0pa4e5vF9YZgJZqeKivoY3xnh9oubZ8yQfA8EX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUA1P0S1JZtYn6RNJX0r6wt07i2gKQOs1Ff7MT9z9XwU8D4A24rAfCKrZ8Luk9Wb2qpl1FdEQgPZo9rB/mrvvNLPvSXrezP7u7i8OfED2nwL/MQAVY+5ezBOZLZG0z91/nXhMMRsDUJe721Ae1/Bhv5kdZ2ajDt6W9FNJbzX6fADaq5nD/rGSnjGzg8/zpLv/sZCuALRcYYf9Q9oYh/0tcfzxx9et3Xvvvcl1p0yZkqxPnz49Wf/888+TdbRfyw/7AQxvhB8IivADQRF+ICjCDwRF+IGgijirDy02Z86cZP3uu++uW5swYUJT204NI0rShx9+2NTzozzs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKE7prYDx48cn66+//nqyfuKJJ9atNfv3Xb16dbK+YMGCZH3v3r1NbR+HjlN6ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNXwAMPPJCs33rrrcl6NnfCoFr99/3444+T9dS1BpYtW5Zc98CBAw31FB3j/ACSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqNxxfjNbIelKSf3uPiVbNkbSakmnS+qTdJ27/zt3Y0HH+U877bRkvbe3N1kfOXJksr558+a6tT179iTXzZuCu1n9/f11a1OnTk2uu3v37qLbCaHIcf5HJV36jWWLJG1w98mSNmT3AQwjueF39xclffNyLDMkrcxur5R0dcF9AWixRt/zj3X3Xdnt3ZLGFtQPgDZpeq4+d/fUe3kz65LU1ex2ABSr0T3/HjMbJ0nZ77qf6rh7t7t3untng9sC0AKNhn+dpLnZ7bmS1hbTDoB2yQ2/mfVI+quk75vZDjO7SdJ9ki4xs3clTc/uAxhGct/zu/vsOqWLC+7lsNXR0ZGsjxo1KlnfuHFjsn7hhRfWrY0YMSK57uzZ9f68NYsXL07WJ02alKyffPLJdWtr16YPGC+77LJknTkBmsM3/ICgCD8QFOEHgiL8QFCEHwiK8ANBNf31XuQ75phjkvW806qXLl3a8Lb379+frD/yyCPJ+rXXXpusT5w48ZB7OujTTz9N1rl0d2ux5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoBjnb4O802bzXHHFFcn6s88+29Tzp3R2tu4CTC+//HKyvm/fvpZtG+z5gbAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoxvnboKenJ1m/6qqrkvVzzz03WT/rrLPq1s4555zkujNnzkzWR48enax/9NFHDa8/b9685LqPP/54sr5ly5ZkHWns+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMu7ZryZrZB0paR+d5+SLVsiaZ6kD7KHLXb3P+RuzCy9scPUmDFjkvVt27Yl6yeccEKybmZ1a3l/3zwvvPBCsj5//vxk/bnnnqtbmzx5cnLd5cuXJ+u33HJLsh6Vu9f/BzHAUPb8j0q6dJDlS929I/vJDT6AaskNv7u/KGlvG3oB0EbNvOdfYGa9ZrbCzNLfAQVQOY2G/yFJkyR1SNol6Tf1HmhmXWa2ycw2NbgtAC3QUPjdfY+7f+nuX0laLum8xGO73b3T3Vt3JUgAh6yh8JvZuAF3Z0p6q5h2ALRL7im9ZtYj6SJJ3zWzHZJ+KekiM+uQ5JL6JN3cwh4BtEDuOH+hGws6zp9n+vTpyfpTTz2VrKe+B5D39122bFmyfvvttyfr+/fvT9bvueeeurVFixYl192+fXuynve6vffee8n64arIcX4AhyHCDwRF+IGgCD8QFOEHgiL8QFAM9Q0DeUNaN9xwQ91a3qW177jjjmS92Wmyjz322Lq1J598Mrlu3iXNn3jiiWR97ty5yfrhiqE+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4/wozaxZs5L1VatWJes7d+5M1js6OurW9u49fK9Jyzg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcX6U5ogj0vuevPP1r7/++mT9zjvvrFu76667kusOZ4zzA0gi/EBQhB8IivADQRF+ICjCDwRF+IGgcsf5zWyCpMckjZXkkrrd/UEzGyNptaTTJfVJus7d/53zXIzzY8hS5+NL0ksvvZSsjxgxom7t7LPPTq67devWZL3Kihzn/0LSQnf/gaQfSZpvZj+QtEjSBnefLGlDdh/AMJEbfnff5e6vZbc/kfSOpFMkzZC0MnvYSklXt6pJAMU7pPf8Zna6pKmS/iZprLvvykq7VXtbAGCYOGqoDzSzkZKelnSbu//H7P9vK9zd672fN7MuSV3NNgqgWEPa85vZd1QL/ip3X5Mt3mNm47L6OEn9g63r7t3u3ununUU0DKAYueG32i7+YUnvuPtvB5TWSTo4DepcSWuLbw9AqwxlqG+apI2SNkv6Klu8WLX3/b+XdKqk7aoN9SWvh8xQH4q0cOHCZP3++++vW1uzZk3dmiTdeOONyfpnn32WrJdpqEN9ue/53f0vkuo92cWH0hSA6uAbfkBQhB8IivADQRF+ICjCDwRF+IGguHQ3hq2TTjopWU+d8nvmmWcm1807nbi3tzdZLxOX7gaQRPiBoAg/EBThB4Ii/EBQhB8IivADQTHOj8PWqaeeWrfW19eXXLenpydZnzNnTiMttQXj/ACSCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMb5EdL69euT9QsuuCBZP//885P1LVu2HHJPRWGcH0AS4QeCIvxAUIQfCIrwA0ERfiAowg8ElTtFt5lNkPSYpLGSXFK3uz9oZkskzZP0QfbQxe7+h1Y1ChTpmmuuSdbffPPNZD3vuv9ljvMPVW74JX0haaG7v2ZmoyS9ambPZ7Wl7v7r1rUHoFVyw+/uuyTtym5/YmbvSDql1Y0BaK1Des9vZqdLmirpb9miBWbWa2YrzGx0nXW6zGyTmW1qqlMAhRpy+M1spKSnJd3m7v+R9JCkSZI6VDsy+M1g67l7t7t3untnAf0CKMiQwm9m31Et+KvcfY0kufsed//S3b+StFzSea1rE0DRcsNvZibpYUnvuPtvBywfN+BhMyW9VXx7AFol95ReM5smaaOkzZK+yhYvljRbtUN+l9Qn6ebsw8HUc3FKL9BiQz2ll/P5gcMM5/MDSCL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ENZSr9xbpX5K2D7j/3WxZFVW1t6r2JdFbo4rs7bShPrCt5/N/a+Nmm6p6bb+q9lbVviR6a1RZvXHYDwRF+IGgyg5/d8nbT6lqb1XtS6K3RpXSW6nv+QGUp+w9P4CSlBJ+M7vUzP5hZtvMbFEZPdRjZn1mttnM3ih7irFsGrR+M3trwLIxZva8mb2b/R50mrSSeltiZjuz1+4NM7u8pN4mmNmfzWyLmb1tZj/Plpf62iX6KuV1a/thv5kdKWmrpEsk7ZD0iqTZ7l6JOY3NrE9Sp7uXPiZsZj+WtE/SY+4+JVv2K0l73f2+7D/O0e5+e0V6WyJpX9kzN2cTyowbOLO0pKsl/UwlvnaJvq5TCa9bGXv+8yRtc/d/uvsBSb+TNKOEPirP3V+UtPcbi2dIWpndXqnaP562q9NbJbj7Lnd/Lbv9iaSDM0uX+tol+ipFGeE/RdL7A+7vULWm/HZJ683sVTPrKruZQYwdMDPSbkljy2xmELkzN7fTN2aWrsxr18iM10XjA79vm+buP5R0maT52eFtJXntPVuVhmuGNHNzuwwys/T/lPnaNTrjddHKCP9OSRMG3B+fLasEd9+Z/e6X9IyqN/vwnoOTpGa/+0vu53+qNHPzYDNLqwKvXZVmvC4j/K9ImmxmZ5jZ0ZJmSVpXQh/fYmbHZR/EyMyOk/RTVW/24XWS5ma350paW2IvX1OVmZvrzSytkl+7ys147e5t/5F0uWqf+L8n6Rdl9FCnr4mS3sx+3i67N0k9qh0Gfq7aZyM3STpR0gZJ70p6QdKYCvX2uGqzOfeqFrRxJfU2TbVD+l5Jb2Q/l5f92iX6KuV14xt+QFB84AcERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+IKj/AlLXkc59O3KwAAAAAElFTkSuQmCC\n", 717 | "text/plain": [ 718 | "
" 719 | ] 720 | }, 721 | "metadata": {}, 722 | "output_type": "display_data" 723 | }, 724 | { 725 | "name": "stdout", 726 | "output_type": "stream", 727 | "text": [ 728 | "Model prediction: 9\n" 729 | ] 730 | }, 731 | { 732 | "data": { 733 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADbFJREFUeJzt3W+MVPW9x/HP1xUMgT5AiRsirPSCNKkmwnU1xmBD47XxaiPwhKDR0LRhfYCJ6H1w0fvgYq6aeu2f9FENWCw1xfYmaiC1sVRSKzVKXAWV9Q9ym8UuQVZCYy0x9MJ++2AON1vc8zvDzJk5Z/m+X8lmZ853zpwvEz57zszvzPmZuwtAPOdV3QCAahB+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBnd/NjZkZpxMCHebu1szj2trzm9lNZvaBmR0ws/XtPBeA7rJWz+03sx5J+yXdKGlE0uuSbnP3dxPrsOcHOqwbe/5rJB1w9z+6+98k/ULSsjaeD0AXtRP+SyT9adz9kWzZPzCzATMbNLPBNrYFoGQd/8DP3TdK2ihx2A/USTt7/kOS5o67PydbBmASaCf8r0u6zMy+bGZTJa2StL2ctgB0WsuH/e5+0szulvQbST2SNrv7UGmdAeiolof6WtoY7/mBjuvKST4AJi/CDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Lq6hTdwHgzZ85M1vv6+jq27YMHDybr9957b7K+b9++ZH3//v3J+ltvvZWsdwN7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8Iqq1xfjMblvSZpFOSTrp7fxlNYfK45ZZbkvVbb701t7Z06dLkugsWLGilpaYUjcNfeumlyfoFF1zQ1vZ7enraWr8MZZzk83V3P1rC8wDoIg77gaDaDb9L2mFmb5jZQBkNAeiOdg/7l7j7ITO7WNJvzex9d395/AOyPwr8YQBqpq09v7sfyn6PSnpO0jUTPGaju/fzYSBQLy2H38ymm9mXTt+W9A1J6a86AaiNdg77eyU9Z2ann2eru79QSlcAOs7cvXsbM+vexiBJmj9/frK+du3aZH3NmjXJ+rRp05L1bOeAM3RynN/dm3rRGeoDgiL8QFCEHwiK8ANBEX4gKMIPBMWlu89xc+bMSdbvueeeLnXSfe+//35ubWhoqIud1BN7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+Lpg1a1ayXjTW/sorryTrL7yQfxmFEydOJNf99NNPk/Xjx48n69OnT0/Wd+zYkVsrmuZ69+7dyfqePXuS9c8//zy3VvTvioA9PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExaW7S1A01r1r165k/corr0zWV6xYkaxv3749WU+ZN29esj48PJys9/X1JesjIyO5tbGxseS6aA2X7gaQRPiBoAg/EBThB4Ii/EBQhB8IivADQRV+n9/MNkv6pqRRd78iW3ahpF9KmidpWNJKd/9z59qs3tSpU3NrW7duTa5bNI7/yCOPJOsvvvhist6OonH8Ih999FE5jaDrmtnz/1TSTWcsWy9pp7tfJmlndh/AJFIYfnd/WdKxMxYvk7Qlu71F0vKS+wLQYa2+5+9198PZ7Y8l9ZbUD4Auafsafu7uqXP2zWxA0kC72wFQrlb3/EfMbLYkZb9H8x7o7hvdvd/d+1vcFoAOaDX82yWtzm6vlrStnHYAdEth+M3saUmvSvqKmY2Y2XckfVfSjWb2oaR/ye4DmET4Pn9mxowZyfr999+fW1u/Pj3SefTo0WR94cKFyXrRtfWB8fg+P4Akwg8ERfiBoAg/EBThB4Ii/EBQTNGdWb48/d2k1HBe0ddar7/++mSdoTxUgT0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFOH/muuuua3ndPXv2JOupaaqBqrDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguHR3ZnQ0d9IhSdJFF12UWztx4kRy3UcffTRZ37YtPefJ3r17k3VgPC7dDSCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCKhznN7PNkr4padTdr8iWbZC0RtIn2cMecPdfF26sxuP8Ra/D2NhYx7Zd9NyPP/54sv7aa6/l1vr6+pLrHjhwIFkfGhpK1otcfvnlubVXX301uS7XQWhNmeP8P5V00wTLf+jui7KfwuADqJfC8Lv7y5KOdaEXAF3Uznv+u83sbTPbbGYzS+sIQFe0Gv4fS5ovaZGkw5K+n/dAMxsws0EzG2xxWwA6oKXwu/sRdz/l7mOSNkm6JvHYje7e7+79rTYJoHwthd/MZo+7u0LSvnLaAdAthZfuNrOnJS2VNMvMRiT9p6SlZrZIkksalnRXB3sE0AF8nz/z2GOPJev33XdflzqJ45NPPknWX3rppWR91apVJXZz7uD7/ACSCD8QFOEHgiL8QFCEHwiK8ANBMdSX6enpSdYXL16cW9u6dWty3fPPT59OMXfu3GT9vPNi/o0u+r+5YcOGZP2hhx4qsZvJg6E+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU4ff5ozh16lSyPjiYfxWyhQsXtrXtG264IVmfMmVKsp4a77766qtbaakWzNLD1VdddVWXOjk3secHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY56+BnTt3trX+okWLcmtF4/wnT55M1p988slkfdOmTcn6unXrcmu33357cl10Fnt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiqcJzfzOZK+pmkXkkuaaO7/8jMLpT0S0nzJA1LWunuf+5cq8izY8eO3NrDDz+cXLdoToE1a9Yk6wsWLEjWly5dmqy3Y2RkpGPPHUEze/6Tkv7N3b8q6VpJa83sq5LWS9rp7pdJ2pndBzBJFIbf3Q+7+5vZ7c8kvSfpEknLJG3JHrZF0vJONQmgfGf1nt/M5klaLGm3pF53P5yVPlbjbQGASaLpc/vNbIakZyStc/e/jL++mrt73jx8ZjYgaaDdRgGUq6k9v5lNUSP4P3f3Z7PFR8xsdlafLWl0onXdfaO797t7fxkNAyhHYfitsYv/iaT33P0H40rbJa3Obq+WtK389gB0SuEU3Wa2RNIuSe9IGssWP6DG+/7/kdQn6aAaQ33HCp6rtlN0T2bTpk3LrW3evDm57sqVK8tup2lFl0t//vnnk/U77rgjWT9+/PhZ93QuaHaK7sL3/O7+B0l5T5a+4DyA2uIMPyAowg8ERfiBoAg/EBThB4Ii/EBQheP8pW6Mcf6u6+1Nf+XiiSeeSNb7+9MnZl588cXJ+vDwcG7tqaeeSq6bmnoc+Zod52fPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMc6PpDvvvDNZv/baa5P1Bx98MLc2OjrhxZ/QJsb5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQjPMD5xjG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIXhN7O5ZvY7M3vXzIbM7J5s+QYzO2Rme7OfmzvfLoCyFJ7kY2azJc129zfN7EuS3pC0XNJKSX919+81vTFO8gE6rtmTfM5v4okOSzqc3f7MzN6TdEl77QGo2lm95zezeZIWS9qdLbrbzN42s81mNjNnnQEzGzSzwbY6BVCqps/tN7MZkn4v6WF3f9bMeiUdleSS/kuNtwbfLngODvuBDmv2sL+p8JvZFEm/kvQbd//BBPV5kn7l7lcUPA/hBzqstC/2mJlJ+omk98YHP/sg8LQVkvadbZMAqtPMp/1LJO2S9I6ksWzxA5Juk7RIjcP+YUl3ZR8Opp6LPT/QYaUe9peF8AOdx/f5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiq8gGfJjko6OO7+rGxZHdW1t7r2JdFbq8rs7dJmH9jV7/N/YeNmg+7eX1kDCXXtra59SfTWqqp647AfCIrwA0FVHf6NFW8/pa691bUvid5aVUlvlb7nB1Cdqvf8ACpSSfjN7CYz+8DMDpjZ+ip6yGNmw2b2TjbzcKVTjGXToI2a2b5xyy40s9+a2YfZ7wmnSauot1rM3JyYWbrS165uM153/bDfzHok7Zd0o6QRSa9Lus3d3+1qIznMbFhSv7tXPiZsZl+T9FdJPzs9G5KZ/bekY+7+3ewP50x3//ea9LZBZzlzc4d6y5tZ+luq8LUrc8brMlSx579G0gF3/6O7/03SLyQtq6CP2nP3lyUdO2PxMklbsttb1PjP03U5vdWCux929zez259JOj2zdKWvXaKvSlQR/ksk/Wnc/RHVa8pvl7TDzN4ws4Gqm5lA77iZkT6W1FtlMxMonLm5m86YWbo2r10rM16XjQ/8vmiJu/+zpH+VtDY7vK0lb7xnq9NwzY8lzVdjGrfDkr5fZTPZzNLPSFrn7n8ZX6vytZugr0petyrCf0jS3HH352TLasHdD2W/RyU9p8bblDo5cnqS1Oz3aMX9/D93P+Lup9x9TNImVfjaZTNLPyPp5+7+bLa48tduor6qet2qCP/rki4zsy+b2VRJqyRtr6CPLzCz6dkHMTKz6ZK+ofrNPrxd0urs9mpJ2yrs5R/UZebmvJmlVfFrV7sZr9296z+SblbjE///lfQfVfSQ09c/SXor+xmqujdJT6txGPh/anw28h1JF0naKelDSS9KurBGvT2lxmzOb6sRtNkV9bZEjUP6tyXtzX5urvq1S/RVyevGGX5AUHzgBwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqL8DmYaFlMuCxPsAAAAASUVORK5CYII=\n", 734 | "text/plain": [ 735 | "
" 736 | ] 737 | }, 738 | "metadata": {}, 739 | "output_type": "display_data" 740 | }, 741 | { 742 | "name": "stdout", 743 | "output_type": "stream", 744 | "text": [ 745 | "Model prediction: 5\n" 746 | ] 747 | } 748 | ], 749 | "source": [ 750 | "# Use the model to predict the images class\n", 751 | "preds = list(model.predict(test_input_fn))\n", 752 | "\n", 753 | "n_images = 9\n", 754 | "# Display\n", 755 | "for i in range(n_images):\n", 756 | " plt.imshow(np.reshape(some_images[i], [28, 28]), cmap='gray')\n", 757 | " plt.show()\n", 758 | " print(\"Model prediction:\", preds[i])" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": null, 764 | "metadata": {}, 765 | "outputs": [], 766 | "source": [] 767 | } 768 | ], 769 | "metadata": { 770 | "kernelspec": { 771 | "display_name": "Python 3", 772 | "language": "python", 773 | "name": "python3" 774 | }, 775 | "language_info": { 776 | "codemirror_mode": { 777 | "name": "ipython", 778 | "version": 3 779 | }, 780 | "file_extension": ".py", 781 | "mimetype": "text/x-python", 782 | "name": "python", 783 | "nbconvert_exporter": "python", 784 | "pygments_lexer": "ipython3", 785 | "version": "3.5.0" 786 | } 787 | }, 788 | "nbformat": 4, 789 | "nbformat_minor": 2 790 | } 791 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/CNNClassifier_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n", 8 | "\n", 9 | "- 基于MNIST数据集,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n", 10 | "\n", 11 | "- TensorBoard的简单使用\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## 导入各个库" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/plain": [ 29 | "'1.11.0'" 30 | ] 31 | }, 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "%matplotlib inline\n", 39 | "import tensorflow as tf\n", 40 | "import matplotlib.pyplot as plt\n", 41 | "import numpy as np\n", 42 | "import pandas as pd\n", 43 | "import multiprocessing\n", 44 | "\n", 45 | "\n", 46 | "from tensorflow import data\n", 47 | "from tensorflow.python.feature_column import feature_column\n", 48 | "\n", 49 | "tf.__version__" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## MNIST数据集载入" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "### 看看MNIST数据长什么样子的\n", 64 | "\n", 65 | "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n", 66 | "\n", 67 | "More info: http://yann.lecun.com/exdb/mnist/" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "- MNIST数据集包含70000张图像和对应的标签(图像的分类)。数据集被划为3个子集:训练集,验证集和测试集。\n", 75 | "\n", 76 | "- 定义**MNIST**数据的相关信息" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 2, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n", 86 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n", 87 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n", 88 | "\n", 89 | "MULTI_THREADING = True\n", 90 | "RESUME_TRAINING = False\n", 91 | "\n", 92 | "NUM_CLASS = 10\n", 93 | "IMG_SHAPE = [28,28]\n", 94 | "\n", 95 | "IMG_WIDTH = 28\n", 96 | "IMG_HEIGHT = 28\n", 97 | "IMG_FLAT = 784\n", 98 | "NUM_CHANNEL = 1\n", 99 | "\n", 100 | "BATCH_SIZE = 128\n", 101 | "NUM_TRAIN = 55000\n", 102 | "NUM_VAL = 5000\n", 103 | "NUM_TEST = 10000" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 3, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "test_data (10000, 784)\n", 116 | "test_label (10000,)\n", 117 | "val_data (5000, 784)\n", 118 | "val_label (5000,)\n", 119 | "train_data (55000, 784)\n", 120 | "train_label (55000,)\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n", 126 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n", 127 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n", 128 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n", 129 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n", 130 | "\n", 131 | "train_values = train_data.values\n", 132 | "train_data = train_values[:,1:]/255.0\n", 133 | "train_label = train_values[:,0:1].squeeze()\n", 134 | "\n", 135 | "val_values = val_data.values\n", 136 | "val_data = val_values[:,1:]/255.0\n", 137 | "val_label = val_values[:,0:1].squeeze()\n", 138 | "\n", 139 | "test_values = test_data.values\n", 140 | "test_data = test_values[:,1:]/255.0\n", 141 | "test_label = test_values[:,0:1].squeeze()\n", 142 | "\n", 143 | "print('test_data',np.shape(test_data))\n", 144 | "print('test_label',np.shape(test_label))\n", 145 | "\n", 146 | "print('val_data',np.shape(val_data))\n", 147 | "print('val_label',np.shape(val_label))\n", 148 | "\n", 149 | "print('train_data',np.shape(train_data))\n", 150 | "print('train_label',np.shape(train_label))\n", 151 | "\n", 152 | "# train_data.head(10)\n", 153 | "# test_data.head(10)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "## 试试自己写一个estimator\n", 161 | "\n", 162 | "- 基于MNIST数据集,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n", 163 | "\n", 164 | "- [官网API](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator)\n", 165 | "\n", 166 | "- 看看有哪些参数\n", 167 | "\n", 168 | "```python\n", 169 | "__init__(\n", 170 | " model_fn,\n", 171 | " model_dir=None,\n", 172 | " config=None,\n", 173 | " params=None,\n", 174 | " warm_start_from=None\n", 175 | ")\n", 176 | "```\n", 177 | "- 本例中,主要用了 **tf.estimator.Estimator** 中的 `model_fn`,`model_dir`\n" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "### 先简单看看数据流\n", 185 | "\n", 186 | "下面的图表直接显示了本次MNIST例子的数据流向,共有**2个卷积层**,每一层卷积之后采用最大池化进行下采样(图中并未画出),最后接**2个全连接层**,实现对MNIST数据集的分类\n", 187 | "\n", 188 | "![Flowchart](images/02_network_flowchart.png)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "### 先看看input_fn之创建输入函数" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 4, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "# validate tf.data.TextLineDataset() using make_one_shot_iterator()\n", 205 | "\n", 206 | "def decode_line(line):\n", 207 | " # Decode the csv_line to tensor.\n", 208 | " record_defaults = [[1.0] for col in range(785)]\n", 209 | " items = tf.decode_csv(line, record_defaults)\n", 210 | " features = items[1:785]\n", 211 | " label = items[0]\n", 212 | "\n", 213 | " features = tf.cast(features, tf.float32)\n", 214 | " features = tf.reshape(features,[28,28,1])\n", 215 | " features = tf.image.flip_left_right(features)\n", 216 | "# print('features_aug',features_aug)\n", 217 | " label = tf.cast(label, tf.int64)\n", 218 | "# label = tf.one_hot(label,num_class)\n", 219 | " return features,label" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 5, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "def csv_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.TRAIN, \n", 229 | " skip_header_lines=1, \n", 230 | " num_epochs=None, \n", 231 | " batch_size=128):\n", 232 | " shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False\n", 233 | " \n", 234 | " num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1\n", 235 | " \n", 236 | " print(\"\")\n", 237 | " print(\"* data input_fn:\")\n", 238 | " print(\"================\")\n", 239 | " print(\"Input file(s): {}\".format(files_name_pattern))\n", 240 | " print(\"Batch size: {}\".format(batch_size))\n", 241 | " print(\"Epoch Count: {}\".format(num_epochs))\n", 242 | " print(\"Mode: {}\".format(mode))\n", 243 | " print(\"Thread Count: {}\".format(num_threads))\n", 244 | " print(\"Shuffle: {}\".format(shuffle))\n", 245 | " print(\"================\")\n", 246 | " print(\"\")\n", 247 | "\n", 248 | " file_names = tf.matching_files(files_name_pattern)\n", 249 | " dataset = data.TextLineDataset(filenames=file_names).skip(1)\n", 250 | "# dataset = tf.data.TextLineDataset(filenames).skip(1)\n", 251 | " print(\"DATASET\",dataset)\n", 252 | "\n", 253 | " # Use `Dataset.map()` to build a pair of a feature dictionary and a label\n", 254 | " # tensor for each example.\n", 255 | " dataset = dataset.map(decode_line)\n", 256 | " print(\"DATASET_1\",dataset)\n", 257 | " dataset = dataset.shuffle(buffer_size=10000)\n", 258 | " print(\"DATASET_2\",dataset)\n", 259 | " dataset = dataset.batch(32)\n", 260 | " print(\"DATASET_3\",dataset)\n", 261 | " dataset = dataset.repeat(num_epochs)\n", 262 | " print(\"DATASET_4\",dataset)\n", 263 | " iterator = dataset.make_one_shot_iterator()\n", 264 | " \n", 265 | " # `features` is a dictionary in which each value is a batch of values for\n", 266 | " # that feature; `labels` is a batch of labels.\n", 267 | " features, labels = iterator.get_next()\n", 268 | " \n", 269 | " features = {'images':features}\n", 270 | " \n", 271 | " return features,labels\n" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 6, 277 | "metadata": {}, 278 | "outputs": [ 279 | { 280 | "name": "stdout", 281 | "output_type": "stream", 282 | "text": [ 283 | "\n", 284 | "* data input_fn:\n", 285 | "================\n", 286 | "Input file(s): data_csv/mnist_train.csv\n", 287 | "Batch size: 128\n", 288 | "Epoch Count: None\n", 289 | "Mode: train\n", 290 | "Thread Count: 4\n", 291 | "Shuffle: True\n", 292 | "================\n", 293 | "\n", 294 | "DATASET \n", 295 | "features_aug Tensor(\"flip_left_right/ReverseV2:0\", shape=(28, 28, 1), dtype=float32)\n", 296 | "DATASET_1 \n", 297 | "DATASET_2 \n", 298 | "DATASET_3 \n", 299 | "DATASET_4 \n", 300 | "Features in CSV: ['images']\n", 301 | "Target in CSV: Tensor(\"IteratorGetNext:1\", shape=(?,), dtype=int64)\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "features, target = csv_input_fn(files_name_pattern=TRAIN_DATA_FILES_PATTERN)\n", 307 | "print(\"Features in CSV: {}\".format(list(features.keys())))\n", 308 | "print(\"Target in CSV: {}\".format(target))" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "### 定义feature_columns" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 10, 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "feature_x = tf.feature_column.numeric_column('images', shape=IMG_SHAPE)\n", 325 | "\n", 326 | "feature_columns = [feature_x]" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "### 重点在这里——model_fn\n", 334 | "\n", 335 | "\n", 336 | "#### model_fn: Model function. Follows the signature:\n", 337 | "\n", 338 | "* Args:\n", 339 | " * `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same.\n", 340 | " * `labels`: This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same (for multi-head models).If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` must still be able to handle `labels=None`.\n", 341 | " * `mode`: Optional. Specifies if this training, evaluation or prediction. See `tf.estimator.ModeKeys`.\n", 342 | " * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning.\n", 343 | " * `config`: Optional `estimator.RunConfig` object. Will receive what is passed to Estimator as its `config` parameter, or a default value. Allows setting up things in your `model_fn` based on configuration such as `num_ps_replicas`, or `model_dir`.\n", 344 | "* Returns:\n", 345 | " `tf.estimator.EstimatorSpec`\n", 346 | " \n", 347 | "#### 注意model_fn返回的tf.estimator.EstimatorSpec\n", 348 | "\n", 349 | "\n", 350 | "\n" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "### 定义我们自己的model_fn" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 11, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "def model_fn(features, labels, mode, params):\n", 367 | " # Args:\n", 368 | " #\n", 369 | " # features: This is the x-arg from the input_fn.\n", 370 | " # labels: This is the y-arg from the input_fn,\n", 371 | " # see e.g. train_input_fn for these two.\n", 372 | " # mode: Either TRAIN, EVAL, or PREDICT\n", 373 | " # params: User-defined hyper-parameters, e.g. learning-rate.\n", 374 | " \n", 375 | " # Reference to the tensor named \"x\" in the input-function.\n", 376 | "# x = features[\"images\"]\n", 377 | " x = tf.feature_column.input_layer(features, params['feature_columns'])\n", 378 | " # The convolutional layers expect 4-rank tensors\n", 379 | " # but x is a 2-rank tensor, so reshape it.\n", 380 | " net = tf.reshape(x, [-1, IMG_HEIGHT, IMG_WIDTH, NUM_CHANNEL]) \n", 381 | "\n", 382 | " # First convolutional layer.\n", 383 | " net = tf.layers.conv2d(inputs=net, name='layer_conv1',\n", 384 | " filters=16, kernel_size=5,\n", 385 | " padding='same', activation=tf.nn.relu)\n", 386 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2)\n", 387 | "\n", 388 | " # Second convolutional layer.\n", 389 | " net = tf.layers.conv2d(inputs=net, name='layer_conv2',\n", 390 | " filters=36, kernel_size=5,\n", 391 | " padding='same', activation=tf.nn.relu)\n", 392 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2) \n", 393 | "\n", 394 | " # Flatten to a 2-rank tensor.\n", 395 | " net = tf.contrib.layers.flatten(net)\n", 396 | " # Eventually this should be replaced with:\n", 397 | " # net = tf.layers.flatten(net)\n", 398 | "\n", 399 | " # First fully-connected / dense layer.\n", 400 | " # This uses the ReLU activation function.\n", 401 | " net = tf.layers.dense(inputs=net, name='layer_fc1',\n", 402 | " units=128, activation=tf.nn.relu) \n", 403 | "\n", 404 | " # Second fully-connected / dense layer.\n", 405 | " # This is the last layer so it does not use an activation function.\n", 406 | " net = tf.layers.dense(inputs=net, name='layer_fc2',\n", 407 | " units=10)\n", 408 | "\n", 409 | " # Logits output of the neural network.\n", 410 | " logits = net\n", 411 | "\n", 412 | " # Softmax output of the neural network.\n", 413 | " y_pred = tf.nn.softmax(logits=logits)\n", 414 | " \n", 415 | " # Classification output of the neural network.\n", 416 | " y_pred_cls = tf.argmax(y_pred, axis=1)\n", 417 | "\n", 418 | " if mode == tf.estimator.ModeKeys.PREDICT:\n", 419 | " # If the estimator is supposed to be in prediction-mode\n", 420 | " # then use the predicted class-number that is output by\n", 421 | " # the neural network. Optimization etc. is not needed.\n", 422 | " spec = tf.estimator.EstimatorSpec(mode=mode,\n", 423 | " predictions=y_pred_cls)\n", 424 | " else:\n", 425 | " # Otherwise the estimator is supposed to be in either\n", 426 | " # training or evaluation-mode. Note that the loss-function\n", 427 | " # is also required in Evaluation mode.\n", 428 | " \n", 429 | " # Define the loss-function to be optimized, by first\n", 430 | " # calculating the cross-entropy between the output of\n", 431 | " # the neural network and the true labels for the input data.\n", 432 | " # This gives the cross-entropy for each image in the batch.\n", 433 | " cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,\n", 434 | " logits=logits)\n", 435 | "\n", 436 | " # Reduce the cross-entropy batch-tensor to a single number\n", 437 | " # which can be used in optimization of the neural network.\n", 438 | " loss = tf.reduce_mean(cross_entropy)\n", 439 | "\n", 440 | " # Define the optimizer for improving the neural network.\n", 441 | " optimizer = tf.train.AdamOptimizer(learning_rate=params[\"learning_rate\"])\n", 442 | "\n", 443 | " # Get the TensorFlow op for doing a single optimization step.\n", 444 | " train_op = optimizer.minimize(\n", 445 | " loss=loss, global_step=tf.train.get_global_step())\n", 446 | "\n", 447 | " # Define the evaluation metrics,\n", 448 | " # in this case the classification accuracy.\n", 449 | " metrics = \\\n", 450 | " {\n", 451 | " \"accuracy\": tf.metrics.accuracy(labels, y_pred_cls)\n", 452 | " }\n", 453 | "\n", 454 | " # Wrap all of this in an EstimatorSpec.\n", 455 | " spec = tf.estimator.EstimatorSpec(\n", 456 | " mode=mode,\n", 457 | " loss=loss,\n", 458 | " train_op=train_op,\n", 459 | " eval_metric_ops=metrics)\n", 460 | " \n", 461 | " return spec" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "### 自建的estimator在这里" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 12, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "params = {\"learning_rate\": 1e-4,\n", 478 | " 'feature_columns': feature_columns}" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 13, 484 | "metadata": {}, 485 | "outputs": [ 486 | { 487 | "name": "stdout", 488 | "output_type": "stream", 489 | "text": [ 490 | "INFO:tensorflow:Using default config.\n", 491 | "INFO:tensorflow:Using config: {'_model_dir': './cnn_classifer_dataset/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", 492 | "graph_options {\n", 493 | " rewrite_options {\n", 494 | " meta_optimizer_iterations: ONE\n", 495 | " }\n", 496 | "}\n", 497 | ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n" 498 | ] 499 | } 500 | ], 501 | "source": [ 502 | "model = tf.estimator.Estimator(model_fn=model_fn,\n", 503 | " params=params,\n", 504 | " model_dir=\"./cnn_classifer_dataset/\")" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "### 训练训练看看" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "metadata": {}, 518 | "outputs": [ 519 | { 520 | "name": "stdout", 521 | "output_type": "stream", 522 | "text": [ 523 | "\n", 524 | "* data input_fn:\n", 525 | "================\n", 526 | "Input file(s): data_csv/mnist_train.csv\n", 527 | "Batch size: 128\n", 528 | "Epoch Count: None\n", 529 | "Mode: train\n", 530 | "Thread Count: 4\n", 531 | "Shuffle: True\n", 532 | "================\n", 533 | "\n", 534 | "DATASET \n", 535 | "DATASET_1 \n", 536 | "DATASET_2 \n", 537 | "DATASET_3 \n", 538 | "DATASET_4 \n", 539 | "INFO:tensorflow:Calling model_fn.\n", 540 | "INFO:tensorflow:Done calling model_fn.\n", 541 | "INFO:tensorflow:Create CheckpointSaverHook.\n", 542 | "INFO:tensorflow:Graph was finalized.\n", 543 | "INFO:tensorflow:Running local_init_op.\n", 544 | "INFO:tensorflow:Done running local_init_op.\n", 545 | "INFO:tensorflow:Saving checkpoints for 0 into ./cnn_classifer_dataset/model.ckpt.\n", 546 | "INFO:tensorflow:loss = 40.126976, step = 1\n", 547 | "INFO:tensorflow:global_step/sec: 11.057\n", 548 | "INFO:tensorflow:loss = 1.1582, step = 101 (9.049 sec)\n", 549 | "INFO:tensorflow:global_step/sec: 12.9123\n", 550 | "INFO:tensorflow:loss = 0.778288, step = 201 (7.743 sec)\n", 551 | "INFO:tensorflow:global_step/sec: 13.889\n", 552 | "INFO:tensorflow:loss = 1.0873605, step = 301 (7.200 sec)\n", 553 | "INFO:tensorflow:global_step/sec: 14.1931\n", 554 | "INFO:tensorflow:loss = 0.07414566, step = 401 (7.045 sec)\n", 555 | "INFO:tensorflow:global_step/sec: 14.2251\n", 556 | "INFO:tensorflow:loss = 0.32521993, step = 501 (7.029 sec)\n", 557 | "INFO:tensorflow:global_step/sec: 12.7967\n", 558 | "INFO:tensorflow:loss = 0.2568686, step = 601 (7.815 sec)\n", 559 | "INFO:tensorflow:global_step/sec: 12.4253\n", 560 | "INFO:tensorflow:loss = 0.54189134, step = 701 (8.048 sec)\n", 561 | "INFO:tensorflow:global_step/sec: 12.5796\n", 562 | "INFO:tensorflow:loss = 0.15989298, step = 801 (7.949 sec)\n", 563 | "INFO:tensorflow:global_step/sec: 13.7096\n", 564 | "INFO:tensorflow:loss = 0.90422636, step = 901 (7.295 sec)\n", 565 | "INFO:tensorflow:global_step/sec: 13.8366\n", 566 | "INFO:tensorflow:loss = 0.20136827, step = 1001 (7.227 sec)\n", 567 | "INFO:tensorflow:global_step/sec: 13.5184\n", 568 | "INFO:tensorflow:loss = 0.53505665, step = 1101 (7.398 sec)\n", 569 | "INFO:tensorflow:global_step/sec: 12.8457\n", 570 | "INFO:tensorflow:loss = 0.22107196, step = 1201 (7.784 sec)\n", 571 | "INFO:tensorflow:global_step/sec: 13.0342\n", 572 | "INFO:tensorflow:loss = 0.31935138, step = 1301 (7.672 sec)\n" 573 | ] 574 | } 575 | ], 576 | "source": [ 577 | "input_fn = lambda: csv_input_fn(\\\n", 578 | " files_name_pattern= TRAIN_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.TRAIN)\n", 579 | "# Train the Model\n", 580 | "model.train(input_fn, steps=2000)" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": {}, 586 | "source": [ 587 | "### 验证一下瞅瞅" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": 68, 593 | "metadata": {}, 594 | "outputs": [ 595 | { 596 | "name": "stdout", 597 | "output_type": "stream", 598 | "text": [ 599 | "\n", 600 | "* data input_fn:\n", 601 | "================\n", 602 | "Input file(s): data_csv/mnist_val.csv\n", 603 | "Batch size: 128\n", 604 | "Epoch Count: None\n", 605 | "Mode: eval\n", 606 | "Thread Count: 4\n", 607 | "Shuffle: False\n", 608 | "================\n", 609 | "\n", 610 | "DATASET \n", 611 | "DATASET_1 \n", 612 | "DATASET_2 \n", 613 | "DATASET_3 \n", 614 | "DATASET_4 \n", 615 | "INFO:tensorflow:Calling model_fn.\n", 616 | "INFO:tensorflow:Done calling model_fn.\n", 617 | "INFO:tensorflow:Starting evaluation at 2018-10-23-12:36:20\n", 618 | "INFO:tensorflow:Graph was finalized.\n", 619 | "INFO:tensorflow:Restoring parameters from trained_models/simple_cnn/model.ckpt-4000\n", 620 | "INFO:tensorflow:Running local_init_op.\n", 621 | "INFO:tensorflow:Done running local_init_op.\n", 622 | "INFO:tensorflow:Evaluation [1/1]\n", 623 | "INFO:tensorflow:Finished evaluation at 2018-10-23-12:36:29\n", 624 | "INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.96875, global_step = 4000, loss = 0.1153331\n" 625 | ] 626 | }, 627 | { 628 | "data": { 629 | "text/plain": [ 630 | "{'accuracy': 0.96875, 'global_step': 4000, 'loss': 0.1153331}" 631 | ] 632 | }, 633 | "execution_count": 68, 634 | "metadata": {}, 635 | "output_type": "execute_result" 636 | } 637 | ], 638 | "source": [ 639 | "input_fn = lambda: csv_input_fn(files_name_pattern= VAL_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.EVAL)\n", 640 | "\n", 641 | "model.evaluate(input_fn,steps=1)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 69, 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "name": "stdout", 651 | "output_type": "stream", 652 | "text": [ 653 | "\n", 654 | "* data input_fn:\n", 655 | "================\n", 656 | "Input file(s): data_csv/mnist_test.csv\n", 657 | "Batch size: 10\n", 658 | "Epoch Count: None\n", 659 | "Mode: infer\n", 660 | "Thread Count: 4\n", 661 | "Shuffle: False\n", 662 | "================\n", 663 | "\n", 664 | "DATASET \n", 665 | "DATASET_1 \n", 666 | "DATASET_2 \n", 667 | "DATASET_3 \n", 668 | "DATASET_4 \n", 669 | "INFO:tensorflow:Calling model_fn.\n", 670 | "INFO:tensorflow:Done calling model_fn.\n", 671 | "INFO:tensorflow:Graph was finalized.\n", 672 | "INFO:tensorflow:Restoring parameters from trained_models/simple_cnn/model.ckpt-4000\n", 673 | "INFO:tensorflow:Running local_init_op.\n", 674 | "INFO:tensorflow:Done running local_init_op.\n", 675 | "PREDICTIONS [6, 1, 2, 7, 0, 8, 0, 3, 0, 0]\n" 676 | ] 677 | } 678 | ], 679 | "source": [ 680 | "import itertools\n", 681 | "\n", 682 | "input_fn = lambda: csv_input_fn(\\\n", 683 | " files_name_pattern= TEST_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.PREDICT,batch_size=10)\n", 684 | "\n", 685 | "predictions = list(itertools.islice(model.predict(input_fn=input_fn),10))\n", 686 | "print('PREDICTIONS',predictions)\n", 687 | "# print(\"\")\n", 688 | "# print(\"* Predicted Classes: {}\".format(list(map(lambda item: item[\"classes\"][0]\n", 689 | "# ,predictions))))" 690 | ] 691 | } 692 | ], 693 | "metadata": { 694 | "kernelspec": { 695 | "display_name": "Python 3", 696 | "language": "python", 697 | "name": "python3" 698 | }, 699 | "language_info": { 700 | "codemirror_mode": { 701 | "name": "ipython", 702 | "version": 3 703 | }, 704 | "file_extension": ".py", 705 | "mimetype": "text/x-python", 706 | "name": "python", 707 | "nbconvert_exporter": "python", 708 | "pygments_lexer": "ipython3", 709 | "version": "3.6.6" 710 | } 711 | }, 712 | "nbformat": 4, 713 | "nbformat_minor": 2 714 | } 715 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/CNN_raw.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n", 8 | "\n", 9 | "- 基于MNIST数据集,运用TensorFlow搭建一个简单的卷积神经网络,并实现模型训练/验证/测试\n", 10 | "\n", 11 | "- TensorBoard的简单使用\n" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## 看看MNIST数据长什么样子的\n", 19 | "\n", 20 | "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n", 21 | "\n", 22 | "More info: http://yann.lecun.com/exdb/mnist/" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## 流程图" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "下面的图表直接显示了之后实现的卷积神经网络中数据的传递。\n", 37 | "\n", 38 | "![Flowchart](images/02_network_flowchart.png)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "输入图像在第一层卷积层里使用权重过滤器处理。结果在16张新图里,每张代表了卷积层里一个过滤器(的处理结果)。图像经过降采样,分辨率从28x28减少到14x14。\n", 46 | "\n", 47 | "16张小图在第二个卷积层中处理。这16个通道以及这层输出的每个通道都需要一个过滤权重。总共有36个输出,所以在第二个卷积层有16 x 36 = 576个滤波器。输出图再一次降采样到7x7个像素。\n", 48 | "\n", 49 | "第二个卷积层的输出是36张7x7像素的图像。它们被转换到一个长为7 x 7 x 36 = 1764的向量中去,它作为一个有128个神经元(或元素)的全连接网络的输入。这些又输入到另一个有10个神经元的全连接层中,每个神经元代表一个类别,用来确定图像的类别,即图像上的数字。\n", 50 | "\n", 51 | "卷积滤波一开始是随机挑选的,因此分类也是随机完成的。根据交叉熵(cross-entropy)来测量输入图预测值和真实类别间的错误。然后优化器用链式法则自动地将这个误差在卷积网络中传递,更新滤波权重来提升分类质量。这个过程迭代了几千次,直到分类误差足够低。\n", 52 | "\n", 53 | "这些特定的滤波权重和中间图像是一个优化结果,和你执行代码所看到的可能会有所不同。\n", 54 | "\n", 55 | "注意,这些在TensorFlow上的计算是在一部分图像上执行,而非单独的一张图,这使得计算更有效。也意味着在TensorFlow上实现时,这个流程图实际上会有更多的数据维度。\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "## 各种库导入" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "%matplotlib inline\n", 72 | "import matplotlib.pyplot as plt\n", 73 | "import tensorflow as tf\n", 74 | "import numpy as np\n", 75 | "import pandas as pd\n", 76 | "\n", 77 | "import time\n", 78 | "from datetime import timedelta\n", 79 | "import math" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "使用Python3.6(Anaconda)开发,TensorFlow版本是:" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "data": { 96 | "text/plain": [ 97 | "'1.8.0'" 98 | ] 99 | }, 100 | "execution_count": 4, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "tf.__version__" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "## MNIST数据集导入\n", 114 | "\n", 115 | "- 现在已经载入了MNIST数据集,它由70,000张图像和对应的标签(比如图像的类别)组成。数据集分成三份互相独立的子集。\n", 116 | "\n", 117 | "- 定义**MNIST**数据的相关信息" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 3, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n", 127 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n", 128 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n", 129 | "\n", 130 | "MULTI_THREADING = True\n", 131 | "RESUME_TRAINING = False\n", 132 | "\n", 133 | "NUM_CLASS = 10\n", 134 | "IMG_SHAPE = [28,28]\n", 135 | "\n", 136 | "IMG_WIDTH = 28\n", 137 | "IMG_HEIGHT = 28\n", 138 | "IMG_FLAT = 784\n", 139 | "NUM_CHANNEL = 1\n", 140 | "\n", 141 | "BATCH_SIZE = 128\n", 142 | "NUM_TRAIN = 55000\n", 143 | "NUM_VAL = 5000\n", 144 | "NUM_TEST = 10000" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "test_data (10000, 784)\n", 157 | "test_label (10000,)\n", 158 | "val_data (5000, 784)\n", 159 | "val_label (5000,)\n", 160 | "train_data (55000, 784)\n", 161 | "train_label (55000,)\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n", 167 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n", 168 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n", 169 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n", 170 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n", 171 | "\n", 172 | "train_values = train_data.values\n", 173 | "train_data = train_values[:,1:]/255.0\n", 174 | "train_label = train_values[:,0:1].squeeze()\n", 175 | "\n", 176 | "val_values = val_data.values\n", 177 | "val_data = val_values[:,1:]/255.0\n", 178 | "val_label = val_values[:,0:1].squeeze()\n", 179 | "\n", 180 | "test_values = test_data.values\n", 181 | "test_data = test_values[:,1:]/255.0\n", 182 | "test_label = test_values[:,0:1].squeeze()\n", 183 | "\n", 184 | "print('test_data',np.shape(test_data))\n", 185 | "print('test_label',np.shape(test_label))\n", 186 | "\n", 187 | "print('val_data',np.shape(val_data))\n", 188 | "print('val_label',np.shape(val_label))\n", 189 | "\n", 190 | "print('train_data',np.shape(train_data))\n", 191 | "print('train_label',np.shape(train_label))\n", 192 | "\n", 193 | "# train_data.head(10)\n", 194 | "# test_data.head(10)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "### one-hot编码" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 5, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "def one_hot_encoded(class_numbers, num_classes=None):\n", 211 | " \"\"\"\n", 212 | " Generate the One-Hot encoded class-labels from an array of integers.\n", 213 | "\n", 214 | " For example, if class_number=2 and num_classes=4 then\n", 215 | " the one-hot encoded label is the float array: [0. 0. 1. 0.]\n", 216 | "\n", 217 | " :param class_numbers:\n", 218 | " Array of integers with class-numbers.\n", 219 | " Assume the integers are from zero to num_classes-1 inclusive.\n", 220 | "\n", 221 | " :param num_classes:\n", 222 | " Number of classes. If None then use max(class_numbers)+1.\n", 223 | "\n", 224 | " :return:\n", 225 | " 2-dim array of shape: [len(class_numbers), num_classes]\n", 226 | " \"\"\"\n", 227 | "\n", 228 | " # Find the number of classes if None is provided.\n", 229 | " # Assumes the lowest class-number is zero.\n", 230 | " if num_classes is None:\n", 231 | " num_classes = np.max(class_numbers) + 1\n", 232 | "\n", 233 | " return np.eye(num_classes, dtype=float)[class_numbers]" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "### 数据集batch处理" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 45, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "# idx = 0\n", 250 | "\n", 251 | "def batch(batch_size=32):\n", 252 | " \"\"\"\n", 253 | " Create a random batch of training-data.\n", 254 | "\n", 255 | " :param batch_size: Number of images in the batch.\n", 256 | " :return: 3 numpy arrays (x, y, y_cls)\n", 257 | " \"\"\"\n", 258 | "# global idx\n", 259 | " # Create a random index into the training-set.\n", 260 | " idx = np.random.randint(low=0, high=NUM_TRAIN, size=batch_size)\n", 261 | "# idx = iterations\n", 262 | " # Use the index to lookup random training-data.\n", 263 | " x_batch = train_data[idx]\n", 264 | " y_batch = train_label_onehot[idx]\n", 265 | " y_batch_cls = train_label[idx]\n", 266 | "# idx = idx + batch_size\n", 267 | "# print('IDX',idx)\n", 268 | " return x_batch, y_batch, y_batch_cls\n" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 7, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "shape (55000, 10)\n", 281 | "[[0. 0. 0. ... 0. 0. 0.]\n", 282 | " [1. 0. 0. ... 0. 0. 0.]\n", 283 | " [0. 0. 0. ... 0. 0. 0.]\n", 284 | " ...\n", 285 | " [1. 0. 0. ... 0. 0. 0.]\n", 286 | " [0. 0. 0. ... 0. 0. 0.]\n", 287 | " [1. 0. 0. ... 0. 0. 0.]]\n" 288 | ] 289 | } 290 | ], 291 | "source": [ 292 | "train_label_onehot = one_hot_encoded(train_label.T,10).squeeze()\n", 293 | "print('shape',np.shape(train_label_onehot))\n", 294 | "print(train_label_onehot)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "## 神经网络的配置\n", 302 | "\n", 303 | "方便起见,在这里定义神经网络的配置,你可以很容易找到或改变这些数值,然后重新运行Notebook。" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 8, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "# Convolutional Layer 1.\n", 313 | "filter_size1 = 5 # Convolution filters are 5 x 5 pixels.\n", 314 | "num_filters1 = 16 # There are 16 of these filters.\n", 315 | "\n", 316 | "# Convolutional Layer 2.\n", 317 | "filter_size2 = 5 # Convolution filters are 5 x 5 pixels.\n", 318 | "num_filters2 = 36 # There are 36 of these filters.\n", 319 | "\n", 320 | "# Fully-connected layer.\n", 321 | "fc_size = 128 # Number of neurons in fully-connected layer." 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "### 用来绘制图片的帮助函数" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "这个函数用来在3x3的栅格中画9张图像,然后在每张图像下面写出真实类别和预测类别。" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 9, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "img_shape = IMG_SHAPE\n", 345 | "\n", 346 | "def plot_images(images, cls_true, cls_pred=None):\n", 347 | " assert len(images) == len(cls_true) == 9\n", 348 | " \n", 349 | " # Create figure with 3x3 sub-plots.\n", 350 | " fig, axes = plt.subplots(3, 3)\n", 351 | " fig.subplots_adjust(hspace=0.3, wspace=0.3)\n", 352 | "\n", 353 | " for i, ax in enumerate(axes.flat):\n", 354 | " # Plot image.\n", 355 | " ax.imshow(images[i].reshape(img_shape), cmap='binary')\n", 356 | "\n", 357 | " # Show true and predicted classes.\n", 358 | " if cls_pred is None:\n", 359 | " xlabel = \"True: {0}\".format(cls_true[i])\n", 360 | " else:\n", 361 | " xlabel = \"True: {0}, Pred: {1}\".format(cls_true[i], cls_pred[i])\n", 362 | "\n", 363 | " # Show the classes as the label on the x-axis.\n", 364 | " ax.set_xlabel(xlabel)\n", 365 | " \n", 366 | " # Remove ticks from the plot.\n", 367 | " ax.set_xticks([])\n", 368 | " ax.set_yticks([])\n", 369 | " \n", 370 | " # Ensure the plot is shown correctly with multiple plots\n", 371 | " # in a single Notebook cell.\n", 372 | " plt.show()" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": {}, 378 | "source": [ 379 | "### 绘制几张图像来看看数据是否正确" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 10, 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUMAAAD5CAYAAAC9FVegAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHihJREFUeJzt3XmUFNXZx/HvA0LYVQQFFWdOwAVCFBWDu0aBKCogccG4EGM0osEtAaNx1xglKBzRE7YD4QQNigKCUVFAEV8EJIIi4wYiCsRlhLggIsJ9/5i5XdUzPXtXVU/7+5zjmequ6qpnvPSdp27dxZxziIj80DVIOgARkVygylBEBFWGIiKAKkMREUCVoYgIoMpQRARQZSgiAqgyFBEBVBmKiACwS00ObtOmjSssLIwolNzzwQcfUFxcbEnHESeVcf5TGWdWo8qwsLCQZcuW1T6qeqZ79+5JhxA7lXH+UxlnpttkERFUGYqIAKoMRUQAVYYiIoAqQxERoIZPk0Vqa8SIEQBs3boVgDfeeAOAxx9/vNyxgwcPBuCoo44C4MILL4wjRPmBU2YoIoIyQ4nYueeeC8C0adMy7jcr3xd2zJgxAMydOxeAE044AYD99tsvihAlQe+++y4ABx54IAAPPPAAAEOGDIk9FmWGIiIoM5QI+GwQKs4IDzroIABOOeUUAN5///3UvlmzZgGwevVqAKZMmQLAjTfemP1gJVHLly8HoEGDkrxsn332SSwWZYYiIigzlCzy411nzJhRbl/Xrl2BIOtr06YNAC1atADgu+++Sx3bo0cPAF5//XUAPv/884gilqStWLECCP4dDBgwILFYlBmKiBBDZuj7kY0fPx6AvffeO7WvSZMmAJx//vkAtGvXDoBOnTpFHZZE4L///S8AzrnUez4jnDNnDgDt27fP+FnfDxHgrbfeStt3+umnZzVOSd7KlSsBGD16NAAXXXRRkuEAygxFRIAYMsOhQ4cCJRMsVsT3K2vVqhUAXbp0ycq1O3ToAMCwYcOAH+bcdXE644wzgOApMEDLli0BaN26daWfffTRR1Pb4fZDyU/vvPMOAFu2bAHSeyAkRZmhiAiqDEVEgBhukydMmAAE3STCt8BFRUVA0PHyxRdfBGDx4sVAMPzqww8/rPD8jRo1AoKuGr4RP3wef7us2+R4FBQUVPvYv/3tb0AwLCvMd7HxPyV/DB8+HChZggBy47upzFBEhBgyw5NPPjntZ5gfiuVt3rwZCDJF/9fi1VdfrfD8P/rRj4BgoLcf5gWwadMmADp27Fir2CU6Tz31FAC33HILANu2bUvt22uvvQC45557AGjWrFnM0UkUwg9R/Xfaf2+bN2+eREhplBmKiJBjw/F23313AE466aS09zNllWU98cQTQJBdAhx88MEADBw4MFshSpb4oXvhjNDz3Sz81F2SHxYsWFDuvbZt2yYQSWbKDEVEyLHMsDY+/fRTAK644gogfSiYb4+qqsOvxKd///5AMDzPGzRoUGr7rrvuijUmiYdf6iHMD4jIBcoMRUTIg8zwoYceAoIMcbfddkvt80+qJHm+/+eiRYuAoK3QtxnddNNNqWP9dE6SH1555RUAJk2alHrv0EMPBaBXr16JxJSJMkMREepxZvjyyy8DQV8078knn0xt++mjJHl+0s7i4uK09/30beoLmr/mzZsHpPf08H2M/TR+uUCZoYgIqgxFRIB6fJv89NNPA8Hcdz179gTgqKOOSiwmKc+veeKHWHonnngiAHfccUfcIUnM/CQtYWeffXYCkVROmaGICPUwM9y6dSsAzz77LBBM1HD77bcDwZRekpzwanZ33303UH726m7dugHqRpPPPv74YwAWLlwIpE+icuaZZyYSU2WUGYqIUA8zQz8ZqG+DOvXUUwE4+uijE4tJ0t13332p7aVLl6bt88Px1FaY//7xj38A8MknnwDBdzVXKTMUEaGeZIZ+IlCAO++8E4Bdd90VgJtvvjmRmKRi999/f4X7/PBJtRXmv3Xr1qW99lP05SplhiIi5Hhm6J9KXnXVVan3vv/+ewD69OkDqF9hfePLtDpP/X3274/dvn07AF988UW5Y/1Qr5EjR2Y8V8OGDVPb9957L6DlBKI2e/bstNenn356QpFUjzJDERFUGYqIADl6m7xjxw4gmNli7dq1qX2dOnUCggcpUr/4dWmq45xzzgGgffv2QNBFY+rUqXWKwa++F55DUbLHd7L25VVfKDMUESFHM8M1a9YAwQpqYb7bhua/y13+4RbAzJkza32exx57rMpj/MOVBg3S/6737dsXCNbeDjv22GNrHZNUbcaMGUDwsNPPap3rqx0qMxQRIccyQ99Js3fv3mnvjxgxIrWd64/nBaZPn57aHj58OFB+ogavqKgIqLwd8JJLLgGgoKCg3L5f/vKXAHTu3Ll2wUrWfPPNNwA888wzae/76brC3ZtykTJDERFyLDMcO3YsUH4YT7itwcxijUnqprrr4j7yyCMRRyJR8+23foXKfv36AXD11VcnFlNNKDMUESFHMkPfL+nBBx9MOBIRqS2fGfp1kusbZYYiIuRIZujXQP7qq6/S3vejTTTdk4hETZmhiAiqDEVEgBy5TS7Lr5w2b948AFq3bp1kOCLyA6DMUESEHMkMb7jhhrSfIiJxU2YoIgKYc676B5t9Bqyr8sD8UeCca5t0EHFSGec/lXFmNaoMRUTylW6TRURQZSgiAkT8NNnM9gDmlb5sB+wAPit9/TPnXOYZP+t2zS5AeD6ojsANzjnNAhGBhMq4AJgM7Ak44O8q3+gkUcal150M9AE2OOe6RXGNtOvF1WZoZrcBXzvnRpR530rj2BnBNRsBG4DDnHPrs31+SRdXGZvZ3sCezrkVZtYKWA6c6px7Nxvnl4rF+T02sxOArcC4OCrDRG6TzayTmRWZ2cPAKqCDmf0vtH+gmU0o3d7LzKab2TIzW2pmR9bgUr2At1QRxi/KMnbObXTOrSjd/hJ4G9gnut9GMon6e+ycWwBsiuwXKCPJNsODgJHOuS6UZG8VeQAY7pzrDpwD+P+5PcxsTBXXGAj8KxvBSq1EXsZm9mOgK/BqdkKWGorjexyLJEegrHHOlV8LtLyewIGh6f53N7OmzrklwJKKPmRmTYDTgOvqHKnUVtRl3Ap4AhjinPu6ztFKbURaxnFKsjLcEtreCYQXN2kS2jZq10h7GrDEOVdcy/ik7iIrYzNrDEwHJjnnZtUpSqmLqL/HscmJrjWlja6bzWx/M2sAnBnaPRe40r8ws+o2pJ6HbpFzRjbLuLSx/h/ACufcAxGEK7UQ0fc4NjlRGZa6HpgDLALCDzyuBI4xszfMrAi4FCpvazCzlsDPgZnRhiw1lK0yPoGSP3a9zGxF6X+/iDh2qZ5sfo+nAQuBLma23sx+HWXgGo4nIkJuZYYiIolRZSgigipDERFAlaGICKDKUEQEqGGn6zZt2rjCwsKIQsk9H3zwAcXFxVb1kflDZZz/VMaZ1agyLCwsZNmy6oy8yQ/du3dPOoTYqYzzn8o4M90mi4igylBEBFBlKCICqDIUEQFUGYqIAKoMRUSAZCd3rdCWLSXzRQ4dOhSAMWOCGX78Y/Jp06YBUFBQEHN0IpKPlBmKiJCjmeHGjRsBGD9+PAANGzZM7fOdRWfPng3A73//+5ijk9p47bXXABgwYABQMiqgtp577rnUdufOnQHo0KFD7YOTxPjvcd++fQEYPXo0AIMHD04dE/7+R0mZoYgIOZYZfvbZZwAMGjQo4Ugk2+bMmQPAtm3b6nyuWbOC9Z8mTpwIwNSpU+t8XonP559/DqRngABDhgwB4JJLLkm917Rp01hiUmYoIkKOZIYPPFCywNnMmSXrN736atXrgS9cuBAAv4bLIYccAsDxxx8fRYhSS99//z0ATz/9dNbOGR54f//99wNBD4TmzZtn7ToSnZdeegmADRvS150/77zzAGjSpEm5z0RNmaGICDmSGV5zzTVAzZ4aTZ8+Pe3nfvvtB8Bjjz2WOubwww/PVohSSy+88AIAixYtAuD666+v8zk3bdqU2l61ahUA33zzDaDMMJeF24vvuuuujMdceOGFAJQsjR0vZYYiIqgyFBEBEr5N7tOnDxA8BNmxY0eVn2nTpg0Q3A6tW7cOgLVr1wJwxBFHpI7duXNn9oKValu5cmVqe+DAgQB06tQJgBtvvLHO5w93rZH644033kht+0743i67lFRFp556aqwxhSkzFBEhgcxwwYIFqe23334bCBpLK3qAcvnll6e2e/fuDcCuu+4KwPz58wH4y1/+Uu5zf//734HyHTslWuGy8A82pkyZAkCLFi1qfV7/4CT8byiJhnapHf+wM5NevXrFGElmygxFRIgxM/QD830bEkBxcXHGY303mbPOOguAW2+9NbWvWbNmacf6KbzGjh1b7pzDhg0D4NtvvwWCSR0aNWpUu19CKvX4448D6R2sfVthuC23tnx3jHA2eOKJJwKw22671fn8Eq1wRu81btwYgLvvvjvucMpRZigiQoyZ4fbt24GKs0EIhtI9+uijQPDkuDI+M/RPKa+77rrUPj9Ey2eIfpqgjh071ih2qR4/4a7//w7Zaa/1dxWPPPIIEDx5BLjpppsAZfu5zHe4f+WVV8rt83d63bp1izWmTJQZioiQI8PxfHvSpEmTgOplhGX5rO/hhx9Ovbd06dIsRCdV+eKLLwBYvHhxuX1XXHFFnc8/btw4IJjirUuXLql9J510Up3PL9GqbOKVXOrpocxQRIQEMsNMo0yWLFlS5/P6USzhUSdlR7b4p9K+z5tkhx+Av379eiCYhilb1qxZk/a6a9euWT2/RCtTZuif/mfjziFblBmKiKDKUEQEiPE22a99HNVKV36VreXLl6feKzvM7/bbb4/k2j90LVu2BILuEeGJGvwQutatW9f4vJ9++ikQdNnxjjnmmFrFKfF6+eWXgaBLVJgfTrvvvvvGGlNllBmKiBBjZvjUU09l9Xy+m0VRURFQ+XAe31VHHXOj4Vcv80Pv/LA8gNNOOw1I7wyfyZtvvpna9g9M/PRsZSdjaNBAf8PrA78Cnn+QGZYLEzOUpX9VIiLkSKfr2vDTRD300EMVHlNYWAjA5MmTgWACCInGbbfdBqRnAv6OIDxBRyZt27ZNbftMsKKhmxdffHFdwpSYlG3rDU+mcdlll8UdTpWUGYqIUA8zQ79UgJ8YtjJ+2NZxxx0XaUxSonPnzkD6CoX+6X7ZjtNl+enawgYNGgSU7yTv2yglN/nO92WfIoefHGdjSrdsU2YoIkKMmWFliz4988wzaa8vvfRSADZu3Fjheaoz3Xu2n2BLzR166KFpP2vixz/+ccb3w/0Yf/rTn9YuMImMn7Kr7FPkfv36JRFOtSkzFBFBlaGICBDjbbKft8zPOh3mO+aWHaqXaeiev82uzkp6Ur/526yyt1u6Nc5tvrO15wc9XHPNNUmEU23KDEVEiDEzHDBgAADDhw9PvVfZeihV8X9tfHeO8ePHA9C+fftan1Nyi39IprWR65c5c+akve7QoQMQTM6Qq5QZiogQY2boV7HzK98BzJw5E4BRo0bV+Hx//vOfgWAtZMk/fr1rT52tc5tfAXP16tVp7zdp0gTI/YlSlBmKiJDAcDy/NnJ4u3fv3kCwCpqfqPWMM84A4He/+13qM/7JYniFNMlPfrVEP8D/lltuSTIcqYKfWs0PtVu1ahUA+++/f2Ix1YQyQxERcmSihlNOOSXtpwgEGca1114LaI3kXOf7/vrp9XwvgMMOOyyxmGpCmaGICDmSGYpk4tuOpX7Ze++9AZg4cWLCkdSMMkMREVQZiogAqgxFRABVhiIigCpDERFAlaGICACWabX7Cg82+wxYF104OafAOde26sPyh8o4/6mMM6tRZSgikq90mywigipDERFAlaGICBDx2GQz2wOYV/qyHbAD+Kz09c+cc99FdN0+wEigITDWOfe3KK4jyZVx6bV3AV4D3nfO9Y/qOj90CX6PJwN9gA3OuW5RXCPtenE9QDGz24CvnXMjyrxvpXHszNJ1GgHvAD8HPgaWAb90zr2bjfNLxeIq49B5hwHdgGaqDOMRZxmb2QnAVmBcHJVhIrfJZtbJzIrM7GFgFdDBzP4X2j/QzCaUbu9lZtPNbJmZLTWzI6s4/ZHAW865dc65bcBjQL+ofhfJLOIyxswKgF7ApKh+B6lc1GXsnFsAbIrsFygjyTbDg4CRzrkuwIZKjnsAGO6c6w6cA/j/uT3MbEyG4/cBPgq9Xl/6nsQvqjIGGAUMBdQ3LFlRlnGskpzPcI1zblk1jusJHBhaO3d3M2vqnFsCLIksOsmGSMrYzPoDHznnVphZz+yFK7WQN9/jJCvDLaHtnUB4pfAmoW2jZo20G4AOodf7UvlfLIlOVGV8NDDAzPqWnqeVmU12zg2qU7RSG1GVcexyomtNaaPrZjPb38waAGeGds8FrvQvzKyqhtTFQBczKzCzH1GSks/KdsxSM9ksY+fcMOfcvs65QuAC4DlVhMnL8vc4djlRGZa6HpgDLKKknc+7EjjGzN4wsyLgUqi4rcE5tx24CngeKAKmOOfeiTp4qZaslLHktKyVsZlNAxZSktysN7NfRxm4xiaLiJBbmaGISGJUGYqIoMpQRARQZSgiAtSwn2GbNm1cYWFhRKHkng8++IDi4mKr+sj8oTLOfyrjzGpUGRYWFrJsWXU6m+eH7t27Jx1C7FTG+U9lnJluk0VEUGUoIgKoMhQRAVQZiogAqgxFRABVhiIigCpDEREg2cldRUQA2Lx5MwAffvhhhccUFBQAMHLkSAC6du0KwAEHHADAIYccUqcYlBmKiJBwZvjpp58CcM455wBw9NFHA3DZZZcBJT3ls+GLL74A4KWXXgLglFNOAaBRo0ZZOb+I1MxTTz0FwOzZswF48cUXAXjvvfcq/MyBBx4IlAyvA9i2bVva/p0767ZKqTJDERESyAx92wDAT37yEyDI3Pbaay8g+xnhYYcdBkBxcTFAalzm/vvvn5XrSPV9+eWXAPzpT38CYNWqVQDMnTs3dYwy9vywZs0aAB566CEAxo0bl9q3detWAGoy0/4770S7eocyQxERYswMfVbm2wcBPv/8cwCuvLJk0azRo0dn9Zp33XUXAGvXrgWCv0zKCOM3ZcoUAG666Sag/FNDnzEC7LHHHvEFJpFZv75kPahRo0bV6TwHHXQQEDw9jooyQxERYswMX3vtNSB4ahR2yy23ZO06b775Zmp7xIgRAJx5Zsnyreeee27WriPV47ODa6+9FgjuEMzS59ocMmRIavvBBx8EoHXr1nGEKLXgyxGCzO/YY48Fgt4ajRs3BmDXXXcFoEWLFqnPfP311wD84he/AIKsr0ePHgAceuihqWObNm0KQPPmzbP8W6RTZigigipDEREghttk37H6iSeeKLdv4sSJALRt27bO1/G3x7169Sq3b8CAAQC0bNmyzteRmvFNFf5hWUWmTp2a2n7mmWeA4GGLv4X2t12SnC1btgDp37PXX38dgJkzZ6Yde9RRRwGwfPlyIL3LnH+Atu+++wLQoEHyeVnyEYiI5IDIM8M//OEPQNC1wneABjj77LOzdp2XX34ZgI8//jj13sUXXwzABRdckLXrSNXWrVuX2p40aVLaPj+Y3newf/7558t93neW91nl+eefD0C7du2yH6xUy3fffQfAr371KyDIBgFuvPFGAHr27Jnxs5kGUey3335ZjrDulBmKiBBDZui7UPif++yzT2pfXdqA/HCeu+++GwiG/IS7bPg2SYnXihUrUtu+M/Xxxx8PwIIFCwD49ttvAXjkkUcA+Otf/5r6zOrVq4Egy+/Xrx8QtCWqy018fBcY/z3zEyuE2/mHDh0KQLNmzWKOLruUGYqIkMBEDX7qHoDevXsDsNtuuwEwePDgKj/vO237n4sXL07bn812SKmd8NRKPlP3na69Jk2aAPCb3/wGgMcffzy1zw/w94P4fcahp8nx80+I77nnHiCYYHXhwoWpY3yn6vpOmaGICDFkhldffTUA8+fPB2Djxo2pfb79yGcATz75ZJXn88eWHc7VsWNHIGjbkOT861//Kvfev//9bwD69++f8TN+WrVMjjzySCB9OJfEY9GiRWmv/TA53z8wnygzFBEhhszw8MMPB2DlypVA+pPGZ599FoDhw4cDsOeeewIwaNCgCs934YUXAnDwwQenve+XDPAZoiTnvPPOS237bP/VV18F4O233waCfw8zZswA0if99W3I/j0/9Zov+y5dukQWu6QLt+VC8ET/9ttvT73Xt29fIH1yhfpImaGICKoMRUQAsJqsQdC9e3dXWUN3HN5//30guB3u1q0bAM899xyQnUkfvO7du7Ns2TKr+sj8kY0y3rRpU2rbl5MfYlfRA7DwwH/fgf70008H4N133wWCVRPHjBlTp/jCVMaVKztoIpOGDRsCcPnllwPBnIQfffQRAJ06dQKCNY/C/Bo4flKHKB7MVLeMlRmKiJDwusm1cccddwDBXyr/8CWbGaHUTXi43LRp0wA466yzgPIZ4lVXXQXAvffem/qM75Dtp17zQ/XmzJkDBJ2yQQ/MovbHP/4RgPvuu6/CY3bs2AEEGb3/WRP+4emJJ54IpE/pFhdlhiIi1JPM0GcXAJMnTwagVatWgFZSy3V+WiffRcNPzOC7z/hM32eDYTfffDMAb731FhB00/GfgeDfg0TDD8Pzq1r66dS2b9+eOsavc+MzxNrwk0D773p4JTw/yW/UlBmKiFBPMkPf0TPstNNOA9Ini5Xc5TPEiiYAzcSviuZXNfSZ4QsvvJA6xj+51rRe0fBPio844gggeLIfNm/ePCDIFm+77TYAli5dWuPr+bbk//znPzX+bF0pMxQRoR5mhn7tVP+US/Kfb6+aNWsWkP6k0a+xnM21t6VmTj755LTXfsitzwwbNWoEBMtwAFx66aUAjBw5EgjakpOkzFBEBFWGIiJAjt8m+2FX4RXv/KpqenDyw+HX1B02bBiQvj6vb6wfOHAgAAcccEC8wUk5fgZ7v2qef7DiZx8CeO+994BgxvqywmslxUWZoYgI9SQzDA8S79OnT9oxX331FRDMfZeL67FKdvhJOe68887Ue/5B2g033AAE63P7bjkSv86dOwNBl6hHH3203DHh7lEAu+xSUhX5LnPh4ZlxUWYoIkKOZ4aZ+L8gPgPwj+b98B0Nz8p/F110UWp77NixAEyfPh0I2qLKzoQu8fFZ+ahRo4Dg7i3ckfqTTz4BoLCwEAjK1LcBJ0GZoYgI9TAzHD9+PAATJkwA4Le//S0QDOqX/Beerm3u3LlAsJ6vn1ggFzrx/tD5nh9+rfR//vOfqX2vvPIKEGSCfgqvJCkzFBEhxzPD0aNHA3Drrbem3jv++OMBGDx4MAC77747AI0bN445OskFvveAXzbAD9krKioCtJJeLvGrG5bdzhXKDEVEyPHM8LjjjgNg/vz5CUciuc5PHnvIIYcAsHr1akCZoVSfMkMREVQZiogAOX6bLFJdfk2ctWvXJhyJ1FfKDEVEUGUoIgKoMhQRAcD8alTVOtjsM2BddOHknALnXNuqD8sfKuP8pzLOrEaVoYhIvtJtsogIqgxFRICI+xma2R7AvNKX7YAdwGelr3/mnPsuwmvvArwGvO+c6x/VdX7okipjM7sOuKT05Rjn3OgoriOJlvF6YHPp9bY553pEcZ3U9eJqMzSz24CvnXMjyrxvpXHszPL1hgHdgGaqDOMRVxmbWTdgMnAk8D3wHPAb55x6XEcszu9xaWXY1Tn3v2ydszKJ3CabWSczKzKzh4FVQAcz+19o/0Azm1C6vZeZTTezZWa21MyOrMb5C4BewKSofgepXMRl3BlY7Jzb6pzbDrwEnBnV7yKZRf09jluSbYYHASOdc12ADZUc9wAw3DnXHTgH8P9ze5jZmAo+MwoYCuhRebKiKuOVwAlm1trMmgOnAh2yG7pUU5TfYwfMN7P/mNklFRyTNUmOTV7jnFtWjeN6AgeGlgvd3cyaOueWAEvKHmxm/YGPnHMrzKxn9sKVWoikjJ1zb5rZ/cBc4GtgOSXtShK/SMq41JHOuQ1m1g543szecs4tykLMGSVZGW4Jbe8ELPS6SWjbqFkj7dHAADPrW3qeVmY22Tk3qE7RSm1EVcY458YB4wDMbDiwug5xSu1FWcYbSn9+bGZPAj8DIqsMc6JrTWmj62Yz29/MGpDe/jMXuNK/KG08r+xcw5xz+zrnCoELgOdUESYvm2VcesyepT8Lgb7A1GzGKzWXzTI2sxZm1qJ0uzklzwDezH7UgZyoDEtdD8yhpOZfH3r/SuAYM3vDzIqAS6HKtgbJTdks45mlx84ELnfOfRlh3FJ92Srj9sD/mdnrwFJghnNubpSBazieiAi5lRmKiCRGlaGICKoMRUQAVYYiIoAqQxERQJWhiAigylBEBFBlKCICwP8D3P5bzM0W5d8AAAAASUVORK5CYII=\n", 390 | "text/plain": [ 391 | "
" 392 | ] 393 | }, 394 | "metadata": {}, 395 | "output_type": "display_data" 396 | } 397 | ], 398 | "source": [ 399 | "# Get the first images from the test-set.\n", 400 | "images = test_data[0:9]\n", 401 | "\n", 402 | "# Get the true classes for those images.\n", 403 | "cls_true = test_label[0:9]\n", 404 | "\n", 405 | "# Plot the images and labels using our helper-function above.\n", 406 | "plot_images(images=images, cls_true=cls_true)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "metadata": {}, 412 | "source": [ 413 | "## TensorFlow图\n", 414 | "\n", 415 | "TensorFlow的全部目的就是使用一个称之为计算图(computational graph)的东西,它会比直接在Python中进行相同计算量要高效得多。TensorFlow比Numpy更高效,因为TensorFlow了解整个需要运行的计算图,然而Numpy只知道某个时间点上唯一的数学运算。\n", 416 | "\n", 417 | "TensorFlow也能够自动地计算需要优化的变量的梯度,使得模型有更好的表现。这是由于图是简单数学表达式的结合,因此整个图的梯度可以用链式法则推导出来。\n", 418 | "\n", 419 | "TensorFlow还能利用多核CPU和GPU,Google也为TensorFlow制造了称为TPUs(Tensor Processing Units)的特殊芯片,它比GPU更快。\n", 420 | "\n", 421 | "一个TensorFlow图由下面几个部分组成,后面会详细描述:\n", 422 | "\n", 423 | "* 占位符变量(Placeholder)用来改变图的输入。\n", 424 | "* 模型变量(Model)将会被优化,使得模型表现得更好。\n", 425 | "* 模型本质上就是一些数学函数,它根据Placeholder和模型的输入变量来计算一些输出。\n", 426 | "* 一个cost度量用来指导变量的优化。\n", 427 | "* 一个优化策略会更新模型的变量。\n", 428 | "\n", 429 | "另外,TensorFlow图也包含了一些调试状态,比如用TensorBoard打印log数据。" 430 | ] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "metadata": {}, 435 | "source": [ 436 | "### 创建新变量的帮助函数" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "函数用来根据给定大小创建TensorFlow变量,并将它们用随机值初始化。需注意的是在此时并未完成初始化工作,仅仅是在TensorFlow图里定义它们。" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 11, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "def new_weights(shape):\n", 453 | " return tf.Variable(tf.truncated_normal(shape, stddev=0.05))" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 12, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "def new_biases(length):\n", 463 | " return tf.Variable(tf.constant(0.05, shape=[length]))" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "### 创建卷积层的帮助函数" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "这个函数为TensorFlow在计算图里创建了新的卷积层。这里并没有执行什么计算,只是在TensorFlow图里添加了数学公式。\n", 478 | "\n", 479 | "假设输入的是四维的张量,各个维度如下:\n", 480 | "\n", 481 | "1. 图像数量\n", 482 | "2. 每张图像的Y轴\n", 483 | "3. 每张图像的X轴\n", 484 | "4. 每张图像的通道数\n", 485 | "\n", 486 | "输入通道可能是彩色通道,当输入是前面的卷积层生成的时候,它也可能是滤波通道。\n", 487 | "\n", 488 | "输出是另外一个4通道的张量,如下:\n", 489 | "\n", 490 | "1. 图像数量,与输入相同\n", 491 | "2. 每张图像的Y轴。如果用到了2x2的池化,是输入图像宽高的一半。\n", 492 | "3. 每张图像的X轴。同上。\n", 493 | "4. 卷积滤波生成的通道数。" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 20, 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "def new_conv_layer(input, # The previous layer.\n", 503 | " num_input_channels, # Num. channels in prev. layer.\n", 504 | " filter_size, # Width and height of each filter.\n", 505 | " num_filters, # Number of filters.\n", 506 | " use_pooling=True): # Use 2x2 max-pooling.\n", 507 | "\n", 508 | " # Shape of the filter-weights for the convolution.\n", 509 | " # This format is determined by the TensorFlow API.\n", 510 | " shape = [filter_size, filter_size, num_input_channels, num_filters]\n", 511 | "\n", 512 | " # Create new weights aka. filters with the given shape.\n", 513 | " weights = new_weights(shape=shape)\n", 514 | "\n", 515 | " # Create new biases, one for each filter.\n", 516 | " biases = new_biases(length=num_filters)\n", 517 | "\n", 518 | " # Create the TensorFlow operation for convolution.\n", 519 | " # Note the strides are set to 1 in all dimensions.\n", 520 | " # The first and last stride must always be 1,\n", 521 | " # because the first is for the image-number and\n", 522 | " # the last is for the input-channel.\n", 523 | " # But e.g. strides=[1, 2, 2, 1] would mean that the filter\n", 524 | " # is moved 2 pixels across the x- and y-axis of the image.\n", 525 | " # The padding is set to 'SAME' which means the input image\n", 526 | " # is padded with zeroes so the size of the output is the same.\n", 527 | " layer = tf.nn.conv2d(input=input,\n", 528 | " filter=weights,\n", 529 | " strides=[1, 1, 1, 1],\n", 530 | " padding='SAME')\n", 531 | "\n", 532 | " # Add the biases to the results of the convolution.\n", 533 | " # A bias-value is added to each filter-channel.\n", 534 | " layer += biases\n", 535 | "\n", 536 | " # Use pooling to down-sample the image resolution?\n", 537 | " if use_pooling:\n", 538 | " # This is 2x2 max-pooling, which means that we\n", 539 | " # consider 2x2 windows and select the largest value\n", 540 | " # in each window. Then we move 2 pixels to the next window.\n", 541 | " layer = tf.nn.max_pool(value=layer,\n", 542 | " ksize=[1, 2, 2, 1],\n", 543 | " strides=[1, 2, 2, 1],\n", 544 | " padding='SAME')\n", 545 | "\n", 546 | " # Rectified Linear Unit (ReLU).\n", 547 | " # It calculates max(x, 0) for each input pixel x.\n", 548 | " # This adds some non-linearity to the formula and allows us\n", 549 | " # to learn more complicated functions.\n", 550 | " layer = tf.nn.relu(layer)\n", 551 | "\n", 552 | " # Note that ReLU is normally executed before the pooling,\n", 553 | " # but since relu(max_pool(x)) == max_pool(relu(x)) we can\n", 554 | " # save 75% of the relu-operations by max-pooling first.\n", 555 | "\n", 556 | " # We return both the resulting layer and the filter-weights\n", 557 | " # because we will plot the weights later.\n", 558 | " return layer, weights" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "### 转换一个层的帮助函数\n", 566 | "\n", 567 | "卷积层生成了4维的张量。我们会在卷积层之后添加一个全连接层,因此我们需要将这个4维的张量转换成可被全连接层使用的2维张量。" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": 13, 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "def flatten_layer(layer):\n", 577 | " # Get the shape of the input layer.\n", 578 | " layer_shape = layer.get_shape()\n", 579 | "\n", 580 | " # The shape of the input layer is assumed to be:\n", 581 | " # layer_shape == [num_images, img_height, img_width, num_channels]\n", 582 | "\n", 583 | " # The number of features is: img_height * img_width * num_channels\n", 584 | " # We can use a function from TensorFlow to calculate this.\n", 585 | " num_features = layer_shape[1:4].num_elements()\n", 586 | " \n", 587 | " # Reshape the layer to [num_images, num_features].\n", 588 | " # Note that we just set the size of the second dimension\n", 589 | " # to num_features and the size of the first dimension to -1\n", 590 | " # which means the size in that dimension is calculated\n", 591 | " # so the total size of the tensor is unchanged from the reshaping.\n", 592 | " layer_flat = tf.reshape(layer, [-1, num_features])\n", 593 | "\n", 594 | " # The shape of the flattened layer is now:\n", 595 | " # [num_images, img_height * img_width * num_channels]\n", 596 | "\n", 597 | " # Return both the flattened layer and the number of features.\n", 598 | " return layer_flat, num_features" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "### 创建一个全连接层的帮助函数" 606 | ] 607 | }, 608 | { 609 | "cell_type": "markdown", 610 | "metadata": {}, 611 | "source": [ 612 | "这个函数为TensorFlow在计算图中创建了一个全连接层。这里也不进行任何计算,只是往TensorFlow图中添加数学公式。\n", 613 | "\n", 614 | "输入是大小为`[num_images, num_inputs]`的二维张量。输出是大小为`[num_images, num_outputs]`的2维张量。" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 14, 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [ 623 | "def new_fc_layer(input, # The previous layer.\n", 624 | " num_inputs, # Num. inputs from prev. layer.\n", 625 | " num_outputs, # Num. outputs.\n", 626 | " use_relu=True): # Use Rectified Linear Unit (ReLU)?\n", 627 | "\n", 628 | " # Create new weights and biases.\n", 629 | " weights = new_weights(shape=[num_inputs, num_outputs])\n", 630 | " biases = new_biases(length=num_outputs)\n", 631 | "\n", 632 | " # Calculate the layer as the matrix multiplication of\n", 633 | " # the input and weights, and then add the bias-values.\n", 634 | " layer = tf.matmul(input, weights) + biases\n", 635 | "\n", 636 | " # Use ReLU?\n", 637 | " if use_relu:\n", 638 | " layer = tf.nn.relu(layer)\n", 639 | "\n", 640 | " return layer" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": {}, 646 | "source": [ 647 | "### 占位符 (Placeholder)变量" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": {}, 653 | "source": [ 654 | "Placeholder是作为图的输入,每次我们运行图的时候都可能会改变它们。将这个过程称为feeding placeholder变量,后面将会描述它。\n", 655 | "\n", 656 | "首先我们为输入图像定义placeholder变量。这让我们可以改变输入到TensorFlow图中的图像。这也是一个张量(tensor),代表一个多维向量或矩阵。数据类型设置为float32,形状设为`[None, img_size_flat]`,`None`代表tensor可能保存着任意数量的图像,每张图象是一个长度为`img_size_flat`的向量。" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 15, 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "img_size_flat = IMG_FLAT\n", 666 | "\n", 667 | "x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')" 668 | ] 669 | }, 670 | { 671 | "cell_type": "markdown", 672 | "metadata": {}, 673 | "source": [ 674 | "卷积层希望`x`被编码为4维张量,因此我们需要将它的形状转换至`[num_images, img_height, img_width, num_channels]`。注意`img_height == img_width == img_size`,如果第一维的大小设为-1, `num_images`的大小也会被自动推导出来。转换运算如下:" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 16, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "img_size = IMG_HEIGHT\n", 684 | "num_channels = NUM_CHANNEL\n", 685 | "\n", 686 | "x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])" 687 | ] 688 | }, 689 | { 690 | "cell_type": "markdown", 691 | "metadata": {}, 692 | "source": [ 693 | "接下来我们为输入变量`x`中的图像所对应的真实标签定义placeholder变量。变量的形状是`[None, num_classes]`,这代表着它保存了任意数量的标签,每个标签是长度为`num_classes`的向量,本例中长度为10。" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 17, 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "num_classes = NUM_CLASS\n", 703 | "\n", 704 | "y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')" 705 | ] 706 | }, 707 | { 708 | "cell_type": "markdown", 709 | "metadata": {}, 710 | "source": [ 711 | "我们也可以为class-number提供一个placeholder,但这里用argmax来计算它。这里只是TensorFlow中的一些操作,没有执行什么运算。" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 18, 717 | "metadata": {}, 718 | "outputs": [], 719 | "source": [ 720 | "y_true_cls = tf.argmax(y_true, axis=1)" 721 | ] 722 | }, 723 | { 724 | "cell_type": "markdown", 725 | "metadata": {}, 726 | "source": [ 727 | "### 卷积层 1\n", 728 | "\n", 729 | "创建第一个卷积层。将`x_image`当作输入,创建`num_filters1`个不同的滤波器,每个滤波器的宽高都与 `filter_size1`相等。最终我们会用2x2的max-pooling将图像降采样,使它的尺寸减半。" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": 21, 735 | "metadata": {}, 736 | "outputs": [], 737 | "source": [ 738 | "layer_conv1, weights_conv1 = \\\n", 739 | " new_conv_layer(input=x_image,\n", 740 | " num_input_channels=num_channels,\n", 741 | " filter_size=filter_size1,\n", 742 | " num_filters=num_filters1,\n", 743 | " use_pooling=True)" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "metadata": {}, 749 | "source": [ 750 | "检查卷积层输出张量的大小。它是(?,14, 14, 16),这代表着有任意数量的图像(?代表数量),每张图像有14个像素的宽和高,有16个不同的通道,每个滤波器各有一个通道。" 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": 22, 756 | "metadata": {}, 757 | "outputs": [ 758 | { 759 | "data": { 760 | "text/plain": [ 761 | "" 762 | ] 763 | }, 764 | "execution_count": 22, 765 | "metadata": {}, 766 | "output_type": "execute_result" 767 | } 768 | ], 769 | "source": [ 770 | "layer_conv1" 771 | ] 772 | }, 773 | { 774 | "cell_type": "markdown", 775 | "metadata": {}, 776 | "source": [ 777 | "### 卷积层 2\n", 778 | "\n", 779 | "创建第二个卷积层,它将第一个卷积层的输出作为输入。输入通道的数量对应着第一个卷积层的滤波数。" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 23, 785 | "metadata": {}, 786 | "outputs": [], 787 | "source": [ 788 | "layer_conv2, weights_conv2 = \\\n", 789 | " new_conv_layer(input=layer_conv1,\n", 790 | " num_input_channels=num_filters1,\n", 791 | " filter_size=filter_size2,\n", 792 | " num_filters=num_filters2,\n", 793 | " use_pooling=True)" 794 | ] 795 | }, 796 | { 797 | "cell_type": "markdown", 798 | "metadata": {}, 799 | "source": [ 800 | "核对一下这个卷积层输出张量的大小。它的大小是(?, 7, 7, 36),其中?也代表着任意数量的图像,每张图有7像素的宽高,每个滤波器有36个通道。" 801 | ] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "execution_count": 24, 806 | "metadata": {}, 807 | "outputs": [ 808 | { 809 | "data": { 810 | "text/plain": [ 811 | "" 812 | ] 813 | }, 814 | "execution_count": 24, 815 | "metadata": {}, 816 | "output_type": "execute_result" 817 | } 818 | ], 819 | "source": [ 820 | "layer_conv2" 821 | ] 822 | }, 823 | { 824 | "cell_type": "markdown", 825 | "metadata": {}, 826 | "source": [ 827 | "### 转换层\n", 828 | "\n", 829 | "这个卷积层输出一个4维张量。现在我们想将它作为一个全连接网络的输入,这就需要将它转换成2维张量。" 830 | ] 831 | }, 832 | { 833 | "cell_type": "code", 834 | "execution_count": 25, 835 | "metadata": {}, 836 | "outputs": [], 837 | "source": [ 838 | "layer_flat, num_features = flatten_layer(layer_conv2)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "markdown", 843 | "metadata": {}, 844 | "source": [ 845 | "这个张量的大小是(?, 1764),意味着共有一定数量的图像,每张图像被转换成长为1764的向量。其中1764 = 7 x 7 x 36。" 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": 26, 851 | "metadata": {}, 852 | "outputs": [ 853 | { 854 | "data": { 855 | "text/plain": [ 856 | "" 857 | ] 858 | }, 859 | "execution_count": 26, 860 | "metadata": {}, 861 | "output_type": "execute_result" 862 | } 863 | ], 864 | "source": [ 865 | "layer_flat" 866 | ] 867 | }, 868 | { 869 | "cell_type": "code", 870 | "execution_count": 27, 871 | "metadata": {}, 872 | "outputs": [ 873 | { 874 | "data": { 875 | "text/plain": [ 876 | "1764" 877 | ] 878 | }, 879 | "execution_count": 27, 880 | "metadata": {}, 881 | "output_type": "execute_result" 882 | } 883 | ], 884 | "source": [ 885 | "num_features" 886 | ] 887 | }, 888 | { 889 | "cell_type": "markdown", 890 | "metadata": {}, 891 | "source": [ 892 | "### 全连接层 1\n", 893 | "\n", 894 | "往网络中添加一个全连接层。输入是一个前面卷积得到的被转换过的层。全连接层中的神经元或节点数为`fc_size`。我们可以用ReLU来学习非线性关系。" 895 | ] 896 | }, 897 | { 898 | "cell_type": "code", 899 | "execution_count": 28, 900 | "metadata": {}, 901 | "outputs": [], 902 | "source": [ 903 | "layer_fc1 = new_fc_layer(input=layer_flat,\n", 904 | " num_inputs=num_features,\n", 905 | " num_outputs=fc_size,\n", 906 | " use_relu=True)" 907 | ] 908 | }, 909 | { 910 | "cell_type": "markdown", 911 | "metadata": {}, 912 | "source": [ 913 | "全连接层的输出是一个大小为(?,128)的张量,?代表着一定数量的图像,并且`fc_size` == 128。" 914 | ] 915 | }, 916 | { 917 | "cell_type": "code", 918 | "execution_count": 29, 919 | "metadata": {}, 920 | "outputs": [ 921 | { 922 | "data": { 923 | "text/plain": [ 924 | "" 925 | ] 926 | }, 927 | "execution_count": 29, 928 | "metadata": {}, 929 | "output_type": "execute_result" 930 | } 931 | ], 932 | "source": [ 933 | "layer_fc1" 934 | ] 935 | }, 936 | { 937 | "cell_type": "markdown", 938 | "metadata": {}, 939 | "source": [ 940 | "### 全连接层 2\n", 941 | "\n", 942 | "添加另外一个全连接层,它的输出是一个长度为10的向量,它确定了输入图是属于哪个类别。这层并没有用到ReLU。" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 30, 948 | "metadata": {}, 949 | "outputs": [], 950 | "source": [ 951 | "layer_fc2 = new_fc_layer(input=layer_fc1,\n", 952 | " num_inputs=fc_size,\n", 953 | " num_outputs=num_classes,\n", 954 | " use_relu=False)" 955 | ] 956 | }, 957 | { 958 | "cell_type": "code", 959 | "execution_count": 31, 960 | "metadata": {}, 961 | "outputs": [ 962 | { 963 | "data": { 964 | "text/plain": [ 965 | "" 966 | ] 967 | }, 968 | "execution_count": 31, 969 | "metadata": {}, 970 | "output_type": "execute_result" 971 | } 972 | ], 973 | "source": [ 974 | "layer_fc2" 975 | ] 976 | }, 977 | { 978 | "cell_type": "markdown", 979 | "metadata": {}, 980 | "source": [ 981 | "### 预测类别" 982 | ] 983 | }, 984 | { 985 | "cell_type": "markdown", 986 | "metadata": {}, 987 | "source": [ 988 | "第二个全连接层估算了输入图有多大的可能属于10个类别中的其中一个。然而,这是很粗略的估计并且很难解释,因为数值可能很小或很大,因此我们会对它们做归一化,将每个元素限制在0到1之间,并且相加为1。这用一个称为softmax的函数来计算的,结果保存在`y_pred`中。" 989 | ] 990 | }, 991 | { 992 | "cell_type": "code", 993 | "execution_count": 32, 994 | "metadata": {}, 995 | "outputs": [], 996 | "source": [ 997 | "y_pred = tf.nn.softmax(layer_fc2)" 998 | ] 999 | }, 1000 | { 1001 | "cell_type": "markdown", 1002 | "metadata": {}, 1003 | "source": [ 1004 | "类别数字是最大元素的索引。" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "code", 1009 | "execution_count": 33, 1010 | "metadata": {}, 1011 | "outputs": [], 1012 | "source": [ 1013 | "y_pred_cls = tf.argmax(y_pred, axis=1)" 1014 | ] 1015 | }, 1016 | { 1017 | "cell_type": "markdown", 1018 | "metadata": {}, 1019 | "source": [ 1020 | "### 优化损失函数" 1021 | ] 1022 | }, 1023 | { 1024 | "cell_type": "markdown", 1025 | "metadata": {}, 1026 | "source": [ 1027 | "为了使模型更好地对输入图像进行分类,我们必须改变`weights`和`biases`变量。首先我们需要对比模型`y_pred`的预测输出和期望输出的`y_true`,来了解目前模型的性能如何。\n", 1028 | "\n", 1029 | "交叉熵(cross-entropy)是在分类中使用的性能度量。交叉熵是一个常为正值的连续函数,如果模型的预测值精准地符合期望的输出,它就等于零。因此,优化的目的就是通过改变网络层的变量来最小化交叉熵。\n", 1030 | "\n", 1031 | "TensorFlow有一个内置的计算交叉熵的函数。这个函数内部计算了softmax,所以我们要用`layer_fc2`的输出而非直接用`y_pred`,因为`y_pred`上已经计算了softmax。" 1032 | ] 1033 | }, 1034 | { 1035 | "cell_type": "code", 1036 | "execution_count": 34, 1037 | "metadata": {}, 1038 | "outputs": [ 1039 | { 1040 | "name": "stdout", 1041 | "output_type": "stream", 1042 | "text": [ 1043 | "WARNING:tensorflow:From :2: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", 1044 | "Instructions for updating:\n", 1045 | "\n", 1046 | "Future major versions of TensorFlow will allow gradients to flow\n", 1047 | "into the labels input on backprop by default.\n", 1048 | "\n", 1049 | "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n", 1050 | "\n" 1051 | ] 1052 | } 1053 | ], 1054 | "source": [ 1055 | "cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=layer_fc2,\n", 1056 | " labels=y_true)" 1057 | ] 1058 | }, 1059 | { 1060 | "cell_type": "markdown", 1061 | "metadata": {}, 1062 | "source": [ 1063 | "我们为每个图像分类计算了交叉熵,所以有一个当前模型在每张图上表现的度量。但是为了用交叉熵来指导模型变量的优化,我们需要一个额外的标量值,因此简单地利用所有图像分类交叉熵的均值。" 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": 35, 1069 | "metadata": {}, 1070 | "outputs": [], 1071 | "source": [ 1072 | "cost = tf.reduce_mean(cross_entropy)" 1073 | ] 1074 | }, 1075 | { 1076 | "cell_type": "markdown", 1077 | "metadata": {}, 1078 | "source": [ 1079 | "### 优化方法" 1080 | ] 1081 | }, 1082 | { 1083 | "cell_type": "markdown", 1084 | "metadata": {}, 1085 | "source": [ 1086 | "既然我们有一个需要被最小化的损失度量,接着就可以建立优化一个优化器。这个例子中,我们使用的是梯度下降的变体`AdamOptimizer`。\n", 1087 | "\n", 1088 | "优化过程并不是在这里执行。实际上,还没计算任何东西,我们只是往TensorFlow图中添加了优化器,以便之后的操作。" 1089 | ] 1090 | }, 1091 | { 1092 | "cell_type": "code", 1093 | "execution_count": 36, 1094 | "metadata": {}, 1095 | "outputs": [], 1096 | "source": [ 1097 | "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "markdown", 1102 | "metadata": {}, 1103 | "source": [ 1104 | "### 性能度量" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "markdown", 1109 | "metadata": {}, 1110 | "source": [ 1111 | "我们需要另外一些性能度量,来向用户展示这个过程。\n", 1112 | "\n", 1113 | "这是一个布尔值向量,代表预测类型是否等于每张图片的真实类型。" 1114 | ] 1115 | }, 1116 | { 1117 | "cell_type": "code", 1118 | "execution_count": 37, 1119 | "metadata": {}, 1120 | "outputs": [], 1121 | "source": [ 1122 | "correct_prediction = tf.equal(y_pred_cls, y_true_cls)\n" 1123 | ] 1124 | }, 1125 | { 1126 | "cell_type": "markdown", 1127 | "metadata": {}, 1128 | "source": [ 1129 | "上面的计算先将布尔值向量类型转换成浮点型向量,这样子False就变成0,True变成1,然后计算这些值的平均数,以此来计算分类的准确度。" 1130 | ] 1131 | }, 1132 | { 1133 | "cell_type": "code", 1134 | "execution_count": 38, 1135 | "metadata": {}, 1136 | "outputs": [ 1137 | { 1138 | "data": { 1139 | "text/plain": [ 1140 | "" 1141 | ] 1142 | }, 1143 | "execution_count": 38, 1144 | "metadata": {}, 1145 | "output_type": "execute_result" 1146 | } 1147 | ], 1148 | "source": [ 1149 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", 1150 | "tf.summary.scalar('accuracy',accuracy)" 1151 | ] 1152 | }, 1153 | { 1154 | "cell_type": "markdown", 1155 | "metadata": {}, 1156 | "source": [ 1157 | "## 运行TensorFlow" 1158 | ] 1159 | }, 1160 | { 1161 | "cell_type": "markdown", 1162 | "metadata": {}, 1163 | "source": [ 1164 | "### 创建TensorFlow会话(session)\n", 1165 | "\n", 1166 | "一旦创建了TensorFlow图,我们需要创建一个TensorFlow会话,用来运行图。" 1167 | ] 1168 | }, 1169 | { 1170 | "cell_type": "code", 1171 | "execution_count": 39, 1172 | "metadata": {}, 1173 | "outputs": [], 1174 | "source": [ 1175 | "session = tf.Session()" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "code", 1180 | "execution_count": 47, 1181 | "metadata": {}, 1182 | "outputs": [], 1183 | "source": [ 1184 | "train_log_save_dir = '/Users/honglan/Desktop/tensorflow_estimator_learn/cnn_raw_log'\n", 1185 | "model_save_path = '/Users/honglan/Desktop/tensorflow_estimator_learn/cnn_raw'\n", 1186 | "merged = tf.summary.merge_all()\n", 1187 | "train_writer = tf.summary.FileWriter(train_log_save_dir, session.graph)\n", 1188 | "saver = tf.train.Saver()" 1189 | ] 1190 | }, 1191 | { 1192 | "cell_type": "markdown", 1193 | "metadata": {}, 1194 | "source": [ 1195 | "### 初始化变量\n", 1196 | "\n", 1197 | "我们需要在开始优化weights和biases变量之前对它们进行初始化。" 1198 | ] 1199 | }, 1200 | { 1201 | "cell_type": "code", 1202 | "execution_count": 41, 1203 | "metadata": {}, 1204 | "outputs": [], 1205 | "source": [ 1206 | "session.run(tf.global_variables_initializer())" 1207 | ] 1208 | }, 1209 | { 1210 | "cell_type": "markdown", 1211 | "metadata": {}, 1212 | "source": [ 1213 | "### 用来优化迭代的帮助函数" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "markdown", 1218 | "metadata": {}, 1219 | "source": [ 1220 | "在训练集中有50,000张图。用这些图像计算模型的梯度会花很多时间。因此我们利用随机梯度下降的方法,它在优化器的每次迭代里只用到了一小部分的图像。\n", 1221 | "\n", 1222 | "如果内存耗尽导致电脑死机或变得很慢,你应该试着减少这些数量,但同时可能还需要更优化的迭代。" 1223 | ] 1224 | }, 1225 | { 1226 | "cell_type": "code", 1227 | "execution_count": 42, 1228 | "metadata": {}, 1229 | "outputs": [], 1230 | "source": [ 1231 | "train_batch_size = 64" 1232 | ] 1233 | }, 1234 | { 1235 | "cell_type": "markdown", 1236 | "metadata": {}, 1237 | "source": [ 1238 | "函数执行了多次的优化迭代来逐步地提升网络层的变量。在每次迭代中,从训练集中选择一批新的数据,然后TensorFlow用这些训练样本来执行优化器。每100次迭代会打印出相关信息。" 1239 | ] 1240 | }, 1241 | { 1242 | "cell_type": "code", 1243 | "execution_count": 48, 1244 | "metadata": {}, 1245 | "outputs": [], 1246 | "source": [ 1247 | "# Counter for total number of iterations performed so far.\n", 1248 | "total_iterations = 0\n", 1249 | "\n", 1250 | "\n", 1251 | "def optimize(num_iterations):\n", 1252 | " # Ensure we update the global variable rather than a local copy.\n", 1253 | " global total_iterations\n", 1254 | "\n", 1255 | " # Start-time used for printing time-usage below.\n", 1256 | " start_time = time.time()\n", 1257 | "\n", 1258 | " for i in range(total_iterations,\n", 1259 | " total_iterations + num_iterations):\n", 1260 | " # Get a batch of training examples.\n", 1261 | " # x_batch now holds a batch of images and\n", 1262 | " # y_true_batch are the true labels for those images.\n", 1263 | " x_batch, y_true_batch, _ = batch(batch_size=train_batch_size)\n", 1264 | "\n", 1265 | " # Put the batch into a dict with the proper names\n", 1266 | " # for placeholder variables in the TensorFlow graph.\n", 1267 | " feed_dict_train = {x: x_batch,\n", 1268 | " y_true: y_true_batch}\n", 1269 | "\n", 1270 | " # Run the optimizer using this batch of training data.\n", 1271 | " # TensorFlow assigns the variables in feed_dict_train\n", 1272 | " # to the placeholder variables and then runs the optimizer.\n", 1273 | " session.run(optimizer, feed_dict=feed_dict_train)\n", 1274 | "\n", 1275 | " # Print status every 100 iterations.\n", 1276 | " if i % 100 == 0:\n", 1277 | " # Calculate the accuracy on the training-set.\n", 1278 | " acc,summary = session.run([accuracy,merged], feed_dict=feed_dict_train)\n", 1279 | " train_writer.add_summary(summary,total_iterations)\n", 1280 | " \n", 1281 | " # Message for printing.\n", 1282 | " msg = \"Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}\"\n", 1283 | " saver.save(sess=session, save_path=model_save_path+'/'+'cnn_raw',global_step=total_iterations)\n", 1284 | " \n", 1285 | " # Print it.\n", 1286 | " print(msg.format(i + 1, acc))\n", 1287 | "\n", 1288 | " # Update the total number of iterations performed.\n", 1289 | " total_iterations += num_iterations\n", 1290 | "\n", 1291 | " # Ending time.\n", 1292 | " end_time = time.time()\n", 1293 | "\n", 1294 | " # Difference between start and end-times.\n", 1295 | " time_dif = end_time - start_time\n", 1296 | "\n", 1297 | " # Print the time-usage.\n", 1298 | " print(\"Time usage: \" + str(timedelta(seconds=int(round(time_dif)))))" 1299 | ] 1300 | }, 1301 | { 1302 | "cell_type": "markdown", 1303 | "metadata": {}, 1304 | "source": [ 1305 | "### 终于可以开始训练了" 1306 | ] 1307 | }, 1308 | { 1309 | "cell_type": "code", 1310 | "execution_count": 49, 1311 | "metadata": {}, 1312 | "outputs": [ 1313 | { 1314 | "name": "stdout", 1315 | "output_type": "stream", 1316 | "text": [ 1317 | "Optimization Iteration: 1, Training Accuracy: 95.3%\n", 1318 | "Optimization Iteration: 101, Training Accuracy: 95.3%\n", 1319 | "Optimization Iteration: 201, Training Accuracy: 95.3%\n", 1320 | "Optimization Iteration: 301, Training Accuracy: 90.6%\n", 1321 | "Optimization Iteration: 401, Training Accuracy: 93.8%\n", 1322 | "Optimization Iteration: 501, Training Accuracy: 93.8%\n", 1323 | "Optimization Iteration: 601, Training Accuracy: 98.4%\n", 1324 | "Optimization Iteration: 701, Training Accuracy: 93.8%\n", 1325 | "Optimization Iteration: 801, Training Accuracy: 95.3%\n", 1326 | "Optimization Iteration: 901, Training Accuracy: 98.4%\n", 1327 | "Time usage: 0:01:31\n" 1328 | ] 1329 | } 1330 | ], 1331 | "source": [ 1332 | "optimize(1000)" 1333 | ] 1334 | }, 1335 | { 1336 | "cell_type": "markdown", 1337 | "metadata": {}, 1338 | "source": [ 1339 | "### 关闭TensorFlow会话" 1340 | ] 1341 | }, 1342 | { 1343 | "cell_type": "markdown", 1344 | "metadata": {}, 1345 | "source": [ 1346 | "现在我们已经用TensorFlow完成了任务,关闭session,释放资源。" 1347 | ] 1348 | }, 1349 | { 1350 | "cell_type": "code", 1351 | "execution_count": 50, 1352 | "metadata": {}, 1353 | "outputs": [], 1354 | "source": [ 1355 | "# This has been commented out in case you want to modify and experiment\n", 1356 | "# with the Notebook without having to restart it.\n", 1357 | "session.close()" 1358 | ] 1359 | }, 1360 | { 1361 | "cell_type": "markdown", 1362 | "metadata": {}, 1363 | "source": [ 1364 | "## 总结\n", 1365 | "\n", 1366 | "- 常规TensorFlow模型训练的步骤,大家应该都了解\n", 1367 | "\n", 1368 | "- 有没有觉得有点费劲,训练和验证还没有写,我已经懒得写了(留给大家有兴趣的自己写一写)\n" 1369 | ] 1370 | }, 1371 | { 1372 | "cell_type": "code", 1373 | "execution_count": null, 1374 | "metadata": {}, 1375 | "outputs": [], 1376 | "source": [] 1377 | } 1378 | ], 1379 | "metadata": { 1380 | "anaconda-cloud": {}, 1381 | "kernelspec": { 1382 | "display_name": "Python 3", 1383 | "language": "python", 1384 | "name": "python3" 1385 | }, 1386 | "language_info": { 1387 | "codemirror_mode": { 1388 | "name": "ipython", 1389 | "version": 3 1390 | }, 1391 | "file_extension": ".py", 1392 | "mimetype": "text/x-python", 1393 | "name": "python", 1394 | "nbconvert_exporter": "python", 1395 | "pygments_lexer": "ipython3", 1396 | "version": "3.5.0" 1397 | } 1398 | }, 1399 | "nbformat": 4, 1400 | "nbformat_minor": 1 1401 | } 1402 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/DNNClassifier_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n", 8 | "\n", 9 | "\n", 10 | "- 基于MNIST数据集,运用TensorFlow中 **tf.estimator** 预制的 **tf.estimator.DNNClassifier** 搭建一个简单的多层神经网络,实现模型的训练,验证和测试\n", 11 | "\n", 12 | "- TensorBoard的简单使用\n" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## 看看MNIST数据集的样子\n" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "### 导入各个库" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/plain": [ 37 | "'1.8.0'" 38 | ] 39 | }, 40 | "execution_count": 1, 41 | "metadata": {}, 42 | "output_type": "execute_result" 43 | } 44 | ], 45 | "source": [ 46 | "%matplotlib inline\n", 47 | "import tensorflow as tf\n", 48 | "import matplotlib.pyplot as plt\n", 49 | "import numpy as np\n", 50 | "import pandas as pd\n", 51 | "import multiprocessing\n", 52 | "\n", 53 | "\n", 54 | "from tensorflow import data\n", 55 | "from tensorflow.python.feature_column import feature_column\n", 56 | "\n", 57 | "tf.__version__" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### MNIST数据集载入" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 2, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n", 74 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n", 75 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n", 76 | "\n", 77 | "MULTI_THREADING = True\n", 78 | "RESUME_TRAINING = False\n", 79 | "\n", 80 | "NUM_CLASS = 10\n", 81 | "IMG_SHAPE = [28,28]\n", 82 | "\n", 83 | "IMG_WIDTH = 28\n", 84 | "IMG_HEIGHT = 28\n", 85 | "BATCH_SIZE = 128" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "test_data (10000, 784)\n", 98 | "test_label (10000,)\n", 99 | "val_data (5000, 784)\n", 100 | "val_label (5000,)\n", 101 | "train_data (55000, 784)\n", 102 | "train_label (55000,)\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n", 108 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n", 109 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n", 110 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n", 111 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n", 112 | "\n", 113 | "train_values = train_data.values\n", 114 | "train_data = train_values[:,1:]/255.0\n", 115 | "train_label = train_values[:,0:1].squeeze()\n", 116 | "\n", 117 | "val_values = val_data.values\n", 118 | "val_data = val_values[:,1:]/255.0\n", 119 | "val_label = val_values[:,0:1].squeeze()\n", 120 | "\n", 121 | "test_values = test_data.values\n", 122 | "test_data = test_values[:,1:]/255.0\n", 123 | "test_label = test_values[:,0:1].squeeze()\n", 124 | "\n", 125 | "print('test_data',np.shape(test_data))\n", 126 | "print('test_label',np.shape(test_label))\n", 127 | "\n", 128 | "print('val_data',np.shape(val_data))\n", 129 | "print('val_label',np.shape(val_label))\n", 130 | "\n", 131 | "print('train_data',np.shape(train_data))\n", 132 | "print('train_label',np.shape(train_label))\n", 133 | "\n", 134 | "# train_data.head(10)\n", 135 | "# test_data.head(10)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 4, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "img_shape = IMG_SHAPE\n", 145 | "\n", 146 | "def plot_images(images, cls_true, cls_pred=None):\n", 147 | " assert len(images) == len(cls_true) == 9\n", 148 | " \n", 149 | " # Create figure with 3x3 sub-plots.\n", 150 | " fig, axes = plt.subplots(3, 3)\n", 151 | " fig.subplots_adjust(hspace=0.3, wspace=0.3)\n", 152 | "\n", 153 | " for i, ax in enumerate(axes.flat):\n", 154 | " # Plot image.\n", 155 | " ax.imshow(images[i].reshape(img_shape), cmap='binary')\n", 156 | "\n", 157 | " # Show true and predicted classes.\n", 158 | " if cls_pred is None:\n", 159 | " xlabel = \"True: {0}\".format(cls_true[i])\n", 160 | " else:\n", 161 | " xlabel = \"True: {0}, Pred: {1}\".format(cls_true[i], cls_pred[i])\n", 162 | "\n", 163 | " # Show the classes as the label on the x-axis.\n", 164 | " ax.set_xlabel(xlabel)\n", 165 | " \n", 166 | " # Remove ticks from the plot.\n", 167 | " ax.set_xticks([])\n", 168 | " ax.set_yticks([])\n", 169 | " \n", 170 | " # Ensure the plot is shown correctly with multiple plots\n", 171 | " # in a single Notebook cell.\n", 172 | " plt.show()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "## 重头戏之怎么用 tf.estimator.DNNClassifier " 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "### 先看看input_fn之创建输入函数\n", 187 | "\n", 188 | "- 采用 **datasetAPI** 构造输入函数" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 5, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "# validate tf.data.TextLineDataset() using make_one_shot_iterator()\n", 198 | "\n", 199 | "def decode_line(line):\n", 200 | " # Decode the csv_line to tensor.\n", 201 | " record_defaults = [[1.0] for col in range(785)]\n", 202 | " items = tf.decode_csv(line, record_defaults)\n", 203 | " features = items[1:785]\n", 204 | " label = items[0]\n", 205 | "\n", 206 | " features = tf.cast(features, tf.float32)\n", 207 | " features = tf.reshape(features,[28,28,1])\n", 208 | " label = tf.cast(label, tf.int64)\n", 209 | "# label = tf.one_hot(label,num_class)\n", 210 | " return features,label" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 6, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "def csv_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL, \n", 220 | " skip_header_lines=1, \n", 221 | " num_epochs=None, \n", 222 | " batch_size=128):\n", 223 | " shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False\n", 224 | " \n", 225 | " num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1\n", 226 | " \n", 227 | " print(\"\")\n", 228 | " print(\"* data input_fn:\")\n", 229 | " print(\"================\")\n", 230 | " print(\"Input file(s): {}\".format(files_name_pattern))\n", 231 | " print(\"Batch size: {}\".format(batch_size))\n", 232 | " print(\"Epoch Count: {}\".format(num_epochs))\n", 233 | " print(\"Mode: {}\".format(mode))\n", 234 | " print(\"Thread Count: {}\".format(num_threads))\n", 235 | " print(\"Shuffle: {}\".format(shuffle))\n", 236 | " print(\"================\")\n", 237 | " print(\"\")\n", 238 | "\n", 239 | " file_names = tf.matching_files(files_name_pattern)\n", 240 | " dataset = data.TextLineDataset(filenames=file_names).skip(1)\n", 241 | "# dataset = tf.data.TextLineDataset(filenames).skip(1)\n", 242 | " print(\"DATASET\",dataset)\n", 243 | "\n", 244 | " # Use `Dataset.map()` to build a pair of a feature dictionary and a label\n", 245 | " # tensor for each example.\n", 246 | " dataset = dataset.map(decode_line)\n", 247 | " print(\"DATASET_1\",dataset)\n", 248 | " dataset = dataset.shuffle(buffer_size=10000)\n", 249 | " print(\"DATASET_2\",dataset)\n", 250 | " dataset = dataset.batch(32)\n", 251 | " print(\"DATASET_3\",dataset)\n", 252 | " dataset = dataset.repeat(num_epochs)\n", 253 | " print(\"DATASET_4\",dataset)\n", 254 | " iterator = dataset.make_one_shot_iterator()\n", 255 | " \n", 256 | " # `features` is a dictionary in which each value is a batch of values for\n", 257 | " # that feature; `labels` is a batch of labels.\n", 258 | " features, labels = iterator.get_next()\n", 259 | " \n", 260 | " features = {'images':features}\n", 261 | " \n", 262 | " return features,labels\n" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 7, 268 | "metadata": {}, 269 | "outputs": [ 270 | { 271 | "name": "stdout", 272 | "output_type": "stream", 273 | "text": [ 274 | "\n", 275 | "* data input_fn:\n", 276 | "================\n", 277 | "Input file(s): data_csv/mnist_train.csv\n", 278 | "Batch size: 128\n", 279 | "Epoch Count: None\n", 280 | "Mode: eval\n", 281 | "Thread Count: 4\n", 282 | "Shuffle: False\n", 283 | "================\n", 284 | "\n", 285 | "DATASET \n", 286 | "DATASET_1 \n", 287 | "DATASET_2 \n", 288 | "DATASET_3 \n", 289 | "DATASET_4 \n", 290 | "Features in CSV: ['images']\n", 291 | "Target in CSV: Tensor(\"IteratorGetNext:1\", shape=(?,), dtype=int64)\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "features, target = csv_input_fn(files_name_pattern=TRAIN_DATA_FILES_PATTERN)\n", 297 | "print(\"Features in CSV: {}\".format(list(features.keys())))\n", 298 | "print(\"Target in CSV: {}\".format(target))" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "### 定义feature_columns" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 8, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "feature_x = tf.feature_column.numeric_column('images', shape=[28,28])\n", 315 | "# print((feature_x))\n", 316 | "\n", 317 | "feature_columns = [feature_x]\n", 318 | "# print((feature_columns))" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 9, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "num_hidden_units = [512, 256, 128]" 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "metadata": {}, 333 | "source": [ 334 | "### DNNClassifier来啦" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 11, 340 | "metadata": {}, 341 | "outputs": [ 342 | { 343 | "name": "stdout", 344 | "output_type": "stream", 345 | "text": [ 346 | "INFO:tensorflow:Using default config.\n", 347 | "INFO:tensorflow:Using config: {'_master': '', '_num_worker_replicas': 1, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_tf_random_seed': None, '_service': None, '_cluster_spec': , '_model_dir': './simple_dnn_dataset', '_num_ps_replicas': 0, '_save_checkpoints_steps': None, '_evaluation_master': '', '_save_summary_steps': 100, '_log_step_count_steps': 100, '_global_id_in_cluster': 0, '_train_distribute': None, '_is_chief': True, '_task_id': 0, '_save_checkpoints_secs': 600, '_task_type': 'worker'}\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "num_class = NUM_CLASS\n", 353 | "\n", 354 | "model = tf.estimator.DNNClassifier(feature_columns = feature_columns,\n", 355 | " hidden_units = num_hidden_units,\n", 356 | " activation_fn = tf.nn.relu,\n", 357 | " n_classes = num_class,\n", 358 | " model_dir = './simple_dnn_dataset')" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "### 愉快滴训练吧" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 14, 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "\n", 378 | "* data input_fn:\n", 379 | "================\n", 380 | "Input file(s): data_csv/mnist_train.csv\n", 381 | "Batch size: 128\n", 382 | "Epoch Count: None\n", 383 | "Mode: train\n", 384 | "Thread Count: 4\n", 385 | "Shuffle: True\n", 386 | "================\n", 387 | "\n", 388 | "DATASET \n", 389 | "DATASET_1 \n", 390 | "DATASET_2 \n", 391 | "DATASET_3 \n", 392 | "DATASET_4 \n", 393 | "INFO:tensorflow:Calling model_fn.\n", 394 | "INFO:tensorflow:Done calling model_fn.\n", 395 | "INFO:tensorflow:Create CheckpointSaverHook.\n", 396 | "INFO:tensorflow:Graph was finalized.\n", 397 | "INFO:tensorflow:Restoring parameters from ./simple_dnn_dataset/model.ckpt-200\n", 398 | "INFO:tensorflow:Running local_init_op.\n", 399 | "INFO:tensorflow:Done running local_init_op.\n", 400 | "INFO:tensorflow:Saving checkpoints for 201 into ./simple_dnn_dataset/model.ckpt.\n", 401 | "INFO:tensorflow:loss = 26.95501, step = 201\n", 402 | "INFO:tensorflow:global_step/sec: 15.0276\n", 403 | "INFO:tensorflow:loss = 23.322294, step = 301 (6.655 sec)\n", 404 | "INFO:tensorflow:global_step/sec: 13.8421\n", 405 | "INFO:tensorflow:loss = 17.458122, step = 401 (7.225 sec)\n", 406 | "INFO:tensorflow:global_step/sec: 13.3083\n", 407 | "INFO:tensorflow:loss = 21.524231, step = 501 (7.517 sec)\n", 408 | "INFO:tensorflow:global_step/sec: 14.5015\n", 409 | "INFO:tensorflow:loss = 21.863522, step = 601 (6.892 sec)\n", 410 | "INFO:tensorflow:global_step/sec: 13.2937\n", 411 | "INFO:tensorflow:loss = 12.238069, step = 701 (7.524 sec)\n", 412 | "INFO:tensorflow:global_step/sec: 13.6131\n", 413 | "INFO:tensorflow:loss = 19.554596, step = 801 (7.345 sec)\n", 414 | "INFO:tensorflow:global_step/sec: 12.5833\n", 415 | "INFO:tensorflow:loss = 4.9210396, step = 901 (7.948 sec)\n", 416 | "INFO:tensorflow:global_step/sec: 13.2139\n", 417 | "INFO:tensorflow:loss = 8.347723, step = 1001 (7.566 sec)\n", 418 | "INFO:tensorflow:global_step/sec: 14.5858\n", 419 | "INFO:tensorflow:loss = 17.034126, step = 1101 (6.856 sec)\n", 420 | "INFO:tensorflow:global_step/sec: 14.5617\n", 421 | "INFO:tensorflow:loss = 21.071743, step = 1201 (6.866 sec)\n", 422 | "INFO:tensorflow:global_step/sec: 14.7257\n", 423 | "INFO:tensorflow:loss = 11.271985, step = 1301 (6.791 sec)\n", 424 | "INFO:tensorflow:global_step/sec: 14.9258\n", 425 | "INFO:tensorflow:loss = 7.7849083, step = 1401 (6.700 sec)\n", 426 | "INFO:tensorflow:global_step/sec: 14.8296\n", 427 | "INFO:tensorflow:loss = 7.3179874, step = 1501 (6.743 sec)\n", 428 | "INFO:tensorflow:global_step/sec: 15.3108\n", 429 | "INFO:tensorflow:loss = 5.9724092, step = 1601 (6.532 sec)\n", 430 | "INFO:tensorflow:global_step/sec: 111.22\n", 431 | "INFO:tensorflow:loss = 23.16468, step = 1701 (0.899 sec)\n", 432 | "INFO:tensorflow:global_step/sec: 165.726\n", 433 | "INFO:tensorflow:loss = 15.113611, step = 1801 (0.603 sec)\n", 434 | "INFO:tensorflow:global_step/sec: 164.038\n", 435 | "INFO:tensorflow:loss = 17.828293, step = 1901 (0.610 sec)\n", 436 | "INFO:tensorflow:global_step/sec: 3.84192\n", 437 | "INFO:tensorflow:loss = 10.36054, step = 2001 (26.032 sec)\n", 438 | "INFO:tensorflow:global_step/sec: 13.1081\n", 439 | "INFO:tensorflow:loss = 10.766257, step = 2101 (7.626 sec)\n", 440 | "INFO:tensorflow:Saving checkpoints for 2200 into ./simple_dnn_dataset/model.ckpt.\n", 441 | "INFO:tensorflow:Loss for final step: 10.364952.\n" 442 | ] 443 | }, 444 | { 445 | "data": { 446 | "text/plain": [ 447 | "" 448 | ] 449 | }, 450 | "execution_count": 14, 451 | "metadata": {}, 452 | "output_type": "execute_result" 453 | } 454 | ], 455 | "source": [ 456 | "input_fn = lambda: csv_input_fn(\\\n", 457 | " files_name_pattern= TRAIN_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.TRAIN)\n", 458 | "\n", 459 | "model.train(input_fn, steps = 2000)" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "### 验证一下呗" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 15, 472 | "metadata": {}, 473 | "outputs": [ 474 | { 475 | "name": "stdout", 476 | "output_type": "stream", 477 | "text": [ 478 | "\n", 479 | "* data input_fn:\n", 480 | "================\n", 481 | "Input file(s): data_csv/mnist_val.csv\n", 482 | "Batch size: 128\n", 483 | "Epoch Count: None\n", 484 | "Mode: eval\n", 485 | "Thread Count: 4\n", 486 | "Shuffle: False\n", 487 | "================\n", 488 | "\n", 489 | "DATASET \n", 490 | "DATASET_1 \n", 491 | "DATASET_2 \n", 492 | "DATASET_3 \n", 493 | "DATASET_4 \n", 494 | "INFO:tensorflow:Calling model_fn.\n", 495 | "INFO:tensorflow:Done calling model_fn.\n", 496 | "INFO:tensorflow:Starting evaluation at 2018-10-25-03:38:01\n", 497 | "INFO:tensorflow:Graph was finalized.\n", 498 | "INFO:tensorflow:Restoring parameters from ./simple_dnn_dataset/model.ckpt-2200\n", 499 | "INFO:tensorflow:Running local_init_op.\n", 500 | "INFO:tensorflow:Done running local_init_op.\n", 501 | "INFO:tensorflow:Evaluation [1/1]\n", 502 | "INFO:tensorflow:Finished evaluation at 2018-10-25-03:38:10\n", 503 | "INFO:tensorflow:Saving dict for global step 2200: accuracy = 0.9375, average_loss = 0.14245859, global_step = 2200, loss = 4.558675\n" 504 | ] 505 | }, 506 | { 507 | "data": { 508 | "text/plain": [ 509 | "{'accuracy': 0.9375,\n", 510 | " 'average_loss': 0.14245859,\n", 511 | " 'global_step': 2200,\n", 512 | " 'loss': 4.558675}" 513 | ] 514 | }, 515 | "execution_count": 15, 516 | "metadata": {}, 517 | "output_type": "execute_result" 518 | } 519 | ], 520 | "source": [ 521 | "input_fn = lambda: csv_input_fn(files_name_pattern= VAL_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.EVAL)\n", 522 | "\n", 523 | "model.evaluate(input_fn,steps=1)" 524 | ] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "metadata": {}, 529 | "source": [ 530 | "### 测试测试吧" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": 19, 536 | "metadata": {}, 537 | "outputs": [ 538 | { 539 | "name": "stdout", 540 | "output_type": "stream", 541 | "text": [ 542 | "\n", 543 | "* data input_fn:\n", 544 | "================\n", 545 | "Input file(s): data_csv/mnist_test.csv\n", 546 | "Batch size: 10\n", 547 | "Epoch Count: None\n", 548 | "Mode: infer\n", 549 | "Thread Count: 4\n", 550 | "Shuffle: False\n", 551 | "================\n", 552 | "\n", 553 | "DATASET \n", 554 | "DATASET_1 \n", 555 | "DATASET_2 \n", 556 | "DATASET_3 \n", 557 | "DATASET_4 \n", 558 | "INFO:tensorflow:Calling model_fn.\n", 559 | "INFO:tensorflow:Done calling model_fn.\n", 560 | "INFO:tensorflow:Graph was finalized.\n", 561 | "INFO:tensorflow:Restoring parameters from ./simple_dnn_dataset/model.ckpt-2200\n", 562 | "INFO:tensorflow:Running local_init_op.\n", 563 | "INFO:tensorflow:Done running local_init_op.\n", 564 | "\n", 565 | "* Predicted Classes: [b'0', b'7', b'6', b'3', b'4', b'6', b'5', b'8', b'5', b'6']\n" 566 | ] 567 | } 568 | ], 569 | "source": [ 570 | "import itertools\n", 571 | "\n", 572 | "input_fn = lambda: csv_input_fn(\\\n", 573 | " files_name_pattern= TEST_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.PREDICT,batch_size=10)\n", 574 | "\n", 575 | "predictions = list(itertools.islice(model.predict(input_fn=input_fn),10))\n", 576 | "# print('PREDICTIONS',predictions)\n", 577 | "print(\"\")\n", 578 | "print(\"* Predicted Classes: {}\".format(list(map(lambda item: item[\"classes\"][0]\n", 579 | " ,predictions))))\n", 580 | "\n" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": null, 586 | "metadata": {}, 587 | "outputs": [], 588 | "source": [] 589 | } 590 | ], 591 | "metadata": { 592 | "kernelspec": { 593 | "display_name": "Python 3", 594 | "language": "python", 595 | "name": "python3" 596 | }, 597 | "language_info": { 598 | "codemirror_mode": { 599 | "name": "ipython", 600 | "version": 3 601 | }, 602 | "file_extension": ".py", 603 | "mimetype": "text/x-python", 604 | "name": "python", 605 | "nbconvert_exporter": "python", 606 | "pygments_lexer": "ipython3", 607 | "version": "3.5.0" 608 | } 609 | }, 610 | "nbformat": 4, 611 | "nbformat_minor": 2 612 | } 613 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/data_csv/mnist_test.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:51c292478d94ec3a01461bdfa82eb0885d262eb09e615679b2d69dedb6ad09e7 3 | size 18289443 4 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/data_csv/mnist_train.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fb1eb744fad41aefc48109fa1694b8385a541b9ff13c5954a646e37ffd4b87f6 3 | size 100447555 4 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/data_csv/mnist_val.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:46106534c381da5f74701cdddc5cde0834abd3a9634d70f3d5beb0d15c8d1675 3 | size 9128008 4 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/02_convolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/02_convolution.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/02_network_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/02_network_flowchart.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/0_TF_HELLO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/0_TF_HELLO.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/dataset_classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/dataset_classes.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/estimator_types.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/estimator_types.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/feed_tf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/feed_tf.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/feed_tf_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/feed_tf_out.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/inputs_to_model_bridge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/inputs_to_model_bridge.jpg -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/pt_sum_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/pt_sum_code.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/pt_sum_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/pt_sum_output.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tensorflow_programming_environment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tensorflow_programming_environment.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tensors_flowing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tensors_flowing.gif -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_feed_out_wrong2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_feed_out_wrong2.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_feed_wrong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_feed_wrong.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_feed_wrong_out_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_feed_wrong_out_1.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_graph.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sess_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sess_code.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sess_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sess_output.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sum_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_graph.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sum_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_output.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sum_sess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_sess.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sum_sess_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_sess_code.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tf_sum_sess_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_sess_out.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tfe_sum_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tfe_sum_code.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/images/tfe_sum_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tfe_sum_output.png -------------------------------------------------------------------------------- /tensorflow_estimator_learn/tmp/basic_pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | two_node_pt = torch.tensor([2]) 4 | three_node_pt = torch.tensor([3]) 5 | sum_node_pt = two_node_pt + three_node_pt 6 | 7 | print('TWO_NODE_PT',two_node_pt) 8 | print('THREEE_NODE_PT',three_node_pt) 9 | print('SUM_NODE_PT',sum_node_pt) 10 | 11 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/tmp/basic_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | two_node = tf.constant(2) 4 | three_node = tf.constant(3) 5 | sum_node = two_node + three_node 6 | 7 | sess = tf.Session() 8 | two_node,three_node,sum_node = \ 9 | sess.run([two_node,three_node,sum_node]) 10 | 11 | print('TWO_NODE',two_node) 12 | print('THREE_NODE',three_node) 13 | print('SUM_NODE',sum_node) 14 | 15 | # print('TWO_NODE',two_node) 16 | # print('THREEE_NODE',three_node) 17 | # print('SUM_NODE',sum_node) 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/tmp/basic_tfe.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.eager as tfe 3 | tfe.enable_eager_execution() 4 | 5 | two_node_tfe = tf.constant(2) 6 | three_node_tfe = tf.constant(3) 7 | sum_node_tfe = two_node_tfe + three_node_tfe 8 | 9 | print('TWO_NODE_TFE',two_node_tfe) 10 | print('THREE_NODE_TFE',three_node_tfe) 11 | print('SUM_NODE_TFE',sum_node_tfe) 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/tmp/feed_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | input_placeholder = tf.placeholder(tf.int32) 4 | sess = tf.Session() 5 | input = sess.run(\ 6 | input_placeholder, feed_dict={input_placeholder: 2}) 7 | 8 | print('INPUT',input) 9 | 10 | -------------------------------------------------------------------------------- /tensorflow_estimator_learn/tmp/feed_tf_wrong.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | input_placeholder = tf.placeholder(tf.int32) 4 | three_node = tf.constant(3) 5 | sum_node = input_placeholder + three_node 6 | sess = tf.Session() 7 | 8 | print('THREE_NODE',sess.run(three_node)) 9 | print('SUM_NODE',sess.run(sum_node)) 10 | 11 | --------------------------------------------------------------------------------