├── .gitattributes
├── README.md
└── tensorflow_estimator_learn
├── CNNClassifier.ipynb
├── CNNClassifier_dataset.ipynb
├── CNN_raw.ipynb
├── DNNClassifier.ipynb
├── DNNClassifier_dataset.ipynb
├── data_csv
├── mnist_test.csv
├── mnist_train.csv
└── mnist_val.csv
├── images
├── 02_convolution.png
├── 02_convolution.svg
├── 02_network_flowchart.png
├── 02_network_flowchart.svg
├── 0_TF_HELLO.png
├── dataset_classes.png
├── estimator_types.png
├── feed_tf.png
├── feed_tf_out.png
├── inputs_to_model_bridge.jpg
├── pt_sum_code.png
├── pt_sum_output.png
├── tensorflow_programming_environment.png
├── tensors_flowing.gif
├── tf_feed_out_wrong2.png
├── tf_feed_wrong.png
├── tf_feed_wrong_out_1.png
├── tf_graph.png
├── tf_sess_code.png
├── tf_sess_output.png
├── tf_sum_graph.png
├── tf_sum_output.png
├── tf_sum_sess.png
├── tf_sum_sess_code.png
├── tf_sum_sess_out.png
├── tfe_sum_code.png
└── tfe_sum_output.png
└── tmp
├── basic_pt.py
├── basic_tf.py
├── basic_tfe.py
├── feed_tf.py
└── feed_tf_wrong.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.csv filter=lfs diff=lfs merge=lfs -text
2 | *.tfrecords filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # tensorflow_estimator_tutorial
2 | - The tensorflow version is out of date, please pay attention to the version problem.
3 | **Enjoy tf.estimator**
4 |
5 | ## 代码结构
6 | ```
7 | |--tensorflow_estimator_learn
8 | |--data_csv
9 | |--mnist_test.csv
10 | |--mnist_train.csv
11 | |--mnist_val.csv
12 |
13 | |--images
14 | |--ZJUAI_2018_AUT
15 | |--ZJUAI_2018_AUT
16 |
17 | |--tmp
18 | |--ZJUAI_2018_AUT
19 | |--ZJUAI_2018_AUT
20 |
21 | |--CNNClassifier.jpynb
22 |
23 | |--CNNClassifier_dataset.jpynb
24 |
25 | |--CNN_raw.jpynb
26 |
27 | |--DNNClassifier.jpynb
28 |
29 | |--DNNClassifier_dataset.jpynb
30 | ```
31 | ## 文件说明
32 | ### data_csv
33 | data_csv文件中存放了**MNSIT**原始csv文件,分为验证、训练、测试三个部分
34 | ### images
35 | images文件中存放了**jupyter notebook**中所涉及的一些图片
36 | ### tmp
37 | tmp 文件中存放了一些临时代码
38 | ### CNNClassifier.jpynb
39 | 未采用`tf.data`API的自定义estimator实现
40 | ### CNNClassifier_dataset.jpynb
41 | 采用`tf.data`API的自定义estimator实现
42 | ### CNN_raw.jpynb
43 | 未采用高阶API的 **搭建CNN实现MNIST分类**
44 | ### DNNClassifier.jpynb
45 | 未采用`tf.data`API的预制sestimator实现
46 | ### DNNClassifier_dataset.jpynb
47 | 采用`tf.data`API的预制estimator实现
48 |
49 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/CNNClassifier.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n",
8 | "\n",
9 | "- 基于MNIST数据集,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n",
10 | "\n",
11 | "- TensorBoard的简单使用\n"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## 导入各个库"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "data": {
28 | "text/plain": [
29 | "'1.8.0'"
30 | ]
31 | },
32 | "execution_count": 1,
33 | "metadata": {},
34 | "output_type": "execute_result"
35 | }
36 | ],
37 | "source": [
38 | "%matplotlib inline\n",
39 | "import tensorflow as tf\n",
40 | "import matplotlib.pyplot as plt\n",
41 | "import numpy as np\n",
42 | "import pandas as pd\n",
43 | "import multiprocessing\n",
44 | "\n",
45 | "\n",
46 | "from tensorflow import data\n",
47 | "from tensorflow.python.feature_column import feature_column\n",
48 | "\n",
49 | "tf.__version__"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## MNIST数据集载入"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "### 看看MNIST数据长什么样子的\n",
64 | "\n",
65 | "\n",
66 | "\n",
67 | "More info: http://yann.lecun.com/exdb/mnist/"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "- MNIST数据集包含70000张图像和对应的标签(图像的分类)。数据集被划为3个子集:训练集,验证集和测试集。\n",
75 | "\n",
76 | "- 定义**MNIST**数据的基本信息"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 2,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n",
86 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n",
87 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n",
88 | "\n",
89 | "MULTI_THREADING = True\n",
90 | "RESUME_TRAINING = False\n",
91 | "\n",
92 | "NUM_CLASS = 10\n",
93 | "IMG_SHAPE = [28,28]\n",
94 | "\n",
95 | "IMG_WIDTH = 28\n",
96 | "IMG_HEIGHT = 28\n",
97 | "IMG_FLAT = 784\n",
98 | "NUM_CHANNEL = 1\n",
99 | "\n",
100 | "BATCH_SIZE = 128\n",
101 | "NUM_TRAIN = 55000\n",
102 | "NUM_VAL = 5000\n",
103 | "NUM_TEST = 10000"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "metadata": {},
109 | "source": [
110 | "### 读取csv文件并查看数据信息"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 3,
116 | "metadata": {},
117 | "outputs": [
118 | {
119 | "name": "stdout",
120 | "output_type": "stream",
121 | "text": [
122 | "test_data (10000, 784)\n",
123 | "test_label (10000,)\n",
124 | "val_data (5000, 784)\n",
125 | "val_label (5000,)\n",
126 | "train_data (55000, 784)\n",
127 | "train_label (55000,)\n"
128 | ]
129 | }
130 | ],
131 | "source": [
132 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n",
133 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n",
134 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n",
135 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n",
136 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n",
137 | "\n",
138 | "train_values = train_data.values\n",
139 | "train_data = train_values[:,1:]/255.0\n",
140 | "train_label = train_values[:,0:1].squeeze()\n",
141 | "\n",
142 | "val_values = val_data.values\n",
143 | "val_data = val_values[:,1:]/255.0\n",
144 | "val_label = val_values[:,0:1].squeeze()\n",
145 | "\n",
146 | "test_values = test_data.values\n",
147 | "test_data = test_values[:,1:]/255.0\n",
148 | "test_label = test_values[:,0:1].squeeze()\n",
149 | "\n",
150 | "print('test_data',np.shape(test_data))\n",
151 | "print('test_label',np.shape(test_label))\n",
152 | "\n",
153 | "print('val_data',np.shape(val_data))\n",
154 | "print('val_label',np.shape(val_label))\n",
155 | "\n",
156 | "print('train_data',np.shape(train_data))\n",
157 | "print('train_label',np.shape(train_label))\n",
158 | "\n",
159 | "# train_data.head(10)\n",
160 | "# test_data.head(10)"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | "## 试试自己写一个estimator\n",
168 | "\n",
169 | "- 基于**MNIST数据集**,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n",
170 | "\n",
171 | "- [官网API](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator)\n",
172 | "\n",
173 | "- 看看有哪些参数\n",
174 | "\n",
175 | "```python\n",
176 | "__init__(\n",
177 | " model_fn,\n",
178 | " model_dir=None,\n",
179 | " config=None,\n",
180 | " params=None,\n",
181 | " warm_start_from=None\n",
182 | ")\n",
183 | "```\n",
184 | "- 本例中,重点在 **tf.estimator.Estimator** 中的 `model_fn`\n"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "### 先简单看看数据流\n",
192 | "\n",
193 | "下面的图表直接显示了本次MNIST例子的数据流向,共有**2个卷积层**,每一层卷积之后采用最大池化进行下采样(图中并未画出),最后接**2个全连接层**,实现对MNIST数据集的分类\n",
194 | "\n",
195 | ""
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "### 先看看input_fn之创建输入函数"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 5,
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "batch_size = BATCH_SIZE\n",
212 | "\n",
213 | "# Define the input function for training\n",
214 | "train_input_fn = tf.estimator.inputs.numpy_input_fn(\n",
215 | " x = {'images': np.array(train_data)},\n",
216 | " y = np.array(train_label),\n",
217 | " batch_size=batch_size,\n",
218 | " num_epochs=None, \n",
219 | " shuffle=True)"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 6,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "# Evaluate the Model\n",
229 | "# Define the input function for evaluating\n",
230 | "eval_input_fn = tf.estimator.inputs.numpy_input_fn(\n",
231 | " x = {'images': np.array(test_data)},\n",
232 | " y = np.array(test_label),\n",
233 | " batch_size=batch_size, \n",
234 | " shuffle=False)"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 7,
240 | "metadata": {},
241 | "outputs": [
242 | {
243 | "name": "stdout",
244 | "output_type": "stream",
245 | "text": [
246 | "some images (9, 784)\n"
247 | ]
248 | }
249 | ],
250 | "source": [
251 | "# Predict some images\n",
252 | "some_images = test_data[0:9]\n",
253 | "print('some images',np.shape(some_images))\n",
254 | "\n",
255 | "# Define the input function for predicting\n",
256 | "test_input_fn = tf.estimator.inputs.numpy_input_fn(\n",
257 | " x={'images': some_images},\n",
258 | " num_epochs=1,\n",
259 | " shuffle=False)"
260 | ]
261 | },
262 | {
263 | "cell_type": "markdown",
264 | "metadata": {},
265 | "source": [
266 | "### 定义feature_columns"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": 10,
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "feature_x = tf.feature_column.numeric_column('images', shape=IMG_SHAPE)\n",
276 | "\n",
277 | "feature_columns = [feature_x]"
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {},
283 | "source": [
284 | "### 重点在这里——model_fn\n",
285 | "\n",
286 | "\n",
287 | "#### model_fn: Model function. Follows the signature:\n",
288 | "\n",
289 | "* Args:\n",
290 | " * `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same.\n",
291 | " * `labels`: This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same (for multi-head models).If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` must still be able to handle `labels=None`.\n",
292 | " * `mode`: Optional. Specifies if this training, evaluation or prediction. See `tf.estimator.ModeKeys`.\n",
293 | " * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning.\n",
294 | " * `config`: Optional `estimator.RunConfig` object. Will receive what is passed to Estimator as its `config` parameter, or a default value. Allows setting up things in your `model_fn` based on configuration such as `num_ps_replicas`, or `model_dir`.\n",
295 | "* Returns:\n",
296 | " `tf.estimator.EstimatorSpec`\n",
297 | " \n",
298 | "#### 注意model_fn返回的tf.estimator.EstimatorSpec\n",
299 | "
\n",
300 | "\n",
301 | "\n"
302 | ]
303 | },
304 | {
305 | "cell_type": "markdown",
306 | "metadata": {},
307 | "source": [
308 | "### 定义我们自己的model_fn"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": 19,
314 | "metadata": {},
315 | "outputs": [],
316 | "source": [
317 | "def model_fn(features, labels, mode, params):\n",
318 | " # Args:\n",
319 | " #\n",
320 | " # features: This is the x-arg from the input_fn.\n",
321 | " # labels: This is the y-arg from the input_fn,\n",
322 | " # see e.g. train_input_fn for these two.\n",
323 | " # mode: Either TRAIN, EVAL, or PREDICT\n",
324 | " # params: User-defined hyper-parameters, e.g. learning-rate.\n",
325 | " \n",
326 | " # Reference to the tensor named \"x\" in the input-function.\n",
327 | "# x = features[\"images\"]\n",
328 | " x = tf.feature_column.input_layer(features, params['feature_columns'])\n",
329 | " # The convolutional layers expect 4-rank tensors\n",
330 | " # but x is a 2-rank tensor, so reshape it.\n",
331 | " net = tf.reshape(x, [-1, IMG_HEIGHT, IMG_WIDTH, NUM_CHANNEL]) \n",
332 | "\n",
333 | " # First convolutional layer.\n",
334 | " net = tf.layers.conv2d(inputs=net, name='layer_conv1',\n",
335 | " filters=16, kernel_size=5,\n",
336 | " padding='same', activation=tf.nn.relu)\n",
337 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2)\n",
338 | "\n",
339 | " # Second convolutional layer.\n",
340 | " net = tf.layers.conv2d(inputs=net, name='layer_conv2',\n",
341 | " filters=36, kernel_size=5,\n",
342 | " padding='same', activation=tf.nn.relu)\n",
343 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2) \n",
344 | "\n",
345 | " # Flatten to a 2-rank tensor.\n",
346 | " net = tf.contrib.layers.flatten(net)\n",
347 | " # Eventually this should be replaced with:\n",
348 | " # net = tf.layers.flatten(net)\n",
349 | "\n",
350 | " # First fully-connected / dense layer.\n",
351 | " # This uses the ReLU activation function.\n",
352 | " net = tf.layers.dense(inputs=net, name='layer_fc1',\n",
353 | " units=128, activation=tf.nn.relu) \n",
354 | "\n",
355 | " # Second fully-connected / dense layer.\n",
356 | " # This is the last layer so it does not use an activation function.\n",
357 | " net = tf.layers.dense(inputs=net, name='layer_fc2',\n",
358 | " units=10)\n",
359 | "\n",
360 | " # Logits output of the neural network.\n",
361 | " logits = net\n",
362 | "\n",
363 | " # Softmax output of the neural network.\n",
364 | " y_pred = tf.nn.softmax(logits=logits)\n",
365 | " \n",
366 | " # Classification output of the neural network.\n",
367 | " y_pred_cls = tf.argmax(y_pred, axis=1)\n",
368 | "\n",
369 | " if mode == tf.estimator.ModeKeys.PREDICT:\n",
370 | " # If the estimator is supposed to be in prediction-mode\n",
371 | " # then use the predicted class-number that is output by\n",
372 | " # the neural network. Optimization etc. is not needed.\n",
373 | " spec = tf.estimator.EstimatorSpec(mode=mode,\n",
374 | " predictions=y_pred_cls)\n",
375 | " else:\n",
376 | " # Otherwise the estimator is supposed to be in either\n",
377 | " # training or evaluation-mode. Note that the loss-function\n",
378 | " # is also required in Evaluation mode.\n",
379 | " \n",
380 | " # Define the loss-function to be optimized, by first\n",
381 | " # calculating the cross-entropy between the output of\n",
382 | " # the neural network and the true labels for the input data.\n",
383 | " # This gives the cross-entropy for each image in the batch.\n",
384 | " cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,\n",
385 | " logits=logits)\n",
386 | "\n",
387 | " # Reduce the cross-entropy batch-tensor to a single number\n",
388 | " # which can be used in optimization of the neural network.\n",
389 | " loss = tf.reduce_mean(cross_entropy)\n",
390 | "\n",
391 | " # Define the optimizer for improving the neural network.\n",
392 | " optimizer = tf.train.AdamOptimizer(learning_rate=params[\"learning_rate\"])\n",
393 | "\n",
394 | " # Get the TensorFlow op for doing a single optimization step.\n",
395 | " train_op = optimizer.minimize(\n",
396 | " loss=loss, global_step=tf.train.get_global_step())\n",
397 | "\n",
398 | " # Define the evaluation metrics,\n",
399 | " # in this case the classification accuracy.\n",
400 | " metrics = \\\n",
401 | " {\n",
402 | " \"accuracy\": tf.metrics.accuracy(labels, y_pred_cls)\n",
403 | " }\n",
404 | "\n",
405 | " # Wrap all of this in an EstimatorSpec.\n",
406 | " spec = tf.estimator.EstimatorSpec(\n",
407 | " mode=mode,\n",
408 | " loss=loss,\n",
409 | " train_op=train_op,\n",
410 | " eval_metric_ops=metrics)\n",
411 | " \n",
412 | " return spec"
413 | ]
414 | },
415 | {
416 | "cell_type": "markdown",
417 | "metadata": {},
418 | "source": [
419 | "### 自建的estimator在这里\n",
420 | "\n",
421 | "我们可以指定超参数,例如优化器的学习率。"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": 23,
427 | "metadata": {},
428 | "outputs": [],
429 | "source": [
430 | "params = {\"learning_rate\": 1e-4,\n",
431 | " 'feature_columns': feature_columns}"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": 24,
437 | "metadata": {},
438 | "outputs": [
439 | {
440 | "name": "stdout",
441 | "output_type": "stream",
442 | "text": [
443 | "INFO:tensorflow:Using default config.\n",
444 | "INFO:tensorflow:Using config: {'_evaluation_master': '', '_session_config': None, '_save_checkpoints_secs': 600, '_log_step_count_steps': 100, '_master': '', '_is_chief': True, '_task_type': 'worker', '_keep_checkpoint_max': 5, '_service': None, '_tf_random_seed': None, '_cluster_spec': , '_model_dir': './cnn_classifer/', '_task_id': 0, '_num_worker_replicas': 1, '_num_ps_replicas': 0, '_save_checkpoints_steps': None, '_global_id_in_cluster': 0, '_train_distribute': None, '_keep_checkpoint_every_n_hours': 10000, '_save_summary_steps': 100}\n"
445 | ]
446 | }
447 | ],
448 | "source": [
449 | "model = tf.estimator.Estimator(model_fn=model_fn,\n",
450 | " params=params,\n",
451 | " model_dir=\"./cnn_classifer/\")"
452 | ]
453 | },
454 | {
455 | "cell_type": "markdown",
456 | "metadata": {},
457 | "source": [
458 | "### 训练训练看看"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": 26,
464 | "metadata": {},
465 | "outputs": [
466 | {
467 | "name": "stdout",
468 | "output_type": "stream",
469 | "text": [
470 | "INFO:tensorflow:Calling model_fn.\n",
471 | "INFO:tensorflow:Done calling model_fn.\n",
472 | "INFO:tensorflow:Create CheckpointSaverHook.\n",
473 | "INFO:tensorflow:Graph was finalized.\n",
474 | "INFO:tensorflow:Running local_init_op.\n",
475 | "INFO:tensorflow:Done running local_init_op.\n",
476 | "INFO:tensorflow:Saving checkpoints for 1 into ./cnn_classifer/model.ckpt.\n",
477 | "INFO:tensorflow:step = 1, loss = 2.3124514\n",
478 | "INFO:tensorflow:global_step/sec: 5.13683\n",
479 | "INFO:tensorflow:step = 101, loss = 1.004812 (19.469 sec)\n",
480 | "INFO:tensorflow:global_step/sec: 4.44593\n",
481 | "INFO:tensorflow:step = 201, loss = 0.40566427 (22.492 sec)\n",
482 | "INFO:tensorflow:global_step/sec: 5.59063\n",
483 | "INFO:tensorflow:step = 301, loss = 0.28785554 (17.887 sec)\n",
484 | "INFO:tensorflow:global_step/sec: 5.88434\n",
485 | "INFO:tensorflow:step = 401, loss = 0.23790869 (16.994 sec)\n",
486 | "INFO:tensorflow:global_step/sec: 5.28758\n",
487 | "INFO:tensorflow:step = 501, loss = 0.2865603 (18.912 sec)\n",
488 | "INFO:tensorflow:global_step/sec: 5.55467\n",
489 | "INFO:tensorflow:step = 601, loss = 0.27893203 (18.004 sec)\n",
490 | "INFO:tensorflow:global_step/sec: 5.46903\n",
491 | "INFO:tensorflow:step = 701, loss = 0.13836136 (18.286 sec)\n",
492 | "INFO:tensorflow:global_step/sec: 5.41053\n",
493 | "INFO:tensorflow:step = 801, loss = 0.12664635 (18.480 sec)\n",
494 | "INFO:tensorflow:global_step/sec: 5.21324\n",
495 | "INFO:tensorflow:step = 901, loss = 0.22681555 (19.184 sec)\n",
496 | "INFO:tensorflow:global_step/sec: 5.59755\n",
497 | "INFO:tensorflow:step = 1001, loss = 0.19516315 (17.862 sec)\n",
498 | "INFO:tensorflow:global_step/sec: 5.7998\n",
499 | "INFO:tensorflow:step = 1101, loss = 0.15528539 (17.242 sec)\n",
500 | "INFO:tensorflow:global_step/sec: 5.63879\n",
501 | "INFO:tensorflow:step = 1201, loss = 0.07765657 (17.734 sec)\n",
502 | "INFO:tensorflow:global_step/sec: 5.5637\n",
503 | "INFO:tensorflow:step = 1301, loss = 0.11297858 (17.974 sec)\n",
504 | "INFO:tensorflow:global_step/sec: 5.15256\n",
505 | "INFO:tensorflow:step = 1401, loss = 0.13372605 (19.412 sec)\n",
506 | "INFO:tensorflow:global_step/sec: 5.43482\n",
507 | "INFO:tensorflow:step = 1501, loss = 0.13708562 (18.397 sec)\n",
508 | "INFO:tensorflow:global_step/sec: 5.36527\n",
509 | "INFO:tensorflow:step = 1601, loss = 0.050685763 (18.639 sec)\n",
510 | "INFO:tensorflow:global_step/sec: 5.23113\n",
511 | "INFO:tensorflow:step = 1701, loss = 0.06853628 (19.115 sec)\n",
512 | "INFO:tensorflow:global_step/sec: 5.25113\n",
513 | "INFO:tensorflow:step = 1801, loss = 0.11101746 (19.058 sec)\n",
514 | "INFO:tensorflow:global_step/sec: 5.02226\n",
515 | "INFO:tensorflow:step = 1901, loss = 0.091775164 (19.900 sec)\n",
516 | "INFO:tensorflow:Saving checkpoints for 2000 into ./cnn_classifer/model.ckpt.\n",
517 | "INFO:tensorflow:Loss for final step: 0.08684543.\n"
518 | ]
519 | },
520 | {
521 | "data": {
522 | "text/plain": [
523 | ""
524 | ]
525 | },
526 | "execution_count": 26,
527 | "metadata": {},
528 | "output_type": "execute_result"
529 | }
530 | ],
531 | "source": [
532 | "model.train(input_fn=train_input_fn, steps=2000)"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": 27,
538 | "metadata": {},
539 | "outputs": [
540 | {
541 | "name": "stdout",
542 | "output_type": "stream",
543 | "text": [
544 | "INFO:tensorflow:Calling model_fn.\n",
545 | "INFO:tensorflow:Done calling model_fn.\n",
546 | "INFO:tensorflow:Starting evaluation at 2018-10-25-04:44:07\n",
547 | "INFO:tensorflow:Graph was finalized.\n",
548 | "INFO:tensorflow:Restoring parameters from ./cnn_classifer/model.ckpt-2000\n",
549 | "INFO:tensorflow:Running local_init_op.\n",
550 | "INFO:tensorflow:Done running local_init_op.\n",
551 | "INFO:tensorflow:Finished evaluation at 2018-10-25-04:44:14\n",
552 | "INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.9761, global_step = 2000, loss = 0.07788641\n"
553 | ]
554 | },
555 | {
556 | "data": {
557 | "text/plain": [
558 | "{'accuracy': 0.9761, 'global_step': 2000, 'loss': 0.07788641}"
559 | ]
560 | },
561 | "execution_count": 27,
562 | "metadata": {},
563 | "output_type": "execute_result"
564 | }
565 | ],
566 | "source": [
567 | "# Use the Estimator 'evaluate' method\n",
568 | "model.evaluate(eval_input_fn)"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {},
574 | "source": [
575 | "### 测试一下瞅瞅"
576 | ]
577 | },
578 | {
579 | "cell_type": "code",
580 | "execution_count": 28,
581 | "metadata": {},
582 | "outputs": [
583 | {
584 | "name": "stdout",
585 | "output_type": "stream",
586 | "text": [
587 | "INFO:tensorflow:Calling model_fn.\n",
588 | "INFO:tensorflow:Done calling model_fn.\n",
589 | "INFO:tensorflow:Graph was finalized.\n",
590 | "INFO:tensorflow:Restoring parameters from ./cnn_classifer/model.ckpt-2000\n",
591 | "INFO:tensorflow:Running local_init_op.\n",
592 | "INFO:tensorflow:Done running local_init_op.\n"
593 | ]
594 | },
595 | {
596 | "data": {
597 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADO5JREFUeJzt3V2IXfW5x/Hf76QpiOlFYjUMNpqeogerSKKjCMYS9VhyYiEWg9SLkkLJ9CJKCyVU7EVzWaQv1JvAlIbGkmMrpNUoYmNjMQ1qcSJqEmNiElIzMW9lhCaCtNGnF7Nsp3H2f+/st7XH5/uBYfZez3p52Mxv1lp77bX/jggByOe/6m4AQD0IP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpD7Vz43Z5uOEQI9FhFuZr6M9v+1ltvfZPmD7gU7WBaC/3O5n+23PkrRf0h2SxiW9LOneiHijsAx7fqDH+rHnv1HSgYg4FBF/l/RrSSs6WB+APuok/JdKOjLl+Xg17T/YHrE9Znusg20B6LKev+EXEaOSRiUO+4FB0sme/6ikBVOef66aBmAG6CT8L0u6wvbnbX9a0tckbelOWwB6re3D/og4a/s+Sb+XNEvShojY07XOAPRU25f62toY5/xAz/XlQz4AZi7CDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp7iG5Jsn1Y0mlJH0g6GxHD3WgKQO91FP7KrRHx1y6sB0AfcdgPJNVp+EPSVts7bY90oyEA/dHpYf+SiDhq+xJJz9p+MyK2T52h+qfAPwZgwDgiurMie52kMxHxo8I83dkYgIYiwq3M1/Zhv+0LbX/mo8eSvixpd7vrA9BfnRz2z5f0O9sfref/I+KZrnQFoOe6dtjf0sY47Ad6rueH/QBmNsIPJEX4gaQIP5AU4QeSIvxAUt24qy+FlStXNqytXr26uOw777xTrL///vvF+qZNm4r148ePN6wdOHCguCzyYs8PJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxS2+LDh061LC2cOHC/jUyjdOnTzes7dmzp4+dDJbx8fGGtYceeqi47NjYWLfb6Rtu6QVQRPiBpAg/kBThB5Ii/EBShB9IivADSXE/f4tK9+xfe+21xWX37t1brF911VXF+nXXXVesL126tGHtpptuKi575MiRYn3BggXFeifOnj1brJ86dapYHxoaanvbb7/9drE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR+ftsbJH1F0smIuKaaNk/SbyQtlHRY0j0R8W7Tjc3g+/kH2dy5cxvWFi1aVFx2586dxfoNN9zQVk+taDZewf79+4v1Zp+fmDdvXsPamjVrisuuX7++WB9k3byf/5eSlp0z7QFJ2yLiCknbqucAZpCm4Y+I7ZImzpm8QtLG6vFGSXd1uS8APdbuOf/8iDhWPT4uaX6X+gHQJx1/tj8ionQub3tE0kin2wHQXe3u+U/YHpKk6vfJRjNGxGhEDEfEcJvbAtAD7YZ/i6RV1eNVkp7oTjsA+qVp+G0/KulFSf9je9z2NyX9UNIdtt+S9L/VcwAzCN/bj4F19913F+uPPfZYsb579+6GtVtvvbW47MTEuRe4Zg6+tx9AEeEHkiL8QFKEH0iK8ANJEX4gKS71oTaXXHJJsb5r166Oll+5cmXD2ubNm4vLzmRc6gNQRPiBpAg/kBThB5Ii/EBShB9IivADSTFEN2rT7OuzL7744mL93XfL3xa/b9++8+4pE/b8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU9/Ojp26++eaGteeee6647OzZs4v1pUuXFuvbt28v1j+puJ8fQBHhB5Ii/EBShB9IivADSRF+ICnCDyTV9H5+2xskfUXSyYi4ppq2TtJqSaeq2R6MiKd71SRmruXLlzesNbuOv23btmL9xRdfbKsnTGplz/9LScummf7TiFhU/RB8YIZpGv6I2C5pog+9AOijTs7577P9uu0Ntud2rSMAfdFu+NdL+oKkRZKOSfpxoxltj9gesz3W5rYA9EBb4Y+IExHxQUR8KOnnkm4szDsaEcMRMdxukwC6r63w2x6a8vSrknZ3px0A/dLKpb5HJS2V9Fnb45J+IGmp7UWSQtJhSd/qYY8AeoD7+dGRCy64oFjfsWNHw9rVV19dXPa2224r1l944YViPSvu5wdQRPiBpAg/kBThB5Ii/EBShB9IiiG60ZG1a9cW64sXL25Ye+aZZ4rLcimvt9jzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBS3NKLojvvvLNYf/zxx4v19957r2Ft2bLpvhT631566aViHdPjll4ARYQfSIrwA0kRfiApwg8kRfiBpAg/kBT38yd30UUXFesPP/xwsT5r1qxi/emnGw/gzHX8erHnB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkmt7Pb3uBpEckzZcUkkYj4me250n6jaSFkg5Luici3m2yLu7n77Nm1+GbXWu//vrri/WDBw8W66V79psti/Z0837+s5K+GxFflHSTpDW2vyjpAUnbIuIKSduq5wBmiKbhj4hjEfFK9fi0pL2SLpW0QtLGaraNku7qVZMAuu+8zvltL5S0WNKfJc2PiGNV6bgmTwsAzBAtf7bf9hxJmyV9JyL+Zv/7tCIiotH5vO0RSSOdNgqgu1ra89uercngb4qI31aTT9gequpDkk5Ot2xEjEbEcEQMd6NhAN3RNPye3MX/QtLeiPjJlNIWSauqx6skPdH99gD0SiuX+pZI+pOkXZI+rCY/qMnz/sckXSbpL5q81DfRZF1c6uuzK6+8slh/8803O1r/ihUrivUnn3yyo/Xj/LV6qa/pOX9E7JDUaGW3n09TAAYHn/ADkiL8QFKEH0iK8ANJEX4gKcIPJMVXd38CXH755Q1rW7du7Wjda9euLdafeuqpjtaP+rDnB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuM7/CTAy0vhb0i677LKO1v38888X682+DwKDiz0/kBThB5Ii/EBShB9IivADSRF+ICnCDyTFdf4ZYMmSJcX6/fff36dO8EnCnh9IivADSRF+ICnCDyRF+IGkCD+QFOEHkmp6nd/2AkmPSJovKSSNRsTPbK+TtFrSqWrWByPi6V41mtktt9xSrM+ZM6ftdR88eLBYP3PmTNvrxmBr5UM+ZyV9NyJesf0ZSTttP1vVfhoRP+pdewB6pWn4I+KYpGPV49O290q6tNeNAeit8zrnt71Q0mJJf64m3Wf7ddsbbM9tsMyI7THbYx11CqCrWg6/7TmSNkv6TkT8TdJ6SV+QtEiTRwY/nm65iBiNiOGIGO5CvwC6pKXw256tyeBviojfSlJEnIiIDyLiQ0k/l3Rj79oE0G1Nw2/bkn4haW9E/GTK9KEps31V0u7utwegV1p5t/9mSV+XtMv2q9W0ByXda3uRJi//HZb0rZ50iI689tprxfrtt99erE9MTHSzHQyQVt7t3yHJ05S4pg/MYHzCD0iK8ANJEX4gKcIPJEX4gaQIP5CU+znEsm3GcwZ6LCKmuzT/Mez5gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiCpfg/R/VdJf5ny/LPVtEE0qL0Nal8SvbWrm71d3uqMff2Qz8c2bo8N6nf7DWpvg9qXRG/tqqs3DvuBpAg/kFTd4R+tefslg9rboPYl0Vu7aumt1nN+APWpe88PoCa1hN/2Mtv7bB+w/UAdPTRi+7DtXbZfrXuIsWoYtJO2d0+ZNs/2s7bfqn5PO0xaTb2ts320eu1etb28pt4W2P6j7Tds77H97Wp6ra9doa9aXre+H/bbniVpv6Q7JI1LelnSvRHxRl8bacD2YUnDEVH7NWHbX5J0RtIjEXFNNe0hSRMR8cPqH+fciPjegPS2TtKZukdurgaUGZo6srSkuyR9QzW+doW+7lENr1sde/4bJR2IiEMR8XdJv5a0ooY+Bl5EbJd07qgZKyRtrB5v1OQfT9816G0gRMSxiHilenxa0kcjS9f62hX6qkUd4b9U0pEpz8c1WEN+h6SttnfaHqm7mWnMr4ZNl6TjkubX2cw0mo7c3E/njCw9MK9dOyNedxtv+H3ckoi4TtL/SVpTHd4OpJg8ZxukyzUtjdzcL9OMLP0vdb527Y543W11hP+opAVTnn+umjYQIuJo9fukpN9p8EYfPvHRIKnV75M19/MvgzRy83QjS2sAXrtBGvG6jvC/LOkK25+3/WlJX5O0pYY+Psb2hdUbMbJ9oaQva/BGH94iaVX1eJWkJ2rs5T8MysjNjUaWVs2v3cCNeB0Rff+RtFyT7/gflPT9Onpo0Nd/S3qt+tlTd2+SHtXkYeA/NPneyDclXSRpm6S3JP1B0rwB6u1XknZJel2TQRuqqbclmjykf13Sq9XP8rpfu0JftbxufMIPSIo3/ICkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJPVP82g/p9/JjhUAAAAASUVORK5CYII=\n",
598 | "text/plain": [
599 | ""
600 | ]
601 | },
602 | "metadata": {},
603 | "output_type": "display_data"
604 | },
605 | {
606 | "name": "stdout",
607 | "output_type": "stream",
608 | "text": [
609 | "Model prediction: 7\n"
610 | ]
611 | },
612 | {
613 | "data": {
614 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADXZJREFUeJzt3X+IHPUZx/HPU5uAaFGT0uMwttGohSj+CKcUCaVFjVZiYkA0wT9SWnr9o0LF+ItUUChiKf1B/wpEDCba2jRcjFFL0zZUTSEJOSVGo1ETuWjCJdcQ0QSRmuTpHzvXXvXmu5uZ2Z29PO8XHLc7z+7Mw3Kfm5md3e/X3F0A4vlS3Q0AqAfhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q1Jc7uTEz4+OEQJu5u7XyuFJ7fjO70czeNrPdZvZAmXUB6Cwr+tl+MztN0juSrpe0T9I2SYvc/c3Ec9jzA23WiT3/1ZJ2u/t77v5vSX+UNL/E+gB0UJnwnyvpgzH392XL/o+Z9ZvZoJkNltgWgIq1/Q0/d18uabnEYT/QTcrs+fdLOm/M/WnZMgATQJnwb5N0kZmdb2aTJS2UtL6atgC0W+HDfnc/ZmZ3Stog6TRJK9x9Z2WdAWirwpf6Cm2Mc36g7TryIR8AExfhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBdXTobhRzzz33JOunn356bu2yyy5LPvfWW28t1NOoZcuWJeubN2/OrT355JOlto1y2PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM3tsFVq9enayXvRZfpz179uTWrrvuuuRz33///arbCYHRewEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUKW+z29mQ5KOSDou6Zi791XR1Kmmzuv4u3btStY3bNiQrF9wwQXJ+s0335ysz5gxI7d2xx13JJ/76KOPJusop4rBPL7r7ocqWA+ADuKwHwiqbPhd0l/N7BUz66+iIQCdUfawf7a77zezr0n6m5ntcveXxz4g+6fAPwagy5Ta87v7/uz3iKRnJF09zmOWu3sfbwYC3aVw+M3sDDP7yuhtSXMkvVFVYwDaq8xhf4+kZ8xsdD1/cPe/VNIVgLYrHH53f0/S5RX2MmH19aXPaBYsWFBq/Tt37kzW582bl1s7dCh9Ffbo0aPJ+uTJk5P1LVu2JOuXX57/JzJ16tTkc9FeXOoDgiL8QFCEHwiK8ANBEX4gKMIPBMUU3RXo7e1N1rPPQuRqdinvhhtuSNaHh4eT9TKWLFmSrM+cObPwul944YXCz0V57PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICiu81fgueeeS9YvvPDCZP3IkSPJ+uHDh0+6p6osXLgwWZ80aVKHOkHV2PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBc5++AvXv31t1CrnvvvTdZv/jii0utf+vWrYVqaD/2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QlLl7+gFmKyTNlTTi7pdmy6ZIWi1puqQhSbe5+4dNN2aW3hgqN3fu3GR9zZo1yXqzKbpHRkaS9dR4AC+99FLyuSjG3dMTRWRa2fM/IenGzy17QNJGd79I0sbsPoAJpGn43f1lSZ8fSma+pJXZ7ZWSbqm4LwBtVvScv8fdR+eIOiCpp6J+AHRI6c/2u7unzuXNrF9Sf9ntAKhW0T3/QTPrlaTsd+67Pu6+3N373L2v4LYAtEHR8K+XtDi7vVjSs9W0A6BTmobfzJ6WtFnSN81sn5n9UNIvJF1vZu9Kui67D2ACaXrO7+6LckrXVtwL2qCvL3221ew6fjOrV69O1rmW3734hB8QFOEHgiL8QFCEHwiK8ANBEX4gKIbuPgWsW7cutzZnzpxS6161alWy/uCDD5ZaP+rDnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHgmo6dHelG2Po7kJ6e3uT9ddeey23NnXq1ORzDx06lKxfc801yfqePXuSdXRelUN3AzgFEX4gKMIPBEX4gaAIPxAU4QeCIvxAUHyffwIYGBhI1ptdy0956qmnknWu45+62PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBNr/Ob2QpJcyWNuPul2bKHJf1I0r+yhy119z+3q8lT3bx585L1WbNmFV73iy++mKw/9NBDhdeNia2VPf8Tkm4cZ/lv3f2K7IfgAxNM0/C7+8uSDnegFwAdVOac/04z22FmK8zsnMo6AtARRcO/TNIMSVdIGpb067wHmlm/mQ2a2WDBbQFog0Lhd/eD7n7c3U9IekzS1YnHLnf3PnfvK9okgOoVCr+ZjR1OdoGkN6ppB0CntHKp72lJ35H0VTPbJ+khSd8xsyskuaQhST9uY48A2qBp+N190TiLH29DL6esZt+3X7p0abI+adKkwtvevn17sn706NHC68bExif8gKAIPxAU4QeCIvxAUIQfCIrwA0ExdHcHLFmyJFm/6qqrSq1/3bp1uTW+sos87PmBoAg/EBThB4Ii/EBQhB8IivADQRF+IChz985tzKxzG+sin376abJe5iu7kjRt2rTc2vDwcKl1Y+Jxd2vlcez5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAovs9/CpgyZUpu7bPPPutgJ1/00Ucf5daa9dbs8w9nnXVWoZ4k6eyzz07W77777sLrbsXx48dza/fff3/yuZ988kklPbDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgml7nN7PzJK2S1CPJJS1399+Z2RRJqyVNlzQk6TZ3/7B9rSLPjh076m4h15o1a3JrzcYa6OnpSdZvv/32Qj11uwMHDiTrjzzySCXbaWXPf0zSEnefKelbkn5iZjMlPSBpo7tfJGljdh/ABNE0/O4+7O6vZrePSHpL0rmS5ktamT1spaRb2tUkgOqd1Dm/mU2XdKWkrZJ63H30uO2AGqcFACaIlj/bb2ZnShqQdJe7f2z2v2HC3N3zxuczs35J/WUbBVCtlvb8ZjZJjeD/3t3XZosPmllvVu+VNDLec919ubv3uXtfFQ0DqEbT8FtjF/+4pLfc/TdjSuslLc5uL5b0bPXtAWiXpkN3m9lsSZskvS7pRLZ4qRrn/X+S9HVJe9W41He4ybpCDt29du3aZH3+/Pkd6iSWY8eO5dZOnDiRW2vF+vXrk/XBwcHC6960aVOyvmXLlmS91aG7m57zu/s/JeWt7NpWNgKg+/AJPyAowg8ERfiBoAg/EBThB4Ii/EBQTNHdBe67775kvewU3imXXHJJst7Or82uWLEiWR8aGiq1/oGBgdzarl27Sq27mzFFN4Akwg8ERfiBoAg/EBThB4Ii/EBQhB8Iiuv8wCmG6/wAkgg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqKbhN7PzzOwfZvamme00s59myx82s/1mtj37uan97QKoStPBPMysV1Kvu79qZl+R9IqkWyTdJumou/+q5Y0xmAfQdq0O5vHlFlY0LGk4u33EzN6SdG659gDU7aTO+c1suqQrJW3NFt1pZjvMbIWZnZPznH4zGzSzwVKdAqhUy2P4mdmZkl6S9Ii7rzWzHkmHJLmkn6txavCDJuvgsB9os1YP+1sKv5lNkvS8pA3u/ptx6tMlPe/ulzZZD+EH2qyyATzNzCQ9LumtscHP3ggctUDSGyfbJID6tPJu/2xJmyS9LulEtnippEWSrlDjsH9I0o+zNwdT62LPD7RZpYf9VSH8QPsxbj+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQTQfwrNghSXvH3P9qtqwbdWtv3dqXRG9FVdnbN1p9YEe/z/+FjZsNuntfbQ0kdGtv3dqXRG9F1dUbh/1AUIQfCKru8C+vefsp3dpbt/Yl0VtRtfRW6zk/gPrUvecHUJNawm9mN5rZ22a228weqKOHPGY2ZGavZzMP1zrFWDYN2oiZvTFm2RQz+5uZvZv9HneatJp664qZmxMzS9f62nXbjNcdP+w3s9MkvSPpekn7JG2TtMjd3+xoIznMbEhSn7vXfk3YzL4t6aikVaOzIZnZLyUddvdfZP84z3H3+7ukt4d1kjM3t6m3vJmlv68aX7sqZ7yuQh17/qsl7Xb399z935L+KGl+DX10PXd/WdLhzy2eL2lldnulGn88HZfTW1dw92F3fzW7fUTS6MzStb52ib5qUUf4z5X0wZj7+9RdU367pL+a2Stm1l93M+PoGTMz0gFJPXU2M46mMzd30udmlu6a167IjNdV4w2/L5rt7rMkfU/ST7LD267kjXO2brpcs0zSDDWmcRuW9Os6m8lmlh6QdJe7fzy2VudrN05ftbxudYR/v6Tzxtyfli3rCu6+P/s9IukZNU5TusnB0UlSs98jNffzX+5+0N2Pu/sJSY+pxtcum1l6QNLv3X1ttrj21268vup63eoI/zZJF5nZ+WY2WdJCSetr6OMLzOyM7I0YmdkZkuao+2YfXi9pcXZ7saRna+zl/3TLzM15M0ur5teu62a8dveO/0i6SY13/PdI+lkdPeT0dYGk17KfnXX3JulpNQ4DP1PjvZEfSpoqaaOkdyX9XdKULurtSTVmc96hRtB6a+ptthqH9Dskbc9+bqr7tUv0Vcvrxif8gKB4ww8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFD/Abw9Wv8QfFP9AAAAAElFTkSuQmCC\n",
615 | "text/plain": [
616 | ""
617 | ]
618 | },
619 | "metadata": {},
620 | "output_type": "display_data"
621 | },
622 | {
623 | "name": "stdout",
624 | "output_type": "stream",
625 | "text": [
626 | "Model prediction: 2\n"
627 | ]
628 | },
629 | {
630 | "data": {
631 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADCRJREFUeJzt3X/oXfV9x/Hne1n6h2n/MKvGYMV0RaclYjK+iGCYHdXiRND8I1UYkcnSPxqwsD8m7o8JYyCydgz/KKQ0NJXOZkSDWqdtJ8N0MKpRM383OvmWJsREUahVpDN574/viXzV7z33m3vPvecm7+cDLt9zz+eee94c8srn/LrnE5mJpHr+oO8CJPXD8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKuoPp7myiPB2QmnCMjOW87mxev6IuCYifhURr0XE7eN8l6TpilHv7Y+IFcAB4GrgIPAUcFNmvtSyjD2/NGHT6PkvA17LzNcz8/fAj4Hrx/g+SVM0TvjPBX6z6P3BZt7HRMTWiNgXEfvGWJekjk38hF9mbge2g7v90iwZp+c/BJy36P0XmnmSTgHjhP8p4IKI+GJEfAb4OvBQN2VJmrSRd/sz88OI2Ab8FFgB7MjMFzurTNJEjXypb6SVecwvTdxUbvKRdOoy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqmoqQ7RrXouvPDCgW2vvPJK67K33XZba/s999wzUk1aYM8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0WNdZ0/IuaBd4FjwIeZOddFUTp9bNy4cWDb8ePHW5c9ePBg1+VokS5u8vnzzHyrg++RNEXu9ktFjRv+BH4WEU9HxNYuCpI0HePu9m/KzEMRcTbw84h4JTP3Lv5A85+C/zFIM2asnj8zDzV/jwJ7gMuW+Mz2zJzzZKA0W0YOf0SsiojPnZgGvga80FVhkiZrnN3+NcCeiDjxPf+amY91UpWkiRs5/Jn5OnBph7XoNLRhw4aBbe+9917rsnv27Om6HC3ipT6pKMMvFWX4paIMv1SU4ZeKMvxSUT66W2NZv359a/u2bdsGtt17771dl6OTYM8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0V5nV9jueiii1rbV61aNbBt165dXZejk2DPLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFRWZOb2UR01uZpuLJJ59sbT/rrLMGtg17FsCwR3traZkZy/mcPb9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFTX09/wRsQO4DjiameubeauBXcA6YB64MTPfmVyZ6su6deta2+fm5lrbDxw4MLDN6/j9Wk7P/wPgmk/Mux14PDMvAB5v3ks6hQwNf2buBd7+xOzrgZ3N9E7gho7rkjRhox7zr8nMw830G8CajuqRNCVjP8MvM7Ptnv2I2ApsHXc9kro1as9/JCLWAjR/jw76YGZuz8y5zGw/MyRpqkYN/0PAlmZ6C/BgN+VImpah4Y+I+4D/Bv4kIg5GxK3AXcDVEfEqcFXzXtIpZOgxf2beNKDpqx3Xohl05ZVXjrX8m2++2VEl6pp3+ElFGX6pKMMvFWX4paIMv1SU4ZeKcohutbrkkkvGWv7uu+/uqBJ1zZ5fKsrwS0UZfqkowy8VZfilogy/VJThl4pyiO7iLr/88tb2Rx55pLV9fn6+tf2KK64Y2PbBBx+0LqvROES3pFaGXyrK8EtFGX6pKMMvFWX4paIMv1SUv+cv7qqrrmptX716dWv7Y4891trutfzZZc8vFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UNvc4fETuA64Cjmbm+mXcn8NfAifGX78jMf59UkZqcSy+9tLV92PMedu/e3WU5mqLl9Pw/AK5ZYv4/Z+aG5mXwpVPM0PBn5l7g7SnUImmKxjnm3xYRz0XEjog4s7OKJE3FqOH/LvAlYANwGPj2oA9GxNaI2BcR+0Zcl6QJGCn8mXkkM49l5nHge8BlLZ/dnplzmTk3apGSujdS+CNi7aK3m4EXuilH0rQs51LffcBXgM9HxEHg74GvRMQGIIF54BsTrFHSBPjc/tPcOeec09q+f//+1vZ33nmntf3iiy8+6Zo0WT63X1Irwy8VZfilogy/VJThl4oy/FJRPrr7NHfLLbe0tp999tmt7Y8++miH1WiW2PNLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe5z/NnX/++WMtP+wnvTp12fNLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe5z/NXXfddWMt//DDD3dUiWaNPb9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFTX0On9EnAf8EFgDJLA9M/8lIlYDu4B1wDxwY2b64+8ebNq0aWDbsCG6Vddyev4Pgb/JzC8DlwPfjIgvA7cDj2fmBcDjzXtJp4ih4c/Mw5n5TDP9LvAycC5wPbCz+dhO4IZJFSmpeyd1zB8R64CNwC+BNZl5uGl6g4XDAkmniGXf2x8RnwXuB76Vmb+NiI/aMjMjIgcstxXYOm6hkrq1rJ4/IlayEPwfZeYDzewjEbG2aV8LHF1q2czcnplzmTnXRcGSujE0/LHQxX8feDkzv7Oo6SFgSzO9BXiw+/IkTcpydvuvAP4SeD4i9jfz7gDuAv4tIm4Ffg3cOJkSNczmzZsHtq1YsaJ12Weffba1fe/evSPVpNk3NPyZ+V9ADGj+arflSJoW7/CTijL8UlGGXyrK8EtFGX6pKMMvFeWju08BZ5xxRmv7tddeO/J37969u7X92LFjI3+3Zps9v1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VFZlLPn1rMisb8KgvtVu5cmVr+xNPPDGw7ejRJR+w9JGbb765tf39999vbdfsycxBP8H/GHt+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/zSacbr/JJaGX6pKMMvFWX4paIMv1SU4ZeKMvxSUUPDHxHnRcR/RsRLEfFiRNzWzL8zIg5FxP7mNfrD4yVN3dCbfCJiLbA2M5+JiM8BTwM3ADcCv8vMf1r2yrzJR5q45d7kM3TEnsw8DBxupt+NiJeBc8crT1LfTuqYPyLWARuBXzaztkXEcxGxIyLOHLDM1ojYFxH7xqpUUqeWfW9/RHwWeAL4x8x8ICLWAG8BCfwDC4cGfzXkO9ztlyZsubv9ywp/RKwEfgL8NDO/s0T7OuAnmbl+yPcYfmnCOvthT0QE8H3g5cXBb04EnrAZeOFki5TUn+Wc7d8E/AJ4HjjezL4DuAnYwMJu/zzwjebkYNt32fNLE9bpbn9XDL80ef6eX1Irwy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlFDH+DZsbeAXy96//lm3iya1dpmtS6wtlF1Wdv5y/3gVH/P/6mVR+zLzLneCmgxq7XNal1gbaPqqzZ3+6WiDL9UVN/h397z+tvMam2zWhdY26h6qa3XY35J/em755fUk17CHxHXRMSvIuK1iLi9jxoGiYj5iHi+GXm41yHGmmHQjkbEC4vmrY6In0fEq83fJYdJ66m2mRi5uWVk6V633ayNeD313f6IWAEcAK4GDgJPATdl5ktTLWSAiJgH5jKz92vCEfFnwO+AH54YDSki7gbezsy7mv84z8zMv52R2u7kJEdunlBtg0aWvoUet12XI153oY+e/zLgtcx8PTN/D/wYuL6HOmZeZu4F3v7E7OuBnc30Thb+8UzdgNpmQmYezsxnmul3gRMjS/e67Vrq6kUf4T8X+M2i9weZrSG/E/hZRDwdEVv7LmYJaxaNjPQGsKbPYpYwdOTmafrEyNIzs+1GGfG6a57w+7RNmfmnwF8A32x2b2dSLhyzzdLlmu8CX2JhGLfDwLf7LKYZWfp+4FuZ+dvFbX1uuyXq6mW79RH+Q8B5i95/oZk3EzLzUPP3KLCHhcOUWXLkxCCpzd+jPdfzkcw8kpnHMvM48D163HbNyNL3Az/KzAea2b1vu6Xq6mu79RH+p4ALIuKLEfEZ4OvAQz3U8SkRsao5EUNErAK+xuyNPvwQsKWZ3gI82GMtHzMrIzcPGlmanrfdzI14nZlTfwHXsnDG/3+Bv+ujhgF1/THwP83rxb5rA+5jYTfw/1g4N3Ir8EfA48CrwH8Aq2eotntZGM35ORaCtran2jaxsEv/HLC/eV3b97ZrqauX7eYdflJRnvCTijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1TU/wNRj+er2ohshAAAAABJRU5ErkJggg==\n",
632 | "text/plain": [
633 | ""
634 | ]
635 | },
636 | "metadata": {},
637 | "output_type": "display_data"
638 | },
639 | {
640 | "name": "stdout",
641 | "output_type": "stream",
642 | "text": [
643 | "Model prediction: 1\n"
644 | ]
645 | },
646 | {
647 | "data": {
648 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADbdJREFUeJzt3W+MFPUdx/HPF2qfYB9ouRL8U7DFYIhJpTmxDwi2thowGvCBijGGRtNDg2KTPqiBxGKaJo22NE0kkGskPRtrbYLGCyGVlphSE9J4mPrvrv7NQSEniDQqIaYI3z7YufaU298suzM7c3zfr+Ryu/Pdnf068rmZ3d/M/szdBSCeaVU3AKAahB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFBf6OaLmRmnEwIlc3dr5XEd7fnNbKmZvWFmb5vZA52sC0B3Wbvn9pvZdElvSrpW0gFJL0q6zd2HE89hzw+UrBt7/kWS3nb3d939P5L+IGl5B+sD0EWdhP9CSf+acP9AtuwzzKzPzIbMbKiD1wJQsNI/8HP3fkn9Eof9QJ10suc/KOniCfcvypYBmAI6Cf+Lki41s0vM7IuSVkoaLKYtAGVr+7Df3T81s3slPSdpuqSt7v56YZ0BKFXbQ31tvRjv+YHSdeUkHwBTF+EHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQXV1im5034wZM5L1Rx55JFlfvXp1sr53795k/eabb25a27dvX/K5KBd7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IqqNZes1sVNLHkk5K+tTde3Mezyy9XTZv3rxkfWRkpKP1T5uW3n+sXbu2aW3Tpk0dvTYm1+osvUWc5PMddz9SwHoAdBGH/UBQnYbfJe00s71m1ldEQwC6o9PD/sXuftDMviLpz2b2T3ffPfEB2R8F/jAANdPRnt/dD2a/D0t6RtKiSR7T7+69eR8GAuiutsNvZjPM7EvjtyVdJ+m1ohoDUK5ODvtnSXrGzMbX83t3/1MhXQEoXdvhd/d3JX2jwF7Qpp6enqa1gYGBLnaCqYShPiAowg8ERfiBoAg/EBThB4Ii/EBQfHX3FJC6LFaSVqxY0bS2aNFpJ1121ZIlS5rW8i4Hfvnll5P13bt3J+tIY88PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0F19NXdZ/xifHV3W06ePJmsnzp1qkudnC5vrL6T3vKm8L711luT9bzpw89WrX51N3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf4a2LFjR7K+bNmyZL3Kcf4PPvggWT927FjT2pw5c4pu5zOmT59e6vrrinF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU7vf2m9lWSTdIOuzul2fLzpf0lKS5kkYl3eLu/y6vzant6quvTtbnz5+frOeN45c5zr9ly5ZkfefOncn6hx9+2LR2zTXXJJ+7fv36ZD3PPffc07S2efPmjtZ9Nmhlz/9bSUs/t+wBSbvc/VJJu7L7AKaQ3PC7+25JRz+3eLmkgez2gKTmU8YAqKV23/PPcvex7PZ7kmYV1A+ALul4rj5399Q5+2bWJ6mv09cBUKx29/yHzGy2JGW/Dzd7oLv3u3uvu/e2+VoAStBu+Aclrcpur5L0bDHtAOiW3PCb2ZOS9kiab2YHzOwuST+XdK2ZvSXpe9l9AFMI1/MXYO7cucn6nj17kvWZM2cm6518N37ed99v27YtWX/ooYeS9ePHjyfrKXnX8+dtt56enmT9k08+aVp78MEHk8999NFHk/UTJ04k61Xien4ASYQfCIrwA0ERfiAowg8ERfiBoBjqK8C8efOS9ZGRkY7WnzfU9/zzzzetrVy5MvncI0eOtNVTN9x3333J+saNG5P11HbLuwz6sssuS9bfeeedZL1KDPUBSCL8QFCEHwiK8ANBEX4gKMIPBEX4gaA6/hovlG9oaChZv/POO5vW6jyOn2dwcDBZv/3225P1K6+8ssh2zjrs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMb5uyDvevw8V111VUGdTC1m6cvS87ZrJ9t9w4YNyfodd9zR9rrrgj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwSVO85vZlsl3SDpsLtfni3bIOkHkt7PHrbO3XeU1WTd3X333cl63nfEY3I33nhjsr5w4cJkPbXd8/6f5I3znw1a2fP/VtLSSZb/yt2vyH7CBh+YqnLD7+67JR3tQi8AuqiT9/z3mtkrZrbVzM4rrCMAXdFu+DdL+rqkKySNSfplsweaWZ+ZDZlZ+ovoAHRVW+F390PuftLdT0n6jaRFicf2u3uvu/e22ySA4rUVfjObPeHuTZJeK6YdAN3SylDfk5K+LWmmmR2Q9BNJ3zazKyS5pFFJq0vsEUAJcsPv7rdNsvixEnqZsvLGoyPr6elpWluwYEHyuevWrSu6nf95//33k/UTJ06U9tp1wRl+QFCEHwiK8ANBEX4gKMIPBEX4gaD46m6Uav369U1ra9asKfW1R0dHm9ZWrVqVfO7+/fsL7qZ+2PMDQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM86MjO3akv7h5/vz5XerkdMPDw01rL7zwQhc7qSf2/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8BTCzZH3atM7+xi5btqzt5/b39yfrF1xwQdvrlvL/26qcnpyvVE9jzw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQeWO85vZxZIelzRLkkvqd/dfm9n5kp6SNFfSqKRb3P3f5bVaX5s3b07WH3744Y7Wv3379mS9k7H0ssfhy1z/li1bSlt3BK3s+T+V9CN3XyDpW5LWmNkCSQ9I2uXul0rald0HMEXkht/dx9z9pez2x5JGJF0oabmkgexhA5JWlNUkgOKd0Xt+M5sraaGkv0ua5e5jWek9Nd4WAJgiWj6338zOlbRN0g/d/aOJ57O7u5uZN3len6S+ThsFUKyW9vxmdo4awX/C3Z/OFh8ys9lZfbakw5M919373b3X3XuLaBhAMXLDb41d/GOSRtx944TSoKTxqU5XSXq2+PYAlMXcJz1a//8DzBZL+pukVyWNj9usU+N9/x8lfVXSPjWG+o7mrCv9YlPUnDlzkvU9e/Yk6z09Pcl6nS+bzevt0KFDTWsjIyPJ5/b1pd8tjo2NJevHjx9P1s9W7p6+xjyT+57f3V+Q1Gxl3z2TpgDUB2f4AUERfiAowg8ERfiBoAg/EBThB4LKHecv9MXO0nH+PEuWLEnWV6xIXxN1//33J+t1Hudfu3Zt09qmTZuKbgdqfZyfPT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4/xSwdOnSZD113XveNNWDg4PJet4U33nTkw8PDzet7d+/P/lctIdxfgBJhB8IivADQRF+ICjCDwRF+IGgCD8QFOP8wFmGcX4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EFRu+M3sYjN73syGzex1M7s/W77BzA6a2T+yn+vLbxdAUXJP8jGz2ZJmu/tLZvYlSXslrZB0i6Rj7v6Lll+Mk3yA0rV6ks8XWljRmKSx7PbHZjYi6cLO2gNQtTN6z29mcyUtlPT3bNG9ZvaKmW01s/OaPKfPzIbMbKijTgEUquVz+83sXEl/lfQzd3/azGZJOiLJJf1UjbcGd+asg8N+oGStHva3FH4zO0fSdknPufvGSepzJW1398tz1kP4gZIVdmGPNb6e9TFJIxODn30QOO4mSa+daZMAqtPKp/2LJf1N0quSxueCXifpNklXqHHYPyppdfbhYGpd7PmBkhV62F8Uwg+Uj+v5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgsr9As+CHZG0b8L9mdmyOqprb3XtS6K3dhXZ25xWH9jV6/lPe3GzIXfvrayBhLr2Vte+JHprV1W9cdgPBEX4gaCqDn9/xa+fUtfe6tqXRG/tqqS3St/zA6hO1Xt+ABWpJPxmttTM3jCzt83sgSp6aMbMRs3s1Wzm4UqnGMumQTtsZq9NWHa+mf3ZzN7Kfk86TVpFvdVi5ubEzNKVbru6zXjd9cN+M5su6U1J10o6IOlFSbe5+3BXG2nCzEYl9bp75WPCZrZE0jFJj4/PhmRmD0s66u4/z/5wnufuP65Jbxt0hjM3l9Rbs5mlv68Kt12RM14XoYo9/yJJb7v7u+7+H0l/kLS8gj5qz913Szr6ucXLJQ1ktwfU+MfTdU16qwV3H3P3l7LbH0san1m60m2X6KsSVYT/Qkn/mnD/gOo15bdL2mlme82sr+pmJjFrwsxI70maVWUzk8idubmbPjezdG22XTszXheND/xOt9jdvylpmaQ12eFtLXnjPVudhms2S/q6GtO4jUn6ZZXNZDNLb5P0Q3f/aGKtym03SV+VbLcqwn9Q0sUT7l+ULasFdz+Y/T4s6Rk13qbUyaHxSVKz34cr7ud/3P2Qu59091OSfqMKt102s/Q2SU+4+9PZ4sq33WR9VbXdqgj/i5IuNbNLzOyLklZKGqygj9OY2YzsgxiZ2QxJ16l+sw8PSlqV3V4l6dkKe/mMuszc3GxmaVW87Wo347W7d/1H0vVqfOL/jqT1VfTQpK+vSXo5+3m96t4kPanGYeAJNT4buUvSlyXtkvSWpL9IOr9Gvf1OjdmcX1EjaLMr6m2xGof0r0j6R/ZzfdXbLtFXJduNM/yAoPjADwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUP8FAfaK+yOWZZUAAAAASUVORK5CYII=\n",
649 | "text/plain": [
650 | ""
651 | ]
652 | },
653 | "metadata": {},
654 | "output_type": "display_data"
655 | },
656 | {
657 | "name": "stdout",
658 | "output_type": "stream",
659 | "text": [
660 | "Model prediction: 0\n"
661 | ]
662 | },
663 | {
664 | "data": {
665 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADXVJREFUeJzt3W+oXPWdx/HPZ00bMQ2Su8FwScPeGmUlBDfViygb1krXmI2VWPxDQliyKr19UGGL+2BFhRV1QWSbpU8MpBgal27aRSOGWvpnQ1xXWEpuJKvRu60xpCQh5o9paCKBau53H9wTuSZ3ztzMnJkzc7/vF1zuzPmeM/PlJJ/7O2fOzPwcEQKQz5/U3QCAehB+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJzermk9nm7YRAh0WEp7NeWyO/7ZW2f2N7n+1H23ksAN3lVt/bb/sySb+VdLukQ5J2SVobEe+VbMPID3RYN0b+myTti4j9EfFHST+WtLqNxwPQRe2Ef6Gkg5PuHyqWfY7tEdujtkfbeC4AFev4C34RsUnSJonDfqCXtDPyH5a0aNL9LxfLAPSBdsK/S9K1tr9i+4uS1kjaXk1bADqt5cP+iPjU9sOSfiHpMkmbI+LdyjoD0FEtX+pr6ck45wc6ritv8gHQvwg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IquUpuiXJ9gFJpyWdk/RpRAxX0RQ+74Ybbiitb9u2rWFtaGio4m56x4oVK0rrY2NjDWsHDx6sup2+01b4C7dFxIkKHgdAF3HYDyTVbvhD0i9t77Y9UkVDALqj3cP+5RFx2PZVkn5l+/8i4o3JKxR/FPjDAPSYtkb+iDhc/D4m6RVJN02xzqaIGObFQKC3tBx+23Nszz1/W9IKSXuragxAZ7Vz2L9A0iu2zz/Ov0fEzyvpCkDHtRz+iNgv6S8q7AUN3HHHHaX12bNnd6mT3nLXXXeV1h988MGGtTVr1lTdTt/hUh+QFOEHkiL8QFKEH0iK8ANJEX4gqSo+1Yc2zZpV/s+watWqLnXSX3bv3l1af+SRRxrW5syZU7rtxx9/3FJP/YSRH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeS4jp/D7jttttK67fccktp/bnnnquynb4xb9680vqSJUsa1q644orSbbnOD2DGIvxAUoQfSIrwA0kRfiApwg8kRfiBpBwR3Xsyu3tP1kOWLl1aWn/99ddL6x999FFp/cYbb2xYO3PmTOm2/azZflu+fHnD2uDgYOm2x48fb6WlnhARns56jPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTTz/Pb3izpG5KORcTSYtmApJ9IGpJ0QNL9EfH7zrXZ35544onSerPvkF+5cmVpfaZeyx8YGCit33rrraX18fHxKtuZcaYz8v9Q0oX/+x6VtCMirpW0o7gPoI80DX9EvCHp5AWLV0vaUtzeIunuivsC0GGtnvMviIgjxe0PJS2oqB8AXdL2d/hFRJS9Z9/2iKSRdp8HQLVaHfmP2h6UpOL3sUYrRsSmiBiOiOEWnwtAB7Qa/u2S1he310t6tZp2AHRL0/Db3irpfyT9ue1Dth+S9Kyk222/L+mvi/sA+kjTc/6IWNug9PWKe+lb9957b2l91apVpfV9+/aV1kdHRy+5p5ng8ccfL603u45f9nn/U6dOtdLSjMI7/ICkCD+QFOEHkiL8QFKEH0iK8ANJMUV3Be67777SerPpoJ9//vkq2+kbQ0NDpfV169aV1s+dO1daf+aZZxrWPvnkk9JtM2DkB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuM4/TVdeeWXD2s0339zWY2/cuLGt7fvVyEj5t7vNnz+/tD42NlZa37lz5yX3lAkjP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kxXX+aZo9e3bD2sKFC0u33bp1a9XtzAiLFy9ua/u9e/dW1ElOjPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTT6/y2N0v6hqRjEbG0WPakpG9JOl6s9lhE/KxTTfaC06dPN6zt2bOndNvrr7++tD4wMFBaP3nyZGm9l1111VUNa82mNm/mzTffbGv77KYz8v9Q0soplv9rRCwrfmZ08IGZqGn4I+INSf079ACYUjvn/A/bftv2ZtvzKusIQFe0Gv6NkhZLWibpiKTvNVrR9ojtUdujLT4XgA5oKfwRcTQizkXEuKQfSLqpZN1NETEcEcOtNgmgei2F3/bgpLvflMTHq4A+M51LfVslfU3SfNuHJP2TpK/ZXiYpJB2Q9O0O9gigA5qGPyLWTrH4hQ700tPOnj3bsPbBBx+UbnvPPfeU1l977bXS+oYNG0rrnbR06dLS+tVXX11aHxoaaliLiFZa+sz4+Hhb22fHO/yApAg/kBThB5Ii/EBShB9IivADSbndyy2X9GR2956si6677rrS+lNPPVVav/POO0vrZV8b3mknTpworTf7/1M2zbbtlno6b+7cuaX1ssuzM1lETGvHMvIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJc5+8By5YtK61fc801XerkYi+99FJb22/ZsqVhbd26dW099qxZzDA/Fa7zAyhF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcaG0BzSb4rtZvZft37+/Y4/d7GvF9+5lLpkyjPxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kFTT6/y2F0l6UdICSSFpU0R83/aApJ9IGpJ0QNL9EfH7zrWKflT23fztfm8/1/HbM52R/1NJ/xARSyTdLOk7tpdIelTSjoi4VtKO4j6APtE0/BFxJCLeKm6fljQmaaGk1ZLOf03LFkl3d6pJANW7pHN+20OSvirp15IWRMSRovShJk4LAPSJab+33/aXJL0s6bsR8YfJ52sREY2+n8/2iKSRdhsFUK1pjfy2v6CJ4P8oIrYVi4/aHizqg5KOTbVtRGyKiOGIGK6iYQDVaBp+TwzxL0gai4gNk0rbJa0vbq+X9Gr17QHolOkc9v+lpL+V9I7t858tfUzSs5L+w/ZDkn4n6f7OtIh+VvbV8N382nhcrGn4I+JNSY0uyH692nYAdAvv8AOSIvxAUoQfSIrwA0kRfiApwg8kxVd3o6Muv/zylrc9e/ZshZ3gQoz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU1/nRUQ888EDD2qlTp0q3ffrpp6tuB5Mw8gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUlznR0ft2rWrYW3Dhg0Na5K0c+fOqtvBJIz8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5CUm82RbnuRpBclLZAUkjZFxPdtPynpW5KOF6s+FhE/a/JYTMgOdFhEeDrrTSf8g5IGI+It23Ml7ZZ0t6T7JZ2JiH+ZblOEH+i86Ya/6Tv8IuKIpCPF7dO2xyQtbK89AHW7pHN+20OSvirp18Wih22/bXuz7XkNthmxPWp7tK1OAVSq6WH/ZyvaX5L0X5L+OSK22V4g6YQmXgd4WhOnBg82eQwO+4EOq+ycX5Jsf0HSTyX9IiIu+jRGcUTw04hY2uRxCD/QYdMNf9PDftuW9IKkscnBL14IPO+bkvZeapMA6jOdV/uXS/pvSe9IGi8WPyZpraRlmjjsPyDp28WLg2WPxcgPdFilh/1VIfxA51V22A9gZiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1e0puk9I+t2k+/OLZb2oV3vr1b4kemtVlb392XRX7Orn+S96cns0IoZra6BEr/bWq31J9NaqunrjsB9IivADSdUd/k01P3+ZXu2tV/uS6K1VtfRW6zk/gPrUPfIDqEkt4be90vZvbO+z/WgdPTRi+4Dtd2zvqXuKsWIatGO2905aNmD7V7bfL35POU1aTb09aftwse/22F5VU2+LbO+0/Z7td23/fbG81n1X0lct+63rh/22L5P0W0m3SzokaZektRHxXlcbacD2AUnDEVH7NWHbfyXpjKQXz8+GZPs5SScj4tniD+e8iPjHHuntSV3izM0d6q3RzNJ/pxr3XZUzXlehjpH/Jkn7ImJ/RPxR0o8lra6hj54XEW9IOnnB4tWSthS3t2jiP0/XNeitJ0TEkYh4q7h9WtL5maVr3XclfdWijvAvlHRw0v1D6q0pv0PSL23vtj1SdzNTWDBpZqQPJS2os5kpNJ25uZsumFm6Z/ZdKzNeV40X/C62PCJukPQ3kr5THN72pJg4Z+ulyzUbJS3WxDRuRyR9r85mipmlX5b03Yj4w+Ranftuir5q2W91hP+wpEWT7n+5WNYTIuJw8fuYpFc0cZrSS46enyS1+H2s5n4+ExFHI+JcRIxL+oFq3HfFzNIvS/pRRGwrFte+76bqq679Vkf4d0m61vZXbH9R0hpJ22vo4yK25xQvxMj2HEkr1HuzD2+XtL64vV7SqzX28jm9MnNzo5mlVfO+67kZryOi6z+SVmniFf8PJD1eRw8N+rpa0v8WP+/W3ZukrZo4DPxEE6+NPCTpTyXtkPS+pP+UNNBDvf2bJmZzflsTQRusqbflmjikf1vSnuJnVd37rqSvWvYb7/ADkuIFPyApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSf0/fhI1ni26LDgAAAAASUVORK5CYII=\n",
666 | "text/plain": [
667 | ""
668 | ]
669 | },
670 | "metadata": {},
671 | "output_type": "display_data"
672 | },
673 | {
674 | "name": "stdout",
675 | "output_type": "stream",
676 | "text": [
677 | "Model prediction: 4\n"
678 | ]
679 | },
680 | {
681 | "data": {
682 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADGdJREFUeJzt3X/oXfV9x/HnW5cK2v5hUhaCCUsXZFAU7PiqIwvSsVmdVGJRpP4xMiZN/2hghf0xMX9MGAMZa0f+iqYYGqVLO/BXKGVNFoauMkoSyTTqWrOS2ISYNPijFgwxyXt/fE/ct/q95369v8795v18wJd77/mce86bQ175nB/3nE9kJpLquazrAiR1w/BLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrqdya5sojw54TSmGVmLGS+oXr+iLg9In4WEYcj4oFhliVpsmLQ3/ZHxOXAz4FbgWPAPuC+zHy15Tv2/NKYTaLnvwk4nJm/yMyzwPeB9UMsT9IEDRP+a4Bfzvl8rJn2WyJiY0Tsj4j9Q6xL0oiN/YRfZm4DtoG7/dI0GabnPw6smvN5ZTNN0iIwTPj3AddGxOci4lPAV4FdoylL0rgNvNufmeciYhPwY+ByYHtmvjKyyiSN1cCX+gZamcf80thN5Ec+khYvwy8VZfilogy/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VZfilogy/VJThl4oaeIhugIg4ArwHnAfOZebMKIrSpePOO+/s2bZr167W727atKm1/ZFHHmltP3/+fGt7dUOFv/EnmXl6BMuRNEHu9ktFDRv+BHZHxIGI2DiKgiRNxrC7/esy83hE/C6wJyL+JzOfnztD85+C/zFIU2aonj8zjzevp4CngZvmmWdbZs54MlCaLgOHPyKuiojPXHwPfAk4NKrCJI3XMLv9y4GnI+Licv4lM/9tJFVJGrvIzMmtLGJyK9NELFu2rLX94MGDPdtWrlw51LqvvPLK1vb3339/qOUvVpkZC5nPS31SUYZfKsrwS0UZfqkowy8VZfilokZxV58Ku+WWW1rbh7mct3Pnztb2M2fODLxs2fNLZRl+qSjDLxVl+KWiDL9UlOGXijL8UlFe51erK664orV98+bNY1v3E0880do+ydvRL0X2/FJRhl8qyvBLRRl+qSjDLxVl+KWiDL9UlI/uVquZmfaBlvbt2zfwss+dO9favmTJkoGXXZmP7pbUyvBLRRl+qSjDLxVl+KWiDL9UlOGXiup7P39EbAe+DJzKzOuaaUuBHwCrgSPAvZn59vjKVFfuvvvusS179+7dY1u2+ltIz/9d4PaPTHsA2JuZ1wJ7m8+SFpG+4c/M54G3PjJ5PbCjeb8DuGvEdUkas0GP+Zdn5onm/ZvA8hHVI2lChn6GX2Zm22/2I2IjsHHY9UgarUF7/pMRsQKgeT3Va8bM3JaZM5nZfoeIpIkaNPy7gA3N+w3As6MpR9Kk9A1/ROwE/gv4g4g4FhH3Aw8Dt0bE68CfNZ8lLSLez69WL7zwQmv72rVrW9vPnj3bs+3mm29u/e7Bgwdb2zU/7+eX1MrwS0UZfqkowy8VZfilogy/VJSX+orrd6mu36W+ft5+u/ed3kuXLh1q2Zqfl/oktTL8UlGGXyrK8EtFGX6pKMMvFWX4paKGfoyXFrcbb7xxrMvfunXrWJevwdnzS0UZfqkowy8VZfilogy/VJThl4oy/FJRXucvbmZmuIGU3nnnndZ2r/NPL3t+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyqq73P7I2I78GXgVGZe10x7CPga8Ktmtgcz80d9V+Zz+ydu3bp1re3PPfdca/tll7X3D0ePHm1tX716dWu7Rm+Uz+3/LnD7PNP/OTNvaP76Bl/SdOkb/sx8HnhrArVImqBhjvk3RcRLEbE9Iq4eWUWSJmLQ8G8F1gA3ACeAb/WaMSI2RsT+iNg/4LokjcFA4c/Mk5l5PjMvAN8BbmqZd1tmzmTmcHeQSBqpgcIfESvmfPwKcGg05UialL639EbETuCLwGcj4hjwd8AXI+IGIIEjwNfHWKOkMegb/sy8b57Jj42hFo3BsmXLWtv7XcfvZ8+ePUN9X93xF35SUYZfKsrwS0UZfqkowy8VZfilonx09yXunnvuGer7/R7N/eijjw61fHXHnl8qyvBLRRl+qSjDLxVl+KWiDL9UlOGXiur76O6RrsxHd4/FypUre7b1e7R2v1t6Dx1qf07L9ddf39quyRvlo7slXYIMv1SU4ZeKMvxSUYZfKsrwS0UZfqko7+e/BKxdu7Zn27CP5n7mmWeG+r6mlz2/VJThl4oy/FJRhl8qyvBLRRl+qSjDLxXV9zp/RKwCHgeWAwlsy8wtEbEU+AGwGjgC3JuZb4+vVPXSbxjuNqdPn25t37Jly8DL1nRbSM9/DvibzPw88EfANyLi88ADwN7MvBbY23yWtEj0DX9mnsjMF5v37wGvAdcA64EdzWw7gLvGVaSk0ftEx/wRsRr4AvBTYHlmnmia3mT2sEDSIrHg3/ZHxKeBJ4FvZuavI/7/MWGZmb2ezxcRG4GNwxYqabQW1PNHxBJmg/+9zHyqmXwyIlY07SuAU/N9NzO3ZeZMZs6MomBJo9E3/DHbxT8GvJaZ357TtAvY0LzfADw7+vIkjctCdvv/GPgL4OWIONhMexB4GPjXiLgfOArcO54S1c9tt9028HffeOON1vZ333134GVruvUNf2b+BOj1HPA/HW05kibFX/hJRRl+qSjDLxVl+KWiDL9UlOGXivLR3YvAkiVLWtvXrFkz8LLPnDnT2v7BBx8MvGxNN3t+qSjDLxVl+KWiDL9UlOGXijL8UlGGXyrK6/yLwIULF1rb9+/f37Ptuuuua/3u4cOHB6pJi589v1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8V5XX+ReD8+fOt7Zs3b+7ZljnvKGofOnDgwEA1afGz55eKMvxSUYZfKsrwS0UZfqkowy8VZfiloqLfdeCIWAU8DiwHEtiWmVsi4iHga8CvmlkfzMwf9VlW+8okDS0zYyHzLST8K4AVmfliRHwGOADcBdwL/CYz/2mhRRl+afwWGv6+v/DLzBPAieb9exHxGnDNcOVJ6tonOuaPiNXAF4CfNpM2RcRLEbE9Iq7u8Z2NEbE/Ino/a0rSxPXd7f9wxohPA88B/5CZT0XEcuA0s+cB/p7ZQ4O/6rMMd/ulMRvZMT9ARCwBfgj8ODO/PU/7auCHmdn6tEjDL43fQsPfd7c/IgJ4DHhtbvCbE4EXfQU49EmLlNSdhZztXwf8J/AycPEZ0g8C9wE3MLvbfwT4enNysG1Z9vzSmI10t39UDL80fiPb7Zd0aTL8UlGGXyrK8EtFGX6pKMMvFWX4paIMv1SU4ZeKMvxSUYZfKsrwS0UZfqkowy8VNekhuk8DR+d8/mwzbRpNa23TWhdY26BGWdvvLXTGid7P/7GVR+zPzJnOCmgxrbVNa11gbYPqqjZ3+6WiDL9UVNfh39bx+ttMa23TWhdY26A6qa3TY35J3em655fUkU7CHxG3R8TPIuJwRDzQRQ29RMSRiHg5Ig52PcRYMwzaqYg4NGfa0ojYExGvN6/zDpPWUW0PRcTxZtsdjIg7OqptVUT8R0S8GhGvRMRfN9M73XYtdXWy3Sa+2x8RlwM/B24FjgH7gPsy89WJFtJDRBwBZjKz82vCEXEL8Bvg8YujIUXEPwJvZebDzX+cV2fm305JbQ/xCUduHlNtvUaW/ks63HajHPF6FLro+W8CDmfmLzLzLPB9YH0HdUy9zHweeOsjk9cDO5r3O5j9xzNxPWqbCpl5IjNfbN6/B1wcWbrTbddSVye6CP81wC/nfD7GdA35ncDuiDgQERu7LmYey+eMjPQmsLzLYubRd+TmSfrIyNJTs+0GGfF61Dzh93HrMvMPgT8HvtHs3k6lnD1mm6bLNVuBNcwO43YC+FaXxTQjSz8JfDMzfz23rcttN09dnWy3LsJ/HFg15/PKZtpUyMzjzesp4GlmD1OmycmLg6Q2r6c6rudDmXkyM89n5gXgO3S47ZqRpZ8EvpeZTzWTO99289XV1XbrIvz7gGsj4nMR8Sngq8CuDur4mIi4qjkRQ0RcBXyJ6Rt9eBewoXm/AXi2w1p+y7SM3NxrZGk63nZTN+J1Zk78D7iD2TP+/wts7qKGHnX9PvDfzd8rXdcG7GR2N/ADZs+N3A8sA/YCrwP/DiydotqeYHY055eYDdqKjmpbx+wu/UvAwebvjq63XUtdnWw3f+EnFeUJP6kowy8VZfilogy/VJThl4oy/FJRhl8qyvBLRf0f7V4JFFPw3M8AAAAASUVORK5CYII=\n",
683 | "text/plain": [
684 | ""
685 | ]
686 | },
687 | "metadata": {},
688 | "output_type": "display_data"
689 | },
690 | {
691 | "name": "stdout",
692 | "output_type": "stream",
693 | "text": [
694 | "Model prediction: 1\n"
695 | ]
696 | },
697 | {
698 | "data": {
699 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADbxJREFUeJzt3X+MFPUZx/HPU6uJEWKOoicqKWhME38Vy8WYFIXGiqhN0BiNROsZiYfxR6ppDIYaazRNTFNs/EeSMxDOH1X8hRL8hZKmtKExAjnA06onOQU8OVSM518oPP1jh/bE2+8uu7M7ezzvV3K53Xl2Zp4MfG5md2b2a+4uAPH8qOgGABSD8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCOrHzVyZmXE5IdBg7m7VvK6uPb+ZzTGz982s38zurmdZAJrLar2238yOkPSBpIsk7ZD0tqR57v5uYh72/ECDNWPPf66kfnff5u57JT0taW4dywPQRPWE/yRJ20c835FN+x4z6zKzDWa2oY51AchZwz/wc/duSd0Sh/1AK6lnz79T0uQRz0/OpgEYA+oJ/9uSTjOzqWZ2lKRrJK3Kpy0AjVbzYb+7f2dmt0l6XdIRkpa5e19unQFoqJpP9dW0Mt7zAw3XlIt8AIxdhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRV8xDdkmRmA5KGJe2T9J27d+TRVDTHH398sv7MM88k6+vXry9b6+7uTs47MDCQrB+ujj322GT9ggsuSNZfe+21ZP3bb7895J6ara7wZ37l7p/nsBwATcRhPxBUveF3SWvMbKOZdeXREIDmqPewf4a77zSz4yW9YWb/cfd1I1+Q/VHgDwPQYura87v7zuz3kKSVks4d5TXd7t7Bh4FAa6k5/GZ2jJmNP/BY0mxJ7+TVGIDGquewv13SSjM7sJy/uXv6/AeAllFz+N19m6Sf59jLYautrS1Z7+vrS9YrnZPetWtX2VrU8/hSertt3LgxOe9xxx2XrE+fPj1Z7+/vT9ZbAaf6gKAIPxAU4QeCIvxAUIQfCIrwA0HlcVdfeBMnTkzWV6xYkaxPmDAhWX/kkUeS9dtvvz1Zj+qee+4pW5s6dWpy3gULFiTrY+FUXiXs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKHP35q3MrHkra6LZs2cn66+++mpdyz/hhBOS9d27d9e1/LHqjDPOSNa3bt1atrZy5crkvDfccEOyPjw8nKwXyd2tmtex5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoLifv0qpYbSvvPLKupY9f/78ZJ3z+KN78803a152pfP8rXwePy/s+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqIrn+c1smaTfSBpy9zOzaRMkrZA0RdKApKvdfU/j2ize4sWLy9auu+665LyVhoN+9tlna+rpcHf++ecn6+3t7cn68uXLy9aeeOKJWlo6rFSz518uac5B0+6WtNbdT5O0NnsOYAypGH53Xyfpy4Mmz5XUkz3ukXR5zn0BaLBa3/O3u/tg9vgzSenjLwAtp+5r+93dU9/NZ2ZdkrrqXQ+AfNW6599lZpMkKfs9VO6F7t7t7h3u3lHjugA0QK3hXyWpM3vcKemlfNoB0CwVw29mT0n6t6SfmdkOM5sv6UFJF5nZh5J+nT0HMIZUfM/v7vPKlC7MuZeWlhrfYP/+/cl5P/3002R97969NfU0Fhx99NFla4sWLUrOe8sttyTrlcacuPHGG5P16LjCDwiK8ANBEX4gKMIPBEX4gaAIPxAUX93dBJdddlmyvmbNmmT9q6++StaXLFlyyD3lZebMmcn6rFmzytbOO++8utb93HPP1TV/dOz5gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoq3RbZK4rS3zdV6ubPn162dqLL76YnPfEE0+sa91mlqw389/wYI3sbdu2bcn6nDkHf6n093300Uc1r3ssc/f0P0qGPT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMX9/FVKDbN99tlnJ+edNm1asl7pfPVdd92VrO/evbtsraenp2wtD48//niyvnnz5pqXvX79+mQ96nn8vLDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKt7Pb2bLJP1G0pC7n5lNu0/STZIOnGBe5O6vVFzZGL6fH6M75ZRTkvX+/v6ytd7e3uS8F198cbKeur4hsjzv518uabSrUP7q7tOyn4rBB9BaKobf3ddJ+rIJvQBoonre899mZlvMbJmZteXWEYCmqDX8SySdKmmapEFJi8u90My6zGyDmW2ocV0AGqCm8Lv7Lnff5+77JT0q6dzEa7vdvcPdO2ptEkD+agq/mU0a8fQKSe/k0w6AZql4S6+ZPSVplqSJZrZD0h8lzTKzaZJc0oCkBQ3sEUADVAy/u88bZfLSBvSCMejee+9N1lPXkSxcuDA5L+fxG4sr/ICgCD8QFOEHgiL8QFCEHwiK8ANB8dXdSLrqqquS9euvvz5ZHx4eLlv74osvauoJ+WDPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBcZ4fSZdcckld869evbpsbdOmTXUtG/Vhzw8ERfiBoAg/EBThB4Ii/EBQhB8IivADQVUcojvXlTFE95gzODiYrI8bNy5ZnzlzZtka5/kbI88hugEchgg/EBThB4Ii/EBQhB8IivADQRF+IKiK9/Ob2WRJj0lql+SSut39YTObIGmFpCmSBiRd7e57GtcqGuHmm29O1tvb25P1oaGhZJ1z+a2rmj3/d5J+7+6nSzpP0q1mdrqkuyWtdffTJK3NngMYIyqG390H3X1T9nhY0nuSTpI0V1JP9rIeSZc3qkkA+Tuk9/xmNkXSOZLektTu7geu/fxMpbcFAMaIqr/Dz8zGSXpe0h3u/rXZ/y8fdncvd92+mXVJ6qq3UQD5qmrPb2ZHqhT8J939hWzyLjOblNUnSRr1kx9373b3DnfvyKNhAPmoGH4r7eKXSnrP3R8aUVolqTN73CnppfzbA9Ao1Rz2/1LSbyVtNbPebNoiSQ9KesbM5kv6WNLVjWkRjVTpVF+lW75ffvnlmtc9fvz4ZL2trS1Z/+STT2peN6oIv7v/S1K5+4MvzLcdAM3CFX5AUIQfCIrwA0ERfiAowg8ERfiBoBiiG3XZt29fsn7ttdeWrd15553Jefv6+pL1zs7OZB1p7PmBoAg/EBThB4Ii/EBQhB8IivADQRF+ICiG6A6ut7c3WT/rrLOS9ZFf5zaa1P+vpUuXJud94IEHkvXt27cn61ExRDeAJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrz/MHNmDEjWb///vuT9XXr1iXrS5YsKVvbsyc9ovvevXuTdYyO8/wAkgg/EBThB4Ii/EBQhB8IivADQRF+IKiK5/nNbLKkxyS1S3JJ3e7+sJndJ+kmSbuzly5y91cqLIvz/ECDVXuev5rwT5I0yd03mdl4SRslXS7paknfuPtfqm2K8AONV234K47Y4+6Dkgazx8Nm9p6kk+prD0DRDuk9v5lNkXSOpLeySbeZ2RYzW2ZmbWXm6TKzDWa2oa5OAeSq6mv7zWycpH9I+pO7v2Bm7ZI+V+lzgAdUemtwY4VlcNgPNFhu7/klycyOlLRa0uvu/tAo9SmSVrv7mRWWQ/iBBsvtxh4rfT3rUknvjQx+9kHgAVdIeudQmwRQnGo+7Z8h6Z+Stkran01eJGmepGkqHfYPSFqQfTiYWhZ7fqDBcj3szwvhBxqP+/kBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCqvgFnjn7XNLHI55PzKa1olbtrVX7kuitVnn29tNqX9jU+/l/sHKzDe7eUVgDCa3aW6v2JdFbrYrqjcN+ICjCDwRVdPi7C15/Sqv21qp9SfRWq0J6K/Q9P4DiFL3nB1CQQsJvZnPM7H0z6zezu4vooRwzGzCzrWbWW/QQY9kwaENm9s6IaRPM7A0z+zD7PeowaQX1dp+Z7cy2Xa+ZXVpQb5PN7O9m9q6Z9ZnZ77LphW67RF+FbLemH/ab2RGSPpB0kaQdkt6WNM/d321qI2WY2YCkDncv/JywmV0g6RtJjx0YDcnM/izpS3d/MPvD2ebuC1ukt/t0iCM3N6i3ciNL36ACt12eI17noYg9/7mS+t19m7vvlfS0pLkF9NHy3H2dpC8PmjxXUk/2uEel/zxNV6a3luDug+6+KXs8LOnAyNKFbrtEX4UoIvwnSdo+4vkOtdaQ3y5pjZltNLOuopsZRfuIkZE+k9ReZDOjqDhyczMdNLJ0y2y7Wka8zhsf+P3QDHf/haRLJN2aHd62JC+9Z2ul0zVLJJ2q0jBug5IWF9lMNrL085LucPevR9aK3Haj9FXIdisi/DslTR7x/ORsWktw953Z7yFJK1V6m9JKdh0YJDX7PVRwP//j7rvcfZ+775f0qArcdtnI0s9LetLdX8gmF77tRuurqO1WRPjflnSamU01s6MkXSNpVQF9/ICZHZN9ECMzO0bSbLXe6MOrJHVmjzslvVRgL9/TKiM3lxtZWgVvu5Yb8drdm/4j6VKVPvH/SNIfiuihTF+nSNqc/fQV3Zukp1Q6DPxWpc9G5kv6iaS1kj6U9KakCS3U2+Mqjea8RaWgTSqotxkqHdJvkdSb/Vxa9LZL9FXIduMKPyAoPvADgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxDUfwEJEYHZ+iI4owAAAABJRU5ErkJggg==\n",
700 | "text/plain": [
701 | ""
702 | ]
703 | },
704 | "metadata": {},
705 | "output_type": "display_data"
706 | },
707 | {
708 | "name": "stdout",
709 | "output_type": "stream",
710 | "text": [
711 | "Model prediction: 4\n"
712 | ]
713 | },
714 | {
715 | "data": {
716 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADa9JREFUeJzt3XuMVOUZx/Hfo1YxggasRSJ4AUm1wWRpVq0JqTZi4y0iiReQGJoYVhMwNeEPCU0smnhJbYuGP0yWiKLilkZRiGlalDSRmtqIt0WxRWyWCAHWSrUSJXh5+scc2lV33rPMnJlzluf7STY7c545c57M8uOcmXfOec3dBSCeI8puAEA5CD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaCOaufGzIyvEwIt5u42lMc1tec3s0vN7B9mts3MFjXzXADayxr9br+ZHSlpq6RLJO2Q9Iqk2e6+JbEOe36gxdqx5z9P0jZ3/6e7H5D0O0kzmng+AG3UTPhPkfT+gPs7smVfY2ZdZrbJzDY1sS0ABWv5B37u3i2pW+KwH6iSZvb8OyVNGHB/fLYMwDDQTPhfkTTZzM4ws6MlzZK0rpi2ALRaw4f97v6FmS2Q9CdJR0pa4e5vF9YZgJZqeKivoY3xnh9oubZ8yQfA8EX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUA1P0S1JZtYn6RNJX0r6wt07i2gKQOs1Ff7MT9z9XwU8D4A24rAfCKrZ8Luk9Wb2qpl1FdEQgPZo9rB/mrvvNLPvSXrezP7u7i8OfED2nwL/MQAVY+5ezBOZLZG0z91/nXhMMRsDUJe721Ae1/Bhv5kdZ2ajDt6W9FNJbzX6fADaq5nD/rGSnjGzg8/zpLv/sZCuALRcYYf9Q9oYh/0tcfzxx9et3Xvvvcl1p0yZkqxPnz49Wf/888+TdbRfyw/7AQxvhB8IivADQRF+ICjCDwRF+IGgijirDy02Z86cZP3uu++uW5swYUJT204NI0rShx9+2NTzozzs+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKE7prYDx48cn66+//nqyfuKJJ9atNfv3Xb16dbK+YMGCZH3v3r1NbR+HjlN6ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQjPNXwAMPPJCs33rrrcl6NnfCoFr99/3444+T9dS1BpYtW5Zc98CBAw31FB3j/ACSCD8QFOEHgiL8QFCEHwiK8ANBEX4gqNxxfjNbIelKSf3uPiVbNkbSakmnS+qTdJ27/zt3Y0HH+U877bRkvbe3N1kfOXJksr558+a6tT179iTXzZuCu1n9/f11a1OnTk2uu3v37qLbCaHIcf5HJV36jWWLJG1w98mSNmT3AQwjueF39xclffNyLDMkrcxur5R0dcF9AWixRt/zj3X3Xdnt3ZLGFtQPgDZpeq4+d/fUe3kz65LU1ex2ABSr0T3/HjMbJ0nZ77qf6rh7t7t3untng9sC0AKNhn+dpLnZ7bmS1hbTDoB2yQ2/mfVI+quk75vZDjO7SdJ9ki4xs3clTc/uAxhGct/zu/vsOqWLC+7lsNXR0ZGsjxo1KlnfuHFjsn7hhRfWrY0YMSK57uzZ9f68NYsXL07WJ02alKyffPLJdWtr16YPGC+77LJknTkBmsM3/ICgCD8QFOEHgiL8QFCEHwiK8ANBNf31XuQ75phjkvW806qXLl3a8Lb379+frD/yyCPJ+rXXXpusT5w48ZB7OujTTz9N1rl0d2ux5weCIvxAUIQfCIrwA0ERfiAowg8ERfiBoBjnb4O802bzXHHFFcn6s88+29Tzp3R2tu4CTC+//HKyvm/fvpZtG+z5gbAIPxAU4QeCIvxAUIQfCIrwA0ERfiAoxvnboKenJ1m/6qqrkvVzzz03WT/rrLPq1s4555zkujNnzkzWR48enax/9NFHDa8/b9685LqPP/54sr5ly5ZkHWns+YGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMu7ZryZrZB0paR+d5+SLVsiaZ6kD7KHLXb3P+RuzCy9scPUmDFjkvVt27Yl6yeccEKybmZ1a3l/3zwvvPBCsj5//vxk/bnnnqtbmzx5cnLd5cuXJ+u33HJLsh6Vu9f/BzHAUPb8j0q6dJDlS929I/vJDT6AaskNv7u/KGlvG3oB0EbNvOdfYGa9ZrbCzNLfAQVQOY2G/yFJkyR1SNol6Tf1HmhmXWa2ycw2NbgtAC3QUPjdfY+7f+nuX0laLum8xGO73b3T3Vt3JUgAh6yh8JvZuAF3Z0p6q5h2ALRL7im9ZtYj6SJJ3zWzHZJ+KekiM+uQ5JL6JN3cwh4BtEDuOH+hGws6zp9n+vTpyfpTTz2VrKe+B5D39122bFmyfvvttyfr+/fvT9bvueeeurVFixYl192+fXuynve6vffee8n64arIcX4AhyHCDwRF+IGgCD8QFOEHgiL8QFAM9Q0DeUNaN9xwQ91a3qW177jjjmS92Wmyjz322Lq1J598Mrlu3iXNn3jiiWR97ty5yfrhiqE+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4/wozaxZs5L1VatWJes7d+5M1js6OurW9u49fK9Jyzg/gCTCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcX6U5ogj0vuevPP1r7/++mT9zjvvrFu76667kusOZ4zzA0gi/EBQhB8IivADQRF+ICjCDwRF+IGgcsf5zWyCpMckjZXkkrrd/UEzGyNptaTTJfVJus7d/53zXIzzY8hS5+NL0ksvvZSsjxgxom7t7LPPTq67devWZL3Kihzn/0LSQnf/gaQfSZpvZj+QtEjSBnefLGlDdh/AMJEbfnff5e6vZbc/kfSOpFMkzZC0MnvYSklXt6pJAMU7pPf8Zna6pKmS/iZprLvvykq7VXtbAGCYOGqoDzSzkZKelnSbu//H7P9vK9zd672fN7MuSV3NNgqgWEPa85vZd1QL/ip3X5Mt3mNm47L6OEn9g63r7t3u3ununUU0DKAYueG32i7+YUnvuPtvB5TWSTo4DepcSWuLbw9AqwxlqG+apI2SNkv6Klu8WLX3/b+XdKqk7aoN9SWvh8xQH4q0cOHCZP3++++vW1uzZk3dmiTdeOONyfpnn32WrJdpqEN9ue/53f0vkuo92cWH0hSA6uAbfkBQhB8IivADQRF+ICjCDwRF+IGguHQ3hq2TTjopWU+d8nvmmWcm1807nbi3tzdZLxOX7gaQRPiBoAg/EBThB4Ii/EBQhB8IivADQTHOj8PWqaeeWrfW19eXXLenpydZnzNnTiMttQXj/ACSCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMb5EdL69euT9QsuuCBZP//885P1LVu2HHJPRWGcH0AS4QeCIvxAUIQfCIrwA0ERfiAowg8ElTtFt5lNkPSYpLGSXFK3uz9oZkskzZP0QfbQxe7+h1Y1ChTpmmuuSdbffPPNZD3vuv9ljvMPVW74JX0haaG7v2ZmoyS9ambPZ7Wl7v7r1rUHoFVyw+/uuyTtym5/YmbvSDql1Y0BaK1Des9vZqdLmirpb9miBWbWa2YrzGx0nXW6zGyTmW1qqlMAhRpy+M1spKSnJd3m7v+R9JCkSZI6VDsy+M1g67l7t7t3untnAf0CKMiQwm9m31Et+KvcfY0kufsed//S3b+StFzSea1rE0DRcsNvZibpYUnvuPtvBywfN+BhMyW9VXx7AFol95ReM5smaaOkzZK+yhYvljRbtUN+l9Qn6ebsw8HUc3FKL9BiQz2ll/P5gcMM5/MDSCL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ENZSr9xbpX5K2D7j/3WxZFVW1t6r2JdFbo4rs7bShPrCt5/N/a+Nmm6p6bb+q9lbVviR6a1RZvXHYDwRF+IGgyg5/d8nbT6lqb1XtS6K3RpXSW6nv+QGUp+w9P4CSlBJ+M7vUzP5hZtvMbFEZPdRjZn1mttnM3ih7irFsGrR+M3trwLIxZva8mb2b/R50mrSSeltiZjuz1+4NM7u8pN4mmNmfzWyLmb1tZj/Plpf62iX6KuV1a/thv5kdKWmrpEsk7ZD0iqTZ7l6JOY3NrE9Sp7uXPiZsZj+WtE/SY+4+JVv2K0l73f2+7D/O0e5+e0V6WyJpX9kzN2cTyowbOLO0pKsl/UwlvnaJvq5TCa9bGXv+8yRtc/d/uvsBSb+TNKOEPirP3V+UtPcbi2dIWpndXqnaP562q9NbJbj7Lnd/Lbv9iaSDM0uX+tol+ipFGeE/RdL7A+7vULWm/HZJ683sVTPrKruZQYwdMDPSbkljy2xmELkzN7fTN2aWrsxr18iM10XjA79vm+buP5R0maT52eFtJXntPVuVhmuGNHNzuwwys/T/lPnaNTrjddHKCP9OSRMG3B+fLasEd9+Z/e6X9IyqN/vwnoOTpGa/+0vu53+qNHPzYDNLqwKvXZVmvC4j/K9ImmxmZ5jZ0ZJmSVpXQh/fYmbHZR/EyMyOk/RTVW/24XWS5ma350paW2IvX1OVmZvrzSytkl+7ys147e5t/5F0uWqf+L8n6Rdl9FCnr4mS3sx+3i67N0k9qh0Gfq7aZyM3STpR0gZJ70p6QdKYCvX2uGqzOfeqFrRxJfU2TbVD+l5Jb2Q/l5f92iX6KuV14xt+QFB84AcERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+IKj/AlLXkc59O3KwAAAAAElFTkSuQmCC\n",
717 | "text/plain": [
718 | ""
719 | ]
720 | },
721 | "metadata": {},
722 | "output_type": "display_data"
723 | },
724 | {
725 | "name": "stdout",
726 | "output_type": "stream",
727 | "text": [
728 | "Model prediction: 9\n"
729 | ]
730 | },
731 | {
732 | "data": {
733 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAADbFJREFUeJzt3W+MVPW9x/HP1xUMgT5AiRsirPSCNKkmwnU1xmBD47XxaiPwhKDR0LRhfYCJ6H1w0fvgYq6aeu2f9FENWCw1xfYmaiC1sVRSKzVKXAWV9Q9ym8UuQVZCYy0x9MJ++2AON1vc8zvDzJk5Z/m+X8lmZ853zpwvEz57zszvzPmZuwtAPOdV3QCAahB+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBnd/NjZkZpxMCHebu1szj2trzm9lNZvaBmR0ws/XtPBeA7rJWz+03sx5J+yXdKGlE0uuSbnP3dxPrsOcHOqwbe/5rJB1w9z+6+98k/ULSsjaeD0AXtRP+SyT9adz9kWzZPzCzATMbNLPBNrYFoGQd/8DP3TdK2ihx2A/USTt7/kOS5o67PydbBmASaCf8r0u6zMy+bGZTJa2StL2ctgB0WsuH/e5+0szulvQbST2SNrv7UGmdAeiolof6WtoY7/mBjuvKST4AJi/CDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Lq6hTdwHgzZ85M1vv6+jq27YMHDybr9957b7K+b9++ZH3//v3J+ltvvZWsdwN7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8Iqq1xfjMblvSZpFOSTrp7fxlNYfK45ZZbkvVbb701t7Z06dLkugsWLGilpaYUjcNfeumlyfoFF1zQ1vZ7enraWr8MZZzk83V3P1rC8wDoIg77gaDaDb9L2mFmb5jZQBkNAeiOdg/7l7j7ITO7WNJvzex9d395/AOyPwr8YQBqpq09v7sfyn6PSnpO0jUTPGaju/fzYSBQLy2H38ymm9mXTt+W9A1J6a86AaiNdg77eyU9Z2ann2eru79QSlcAOs7cvXsbM+vexiBJmj9/frK+du3aZH3NmjXJ+rRp05L1bOeAM3RynN/dm3rRGeoDgiL8QFCEHwiK8ANBEX4gKMIPBMWlu89xc+bMSdbvueeeLnXSfe+//35ubWhoqIud1BN7fiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IinH+Lpg1a1ayXjTW/sorryTrL7yQfxmFEydOJNf99NNPk/Xjx48n69OnT0/Wd+zYkVsrmuZ69+7dyfqePXuS9c8//zy3VvTvioA9PxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ExaW7S1A01r1r165k/corr0zWV6xYkaxv3749WU+ZN29esj48PJys9/X1JesjIyO5tbGxseS6aA2X7gaQRPiBoAg/EBThB4Ii/EBQhB8IivADQRV+n9/MNkv6pqRRd78iW3ahpF9KmidpWNJKd/9z59qs3tSpU3NrW7duTa5bNI7/yCOPJOsvvvhist6OonH8Ih999FE5jaDrmtnz/1TSTWcsWy9pp7tfJmlndh/AJFIYfnd/WdKxMxYvk7Qlu71F0vKS+wLQYa2+5+9198PZ7Y8l9ZbUD4Auafsafu7uqXP2zWxA0kC72wFQrlb3/EfMbLYkZb9H8x7o7hvdvd/d+1vcFoAOaDX82yWtzm6vlrStnHYAdEth+M3saUmvSvqKmY2Y2XckfVfSjWb2oaR/ye4DmET4Pn9mxowZyfr999+fW1u/Pj3SefTo0WR94cKFyXrRtfWB8fg+P4Akwg8ERfiBoAg/EBThB4Ii/EBQTNGdWb48/d2k1HBe0ddar7/++mSdoTxUgT0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFOH/muuuua3ndPXv2JOupaaqBqrDnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguHR3ZnQ0d9IhSdJFF12UWztx4kRy3UcffTRZ37YtPefJ3r17k3VgPC7dDSCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCKhznN7PNkr4padTdr8iWbZC0RtIn2cMecPdfF26sxuP8Ra/D2NhYx7Zd9NyPP/54sv7aa6/l1vr6+pLrHjhwIFkfGhpK1otcfvnlubVXX301uS7XQWhNmeP8P5V00wTLf+jui7KfwuADqJfC8Lv7y5KOdaEXAF3Uznv+u83sbTPbbGYzS+sIQFe0Gv4fS5ovaZGkw5K+n/dAMxsws0EzG2xxWwA6oKXwu/sRdz/l7mOSNkm6JvHYje7e7+79rTYJoHwthd/MZo+7u0LSvnLaAdAthZfuNrOnJS2VNMvMRiT9p6SlZrZIkksalnRXB3sE0AF8nz/z2GOPJev33XdflzqJ45NPPknWX3rppWR91apVJXZz7uD7/ACSCD8QFOEHgiL8QFCEHwiK8ANBMdSX6enpSdYXL16cW9u6dWty3fPPT59OMXfu3GT9vPNi/o0u+r+5YcOGZP2hhx4qsZvJg6E+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU4ff5ozh16lSyPjiYfxWyhQsXtrXtG264IVmfMmVKsp4a77766qtbaakWzNLD1VdddVWXOjk3secHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY56+BnTt3trX+okWLcmtF4/wnT55M1p988slkfdOmTcn6unXrcmu33357cl10Fnt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiqcJzfzOZK+pmkXkkuaaO7/8jMLpT0S0nzJA1LWunuf+5cq8izY8eO3NrDDz+cXLdoToE1a9Yk6wsWLEjWly5dmqy3Y2RkpGPPHUEze/6Tkv7N3b8q6VpJa83sq5LWS9rp7pdJ2pndBzBJFIbf3Q+7+5vZ7c8kvSfpEknLJG3JHrZF0vJONQmgfGf1nt/M5klaLGm3pF53P5yVPlbjbQGASaLpc/vNbIakZyStc/e/jL++mrt73jx8ZjYgaaDdRgGUq6k9v5lNUSP4P3f3Z7PFR8xsdlafLWl0onXdfaO797t7fxkNAyhHYfitsYv/iaT33P0H40rbJa3Obq+WtK389gB0SuEU3Wa2RNIuSe9IGssWP6DG+/7/kdQn6aAaQ33HCp6rtlN0T2bTpk3LrW3evDm57sqVK8tup2lFl0t//vnnk/U77rgjWT9+/PhZ93QuaHaK7sL3/O7+B0l5T5a+4DyA2uIMPyAowg8ERfiBoAg/EBThB4Ii/EBQheP8pW6Mcf6u6+1Nf+XiiSeeSNb7+9MnZl588cXJ+vDwcG7tqaeeSq6bmnoc+Zod52fPDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBMc6PpDvvvDNZv/baa5P1Bx98MLc2OjrhxZ/QJsb5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQjPMD5xjG+QEkEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIXhN7O5ZvY7M3vXzIbM7J5s+QYzO2Rme7OfmzvfLoCyFJ7kY2azJc129zfN7EuS3pC0XNJKSX919+81vTFO8gE6rtmTfM5v4okOSzqc3f7MzN6TdEl77QGo2lm95zezeZIWS9qdLbrbzN42s81mNjNnnQEzGzSzwbY6BVCqps/tN7MZkn4v6WF3f9bMeiUdleSS/kuNtwbfLngODvuBDmv2sL+p8JvZFEm/kvQbd//BBPV5kn7l7lcUPA/hBzqstC/2mJlJ+omk98YHP/sg8LQVkvadbZMAqtPMp/1LJO2S9I6ksWzxA5Juk7RIjcP+YUl3ZR8Opp6LPT/QYaUe9peF8AOdx/f5ASQRfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiq8gGfJjko6OO7+rGxZHdW1t7r2JdFbq8rs7dJmH9jV7/N/YeNmg+7eX1kDCXXtra59SfTWqqp647AfCIrwA0FVHf6NFW8/pa691bUvid5aVUlvlb7nB1Cdqvf8ACpSSfjN7CYz+8DMDpjZ+ip6yGNmw2b2TjbzcKVTjGXToI2a2b5xyy40s9+a2YfZ7wmnSauot1rM3JyYWbrS165uM153/bDfzHok7Zd0o6QRSa9Lus3d3+1qIznMbFhSv7tXPiZsZl+T9FdJPzs9G5KZ/bekY+7+3ewP50x3//ea9LZBZzlzc4d6y5tZ+luq8LUrc8brMlSx579G0gF3/6O7/03SLyQtq6CP2nP3lyUdO2PxMklbsttb1PjP03U5vdWCux929zez259JOj2zdKWvXaKvSlQR/ksk/Wnc/RHVa8pvl7TDzN4ws4Gqm5lA77iZkT6W1FtlMxMonLm5m86YWbo2r10rM16XjQ/8vmiJu/+zpH+VtDY7vK0lb7xnq9NwzY8lzVdjGrfDkr5fZTPZzNLPSFrn7n8ZX6vytZugr0petyrCf0jS3HH352TLasHdD2W/RyU9p8bblDo5cnqS1Oz3aMX9/D93P+Lup9x9TNImVfjaZTNLPyPp5+7+bLa48tduor6qet2qCP/rki4zsy+b2VRJqyRtr6CPLzCz6dkHMTKz6ZK+ofrNPrxd0urs9mpJ2yrs5R/UZebmvJmlVfFrV7sZr9296z+SblbjE///lfQfVfSQ09c/SXor+xmqujdJT6txGPh/anw28h1JF0naKelDSS9KurBGvT2lxmzOb6sRtNkV9bZEjUP6tyXtzX5urvq1S/RVyevGGX5AUHzgBwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gqL8DmYaFlMuCxPsAAAAASUVORK5CYII=\n",
734 | "text/plain": [
735 | ""
736 | ]
737 | },
738 | "metadata": {},
739 | "output_type": "display_data"
740 | },
741 | {
742 | "name": "stdout",
743 | "output_type": "stream",
744 | "text": [
745 | "Model prediction: 5\n"
746 | ]
747 | }
748 | ],
749 | "source": [
750 | "# Use the model to predict the images class\n",
751 | "preds = list(model.predict(test_input_fn))\n",
752 | "\n",
753 | "n_images = 9\n",
754 | "# Display\n",
755 | "for i in range(n_images):\n",
756 | " plt.imshow(np.reshape(some_images[i], [28, 28]), cmap='gray')\n",
757 | " plt.show()\n",
758 | " print(\"Model prediction:\", preds[i])"
759 | ]
760 | },
761 | {
762 | "cell_type": "code",
763 | "execution_count": null,
764 | "metadata": {},
765 | "outputs": [],
766 | "source": []
767 | }
768 | ],
769 | "metadata": {
770 | "kernelspec": {
771 | "display_name": "Python 3",
772 | "language": "python",
773 | "name": "python3"
774 | },
775 | "language_info": {
776 | "codemirror_mode": {
777 | "name": "ipython",
778 | "version": 3
779 | },
780 | "file_extension": ".py",
781 | "mimetype": "text/x-python",
782 | "name": "python",
783 | "nbconvert_exporter": "python",
784 | "pygments_lexer": "ipython3",
785 | "version": "3.5.0"
786 | }
787 | },
788 | "nbformat": 4,
789 | "nbformat_minor": 2
790 | }
791 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/CNNClassifier_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n",
8 | "\n",
9 | "- 基于MNIST数据集,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n",
10 | "\n",
11 | "- TensorBoard的简单使用\n"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## 导入各个库"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "data": {
28 | "text/plain": [
29 | "'1.11.0'"
30 | ]
31 | },
32 | "execution_count": 1,
33 | "metadata": {},
34 | "output_type": "execute_result"
35 | }
36 | ],
37 | "source": [
38 | "%matplotlib inline\n",
39 | "import tensorflow as tf\n",
40 | "import matplotlib.pyplot as plt\n",
41 | "import numpy as np\n",
42 | "import pandas as pd\n",
43 | "import multiprocessing\n",
44 | "\n",
45 | "\n",
46 | "from tensorflow import data\n",
47 | "from tensorflow.python.feature_column import feature_column\n",
48 | "\n",
49 | "tf.__version__"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## MNIST数据集载入"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "### 看看MNIST数据长什么样子的\n",
64 | "\n",
65 | "\n",
66 | "\n",
67 | "More info: http://yann.lecun.com/exdb/mnist/"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "- MNIST数据集包含70000张图像和对应的标签(图像的分类)。数据集被划为3个子集:训练集,验证集和测试集。\n",
75 | "\n",
76 | "- 定义**MNIST**数据的相关信息"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 2,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n",
86 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n",
87 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n",
88 | "\n",
89 | "MULTI_THREADING = True\n",
90 | "RESUME_TRAINING = False\n",
91 | "\n",
92 | "NUM_CLASS = 10\n",
93 | "IMG_SHAPE = [28,28]\n",
94 | "\n",
95 | "IMG_WIDTH = 28\n",
96 | "IMG_HEIGHT = 28\n",
97 | "IMG_FLAT = 784\n",
98 | "NUM_CHANNEL = 1\n",
99 | "\n",
100 | "BATCH_SIZE = 128\n",
101 | "NUM_TRAIN = 55000\n",
102 | "NUM_VAL = 5000\n",
103 | "NUM_TEST = 10000"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 3,
109 | "metadata": {},
110 | "outputs": [
111 | {
112 | "name": "stdout",
113 | "output_type": "stream",
114 | "text": [
115 | "test_data (10000, 784)\n",
116 | "test_label (10000,)\n",
117 | "val_data (5000, 784)\n",
118 | "val_label (5000,)\n",
119 | "train_data (55000, 784)\n",
120 | "train_label (55000,)\n"
121 | ]
122 | }
123 | ],
124 | "source": [
125 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n",
126 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n",
127 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n",
128 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n",
129 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n",
130 | "\n",
131 | "train_values = train_data.values\n",
132 | "train_data = train_values[:,1:]/255.0\n",
133 | "train_label = train_values[:,0:1].squeeze()\n",
134 | "\n",
135 | "val_values = val_data.values\n",
136 | "val_data = val_values[:,1:]/255.0\n",
137 | "val_label = val_values[:,0:1].squeeze()\n",
138 | "\n",
139 | "test_values = test_data.values\n",
140 | "test_data = test_values[:,1:]/255.0\n",
141 | "test_label = test_values[:,0:1].squeeze()\n",
142 | "\n",
143 | "print('test_data',np.shape(test_data))\n",
144 | "print('test_label',np.shape(test_label))\n",
145 | "\n",
146 | "print('val_data',np.shape(val_data))\n",
147 | "print('val_label',np.shape(val_label))\n",
148 | "\n",
149 | "print('train_data',np.shape(train_data))\n",
150 | "print('train_label',np.shape(train_label))\n",
151 | "\n",
152 | "# train_data.head(10)\n",
153 | "# test_data.head(10)"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "metadata": {},
159 | "source": [
160 | "## 试试自己写一个estimator\n",
161 | "\n",
162 | "- 基于MNIST数据集,运用TensorFlow中的 **tf.estimator** 中的 **tf.estimator.Estimator** 搭建一个简单的卷积神经网络,实现模型的训练,验证和测试\n",
163 | "\n",
164 | "- [官网API](https://tensorflow.google.cn/api_docs/python/tf/estimator/Estimator)\n",
165 | "\n",
166 | "- 看看有哪些参数\n",
167 | "\n",
168 | "```python\n",
169 | "__init__(\n",
170 | " model_fn,\n",
171 | " model_dir=None,\n",
172 | " config=None,\n",
173 | " params=None,\n",
174 | " warm_start_from=None\n",
175 | ")\n",
176 | "```\n",
177 | "- 本例中,主要用了 **tf.estimator.Estimator** 中的 `model_fn`,`model_dir`\n"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "### 先简单看看数据流\n",
185 | "\n",
186 | "下面的图表直接显示了本次MNIST例子的数据流向,共有**2个卷积层**,每一层卷积之后采用最大池化进行下采样(图中并未画出),最后接**2个全连接层**,实现对MNIST数据集的分类\n",
187 | "\n",
188 | ""
189 | ]
190 | },
191 | {
192 | "cell_type": "markdown",
193 | "metadata": {},
194 | "source": [
195 | "### 先看看input_fn之创建输入函数"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 4,
201 | "metadata": {},
202 | "outputs": [],
203 | "source": [
204 | "# validate tf.data.TextLineDataset() using make_one_shot_iterator()\n",
205 | "\n",
206 | "def decode_line(line):\n",
207 | " # Decode the csv_line to tensor.\n",
208 | " record_defaults = [[1.0] for col in range(785)]\n",
209 | " items = tf.decode_csv(line, record_defaults)\n",
210 | " features = items[1:785]\n",
211 | " label = items[0]\n",
212 | "\n",
213 | " features = tf.cast(features, tf.float32)\n",
214 | " features = tf.reshape(features,[28,28,1])\n",
215 | " features = tf.image.flip_left_right(features)\n",
216 | "# print('features_aug',features_aug)\n",
217 | " label = tf.cast(label, tf.int64)\n",
218 | "# label = tf.one_hot(label,num_class)\n",
219 | " return features,label"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 5,
225 | "metadata": {},
226 | "outputs": [],
227 | "source": [
228 | "def csv_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.TRAIN, \n",
229 | " skip_header_lines=1, \n",
230 | " num_epochs=None, \n",
231 | " batch_size=128):\n",
232 | " shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False\n",
233 | " \n",
234 | " num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1\n",
235 | " \n",
236 | " print(\"\")\n",
237 | " print(\"* data input_fn:\")\n",
238 | " print(\"================\")\n",
239 | " print(\"Input file(s): {}\".format(files_name_pattern))\n",
240 | " print(\"Batch size: {}\".format(batch_size))\n",
241 | " print(\"Epoch Count: {}\".format(num_epochs))\n",
242 | " print(\"Mode: {}\".format(mode))\n",
243 | " print(\"Thread Count: {}\".format(num_threads))\n",
244 | " print(\"Shuffle: {}\".format(shuffle))\n",
245 | " print(\"================\")\n",
246 | " print(\"\")\n",
247 | "\n",
248 | " file_names = tf.matching_files(files_name_pattern)\n",
249 | " dataset = data.TextLineDataset(filenames=file_names).skip(1)\n",
250 | "# dataset = tf.data.TextLineDataset(filenames).skip(1)\n",
251 | " print(\"DATASET\",dataset)\n",
252 | "\n",
253 | " # Use `Dataset.map()` to build a pair of a feature dictionary and a label\n",
254 | " # tensor for each example.\n",
255 | " dataset = dataset.map(decode_line)\n",
256 | " print(\"DATASET_1\",dataset)\n",
257 | " dataset = dataset.shuffle(buffer_size=10000)\n",
258 | " print(\"DATASET_2\",dataset)\n",
259 | " dataset = dataset.batch(32)\n",
260 | " print(\"DATASET_3\",dataset)\n",
261 | " dataset = dataset.repeat(num_epochs)\n",
262 | " print(\"DATASET_4\",dataset)\n",
263 | " iterator = dataset.make_one_shot_iterator()\n",
264 | " \n",
265 | " # `features` is a dictionary in which each value is a batch of values for\n",
266 | " # that feature; `labels` is a batch of labels.\n",
267 | " features, labels = iterator.get_next()\n",
268 | " \n",
269 | " features = {'images':features}\n",
270 | " \n",
271 | " return features,labels\n"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 6,
277 | "metadata": {},
278 | "outputs": [
279 | {
280 | "name": "stdout",
281 | "output_type": "stream",
282 | "text": [
283 | "\n",
284 | "* data input_fn:\n",
285 | "================\n",
286 | "Input file(s): data_csv/mnist_train.csv\n",
287 | "Batch size: 128\n",
288 | "Epoch Count: None\n",
289 | "Mode: train\n",
290 | "Thread Count: 4\n",
291 | "Shuffle: True\n",
292 | "================\n",
293 | "\n",
294 | "DATASET \n",
295 | "features_aug Tensor(\"flip_left_right/ReverseV2:0\", shape=(28, 28, 1), dtype=float32)\n",
296 | "DATASET_1 \n",
297 | "DATASET_2 \n",
298 | "DATASET_3 \n",
299 | "DATASET_4 \n",
300 | "Features in CSV: ['images']\n",
301 | "Target in CSV: Tensor(\"IteratorGetNext:1\", shape=(?,), dtype=int64)\n"
302 | ]
303 | }
304 | ],
305 | "source": [
306 | "features, target = csv_input_fn(files_name_pattern=TRAIN_DATA_FILES_PATTERN)\n",
307 | "print(\"Features in CSV: {}\".format(list(features.keys())))\n",
308 | "print(\"Target in CSV: {}\".format(target))"
309 | ]
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "metadata": {},
314 | "source": [
315 | "### 定义feature_columns"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": 10,
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "feature_x = tf.feature_column.numeric_column('images', shape=IMG_SHAPE)\n",
325 | "\n",
326 | "feature_columns = [feature_x]"
327 | ]
328 | },
329 | {
330 | "cell_type": "markdown",
331 | "metadata": {},
332 | "source": [
333 | "### 重点在这里——model_fn\n",
334 | "\n",
335 | "\n",
336 | "#### model_fn: Model function. Follows the signature:\n",
337 | "\n",
338 | "* Args:\n",
339 | " * `features`: This is the first item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same.\n",
340 | " * `labels`: This is the second item returned from the `input_fn` passed to `train`, `evaluate`, and `predict`. This should be a single `tf.Tensor` or `dict` of same (for multi-head models).If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will be passed. If the `model_fn`'s signature does not accept `mode`, the `model_fn` must still be able to handle `labels=None`.\n",
341 | " * `mode`: Optional. Specifies if this training, evaluation or prediction. See `tf.estimator.ModeKeys`.\n",
342 | " * `params`: Optional `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning.\n",
343 | " * `config`: Optional `estimator.RunConfig` object. Will receive what is passed to Estimator as its `config` parameter, or a default value. Allows setting up things in your `model_fn` based on configuration such as `num_ps_replicas`, or `model_dir`.\n",
344 | "* Returns:\n",
345 | " `tf.estimator.EstimatorSpec`\n",
346 | " \n",
347 | "#### 注意model_fn返回的tf.estimator.EstimatorSpec\n",
348 | "
\n",
349 | "\n",
350 | "\n"
351 | ]
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {},
356 | "source": [
357 | "### 定义我们自己的model_fn"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 11,
363 | "metadata": {},
364 | "outputs": [],
365 | "source": [
366 | "def model_fn(features, labels, mode, params):\n",
367 | " # Args:\n",
368 | " #\n",
369 | " # features: This is the x-arg from the input_fn.\n",
370 | " # labels: This is the y-arg from the input_fn,\n",
371 | " # see e.g. train_input_fn for these two.\n",
372 | " # mode: Either TRAIN, EVAL, or PREDICT\n",
373 | " # params: User-defined hyper-parameters, e.g. learning-rate.\n",
374 | " \n",
375 | " # Reference to the tensor named \"x\" in the input-function.\n",
376 | "# x = features[\"images\"]\n",
377 | " x = tf.feature_column.input_layer(features, params['feature_columns'])\n",
378 | " # The convolutional layers expect 4-rank tensors\n",
379 | " # but x is a 2-rank tensor, so reshape it.\n",
380 | " net = tf.reshape(x, [-1, IMG_HEIGHT, IMG_WIDTH, NUM_CHANNEL]) \n",
381 | "\n",
382 | " # First convolutional layer.\n",
383 | " net = tf.layers.conv2d(inputs=net, name='layer_conv1',\n",
384 | " filters=16, kernel_size=5,\n",
385 | " padding='same', activation=tf.nn.relu)\n",
386 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2)\n",
387 | "\n",
388 | " # Second convolutional layer.\n",
389 | " net = tf.layers.conv2d(inputs=net, name='layer_conv2',\n",
390 | " filters=36, kernel_size=5,\n",
391 | " padding='same', activation=tf.nn.relu)\n",
392 | " net = tf.layers.max_pooling2d(inputs=net, pool_size=2, strides=2) \n",
393 | "\n",
394 | " # Flatten to a 2-rank tensor.\n",
395 | " net = tf.contrib.layers.flatten(net)\n",
396 | " # Eventually this should be replaced with:\n",
397 | " # net = tf.layers.flatten(net)\n",
398 | "\n",
399 | " # First fully-connected / dense layer.\n",
400 | " # This uses the ReLU activation function.\n",
401 | " net = tf.layers.dense(inputs=net, name='layer_fc1',\n",
402 | " units=128, activation=tf.nn.relu) \n",
403 | "\n",
404 | " # Second fully-connected / dense layer.\n",
405 | " # This is the last layer so it does not use an activation function.\n",
406 | " net = tf.layers.dense(inputs=net, name='layer_fc2',\n",
407 | " units=10)\n",
408 | "\n",
409 | " # Logits output of the neural network.\n",
410 | " logits = net\n",
411 | "\n",
412 | " # Softmax output of the neural network.\n",
413 | " y_pred = tf.nn.softmax(logits=logits)\n",
414 | " \n",
415 | " # Classification output of the neural network.\n",
416 | " y_pred_cls = tf.argmax(y_pred, axis=1)\n",
417 | "\n",
418 | " if mode == tf.estimator.ModeKeys.PREDICT:\n",
419 | " # If the estimator is supposed to be in prediction-mode\n",
420 | " # then use the predicted class-number that is output by\n",
421 | " # the neural network. Optimization etc. is not needed.\n",
422 | " spec = tf.estimator.EstimatorSpec(mode=mode,\n",
423 | " predictions=y_pred_cls)\n",
424 | " else:\n",
425 | " # Otherwise the estimator is supposed to be in either\n",
426 | " # training or evaluation-mode. Note that the loss-function\n",
427 | " # is also required in Evaluation mode.\n",
428 | " \n",
429 | " # Define the loss-function to be optimized, by first\n",
430 | " # calculating the cross-entropy between the output of\n",
431 | " # the neural network and the true labels for the input data.\n",
432 | " # This gives the cross-entropy for each image in the batch.\n",
433 | " cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,\n",
434 | " logits=logits)\n",
435 | "\n",
436 | " # Reduce the cross-entropy batch-tensor to a single number\n",
437 | " # which can be used in optimization of the neural network.\n",
438 | " loss = tf.reduce_mean(cross_entropy)\n",
439 | "\n",
440 | " # Define the optimizer for improving the neural network.\n",
441 | " optimizer = tf.train.AdamOptimizer(learning_rate=params[\"learning_rate\"])\n",
442 | "\n",
443 | " # Get the TensorFlow op for doing a single optimization step.\n",
444 | " train_op = optimizer.minimize(\n",
445 | " loss=loss, global_step=tf.train.get_global_step())\n",
446 | "\n",
447 | " # Define the evaluation metrics,\n",
448 | " # in this case the classification accuracy.\n",
449 | " metrics = \\\n",
450 | " {\n",
451 | " \"accuracy\": tf.metrics.accuracy(labels, y_pred_cls)\n",
452 | " }\n",
453 | "\n",
454 | " # Wrap all of this in an EstimatorSpec.\n",
455 | " spec = tf.estimator.EstimatorSpec(\n",
456 | " mode=mode,\n",
457 | " loss=loss,\n",
458 | " train_op=train_op,\n",
459 | " eval_metric_ops=metrics)\n",
460 | " \n",
461 | " return spec"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | "### 自建的estimator在这里"
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": 12,
474 | "metadata": {},
475 | "outputs": [],
476 | "source": [
477 | "params = {\"learning_rate\": 1e-4,\n",
478 | " 'feature_columns': feature_columns}"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 13,
484 | "metadata": {},
485 | "outputs": [
486 | {
487 | "name": "stdout",
488 | "output_type": "stream",
489 | "text": [
490 | "INFO:tensorflow:Using default config.\n",
491 | "INFO:tensorflow:Using config: {'_model_dir': './cnn_classifer_dataset/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
492 | "graph_options {\n",
493 | " rewrite_options {\n",
494 | " meta_optimizer_iterations: ONE\n",
495 | " }\n",
496 | "}\n",
497 | ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"
498 | ]
499 | }
500 | ],
501 | "source": [
502 | "model = tf.estimator.Estimator(model_fn=model_fn,\n",
503 | " params=params,\n",
504 | " model_dir=\"./cnn_classifer_dataset/\")"
505 | ]
506 | },
507 | {
508 | "cell_type": "markdown",
509 | "metadata": {},
510 | "source": [
511 | "### 训练训练看看"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": null,
517 | "metadata": {},
518 | "outputs": [
519 | {
520 | "name": "stdout",
521 | "output_type": "stream",
522 | "text": [
523 | "\n",
524 | "* data input_fn:\n",
525 | "================\n",
526 | "Input file(s): data_csv/mnist_train.csv\n",
527 | "Batch size: 128\n",
528 | "Epoch Count: None\n",
529 | "Mode: train\n",
530 | "Thread Count: 4\n",
531 | "Shuffle: True\n",
532 | "================\n",
533 | "\n",
534 | "DATASET \n",
535 | "DATASET_1 \n",
536 | "DATASET_2 \n",
537 | "DATASET_3 \n",
538 | "DATASET_4 \n",
539 | "INFO:tensorflow:Calling model_fn.\n",
540 | "INFO:tensorflow:Done calling model_fn.\n",
541 | "INFO:tensorflow:Create CheckpointSaverHook.\n",
542 | "INFO:tensorflow:Graph was finalized.\n",
543 | "INFO:tensorflow:Running local_init_op.\n",
544 | "INFO:tensorflow:Done running local_init_op.\n",
545 | "INFO:tensorflow:Saving checkpoints for 0 into ./cnn_classifer_dataset/model.ckpt.\n",
546 | "INFO:tensorflow:loss = 40.126976, step = 1\n",
547 | "INFO:tensorflow:global_step/sec: 11.057\n",
548 | "INFO:tensorflow:loss = 1.1582, step = 101 (9.049 sec)\n",
549 | "INFO:tensorflow:global_step/sec: 12.9123\n",
550 | "INFO:tensorflow:loss = 0.778288, step = 201 (7.743 sec)\n",
551 | "INFO:tensorflow:global_step/sec: 13.889\n",
552 | "INFO:tensorflow:loss = 1.0873605, step = 301 (7.200 sec)\n",
553 | "INFO:tensorflow:global_step/sec: 14.1931\n",
554 | "INFO:tensorflow:loss = 0.07414566, step = 401 (7.045 sec)\n",
555 | "INFO:tensorflow:global_step/sec: 14.2251\n",
556 | "INFO:tensorflow:loss = 0.32521993, step = 501 (7.029 sec)\n",
557 | "INFO:tensorflow:global_step/sec: 12.7967\n",
558 | "INFO:tensorflow:loss = 0.2568686, step = 601 (7.815 sec)\n",
559 | "INFO:tensorflow:global_step/sec: 12.4253\n",
560 | "INFO:tensorflow:loss = 0.54189134, step = 701 (8.048 sec)\n",
561 | "INFO:tensorflow:global_step/sec: 12.5796\n",
562 | "INFO:tensorflow:loss = 0.15989298, step = 801 (7.949 sec)\n",
563 | "INFO:tensorflow:global_step/sec: 13.7096\n",
564 | "INFO:tensorflow:loss = 0.90422636, step = 901 (7.295 sec)\n",
565 | "INFO:tensorflow:global_step/sec: 13.8366\n",
566 | "INFO:tensorflow:loss = 0.20136827, step = 1001 (7.227 sec)\n",
567 | "INFO:tensorflow:global_step/sec: 13.5184\n",
568 | "INFO:tensorflow:loss = 0.53505665, step = 1101 (7.398 sec)\n",
569 | "INFO:tensorflow:global_step/sec: 12.8457\n",
570 | "INFO:tensorflow:loss = 0.22107196, step = 1201 (7.784 sec)\n",
571 | "INFO:tensorflow:global_step/sec: 13.0342\n",
572 | "INFO:tensorflow:loss = 0.31935138, step = 1301 (7.672 sec)\n"
573 | ]
574 | }
575 | ],
576 | "source": [
577 | "input_fn = lambda: csv_input_fn(\\\n",
578 | " files_name_pattern= TRAIN_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.TRAIN)\n",
579 | "# Train the Model\n",
580 | "model.train(input_fn, steps=2000)"
581 | ]
582 | },
583 | {
584 | "cell_type": "markdown",
585 | "metadata": {},
586 | "source": [
587 | "### 验证一下瞅瞅"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": 68,
593 | "metadata": {},
594 | "outputs": [
595 | {
596 | "name": "stdout",
597 | "output_type": "stream",
598 | "text": [
599 | "\n",
600 | "* data input_fn:\n",
601 | "================\n",
602 | "Input file(s): data_csv/mnist_val.csv\n",
603 | "Batch size: 128\n",
604 | "Epoch Count: None\n",
605 | "Mode: eval\n",
606 | "Thread Count: 4\n",
607 | "Shuffle: False\n",
608 | "================\n",
609 | "\n",
610 | "DATASET \n",
611 | "DATASET_1 \n",
612 | "DATASET_2 \n",
613 | "DATASET_3 \n",
614 | "DATASET_4 \n",
615 | "INFO:tensorflow:Calling model_fn.\n",
616 | "INFO:tensorflow:Done calling model_fn.\n",
617 | "INFO:tensorflow:Starting evaluation at 2018-10-23-12:36:20\n",
618 | "INFO:tensorflow:Graph was finalized.\n",
619 | "INFO:tensorflow:Restoring parameters from trained_models/simple_cnn/model.ckpt-4000\n",
620 | "INFO:tensorflow:Running local_init_op.\n",
621 | "INFO:tensorflow:Done running local_init_op.\n",
622 | "INFO:tensorflow:Evaluation [1/1]\n",
623 | "INFO:tensorflow:Finished evaluation at 2018-10-23-12:36:29\n",
624 | "INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.96875, global_step = 4000, loss = 0.1153331\n"
625 | ]
626 | },
627 | {
628 | "data": {
629 | "text/plain": [
630 | "{'accuracy': 0.96875, 'global_step': 4000, 'loss': 0.1153331}"
631 | ]
632 | },
633 | "execution_count": 68,
634 | "metadata": {},
635 | "output_type": "execute_result"
636 | }
637 | ],
638 | "source": [
639 | "input_fn = lambda: csv_input_fn(files_name_pattern= VAL_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.EVAL)\n",
640 | "\n",
641 | "model.evaluate(input_fn,steps=1)"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": 69,
647 | "metadata": {},
648 | "outputs": [
649 | {
650 | "name": "stdout",
651 | "output_type": "stream",
652 | "text": [
653 | "\n",
654 | "* data input_fn:\n",
655 | "================\n",
656 | "Input file(s): data_csv/mnist_test.csv\n",
657 | "Batch size: 10\n",
658 | "Epoch Count: None\n",
659 | "Mode: infer\n",
660 | "Thread Count: 4\n",
661 | "Shuffle: False\n",
662 | "================\n",
663 | "\n",
664 | "DATASET \n",
665 | "DATASET_1 \n",
666 | "DATASET_2 \n",
667 | "DATASET_3 \n",
668 | "DATASET_4 \n",
669 | "INFO:tensorflow:Calling model_fn.\n",
670 | "INFO:tensorflow:Done calling model_fn.\n",
671 | "INFO:tensorflow:Graph was finalized.\n",
672 | "INFO:tensorflow:Restoring parameters from trained_models/simple_cnn/model.ckpt-4000\n",
673 | "INFO:tensorflow:Running local_init_op.\n",
674 | "INFO:tensorflow:Done running local_init_op.\n",
675 | "PREDICTIONS [6, 1, 2, 7, 0, 8, 0, 3, 0, 0]\n"
676 | ]
677 | }
678 | ],
679 | "source": [
680 | "import itertools\n",
681 | "\n",
682 | "input_fn = lambda: csv_input_fn(\\\n",
683 | " files_name_pattern= TEST_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.PREDICT,batch_size=10)\n",
684 | "\n",
685 | "predictions = list(itertools.islice(model.predict(input_fn=input_fn),10))\n",
686 | "print('PREDICTIONS',predictions)\n",
687 | "# print(\"\")\n",
688 | "# print(\"* Predicted Classes: {}\".format(list(map(lambda item: item[\"classes\"][0]\n",
689 | "# ,predictions))))"
690 | ]
691 | }
692 | ],
693 | "metadata": {
694 | "kernelspec": {
695 | "display_name": "Python 3",
696 | "language": "python",
697 | "name": "python3"
698 | },
699 | "language_info": {
700 | "codemirror_mode": {
701 | "name": "ipython",
702 | "version": 3
703 | },
704 | "file_extension": ".py",
705 | "mimetype": "text/x-python",
706 | "name": "python",
707 | "nbconvert_exporter": "python",
708 | "pygments_lexer": "ipython3",
709 | "version": "3.6.6"
710 | }
711 | },
712 | "nbformat": 4,
713 | "nbformat_minor": 2
714 | }
715 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/CNN_raw.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n",
8 | "\n",
9 | "- 基于MNIST数据集,运用TensorFlow搭建一个简单的卷积神经网络,并实现模型训练/验证/测试\n",
10 | "\n",
11 | "- TensorBoard的简单使用\n"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## 看看MNIST数据长什么样子的\n",
19 | "\n",
20 | "\n",
21 | "\n",
22 | "More info: http://yann.lecun.com/exdb/mnist/"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "## 流程图"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "下面的图表直接显示了之后实现的卷积神经网络中数据的传递。\n",
37 | "\n",
38 | ""
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "输入图像在第一层卷积层里使用权重过滤器处理。结果在16张新图里,每张代表了卷积层里一个过滤器(的处理结果)。图像经过降采样,分辨率从28x28减少到14x14。\n",
46 | "\n",
47 | "16张小图在第二个卷积层中处理。这16个通道以及这层输出的每个通道都需要一个过滤权重。总共有36个输出,所以在第二个卷积层有16 x 36 = 576个滤波器。输出图再一次降采样到7x7个像素。\n",
48 | "\n",
49 | "第二个卷积层的输出是36张7x7像素的图像。它们被转换到一个长为7 x 7 x 36 = 1764的向量中去,它作为一个有128个神经元(或元素)的全连接网络的输入。这些又输入到另一个有10个神经元的全连接层中,每个神经元代表一个类别,用来确定图像的类别,即图像上的数字。\n",
50 | "\n",
51 | "卷积滤波一开始是随机挑选的,因此分类也是随机完成的。根据交叉熵(cross-entropy)来测量输入图预测值和真实类别间的错误。然后优化器用链式法则自动地将这个误差在卷积网络中传递,更新滤波权重来提升分类质量。这个过程迭代了几千次,直到分类误差足够低。\n",
52 | "\n",
53 | "这些特定的滤波权重和中间图像是一个优化结果,和你执行代码所看到的可能会有所不同。\n",
54 | "\n",
55 | "注意,这些在TensorFlow上的计算是在一部分图像上执行,而非单独的一张图,这使得计算更有效。也意味着在TensorFlow上实现时,这个流程图实际上会有更多的数据维度。\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "## 各种库导入"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 2,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "%matplotlib inline\n",
72 | "import matplotlib.pyplot as plt\n",
73 | "import tensorflow as tf\n",
74 | "import numpy as np\n",
75 | "import pandas as pd\n",
76 | "\n",
77 | "import time\n",
78 | "from datetime import timedelta\n",
79 | "import math"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "使用Python3.6(Anaconda)开发,TensorFlow版本是:"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 4,
92 | "metadata": {},
93 | "outputs": [
94 | {
95 | "data": {
96 | "text/plain": [
97 | "'1.8.0'"
98 | ]
99 | },
100 | "execution_count": 4,
101 | "metadata": {},
102 | "output_type": "execute_result"
103 | }
104 | ],
105 | "source": [
106 | "tf.__version__"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "## MNIST数据集导入\n",
114 | "\n",
115 | "- 现在已经载入了MNIST数据集,它由70,000张图像和对应的标签(比如图像的类别)组成。数据集分成三份互相独立的子集。\n",
116 | "\n",
117 | "- 定义**MNIST**数据的相关信息"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 3,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n",
127 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n",
128 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n",
129 | "\n",
130 | "MULTI_THREADING = True\n",
131 | "RESUME_TRAINING = False\n",
132 | "\n",
133 | "NUM_CLASS = 10\n",
134 | "IMG_SHAPE = [28,28]\n",
135 | "\n",
136 | "IMG_WIDTH = 28\n",
137 | "IMG_HEIGHT = 28\n",
138 | "IMG_FLAT = 784\n",
139 | "NUM_CHANNEL = 1\n",
140 | "\n",
141 | "BATCH_SIZE = 128\n",
142 | "NUM_TRAIN = 55000\n",
143 | "NUM_VAL = 5000\n",
144 | "NUM_TEST = 10000"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 4,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "test_data (10000, 784)\n",
157 | "test_label (10000,)\n",
158 | "val_data (5000, 784)\n",
159 | "val_label (5000,)\n",
160 | "train_data (55000, 784)\n",
161 | "train_label (55000,)\n"
162 | ]
163 | }
164 | ],
165 | "source": [
166 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n",
167 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n",
168 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n",
169 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n",
170 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n",
171 | "\n",
172 | "train_values = train_data.values\n",
173 | "train_data = train_values[:,1:]/255.0\n",
174 | "train_label = train_values[:,0:1].squeeze()\n",
175 | "\n",
176 | "val_values = val_data.values\n",
177 | "val_data = val_values[:,1:]/255.0\n",
178 | "val_label = val_values[:,0:1].squeeze()\n",
179 | "\n",
180 | "test_values = test_data.values\n",
181 | "test_data = test_values[:,1:]/255.0\n",
182 | "test_label = test_values[:,0:1].squeeze()\n",
183 | "\n",
184 | "print('test_data',np.shape(test_data))\n",
185 | "print('test_label',np.shape(test_label))\n",
186 | "\n",
187 | "print('val_data',np.shape(val_data))\n",
188 | "print('val_label',np.shape(val_label))\n",
189 | "\n",
190 | "print('train_data',np.shape(train_data))\n",
191 | "print('train_label',np.shape(train_label))\n",
192 | "\n",
193 | "# train_data.head(10)\n",
194 | "# test_data.head(10)"
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "metadata": {},
200 | "source": [
201 | "### one-hot编码"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 5,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "def one_hot_encoded(class_numbers, num_classes=None):\n",
211 | " \"\"\"\n",
212 | " Generate the One-Hot encoded class-labels from an array of integers.\n",
213 | "\n",
214 | " For example, if class_number=2 and num_classes=4 then\n",
215 | " the one-hot encoded label is the float array: [0. 0. 1. 0.]\n",
216 | "\n",
217 | " :param class_numbers:\n",
218 | " Array of integers with class-numbers.\n",
219 | " Assume the integers are from zero to num_classes-1 inclusive.\n",
220 | "\n",
221 | " :param num_classes:\n",
222 | " Number of classes. If None then use max(class_numbers)+1.\n",
223 | "\n",
224 | " :return:\n",
225 | " 2-dim array of shape: [len(class_numbers), num_classes]\n",
226 | " \"\"\"\n",
227 | "\n",
228 | " # Find the number of classes if None is provided.\n",
229 | " # Assumes the lowest class-number is zero.\n",
230 | " if num_classes is None:\n",
231 | " num_classes = np.max(class_numbers) + 1\n",
232 | "\n",
233 | " return np.eye(num_classes, dtype=float)[class_numbers]"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "### 数据集batch处理"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 45,
246 | "metadata": {},
247 | "outputs": [],
248 | "source": [
249 | "# idx = 0\n",
250 | "\n",
251 | "def batch(batch_size=32):\n",
252 | " \"\"\"\n",
253 | " Create a random batch of training-data.\n",
254 | "\n",
255 | " :param batch_size: Number of images in the batch.\n",
256 | " :return: 3 numpy arrays (x, y, y_cls)\n",
257 | " \"\"\"\n",
258 | "# global idx\n",
259 | " # Create a random index into the training-set.\n",
260 | " idx = np.random.randint(low=0, high=NUM_TRAIN, size=batch_size)\n",
261 | "# idx = iterations\n",
262 | " # Use the index to lookup random training-data.\n",
263 | " x_batch = train_data[idx]\n",
264 | " y_batch = train_label_onehot[idx]\n",
265 | " y_batch_cls = train_label[idx]\n",
266 | "# idx = idx + batch_size\n",
267 | "# print('IDX',idx)\n",
268 | " return x_batch, y_batch, y_batch_cls\n"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 7,
274 | "metadata": {},
275 | "outputs": [
276 | {
277 | "name": "stdout",
278 | "output_type": "stream",
279 | "text": [
280 | "shape (55000, 10)\n",
281 | "[[0. 0. 0. ... 0. 0. 0.]\n",
282 | " [1. 0. 0. ... 0. 0. 0.]\n",
283 | " [0. 0. 0. ... 0. 0. 0.]\n",
284 | " ...\n",
285 | " [1. 0. 0. ... 0. 0. 0.]\n",
286 | " [0. 0. 0. ... 0. 0. 0.]\n",
287 | " [1. 0. 0. ... 0. 0. 0.]]\n"
288 | ]
289 | }
290 | ],
291 | "source": [
292 | "train_label_onehot = one_hot_encoded(train_label.T,10).squeeze()\n",
293 | "print('shape',np.shape(train_label_onehot))\n",
294 | "print(train_label_onehot)"
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "metadata": {},
300 | "source": [
301 | "## 神经网络的配置\n",
302 | "\n",
303 | "方便起见,在这里定义神经网络的配置,你可以很容易找到或改变这些数值,然后重新运行Notebook。"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": 8,
309 | "metadata": {},
310 | "outputs": [],
311 | "source": [
312 | "# Convolutional Layer 1.\n",
313 | "filter_size1 = 5 # Convolution filters are 5 x 5 pixels.\n",
314 | "num_filters1 = 16 # There are 16 of these filters.\n",
315 | "\n",
316 | "# Convolutional Layer 2.\n",
317 | "filter_size2 = 5 # Convolution filters are 5 x 5 pixels.\n",
318 | "num_filters2 = 36 # There are 36 of these filters.\n",
319 | "\n",
320 | "# Fully-connected layer.\n",
321 | "fc_size = 128 # Number of neurons in fully-connected layer."
322 | ]
323 | },
324 | {
325 | "cell_type": "markdown",
326 | "metadata": {},
327 | "source": [
328 | "### 用来绘制图片的帮助函数"
329 | ]
330 | },
331 | {
332 | "cell_type": "markdown",
333 | "metadata": {},
334 | "source": [
335 | "这个函数用来在3x3的栅格中画9张图像,然后在每张图像下面写出真实类别和预测类别。"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 9,
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "img_shape = IMG_SHAPE\n",
345 | "\n",
346 | "def plot_images(images, cls_true, cls_pred=None):\n",
347 | " assert len(images) == len(cls_true) == 9\n",
348 | " \n",
349 | " # Create figure with 3x3 sub-plots.\n",
350 | " fig, axes = plt.subplots(3, 3)\n",
351 | " fig.subplots_adjust(hspace=0.3, wspace=0.3)\n",
352 | "\n",
353 | " for i, ax in enumerate(axes.flat):\n",
354 | " # Plot image.\n",
355 | " ax.imshow(images[i].reshape(img_shape), cmap='binary')\n",
356 | "\n",
357 | " # Show true and predicted classes.\n",
358 | " if cls_pred is None:\n",
359 | " xlabel = \"True: {0}\".format(cls_true[i])\n",
360 | " else:\n",
361 | " xlabel = \"True: {0}, Pred: {1}\".format(cls_true[i], cls_pred[i])\n",
362 | "\n",
363 | " # Show the classes as the label on the x-axis.\n",
364 | " ax.set_xlabel(xlabel)\n",
365 | " \n",
366 | " # Remove ticks from the plot.\n",
367 | " ax.set_xticks([])\n",
368 | " ax.set_yticks([])\n",
369 | " \n",
370 | " # Ensure the plot is shown correctly with multiple plots\n",
371 | " # in a single Notebook cell.\n",
372 | " plt.show()"
373 | ]
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "metadata": {},
378 | "source": [
379 | "### 绘制几张图像来看看数据是否正确"
380 | ]
381 | },
382 | {
383 | "cell_type": "code",
384 | "execution_count": 10,
385 | "metadata": {},
386 | "outputs": [
387 | {
388 | "data": {
389 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUMAAAD5CAYAAAC9FVegAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHihJREFUeJzt3XmUFNXZx/HvA0LYVQQFFWdOwAVCFBWDu0aBKCogccG4EGM0osEtAaNx1xglKBzRE7YD4QQNigKCUVFAEV8EJIIi4wYiCsRlhLggIsJ9/5i5XdUzPXtXVU/7+5zjmequ6qpnvPSdp27dxZxziIj80DVIOgARkVygylBEBFWGIiKAKkMREUCVoYgIoMpQRARQZSgiAqgyFBEBVBmKiACwS00ObtOmjSssLIwolNzzwQcfUFxcbEnHESeVcf5TGWdWo8qwsLCQZcuW1T6qeqZ79+5JhxA7lXH+UxlnpttkERFUGYqIAKoMRUQAVYYiIoAqQxERoIZPk0Vqa8SIEQBs3boVgDfeeAOAxx9/vNyxgwcPBuCoo44C4MILL4wjRPmBU2YoIoIyQ4nYueeeC8C0adMy7jcr3xd2zJgxAMydOxeAE044AYD99tsvihAlQe+++y4ABx54IAAPPPAAAEOGDIk9FmWGIiIoM5QI+GwQKs4IDzroIABOOeUUAN5///3UvlmzZgGwevVqAKZMmQLAjTfemP1gJVHLly8HoEGDkrxsn332SSwWZYYiIigzlCzy411nzJhRbl/Xrl2BIOtr06YNAC1atADgu+++Sx3bo0cPAF5//XUAPv/884gilqStWLECCP4dDBgwILFYlBmKiBBDZuj7kY0fPx6AvffeO7WvSZMmAJx//vkAtGvXDoBOnTpFHZZE4L///S8AzrnUez4jnDNnDgDt27fP+FnfDxHgrbfeStt3+umnZzVOSd7KlSsBGD16NAAXXXRRkuEAygxFRIAYMsOhQ4cCJRMsVsT3K2vVqhUAXbp0ycq1O3ToAMCwYcOAH+bcdXE644wzgOApMEDLli0BaN26daWfffTRR1Pb4fZDyU/vvPMOAFu2bAHSeyAkRZmhiAiqDEVEgBhukydMmAAE3STCt8BFRUVA0PHyxRdfBGDx4sVAMPzqww8/rPD8jRo1AoKuGr4RP3wef7us2+R4FBQUVPvYv/3tb0AwLCvMd7HxPyV/DB8+HChZggBy47upzFBEhBgyw5NPPjntZ5gfiuVt3rwZCDJF/9fi1VdfrfD8P/rRj4BgoLcf5gWwadMmADp27Fir2CU6Tz31FAC33HILANu2bUvt22uvvQC45557AGjWrFnM0UkUwg9R/Xfaf2+bN2+eREhplBmKiJBjw/F23313AE466aS09zNllWU98cQTQJBdAhx88MEADBw4MFshSpb4oXvhjNDz3Sz81F2SHxYsWFDuvbZt2yYQSWbKDEVEyLHMsDY+/fRTAK644gogfSiYb4+qqsOvxKd///5AMDzPGzRoUGr7rrvuijUmiYdf6iHMD4jIBcoMRUTIg8zwoYceAoIMcbfddkvt80+qJHm+/+eiRYuAoK3QtxnddNNNqWP9dE6SH1555RUAJk2alHrv0EMPBaBXr16JxJSJMkMREepxZvjyyy8DQV8078knn0xt++mjJHl+0s7i4uK09/30beoLmr/mzZsHpPf08H2M/TR+uUCZoYgIqgxFRIB6fJv89NNPA8Hcdz179gTgqKOOSiwmKc+veeKHWHonnngiAHfccUfcIUnM/CQtYWeffXYCkVROmaGICPUwM9y6dSsAzz77LBBM1HD77bcDwZRekpzwanZ33303UH726m7dugHqRpPPPv74YwAWLlwIpE+icuaZZyYSU2WUGYqIUA8zQz8ZqG+DOvXUUwE4+uijE4tJ0t13332p7aVLl6bt88Px1FaY//7xj38A8MknnwDBdzVXKTMUEaGeZIZ+IlCAO++8E4Bdd90VgJtvvjmRmKRi999/f4X7/PBJtRXmv3Xr1qW99lP05SplhiIi5Hhm6J9KXnXVVan3vv/+ewD69OkDqF9hfePLtDpP/X3274/dvn07AF988UW5Y/1Qr5EjR2Y8V8OGDVPb9957L6DlBKI2e/bstNenn356QpFUjzJDERFUGYqIADl6m7xjxw4gmNli7dq1qX2dOnUCggcpUr/4dWmq45xzzgGgffv2QNBFY+rUqXWKwa++F55DUbLHd7L25VVfKDMUESFHM8M1a9YAwQpqYb7bhua/y13+4RbAzJkza32exx57rMpj/MOVBg3S/6737dsXCNbeDjv22GNrHZNUbcaMGUDwsNPPap3rqx0qMxQRIccyQ99Js3fv3mnvjxgxIrWd64/nBaZPn57aHj58OFB+ogavqKgIqLwd8JJLLgGgoKCg3L5f/vKXAHTu3Ll2wUrWfPPNNwA888wzae/76brC3ZtykTJDERFyLDMcO3YsUH4YT7itwcxijUnqprrr4j7yyCMRRyJR8+23foXKfv36AXD11VcnFlNNKDMUESFHMkPfL+nBBx9MOBIRqS2fGfp1kusbZYYiIuRIZujXQP7qq6/S3vejTTTdk4hETZmhiAiqDEVEgBy5TS7Lr5w2b948AFq3bp1kOCLyA6DMUESEHMkMb7jhhrSfIiJxU2YoIgKYc676B5t9Bqyr8sD8UeCca5t0EHFSGec/lXFmNaoMRUTylW6TRURQZSgiAkT8NNnM9gDmlb5sB+wAPit9/TPnXOYZP+t2zS5AeD6ojsANzjnNAhGBhMq4AJgM7Ak44O8q3+gkUcal150M9AE2OOe6RXGNtOvF1WZoZrcBXzvnRpR530rj2BnBNRsBG4DDnHPrs31+SRdXGZvZ3sCezrkVZtYKWA6c6px7Nxvnl4rF+T02sxOArcC4OCrDRG6TzayTmRWZ2cPAKqCDmf0vtH+gmU0o3d7LzKab2TIzW2pmR9bgUr2At1QRxi/KMnbObXTOrSjd/hJ4G9gnut9GMon6e+ycWwBsiuwXKCPJNsODgJHOuS6UZG8VeQAY7pzrDpwD+P+5PcxsTBXXGAj8KxvBSq1EXsZm9mOgK/BqdkKWGorjexyLJEegrHHOlV8LtLyewIGh6f53N7OmzrklwJKKPmRmTYDTgOvqHKnUVtRl3Ap4AhjinPu6ztFKbURaxnFKsjLcEtreCYQXN2kS2jZq10h7GrDEOVdcy/ik7iIrYzNrDEwHJjnnZtUpSqmLqL/HscmJrjWlja6bzWx/M2sAnBnaPRe40r8ws+o2pJ6HbpFzRjbLuLSx/h/ACufcAxGEK7UQ0fc4NjlRGZa6HpgDLALCDzyuBI4xszfMrAi4FCpvazCzlsDPgZnRhiw1lK0yPoGSP3a9zGxF6X+/iDh2qZ5sfo+nAQuBLma23sx+HWXgGo4nIkJuZYYiIolRZSgigipDERFAlaGICKDKUEQEqGGn6zZt2rjCwsKIQsk9H3zwAcXFxVb1kflDZZz/VMaZ1agyLCwsZNmy6oy8yQ/du3dPOoTYqYzzn8o4M90mi4igylBEBFBlKCICqDIUEQFUGYqIAKoMRUSAZCd3rdCWLSXzRQ4dOhSAMWOCGX78Y/Jp06YBUFBQEHN0IpKPlBmKiJCjmeHGjRsBGD9+PAANGzZM7fOdRWfPng3A73//+5ijk9p47bXXABgwYABQMiqgtp577rnUdufOnQHo0KFD7YOTxPjvcd++fQEYPXo0AIMHD04dE/7+R0mZoYgIOZYZfvbZZwAMGjQo4Ugk2+bMmQPAtm3b6nyuWbOC9Z8mTpwIwNSpU+t8XonP559/DqRngABDhgwB4JJLLkm917Rp01hiUmYoIkKOZIYPPFCywNnMmSXrN736atXrgS9cuBAAv4bLIYccAsDxxx8fRYhSS99//z0ATz/9dNbOGR54f//99wNBD4TmzZtn7ToSnZdeegmADRvS150/77zzAGjSpEm5z0RNmaGICDmSGV5zzTVAzZ4aTZ8+Pe3nfvvtB8Bjjz2WOubwww/PVohSSy+88AIAixYtAuD666+v8zk3bdqU2l61ahUA33zzDaDMMJeF24vvuuuujMdceOGFAJQsjR0vZYYiIqgyFBEBEr5N7tOnDxA8BNmxY0eVn2nTpg0Q3A6tW7cOgLVr1wJwxBFHpI7duXNn9oKValu5cmVqe+DAgQB06tQJgBtvvLHO5w93rZH644033kht+0743i67lFRFp556aqwxhSkzFBEhgcxwwYIFqe23334bCBpLK3qAcvnll6e2e/fuDcCuu+4KwPz58wH4y1/+Uu5zf//734HyHTslWuGy8A82pkyZAkCLFi1qfV7/4CT8byiJhnapHf+wM5NevXrFGElmygxFRIgxM/QD830bEkBxcXHGY303mbPOOguAW2+9NbWvWbNmacf6KbzGjh1b7pzDhg0D4NtvvwWCSR0aNWpUu19CKvX4448D6R2sfVthuC23tnx3jHA2eOKJJwKw22671fn8Eq1wRu81btwYgLvvvjvucMpRZigiQoyZ4fbt24GKs0EIhtI9+uijQPDkuDI+M/RPKa+77rrUPj9Ey2eIfpqgjh071ih2qR4/4a7//w7Zaa/1dxWPPPIIEDx5BLjpppsAZfu5zHe4f+WVV8rt83d63bp1izWmTJQZioiQI8PxfHvSpEmTgOplhGX5rO/hhx9Ovbd06dIsRCdV+eKLLwBYvHhxuX1XXHFFnc8/btw4IJjirUuXLql9J510Up3PL9GqbOKVXOrpocxQRIQEMsNMo0yWLFlS5/P6USzhUSdlR7b4p9K+z5tkhx+Av379eiCYhilb1qxZk/a6a9euWT2/RCtTZuif/mfjziFblBmKiKDKUEQEiPE22a99HNVKV36VreXLl6feKzvM7/bbb4/k2j90LVu2BILuEeGJGvwQutatW9f4vJ9++ikQdNnxjjnmmFrFKfF6+eWXgaBLVJgfTrvvvvvGGlNllBmKiBBjZvjUU09l9Xy+m0VRURFQ+XAe31VHHXOj4Vcv80Pv/LA8gNNOOw1I7wyfyZtvvpna9g9M/PRsZSdjaNBAf8PrA78Cnn+QGZYLEzOUpX9VIiLkSKfr2vDTRD300EMVHlNYWAjA5MmTgWACCInGbbfdBqRnAv6OIDxBRyZt27ZNbftMsKKhmxdffHFdwpSYlG3rDU+mcdlll8UdTpWUGYqIUA8zQ79UgJ8YtjJ+2NZxxx0XaUxSonPnzkD6CoX+6X7ZjtNl+enawgYNGgSU7yTv2yglN/nO92WfIoefHGdjSrdsU2YoIkKMmWFliz4988wzaa8vvfRSADZu3Fjheaoz3Xu2n2BLzR166KFpP2vixz/+ccb3w/0Yf/rTn9YuMImMn7Kr7FPkfv36JRFOtSkzFBFBlaGICBDjbbKft8zPOh3mO+aWHaqXaeiev82uzkp6Ur/526yyt1u6Nc5tvrO15wc9XHPNNUmEU23KDEVEiDEzHDBgAADDhw9PvVfZeihV8X9tfHeO8ePHA9C+fftan1Nyi39IprWR65c5c+akve7QoQMQTM6Qq5QZiogQY2boV7HzK98BzJw5E4BRo0bV+Hx//vOfgWAtZMk/fr1rT52tc5tfAXP16tVp7zdp0gTI/YlSlBmKiJDAcDy/NnJ4u3fv3kCwCpqfqPWMM84A4He/+13qM/7JYniFNMlPfrVEP8D/lltuSTIcqYKfWs0PtVu1ahUA+++/f2Ix1YQyQxERcmSihlNOOSXtpwgEGca1114LaI3kXOf7/vrp9XwvgMMOOyyxmGpCmaGICDmSGYpk4tuOpX7Ze++9AZg4cWLCkdSMMkMREVQZiogAqgxFRABVhiIigCpDERFAlaGICACWabX7Cg82+wxYF104OafAOde26sPyh8o4/6mMM6tRZSgikq90mywigipDERFAlaGICBDx2GQz2wOYV/qyHbAD+Kz09c+cc99FdN0+wEigITDWOfe3KK4jyZVx6bV3AV4D3nfO9Y/qOj90CX6PJwN9gA3OuW5RXCPtenE9QDGz24CvnXMjyrxvpXHszNJ1GgHvAD8HPgaWAb90zr2bjfNLxeIq49B5hwHdgGaqDOMRZxmb2QnAVmBcHJVhIrfJZtbJzIrM7GFgFdDBzP4X2j/QzCaUbu9lZtPNbJmZLTWzI6s4/ZHAW865dc65bcBjQL+ofhfJLOIyxswKgF7ApKh+B6lc1GXsnFsAbIrsFygjyTbDg4CRzrkuwIZKjnsAGO6c6w6cA/j/uT3MbEyG4/cBPgq9Xl/6nsQvqjIGGAUMBdQ3LFlRlnGskpzPcI1zblk1jusJHBhaO3d3M2vqnFsCLIksOsmGSMrYzPoDHznnVphZz+yFK7WQN9/jJCvDLaHtnUB4pfAmoW2jZo20G4AOodf7UvlfLIlOVGV8NDDAzPqWnqeVmU12zg2qU7RSG1GVcexyomtNaaPrZjPb38waAGeGds8FrvQvzKyqhtTFQBczKzCzH1GSks/KdsxSM9ksY+fcMOfcvs65QuAC4DlVhMnL8vc4djlRGZa6HpgDLKKknc+7EjjGzN4wsyLgUqi4rcE5tx24CngeKAKmOOfeiTp4qZaslLHktKyVsZlNAxZSktysN7NfRxm4xiaLiJBbmaGISGJUGYqIoMpQRARQZSgiAtSwn2GbNm1cYWFhRKHkng8++IDi4mKr+sj8oTLOfyrjzGpUGRYWFrJsWXU6m+eH7t27Jx1C7FTG+U9lnJluk0VEUGUoIgKoMhQRAVQZiogAqgxFRABVhiIigCpDEREg2cldRUQA2Lx5MwAffvhhhccUFBQAMHLkSAC6du0KwAEHHADAIYccUqcYlBmKiJBwZvjpp58CcM455wBw9NFHA3DZZZcBJT3ls+GLL74A4KWXXgLglFNOAaBRo0ZZOb+I1MxTTz0FwOzZswF48cUXAXjvvfcq/MyBBx4IlAyvA9i2bVva/p0767ZKqTJDERESyAx92wDAT37yEyDI3Pbaay8g+xnhYYcdBkBxcTFAalzm/vvvn5XrSPV9+eWXAPzpT38CYNWqVQDMnTs3dYwy9vywZs0aAB566CEAxo0bl9q3detWAGoy0/4770S7eocyQxERYswMfVbm2wcBPv/8cwCuvLJk0azRo0dn9Zp33XUXAGvXrgWCv0zKCOM3ZcoUAG666Sag/FNDnzEC7LHHHvEFJpFZv75kPahRo0bV6TwHHXQQEDw9jooyQxERYswMX3vtNSB4ahR2yy23ZO06b775Zmp7xIgRAJx5Zsnyreeee27WriPV47ODa6+9FgjuEMzS59ocMmRIavvBBx8EoHXr1nGEKLXgyxGCzO/YY48Fgt4ajRs3BmDXXXcFoEWLFqnPfP311wD84he/AIKsr0ePHgAceuihqWObNm0KQPPmzbP8W6RTZigigipDEREghttk37H6iSeeKLdv4sSJALRt27bO1/G3x7169Sq3b8CAAQC0bNmyzteRmvFNFf5hWUWmTp2a2n7mmWeA4GGLv4X2t12SnC1btgDp37PXX38dgJkzZ6Yde9RRRwGwfPlyIL3LnH+Atu+++wLQoEHyeVnyEYiI5IDIM8M//OEPQNC1wneABjj77LOzdp2XX34ZgI8//jj13sUXXwzABRdckLXrSNXWrVuX2p40aVLaPj+Y3newf/7558t93neW91nl+eefD0C7du2yH6xUy3fffQfAr371KyDIBgFuvPFGAHr27Jnxs5kGUey3335ZjrDulBmKiBBDZui7UPif++yzT2pfXdqA/HCeu+++GwiG/IS7bPg2SYnXihUrUtu+M/Xxxx8PwIIFCwD49ttvAXjkkUcA+Otf/5r6zOrVq4Egy+/Xrx8QtCWqy018fBcY/z3zEyuE2/mHDh0KQLNmzWKOLruUGYqIkMBEDX7qHoDevXsDsNtuuwEwePDgKj/vO237n4sXL07bn812SKmd8NRKPlP3na69Jk2aAPCb3/wGgMcffzy1zw/w94P4fcahp8nx80+I77nnHiCYYHXhwoWpY3yn6vpOmaGICDFkhldffTUA8+fPB2Djxo2pfb79yGcATz75ZJXn88eWHc7VsWNHIGjbkOT861//Kvfev//9bwD69++f8TN+WrVMjjzySCB9OJfEY9GiRWmv/TA53z8wnygzFBEhhszw8MMPB2DlypVA+pPGZ599FoDhw4cDsOeeewIwaNCgCs934YUXAnDwwQenve+XDPAZoiTnvPPOS237bP/VV18F4O233waCfw8zZswA0if99W3I/j0/9Zov+y5dukQWu6QLt+VC8ET/9ttvT73Xt29fIH1yhfpImaGICKoMRUQAsJqsQdC9e3dXWUN3HN5//30guB3u1q0bAM899xyQnUkfvO7du7Ns2TKr+sj8kY0y3rRpU2rbl5MfYlfRA7DwwH/fgf70008H4N133wWCVRPHjBlTp/jCVMaVKztoIpOGDRsCcPnllwPBnIQfffQRAJ06dQKCNY/C/Bo4flKHKB7MVLeMlRmKiJDwusm1cccddwDBXyr/8CWbGaHUTXi43LRp0wA466yzgPIZ4lVXXQXAvffem/qM75Dtp17zQ/XmzJkDBJ2yQQ/MovbHP/4RgPvuu6/CY3bs2AEEGb3/WRP+4emJJ54IpE/pFhdlhiIi1JPM0GcXAJMnTwagVatWgFZSy3V+WiffRcNPzOC7z/hM32eDYTfffDMAb731FhB00/GfgeDfg0TDD8Pzq1r66dS2b9+eOsavc+MzxNrwk0D773p4JTw/yW/UlBmKiFBPMkPf0TPstNNOA9Ini5Xc5TPEiiYAzcSviuZXNfSZ4QsvvJA6xj+51rRe0fBPio844gggeLIfNm/ePCDIFm+77TYAli5dWuPr+bbk//znPzX+bF0pMxQRoR5mhn7tVP+US/Kfb6+aNWsWkP6k0a+xnM21t6VmTj755LTXfsitzwwbNWoEBMtwAFx66aUAjBw5EgjakpOkzFBEBFWGIiJAjt8m+2FX4RXv/KpqenDyw+HX1B02bBiQvj6vb6wfOHAgAAcccEC8wUk5fgZ7v2qef7DiZx8CeO+994BgxvqywmslxUWZoYgI9SQzDA8S79OnT9oxX331FRDMfZeL67FKdvhJOe68887Ue/5B2g033AAE63P7bjkSv86dOwNBl6hHH3203DHh7lEAu+xSUhX5LnPh4ZlxUWYoIkKOZ4aZ+L8gPgPwj+b98B0Nz8p/F110UWp77NixAEyfPh0I2qLKzoQu8fFZ+ahRo4Dg7i3ckfqTTz4BoLCwEAjK1LcBJ0GZoYgI9TAzHD9+PAATJkwA4Le//S0QDOqX/Beerm3u3LlAsJ6vn1ggFzrx/tD5nh9+rfR//vOfqX2vvPIKEGSCfgqvJCkzFBEhxzPD0aNHA3Drrbem3jv++OMBGDx4MAC77747AI0bN445OskFvveAXzbAD9krKioCtJJeLvGrG5bdzhXKDEVEyPHM8LjjjgNg/vz5CUciuc5PHnvIIYcAsHr1akCZoVSfMkMREVQZiogAOX6bLFJdfk2ctWvXJhyJ1FfKDEVEUGUoIgKoMhQRAcD8alTVOtjsM2BddOHknALnXNuqD8sfKuP8pzLOrEaVoYhIvtJtsogIqgxFRICI+xma2R7AvNKX7YAdwGelr3/mnPsuwmvvArwGvO+c6x/VdX7okipjM7sOuKT05Rjn3OgoriOJlvF6YHPp9bY553pEcZ3U9eJqMzSz24CvnXMjyrxvpXHszPL1hgHdgGaqDOMRVxmbWTdgMnAk8D3wHPAb55x6XEcszu9xaWXY1Tn3v2ydszKJ3CabWSczKzKzh4FVQAcz+19o/0Azm1C6vZeZTTezZWa21MyOrMb5C4BewKSofgepXMRl3BlY7Jzb6pzbDrwEnBnV7yKZRf09jluSbYYHASOdc12ADZUc9wAw3DnXHTgH8P9ze5jZmAo+MwoYCuhRebKiKuOVwAlm1trMmgOnAh2yG7pUU5TfYwfMN7P/mNklFRyTNUmOTV7jnFtWjeN6AgeGlgvd3cyaOueWAEvKHmxm/YGPnHMrzKxn9sKVWoikjJ1zb5rZ/cBc4GtgOSXtShK/SMq41JHOuQ1m1g543szecs4tykLMGSVZGW4Jbe8ELPS6SWjbqFkj7dHAADPrW3qeVmY22Tk3qE7RSm1EVcY458YB4wDMbDiwug5xSu1FWcYbSn9+bGZPAj8DIqsMc6JrTWmj62Yz29/MGpDe/jMXuNK/KG08r+xcw5xz+zrnCoELgOdUESYvm2VcesyepT8Lgb7A1GzGKzWXzTI2sxZm1qJ0uzklzwDezH7UgZyoDEtdD8yhpOZfH3r/SuAYM3vDzIqAS6HKtgbJTdks45mlx84ELnfOfRlh3FJ92Srj9sD/mdnrwFJghnNubpSBazieiAi5lRmKiCRGlaGICKoMRUQAVYYiIoAqQxERQJWhiAigylBEBFBlKCICwP8D3P5bzM0W5d8AAAAASUVORK5CYII=\n",
390 | "text/plain": [
391 | ""
392 | ]
393 | },
394 | "metadata": {},
395 | "output_type": "display_data"
396 | }
397 | ],
398 | "source": [
399 | "# Get the first images from the test-set.\n",
400 | "images = test_data[0:9]\n",
401 | "\n",
402 | "# Get the true classes for those images.\n",
403 | "cls_true = test_label[0:9]\n",
404 | "\n",
405 | "# Plot the images and labels using our helper-function above.\n",
406 | "plot_images(images=images, cls_true=cls_true)"
407 | ]
408 | },
409 | {
410 | "cell_type": "markdown",
411 | "metadata": {},
412 | "source": [
413 | "## TensorFlow图\n",
414 | "\n",
415 | "TensorFlow的全部目的就是使用一个称之为计算图(computational graph)的东西,它会比直接在Python中进行相同计算量要高效得多。TensorFlow比Numpy更高效,因为TensorFlow了解整个需要运行的计算图,然而Numpy只知道某个时间点上唯一的数学运算。\n",
416 | "\n",
417 | "TensorFlow也能够自动地计算需要优化的变量的梯度,使得模型有更好的表现。这是由于图是简单数学表达式的结合,因此整个图的梯度可以用链式法则推导出来。\n",
418 | "\n",
419 | "TensorFlow还能利用多核CPU和GPU,Google也为TensorFlow制造了称为TPUs(Tensor Processing Units)的特殊芯片,它比GPU更快。\n",
420 | "\n",
421 | "一个TensorFlow图由下面几个部分组成,后面会详细描述:\n",
422 | "\n",
423 | "* 占位符变量(Placeholder)用来改变图的输入。\n",
424 | "* 模型变量(Model)将会被优化,使得模型表现得更好。\n",
425 | "* 模型本质上就是一些数学函数,它根据Placeholder和模型的输入变量来计算一些输出。\n",
426 | "* 一个cost度量用来指导变量的优化。\n",
427 | "* 一个优化策略会更新模型的变量。\n",
428 | "\n",
429 | "另外,TensorFlow图也包含了一些调试状态,比如用TensorBoard打印log数据。"
430 | ]
431 | },
432 | {
433 | "cell_type": "markdown",
434 | "metadata": {},
435 | "source": [
436 | "### 创建新变量的帮助函数"
437 | ]
438 | },
439 | {
440 | "cell_type": "markdown",
441 | "metadata": {},
442 | "source": [
443 | "函数用来根据给定大小创建TensorFlow变量,并将它们用随机值初始化。需注意的是在此时并未完成初始化工作,仅仅是在TensorFlow图里定义它们。"
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": 11,
449 | "metadata": {},
450 | "outputs": [],
451 | "source": [
452 | "def new_weights(shape):\n",
453 | " return tf.Variable(tf.truncated_normal(shape, stddev=0.05))"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": 12,
459 | "metadata": {},
460 | "outputs": [],
461 | "source": [
462 | "def new_biases(length):\n",
463 | " return tf.Variable(tf.constant(0.05, shape=[length]))"
464 | ]
465 | },
466 | {
467 | "cell_type": "markdown",
468 | "metadata": {},
469 | "source": [
470 | "### 创建卷积层的帮助函数"
471 | ]
472 | },
473 | {
474 | "cell_type": "markdown",
475 | "metadata": {},
476 | "source": [
477 | "这个函数为TensorFlow在计算图里创建了新的卷积层。这里并没有执行什么计算,只是在TensorFlow图里添加了数学公式。\n",
478 | "\n",
479 | "假设输入的是四维的张量,各个维度如下:\n",
480 | "\n",
481 | "1. 图像数量\n",
482 | "2. 每张图像的Y轴\n",
483 | "3. 每张图像的X轴\n",
484 | "4. 每张图像的通道数\n",
485 | "\n",
486 | "输入通道可能是彩色通道,当输入是前面的卷积层生成的时候,它也可能是滤波通道。\n",
487 | "\n",
488 | "输出是另外一个4通道的张量,如下:\n",
489 | "\n",
490 | "1. 图像数量,与输入相同\n",
491 | "2. 每张图像的Y轴。如果用到了2x2的池化,是输入图像宽高的一半。\n",
492 | "3. 每张图像的X轴。同上。\n",
493 | "4. 卷积滤波生成的通道数。"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 20,
499 | "metadata": {},
500 | "outputs": [],
501 | "source": [
502 | "def new_conv_layer(input, # The previous layer.\n",
503 | " num_input_channels, # Num. channels in prev. layer.\n",
504 | " filter_size, # Width and height of each filter.\n",
505 | " num_filters, # Number of filters.\n",
506 | " use_pooling=True): # Use 2x2 max-pooling.\n",
507 | "\n",
508 | " # Shape of the filter-weights for the convolution.\n",
509 | " # This format is determined by the TensorFlow API.\n",
510 | " shape = [filter_size, filter_size, num_input_channels, num_filters]\n",
511 | "\n",
512 | " # Create new weights aka. filters with the given shape.\n",
513 | " weights = new_weights(shape=shape)\n",
514 | "\n",
515 | " # Create new biases, one for each filter.\n",
516 | " biases = new_biases(length=num_filters)\n",
517 | "\n",
518 | " # Create the TensorFlow operation for convolution.\n",
519 | " # Note the strides are set to 1 in all dimensions.\n",
520 | " # The first and last stride must always be 1,\n",
521 | " # because the first is for the image-number and\n",
522 | " # the last is for the input-channel.\n",
523 | " # But e.g. strides=[1, 2, 2, 1] would mean that the filter\n",
524 | " # is moved 2 pixels across the x- and y-axis of the image.\n",
525 | " # The padding is set to 'SAME' which means the input image\n",
526 | " # is padded with zeroes so the size of the output is the same.\n",
527 | " layer = tf.nn.conv2d(input=input,\n",
528 | " filter=weights,\n",
529 | " strides=[1, 1, 1, 1],\n",
530 | " padding='SAME')\n",
531 | "\n",
532 | " # Add the biases to the results of the convolution.\n",
533 | " # A bias-value is added to each filter-channel.\n",
534 | " layer += biases\n",
535 | "\n",
536 | " # Use pooling to down-sample the image resolution?\n",
537 | " if use_pooling:\n",
538 | " # This is 2x2 max-pooling, which means that we\n",
539 | " # consider 2x2 windows and select the largest value\n",
540 | " # in each window. Then we move 2 pixels to the next window.\n",
541 | " layer = tf.nn.max_pool(value=layer,\n",
542 | " ksize=[1, 2, 2, 1],\n",
543 | " strides=[1, 2, 2, 1],\n",
544 | " padding='SAME')\n",
545 | "\n",
546 | " # Rectified Linear Unit (ReLU).\n",
547 | " # It calculates max(x, 0) for each input pixel x.\n",
548 | " # This adds some non-linearity to the formula and allows us\n",
549 | " # to learn more complicated functions.\n",
550 | " layer = tf.nn.relu(layer)\n",
551 | "\n",
552 | " # Note that ReLU is normally executed before the pooling,\n",
553 | " # but since relu(max_pool(x)) == max_pool(relu(x)) we can\n",
554 | " # save 75% of the relu-operations by max-pooling first.\n",
555 | "\n",
556 | " # We return both the resulting layer and the filter-weights\n",
557 | " # because we will plot the weights later.\n",
558 | " return layer, weights"
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {},
564 | "source": [
565 | "### 转换一个层的帮助函数\n",
566 | "\n",
567 | "卷积层生成了4维的张量。我们会在卷积层之后添加一个全连接层,因此我们需要将这个4维的张量转换成可被全连接层使用的2维张量。"
568 | ]
569 | },
570 | {
571 | "cell_type": "code",
572 | "execution_count": 13,
573 | "metadata": {},
574 | "outputs": [],
575 | "source": [
576 | "def flatten_layer(layer):\n",
577 | " # Get the shape of the input layer.\n",
578 | " layer_shape = layer.get_shape()\n",
579 | "\n",
580 | " # The shape of the input layer is assumed to be:\n",
581 | " # layer_shape == [num_images, img_height, img_width, num_channels]\n",
582 | "\n",
583 | " # The number of features is: img_height * img_width * num_channels\n",
584 | " # We can use a function from TensorFlow to calculate this.\n",
585 | " num_features = layer_shape[1:4].num_elements()\n",
586 | " \n",
587 | " # Reshape the layer to [num_images, num_features].\n",
588 | " # Note that we just set the size of the second dimension\n",
589 | " # to num_features and the size of the first dimension to -1\n",
590 | " # which means the size in that dimension is calculated\n",
591 | " # so the total size of the tensor is unchanged from the reshaping.\n",
592 | " layer_flat = tf.reshape(layer, [-1, num_features])\n",
593 | "\n",
594 | " # The shape of the flattened layer is now:\n",
595 | " # [num_images, img_height * img_width * num_channels]\n",
596 | "\n",
597 | " # Return both the flattened layer and the number of features.\n",
598 | " return layer_flat, num_features"
599 | ]
600 | },
601 | {
602 | "cell_type": "markdown",
603 | "metadata": {},
604 | "source": [
605 | "### 创建一个全连接层的帮助函数"
606 | ]
607 | },
608 | {
609 | "cell_type": "markdown",
610 | "metadata": {},
611 | "source": [
612 | "这个函数为TensorFlow在计算图中创建了一个全连接层。这里也不进行任何计算,只是往TensorFlow图中添加数学公式。\n",
613 | "\n",
614 | "输入是大小为`[num_images, num_inputs]`的二维张量。输出是大小为`[num_images, num_outputs]`的2维张量。"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "execution_count": 14,
620 | "metadata": {},
621 | "outputs": [],
622 | "source": [
623 | "def new_fc_layer(input, # The previous layer.\n",
624 | " num_inputs, # Num. inputs from prev. layer.\n",
625 | " num_outputs, # Num. outputs.\n",
626 | " use_relu=True): # Use Rectified Linear Unit (ReLU)?\n",
627 | "\n",
628 | " # Create new weights and biases.\n",
629 | " weights = new_weights(shape=[num_inputs, num_outputs])\n",
630 | " biases = new_biases(length=num_outputs)\n",
631 | "\n",
632 | " # Calculate the layer as the matrix multiplication of\n",
633 | " # the input and weights, and then add the bias-values.\n",
634 | " layer = tf.matmul(input, weights) + biases\n",
635 | "\n",
636 | " # Use ReLU?\n",
637 | " if use_relu:\n",
638 | " layer = tf.nn.relu(layer)\n",
639 | "\n",
640 | " return layer"
641 | ]
642 | },
643 | {
644 | "cell_type": "markdown",
645 | "metadata": {},
646 | "source": [
647 | "### 占位符 (Placeholder)变量"
648 | ]
649 | },
650 | {
651 | "cell_type": "markdown",
652 | "metadata": {},
653 | "source": [
654 | "Placeholder是作为图的输入,每次我们运行图的时候都可能会改变它们。将这个过程称为feeding placeholder变量,后面将会描述它。\n",
655 | "\n",
656 | "首先我们为输入图像定义placeholder变量。这让我们可以改变输入到TensorFlow图中的图像。这也是一个张量(tensor),代表一个多维向量或矩阵。数据类型设置为float32,形状设为`[None, img_size_flat]`,`None`代表tensor可能保存着任意数量的图像,每张图象是一个长度为`img_size_flat`的向量。"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": 15,
662 | "metadata": {},
663 | "outputs": [],
664 | "source": [
665 | "img_size_flat = IMG_FLAT\n",
666 | "\n",
667 | "x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')"
668 | ]
669 | },
670 | {
671 | "cell_type": "markdown",
672 | "metadata": {},
673 | "source": [
674 | "卷积层希望`x`被编码为4维张量,因此我们需要将它的形状转换至`[num_images, img_height, img_width, num_channels]`。注意`img_height == img_width == img_size`,如果第一维的大小设为-1, `num_images`的大小也会被自动推导出来。转换运算如下:"
675 | ]
676 | },
677 | {
678 | "cell_type": "code",
679 | "execution_count": 16,
680 | "metadata": {},
681 | "outputs": [],
682 | "source": [
683 | "img_size = IMG_HEIGHT\n",
684 | "num_channels = NUM_CHANNEL\n",
685 | "\n",
686 | "x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])"
687 | ]
688 | },
689 | {
690 | "cell_type": "markdown",
691 | "metadata": {},
692 | "source": [
693 | "接下来我们为输入变量`x`中的图像所对应的真实标签定义placeholder变量。变量的形状是`[None, num_classes]`,这代表着它保存了任意数量的标签,每个标签是长度为`num_classes`的向量,本例中长度为10。"
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": 17,
699 | "metadata": {},
700 | "outputs": [],
701 | "source": [
702 | "num_classes = NUM_CLASS\n",
703 | "\n",
704 | "y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')"
705 | ]
706 | },
707 | {
708 | "cell_type": "markdown",
709 | "metadata": {},
710 | "source": [
711 | "我们也可以为class-number提供一个placeholder,但这里用argmax来计算它。这里只是TensorFlow中的一些操作,没有执行什么运算。"
712 | ]
713 | },
714 | {
715 | "cell_type": "code",
716 | "execution_count": 18,
717 | "metadata": {},
718 | "outputs": [],
719 | "source": [
720 | "y_true_cls = tf.argmax(y_true, axis=1)"
721 | ]
722 | },
723 | {
724 | "cell_type": "markdown",
725 | "metadata": {},
726 | "source": [
727 | "### 卷积层 1\n",
728 | "\n",
729 | "创建第一个卷积层。将`x_image`当作输入,创建`num_filters1`个不同的滤波器,每个滤波器的宽高都与 `filter_size1`相等。最终我们会用2x2的max-pooling将图像降采样,使它的尺寸减半。"
730 | ]
731 | },
732 | {
733 | "cell_type": "code",
734 | "execution_count": 21,
735 | "metadata": {},
736 | "outputs": [],
737 | "source": [
738 | "layer_conv1, weights_conv1 = \\\n",
739 | " new_conv_layer(input=x_image,\n",
740 | " num_input_channels=num_channels,\n",
741 | " filter_size=filter_size1,\n",
742 | " num_filters=num_filters1,\n",
743 | " use_pooling=True)"
744 | ]
745 | },
746 | {
747 | "cell_type": "markdown",
748 | "metadata": {},
749 | "source": [
750 | "检查卷积层输出张量的大小。它是(?,14, 14, 16),这代表着有任意数量的图像(?代表数量),每张图像有14个像素的宽和高,有16个不同的通道,每个滤波器各有一个通道。"
751 | ]
752 | },
753 | {
754 | "cell_type": "code",
755 | "execution_count": 22,
756 | "metadata": {},
757 | "outputs": [
758 | {
759 | "data": {
760 | "text/plain": [
761 | ""
762 | ]
763 | },
764 | "execution_count": 22,
765 | "metadata": {},
766 | "output_type": "execute_result"
767 | }
768 | ],
769 | "source": [
770 | "layer_conv1"
771 | ]
772 | },
773 | {
774 | "cell_type": "markdown",
775 | "metadata": {},
776 | "source": [
777 | "### 卷积层 2\n",
778 | "\n",
779 | "创建第二个卷积层,它将第一个卷积层的输出作为输入。输入通道的数量对应着第一个卷积层的滤波数。"
780 | ]
781 | },
782 | {
783 | "cell_type": "code",
784 | "execution_count": 23,
785 | "metadata": {},
786 | "outputs": [],
787 | "source": [
788 | "layer_conv2, weights_conv2 = \\\n",
789 | " new_conv_layer(input=layer_conv1,\n",
790 | " num_input_channels=num_filters1,\n",
791 | " filter_size=filter_size2,\n",
792 | " num_filters=num_filters2,\n",
793 | " use_pooling=True)"
794 | ]
795 | },
796 | {
797 | "cell_type": "markdown",
798 | "metadata": {},
799 | "source": [
800 | "核对一下这个卷积层输出张量的大小。它的大小是(?, 7, 7, 36),其中?也代表着任意数量的图像,每张图有7像素的宽高,每个滤波器有36个通道。"
801 | ]
802 | },
803 | {
804 | "cell_type": "code",
805 | "execution_count": 24,
806 | "metadata": {},
807 | "outputs": [
808 | {
809 | "data": {
810 | "text/plain": [
811 | ""
812 | ]
813 | },
814 | "execution_count": 24,
815 | "metadata": {},
816 | "output_type": "execute_result"
817 | }
818 | ],
819 | "source": [
820 | "layer_conv2"
821 | ]
822 | },
823 | {
824 | "cell_type": "markdown",
825 | "metadata": {},
826 | "source": [
827 | "### 转换层\n",
828 | "\n",
829 | "这个卷积层输出一个4维张量。现在我们想将它作为一个全连接网络的输入,这就需要将它转换成2维张量。"
830 | ]
831 | },
832 | {
833 | "cell_type": "code",
834 | "execution_count": 25,
835 | "metadata": {},
836 | "outputs": [],
837 | "source": [
838 | "layer_flat, num_features = flatten_layer(layer_conv2)"
839 | ]
840 | },
841 | {
842 | "cell_type": "markdown",
843 | "metadata": {},
844 | "source": [
845 | "这个张量的大小是(?, 1764),意味着共有一定数量的图像,每张图像被转换成长为1764的向量。其中1764 = 7 x 7 x 36。"
846 | ]
847 | },
848 | {
849 | "cell_type": "code",
850 | "execution_count": 26,
851 | "metadata": {},
852 | "outputs": [
853 | {
854 | "data": {
855 | "text/plain": [
856 | ""
857 | ]
858 | },
859 | "execution_count": 26,
860 | "metadata": {},
861 | "output_type": "execute_result"
862 | }
863 | ],
864 | "source": [
865 | "layer_flat"
866 | ]
867 | },
868 | {
869 | "cell_type": "code",
870 | "execution_count": 27,
871 | "metadata": {},
872 | "outputs": [
873 | {
874 | "data": {
875 | "text/plain": [
876 | "1764"
877 | ]
878 | },
879 | "execution_count": 27,
880 | "metadata": {},
881 | "output_type": "execute_result"
882 | }
883 | ],
884 | "source": [
885 | "num_features"
886 | ]
887 | },
888 | {
889 | "cell_type": "markdown",
890 | "metadata": {},
891 | "source": [
892 | "### 全连接层 1\n",
893 | "\n",
894 | "往网络中添加一个全连接层。输入是一个前面卷积得到的被转换过的层。全连接层中的神经元或节点数为`fc_size`。我们可以用ReLU来学习非线性关系。"
895 | ]
896 | },
897 | {
898 | "cell_type": "code",
899 | "execution_count": 28,
900 | "metadata": {},
901 | "outputs": [],
902 | "source": [
903 | "layer_fc1 = new_fc_layer(input=layer_flat,\n",
904 | " num_inputs=num_features,\n",
905 | " num_outputs=fc_size,\n",
906 | " use_relu=True)"
907 | ]
908 | },
909 | {
910 | "cell_type": "markdown",
911 | "metadata": {},
912 | "source": [
913 | "全连接层的输出是一个大小为(?,128)的张量,?代表着一定数量的图像,并且`fc_size` == 128。"
914 | ]
915 | },
916 | {
917 | "cell_type": "code",
918 | "execution_count": 29,
919 | "metadata": {},
920 | "outputs": [
921 | {
922 | "data": {
923 | "text/plain": [
924 | ""
925 | ]
926 | },
927 | "execution_count": 29,
928 | "metadata": {},
929 | "output_type": "execute_result"
930 | }
931 | ],
932 | "source": [
933 | "layer_fc1"
934 | ]
935 | },
936 | {
937 | "cell_type": "markdown",
938 | "metadata": {},
939 | "source": [
940 | "### 全连接层 2\n",
941 | "\n",
942 | "添加另外一个全连接层,它的输出是一个长度为10的向量,它确定了输入图是属于哪个类别。这层并没有用到ReLU。"
943 | ]
944 | },
945 | {
946 | "cell_type": "code",
947 | "execution_count": 30,
948 | "metadata": {},
949 | "outputs": [],
950 | "source": [
951 | "layer_fc2 = new_fc_layer(input=layer_fc1,\n",
952 | " num_inputs=fc_size,\n",
953 | " num_outputs=num_classes,\n",
954 | " use_relu=False)"
955 | ]
956 | },
957 | {
958 | "cell_type": "code",
959 | "execution_count": 31,
960 | "metadata": {},
961 | "outputs": [
962 | {
963 | "data": {
964 | "text/plain": [
965 | ""
966 | ]
967 | },
968 | "execution_count": 31,
969 | "metadata": {},
970 | "output_type": "execute_result"
971 | }
972 | ],
973 | "source": [
974 | "layer_fc2"
975 | ]
976 | },
977 | {
978 | "cell_type": "markdown",
979 | "metadata": {},
980 | "source": [
981 | "### 预测类别"
982 | ]
983 | },
984 | {
985 | "cell_type": "markdown",
986 | "metadata": {},
987 | "source": [
988 | "第二个全连接层估算了输入图有多大的可能属于10个类别中的其中一个。然而,这是很粗略的估计并且很难解释,因为数值可能很小或很大,因此我们会对它们做归一化,将每个元素限制在0到1之间,并且相加为1。这用一个称为softmax的函数来计算的,结果保存在`y_pred`中。"
989 | ]
990 | },
991 | {
992 | "cell_type": "code",
993 | "execution_count": 32,
994 | "metadata": {},
995 | "outputs": [],
996 | "source": [
997 | "y_pred = tf.nn.softmax(layer_fc2)"
998 | ]
999 | },
1000 | {
1001 | "cell_type": "markdown",
1002 | "metadata": {},
1003 | "source": [
1004 | "类别数字是最大元素的索引。"
1005 | ]
1006 | },
1007 | {
1008 | "cell_type": "code",
1009 | "execution_count": 33,
1010 | "metadata": {},
1011 | "outputs": [],
1012 | "source": [
1013 | "y_pred_cls = tf.argmax(y_pred, axis=1)"
1014 | ]
1015 | },
1016 | {
1017 | "cell_type": "markdown",
1018 | "metadata": {},
1019 | "source": [
1020 | "### 优化损失函数"
1021 | ]
1022 | },
1023 | {
1024 | "cell_type": "markdown",
1025 | "metadata": {},
1026 | "source": [
1027 | "为了使模型更好地对输入图像进行分类,我们必须改变`weights`和`biases`变量。首先我们需要对比模型`y_pred`的预测输出和期望输出的`y_true`,来了解目前模型的性能如何。\n",
1028 | "\n",
1029 | "交叉熵(cross-entropy)是在分类中使用的性能度量。交叉熵是一个常为正值的连续函数,如果模型的预测值精准地符合期望的输出,它就等于零。因此,优化的目的就是通过改变网络层的变量来最小化交叉熵。\n",
1030 | "\n",
1031 | "TensorFlow有一个内置的计算交叉熵的函数。这个函数内部计算了softmax,所以我们要用`layer_fc2`的输出而非直接用`y_pred`,因为`y_pred`上已经计算了softmax。"
1032 | ]
1033 | },
1034 | {
1035 | "cell_type": "code",
1036 | "execution_count": 34,
1037 | "metadata": {},
1038 | "outputs": [
1039 | {
1040 | "name": "stdout",
1041 | "output_type": "stream",
1042 | "text": [
1043 | "WARNING:tensorflow:From :2: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
1044 | "Instructions for updating:\n",
1045 | "\n",
1046 | "Future major versions of TensorFlow will allow gradients to flow\n",
1047 | "into the labels input on backprop by default.\n",
1048 | "\n",
1049 | "See @{tf.nn.softmax_cross_entropy_with_logits_v2}.\n",
1050 | "\n"
1051 | ]
1052 | }
1053 | ],
1054 | "source": [
1055 | "cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=layer_fc2,\n",
1056 | " labels=y_true)"
1057 | ]
1058 | },
1059 | {
1060 | "cell_type": "markdown",
1061 | "metadata": {},
1062 | "source": [
1063 | "我们为每个图像分类计算了交叉熵,所以有一个当前模型在每张图上表现的度量。但是为了用交叉熵来指导模型变量的优化,我们需要一个额外的标量值,因此简单地利用所有图像分类交叉熵的均值。"
1064 | ]
1065 | },
1066 | {
1067 | "cell_type": "code",
1068 | "execution_count": 35,
1069 | "metadata": {},
1070 | "outputs": [],
1071 | "source": [
1072 | "cost = tf.reduce_mean(cross_entropy)"
1073 | ]
1074 | },
1075 | {
1076 | "cell_type": "markdown",
1077 | "metadata": {},
1078 | "source": [
1079 | "### 优化方法"
1080 | ]
1081 | },
1082 | {
1083 | "cell_type": "markdown",
1084 | "metadata": {},
1085 | "source": [
1086 | "既然我们有一个需要被最小化的损失度量,接着就可以建立优化一个优化器。这个例子中,我们使用的是梯度下降的变体`AdamOptimizer`。\n",
1087 | "\n",
1088 | "优化过程并不是在这里执行。实际上,还没计算任何东西,我们只是往TensorFlow图中添加了优化器,以便之后的操作。"
1089 | ]
1090 | },
1091 | {
1092 | "cell_type": "code",
1093 | "execution_count": 36,
1094 | "metadata": {},
1095 | "outputs": [],
1096 | "source": [
1097 | "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)"
1098 | ]
1099 | },
1100 | {
1101 | "cell_type": "markdown",
1102 | "metadata": {},
1103 | "source": [
1104 | "### 性能度量"
1105 | ]
1106 | },
1107 | {
1108 | "cell_type": "markdown",
1109 | "metadata": {},
1110 | "source": [
1111 | "我们需要另外一些性能度量,来向用户展示这个过程。\n",
1112 | "\n",
1113 | "这是一个布尔值向量,代表预测类型是否等于每张图片的真实类型。"
1114 | ]
1115 | },
1116 | {
1117 | "cell_type": "code",
1118 | "execution_count": 37,
1119 | "metadata": {},
1120 | "outputs": [],
1121 | "source": [
1122 | "correct_prediction = tf.equal(y_pred_cls, y_true_cls)\n"
1123 | ]
1124 | },
1125 | {
1126 | "cell_type": "markdown",
1127 | "metadata": {},
1128 | "source": [
1129 | "上面的计算先将布尔值向量类型转换成浮点型向量,这样子False就变成0,True变成1,然后计算这些值的平均数,以此来计算分类的准确度。"
1130 | ]
1131 | },
1132 | {
1133 | "cell_type": "code",
1134 | "execution_count": 38,
1135 | "metadata": {},
1136 | "outputs": [
1137 | {
1138 | "data": {
1139 | "text/plain": [
1140 | ""
1141 | ]
1142 | },
1143 | "execution_count": 38,
1144 | "metadata": {},
1145 | "output_type": "execute_result"
1146 | }
1147 | ],
1148 | "source": [
1149 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
1150 | "tf.summary.scalar('accuracy',accuracy)"
1151 | ]
1152 | },
1153 | {
1154 | "cell_type": "markdown",
1155 | "metadata": {},
1156 | "source": [
1157 | "## 运行TensorFlow"
1158 | ]
1159 | },
1160 | {
1161 | "cell_type": "markdown",
1162 | "metadata": {},
1163 | "source": [
1164 | "### 创建TensorFlow会话(session)\n",
1165 | "\n",
1166 | "一旦创建了TensorFlow图,我们需要创建一个TensorFlow会话,用来运行图。"
1167 | ]
1168 | },
1169 | {
1170 | "cell_type": "code",
1171 | "execution_count": 39,
1172 | "metadata": {},
1173 | "outputs": [],
1174 | "source": [
1175 | "session = tf.Session()"
1176 | ]
1177 | },
1178 | {
1179 | "cell_type": "code",
1180 | "execution_count": 47,
1181 | "metadata": {},
1182 | "outputs": [],
1183 | "source": [
1184 | "train_log_save_dir = '/Users/honglan/Desktop/tensorflow_estimator_learn/cnn_raw_log'\n",
1185 | "model_save_path = '/Users/honglan/Desktop/tensorflow_estimator_learn/cnn_raw'\n",
1186 | "merged = tf.summary.merge_all()\n",
1187 | "train_writer = tf.summary.FileWriter(train_log_save_dir, session.graph)\n",
1188 | "saver = tf.train.Saver()"
1189 | ]
1190 | },
1191 | {
1192 | "cell_type": "markdown",
1193 | "metadata": {},
1194 | "source": [
1195 | "### 初始化变量\n",
1196 | "\n",
1197 | "我们需要在开始优化weights和biases变量之前对它们进行初始化。"
1198 | ]
1199 | },
1200 | {
1201 | "cell_type": "code",
1202 | "execution_count": 41,
1203 | "metadata": {},
1204 | "outputs": [],
1205 | "source": [
1206 | "session.run(tf.global_variables_initializer())"
1207 | ]
1208 | },
1209 | {
1210 | "cell_type": "markdown",
1211 | "metadata": {},
1212 | "source": [
1213 | "### 用来优化迭代的帮助函数"
1214 | ]
1215 | },
1216 | {
1217 | "cell_type": "markdown",
1218 | "metadata": {},
1219 | "source": [
1220 | "在训练集中有50,000张图。用这些图像计算模型的梯度会花很多时间。因此我们利用随机梯度下降的方法,它在优化器的每次迭代里只用到了一小部分的图像。\n",
1221 | "\n",
1222 | "如果内存耗尽导致电脑死机或变得很慢,你应该试着减少这些数量,但同时可能还需要更优化的迭代。"
1223 | ]
1224 | },
1225 | {
1226 | "cell_type": "code",
1227 | "execution_count": 42,
1228 | "metadata": {},
1229 | "outputs": [],
1230 | "source": [
1231 | "train_batch_size = 64"
1232 | ]
1233 | },
1234 | {
1235 | "cell_type": "markdown",
1236 | "metadata": {},
1237 | "source": [
1238 | "函数执行了多次的优化迭代来逐步地提升网络层的变量。在每次迭代中,从训练集中选择一批新的数据,然后TensorFlow用这些训练样本来执行优化器。每100次迭代会打印出相关信息。"
1239 | ]
1240 | },
1241 | {
1242 | "cell_type": "code",
1243 | "execution_count": 48,
1244 | "metadata": {},
1245 | "outputs": [],
1246 | "source": [
1247 | "# Counter for total number of iterations performed so far.\n",
1248 | "total_iterations = 0\n",
1249 | "\n",
1250 | "\n",
1251 | "def optimize(num_iterations):\n",
1252 | " # Ensure we update the global variable rather than a local copy.\n",
1253 | " global total_iterations\n",
1254 | "\n",
1255 | " # Start-time used for printing time-usage below.\n",
1256 | " start_time = time.time()\n",
1257 | "\n",
1258 | " for i in range(total_iterations,\n",
1259 | " total_iterations + num_iterations):\n",
1260 | " # Get a batch of training examples.\n",
1261 | " # x_batch now holds a batch of images and\n",
1262 | " # y_true_batch are the true labels for those images.\n",
1263 | " x_batch, y_true_batch, _ = batch(batch_size=train_batch_size)\n",
1264 | "\n",
1265 | " # Put the batch into a dict with the proper names\n",
1266 | " # for placeholder variables in the TensorFlow graph.\n",
1267 | " feed_dict_train = {x: x_batch,\n",
1268 | " y_true: y_true_batch}\n",
1269 | "\n",
1270 | " # Run the optimizer using this batch of training data.\n",
1271 | " # TensorFlow assigns the variables in feed_dict_train\n",
1272 | " # to the placeholder variables and then runs the optimizer.\n",
1273 | " session.run(optimizer, feed_dict=feed_dict_train)\n",
1274 | "\n",
1275 | " # Print status every 100 iterations.\n",
1276 | " if i % 100 == 0:\n",
1277 | " # Calculate the accuracy on the training-set.\n",
1278 | " acc,summary = session.run([accuracy,merged], feed_dict=feed_dict_train)\n",
1279 | " train_writer.add_summary(summary,total_iterations)\n",
1280 | " \n",
1281 | " # Message for printing.\n",
1282 | " msg = \"Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}\"\n",
1283 | " saver.save(sess=session, save_path=model_save_path+'/'+'cnn_raw',global_step=total_iterations)\n",
1284 | " \n",
1285 | " # Print it.\n",
1286 | " print(msg.format(i + 1, acc))\n",
1287 | "\n",
1288 | " # Update the total number of iterations performed.\n",
1289 | " total_iterations += num_iterations\n",
1290 | "\n",
1291 | " # Ending time.\n",
1292 | " end_time = time.time()\n",
1293 | "\n",
1294 | " # Difference between start and end-times.\n",
1295 | " time_dif = end_time - start_time\n",
1296 | "\n",
1297 | " # Print the time-usage.\n",
1298 | " print(\"Time usage: \" + str(timedelta(seconds=int(round(time_dif)))))"
1299 | ]
1300 | },
1301 | {
1302 | "cell_type": "markdown",
1303 | "metadata": {},
1304 | "source": [
1305 | "### 终于可以开始训练了"
1306 | ]
1307 | },
1308 | {
1309 | "cell_type": "code",
1310 | "execution_count": 49,
1311 | "metadata": {},
1312 | "outputs": [
1313 | {
1314 | "name": "stdout",
1315 | "output_type": "stream",
1316 | "text": [
1317 | "Optimization Iteration: 1, Training Accuracy: 95.3%\n",
1318 | "Optimization Iteration: 101, Training Accuracy: 95.3%\n",
1319 | "Optimization Iteration: 201, Training Accuracy: 95.3%\n",
1320 | "Optimization Iteration: 301, Training Accuracy: 90.6%\n",
1321 | "Optimization Iteration: 401, Training Accuracy: 93.8%\n",
1322 | "Optimization Iteration: 501, Training Accuracy: 93.8%\n",
1323 | "Optimization Iteration: 601, Training Accuracy: 98.4%\n",
1324 | "Optimization Iteration: 701, Training Accuracy: 93.8%\n",
1325 | "Optimization Iteration: 801, Training Accuracy: 95.3%\n",
1326 | "Optimization Iteration: 901, Training Accuracy: 98.4%\n",
1327 | "Time usage: 0:01:31\n"
1328 | ]
1329 | }
1330 | ],
1331 | "source": [
1332 | "optimize(1000)"
1333 | ]
1334 | },
1335 | {
1336 | "cell_type": "markdown",
1337 | "metadata": {},
1338 | "source": [
1339 | "### 关闭TensorFlow会话"
1340 | ]
1341 | },
1342 | {
1343 | "cell_type": "markdown",
1344 | "metadata": {},
1345 | "source": [
1346 | "现在我们已经用TensorFlow完成了任务,关闭session,释放资源。"
1347 | ]
1348 | },
1349 | {
1350 | "cell_type": "code",
1351 | "execution_count": 50,
1352 | "metadata": {},
1353 | "outputs": [],
1354 | "source": [
1355 | "# This has been commented out in case you want to modify and experiment\n",
1356 | "# with the Notebook without having to restart it.\n",
1357 | "session.close()"
1358 | ]
1359 | },
1360 | {
1361 | "cell_type": "markdown",
1362 | "metadata": {},
1363 | "source": [
1364 | "## 总结\n",
1365 | "\n",
1366 | "- 常规TensorFlow模型训练的步骤,大家应该都了解\n",
1367 | "\n",
1368 | "- 有没有觉得有点费劲,训练和验证还没有写,我已经懒得写了(留给大家有兴趣的自己写一写)\n"
1369 | ]
1370 | },
1371 | {
1372 | "cell_type": "code",
1373 | "execution_count": null,
1374 | "metadata": {},
1375 | "outputs": [],
1376 | "source": []
1377 | }
1378 | ],
1379 | "metadata": {
1380 | "anaconda-cloud": {},
1381 | "kernelspec": {
1382 | "display_name": "Python 3",
1383 | "language": "python",
1384 | "name": "python3"
1385 | },
1386 | "language_info": {
1387 | "codemirror_mode": {
1388 | "name": "ipython",
1389 | "version": 3
1390 | },
1391 | "file_extension": ".py",
1392 | "mimetype": "text/x-python",
1393 | "name": "python",
1394 | "nbconvert_exporter": "python",
1395 | "pygments_lexer": "ipython3",
1396 | "version": "3.5.0"
1397 | }
1398 | },
1399 | "nbformat": 4,
1400 | "nbformat_minor": 1
1401 | }
1402 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/DNNClassifier_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# TensorFlow 那些事儿之DL中的 HELLO WORLD\n",
8 | "\n",
9 | "\n",
10 | "- 基于MNIST数据集,运用TensorFlow中 **tf.estimator** 预制的 **tf.estimator.DNNClassifier** 搭建一个简单的多层神经网络,实现模型的训练,验证和测试\n",
11 | "\n",
12 | "- TensorBoard的简单使用\n"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "## 看看MNIST数据集的样子\n"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "### 导入各个库"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "data": {
36 | "text/plain": [
37 | "'1.8.0'"
38 | ]
39 | },
40 | "execution_count": 1,
41 | "metadata": {},
42 | "output_type": "execute_result"
43 | }
44 | ],
45 | "source": [
46 | "%matplotlib inline\n",
47 | "import tensorflow as tf\n",
48 | "import matplotlib.pyplot as plt\n",
49 | "import numpy as np\n",
50 | "import pandas as pd\n",
51 | "import multiprocessing\n",
52 | "\n",
53 | "\n",
54 | "from tensorflow import data\n",
55 | "from tensorflow.python.feature_column import feature_column\n",
56 | "\n",
57 | "tf.__version__"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "### MNIST数据集载入"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 2,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "TRAIN_DATA_FILES_PATTERN = 'data_csv/mnist_train.csv'\n",
74 | "VAL_DATA_FILES_PATTERN = 'data_csv/mnist_val.csv'\n",
75 | "TEST_DATA_FILES_PATTERN = 'data_csv/mnist_test.csv'\n",
76 | "\n",
77 | "MULTI_THREADING = True\n",
78 | "RESUME_TRAINING = False\n",
79 | "\n",
80 | "NUM_CLASS = 10\n",
81 | "IMG_SHAPE = [28,28]\n",
82 | "\n",
83 | "IMG_WIDTH = 28\n",
84 | "IMG_HEIGHT = 28\n",
85 | "BATCH_SIZE = 128"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 3,
91 | "metadata": {},
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "test_data (10000, 784)\n",
98 | "test_label (10000,)\n",
99 | "val_data (5000, 784)\n",
100 | "val_label (5000,)\n",
101 | "train_data (55000, 784)\n",
102 | "train_label (55000,)\n"
103 | ]
104 | }
105 | ],
106 | "source": [
107 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN)\n",
108 | "# train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None, names=HEADER )\n",
109 | "train_data = pd.read_csv(TRAIN_DATA_FILES_PATTERN, header=None)\n",
110 | "test_data = pd.read_csv(TEST_DATA_FILES_PATTERN, header=None)\n",
111 | "val_data = pd.read_csv(VAL_DATA_FILES_PATTERN, header=None)\n",
112 | "\n",
113 | "train_values = train_data.values\n",
114 | "train_data = train_values[:,1:]/255.0\n",
115 | "train_label = train_values[:,0:1].squeeze()\n",
116 | "\n",
117 | "val_values = val_data.values\n",
118 | "val_data = val_values[:,1:]/255.0\n",
119 | "val_label = val_values[:,0:1].squeeze()\n",
120 | "\n",
121 | "test_values = test_data.values\n",
122 | "test_data = test_values[:,1:]/255.0\n",
123 | "test_label = test_values[:,0:1].squeeze()\n",
124 | "\n",
125 | "print('test_data',np.shape(test_data))\n",
126 | "print('test_label',np.shape(test_label))\n",
127 | "\n",
128 | "print('val_data',np.shape(val_data))\n",
129 | "print('val_label',np.shape(val_label))\n",
130 | "\n",
131 | "print('train_data',np.shape(train_data))\n",
132 | "print('train_label',np.shape(train_label))\n",
133 | "\n",
134 | "# train_data.head(10)\n",
135 | "# test_data.head(10)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 4,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "img_shape = IMG_SHAPE\n",
145 | "\n",
146 | "def plot_images(images, cls_true, cls_pred=None):\n",
147 | " assert len(images) == len(cls_true) == 9\n",
148 | " \n",
149 | " # Create figure with 3x3 sub-plots.\n",
150 | " fig, axes = plt.subplots(3, 3)\n",
151 | " fig.subplots_adjust(hspace=0.3, wspace=0.3)\n",
152 | "\n",
153 | " for i, ax in enumerate(axes.flat):\n",
154 | " # Plot image.\n",
155 | " ax.imshow(images[i].reshape(img_shape), cmap='binary')\n",
156 | "\n",
157 | " # Show true and predicted classes.\n",
158 | " if cls_pred is None:\n",
159 | " xlabel = \"True: {0}\".format(cls_true[i])\n",
160 | " else:\n",
161 | " xlabel = \"True: {0}, Pred: {1}\".format(cls_true[i], cls_pred[i])\n",
162 | "\n",
163 | " # Show the classes as the label on the x-axis.\n",
164 | " ax.set_xlabel(xlabel)\n",
165 | " \n",
166 | " # Remove ticks from the plot.\n",
167 | " ax.set_xticks([])\n",
168 | " ax.set_yticks([])\n",
169 | " \n",
170 | " # Ensure the plot is shown correctly with multiple plots\n",
171 | " # in a single Notebook cell.\n",
172 | " plt.show()"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": [
179 | "## 重头戏之怎么用 tf.estimator.DNNClassifier "
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "metadata": {},
185 | "source": [
186 | "### 先看看input_fn之创建输入函数\n",
187 | "\n",
188 | "- 采用 **datasetAPI** 构造输入函数"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": 5,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "# validate tf.data.TextLineDataset() using make_one_shot_iterator()\n",
198 | "\n",
199 | "def decode_line(line):\n",
200 | " # Decode the csv_line to tensor.\n",
201 | " record_defaults = [[1.0] for col in range(785)]\n",
202 | " items = tf.decode_csv(line, record_defaults)\n",
203 | " features = items[1:785]\n",
204 | " label = items[0]\n",
205 | "\n",
206 | " features = tf.cast(features, tf.float32)\n",
207 | " features = tf.reshape(features,[28,28,1])\n",
208 | " label = tf.cast(label, tf.int64)\n",
209 | "# label = tf.one_hot(label,num_class)\n",
210 | " return features,label"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 6,
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "def csv_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL, \n",
220 | " skip_header_lines=1, \n",
221 | " num_epochs=None, \n",
222 | " batch_size=128):\n",
223 | " shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False\n",
224 | " \n",
225 | " num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1\n",
226 | " \n",
227 | " print(\"\")\n",
228 | " print(\"* data input_fn:\")\n",
229 | " print(\"================\")\n",
230 | " print(\"Input file(s): {}\".format(files_name_pattern))\n",
231 | " print(\"Batch size: {}\".format(batch_size))\n",
232 | " print(\"Epoch Count: {}\".format(num_epochs))\n",
233 | " print(\"Mode: {}\".format(mode))\n",
234 | " print(\"Thread Count: {}\".format(num_threads))\n",
235 | " print(\"Shuffle: {}\".format(shuffle))\n",
236 | " print(\"================\")\n",
237 | " print(\"\")\n",
238 | "\n",
239 | " file_names = tf.matching_files(files_name_pattern)\n",
240 | " dataset = data.TextLineDataset(filenames=file_names).skip(1)\n",
241 | "# dataset = tf.data.TextLineDataset(filenames).skip(1)\n",
242 | " print(\"DATASET\",dataset)\n",
243 | "\n",
244 | " # Use `Dataset.map()` to build a pair of a feature dictionary and a label\n",
245 | " # tensor for each example.\n",
246 | " dataset = dataset.map(decode_line)\n",
247 | " print(\"DATASET_1\",dataset)\n",
248 | " dataset = dataset.shuffle(buffer_size=10000)\n",
249 | " print(\"DATASET_2\",dataset)\n",
250 | " dataset = dataset.batch(32)\n",
251 | " print(\"DATASET_3\",dataset)\n",
252 | " dataset = dataset.repeat(num_epochs)\n",
253 | " print(\"DATASET_4\",dataset)\n",
254 | " iterator = dataset.make_one_shot_iterator()\n",
255 | " \n",
256 | " # `features` is a dictionary in which each value is a batch of values for\n",
257 | " # that feature; `labels` is a batch of labels.\n",
258 | " features, labels = iterator.get_next()\n",
259 | " \n",
260 | " features = {'images':features}\n",
261 | " \n",
262 | " return features,labels\n"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": 7,
268 | "metadata": {},
269 | "outputs": [
270 | {
271 | "name": "stdout",
272 | "output_type": "stream",
273 | "text": [
274 | "\n",
275 | "* data input_fn:\n",
276 | "================\n",
277 | "Input file(s): data_csv/mnist_train.csv\n",
278 | "Batch size: 128\n",
279 | "Epoch Count: None\n",
280 | "Mode: eval\n",
281 | "Thread Count: 4\n",
282 | "Shuffle: False\n",
283 | "================\n",
284 | "\n",
285 | "DATASET \n",
286 | "DATASET_1 \n",
287 | "DATASET_2 \n",
288 | "DATASET_3 \n",
289 | "DATASET_4 \n",
290 | "Features in CSV: ['images']\n",
291 | "Target in CSV: Tensor(\"IteratorGetNext:1\", shape=(?,), dtype=int64)\n"
292 | ]
293 | }
294 | ],
295 | "source": [
296 | "features, target = csv_input_fn(files_name_pattern=TRAIN_DATA_FILES_PATTERN)\n",
297 | "print(\"Features in CSV: {}\".format(list(features.keys())))\n",
298 | "print(\"Target in CSV: {}\".format(target))"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "metadata": {},
304 | "source": [
305 | "### 定义feature_columns"
306 | ]
307 | },
308 | {
309 | "cell_type": "code",
310 | "execution_count": 8,
311 | "metadata": {},
312 | "outputs": [],
313 | "source": [
314 | "feature_x = tf.feature_column.numeric_column('images', shape=[28,28])\n",
315 | "# print((feature_x))\n",
316 | "\n",
317 | "feature_columns = [feature_x]\n",
318 | "# print((feature_columns))"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 9,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "num_hidden_units = [512, 256, 128]"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "metadata": {},
333 | "source": [
334 | "### DNNClassifier来啦"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 11,
340 | "metadata": {},
341 | "outputs": [
342 | {
343 | "name": "stdout",
344 | "output_type": "stream",
345 | "text": [
346 | "INFO:tensorflow:Using default config.\n",
347 | "INFO:tensorflow:Using config: {'_master': '', '_num_worker_replicas': 1, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_tf_random_seed': None, '_service': None, '_cluster_spec': , '_model_dir': './simple_dnn_dataset', '_num_ps_replicas': 0, '_save_checkpoints_steps': None, '_evaluation_master': '', '_save_summary_steps': 100, '_log_step_count_steps': 100, '_global_id_in_cluster': 0, '_train_distribute': None, '_is_chief': True, '_task_id': 0, '_save_checkpoints_secs': 600, '_task_type': 'worker'}\n"
348 | ]
349 | }
350 | ],
351 | "source": [
352 | "num_class = NUM_CLASS\n",
353 | "\n",
354 | "model = tf.estimator.DNNClassifier(feature_columns = feature_columns,\n",
355 | " hidden_units = num_hidden_units,\n",
356 | " activation_fn = tf.nn.relu,\n",
357 | " n_classes = num_class,\n",
358 | " model_dir = './simple_dnn_dataset')"
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "metadata": {},
364 | "source": [
365 | "### 愉快滴训练吧"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": 14,
371 | "metadata": {},
372 | "outputs": [
373 | {
374 | "name": "stdout",
375 | "output_type": "stream",
376 | "text": [
377 | "\n",
378 | "* data input_fn:\n",
379 | "================\n",
380 | "Input file(s): data_csv/mnist_train.csv\n",
381 | "Batch size: 128\n",
382 | "Epoch Count: None\n",
383 | "Mode: train\n",
384 | "Thread Count: 4\n",
385 | "Shuffle: True\n",
386 | "================\n",
387 | "\n",
388 | "DATASET \n",
389 | "DATASET_1 \n",
390 | "DATASET_2 \n",
391 | "DATASET_3 \n",
392 | "DATASET_4 \n",
393 | "INFO:tensorflow:Calling model_fn.\n",
394 | "INFO:tensorflow:Done calling model_fn.\n",
395 | "INFO:tensorflow:Create CheckpointSaverHook.\n",
396 | "INFO:tensorflow:Graph was finalized.\n",
397 | "INFO:tensorflow:Restoring parameters from ./simple_dnn_dataset/model.ckpt-200\n",
398 | "INFO:tensorflow:Running local_init_op.\n",
399 | "INFO:tensorflow:Done running local_init_op.\n",
400 | "INFO:tensorflow:Saving checkpoints for 201 into ./simple_dnn_dataset/model.ckpt.\n",
401 | "INFO:tensorflow:loss = 26.95501, step = 201\n",
402 | "INFO:tensorflow:global_step/sec: 15.0276\n",
403 | "INFO:tensorflow:loss = 23.322294, step = 301 (6.655 sec)\n",
404 | "INFO:tensorflow:global_step/sec: 13.8421\n",
405 | "INFO:tensorflow:loss = 17.458122, step = 401 (7.225 sec)\n",
406 | "INFO:tensorflow:global_step/sec: 13.3083\n",
407 | "INFO:tensorflow:loss = 21.524231, step = 501 (7.517 sec)\n",
408 | "INFO:tensorflow:global_step/sec: 14.5015\n",
409 | "INFO:tensorflow:loss = 21.863522, step = 601 (6.892 sec)\n",
410 | "INFO:tensorflow:global_step/sec: 13.2937\n",
411 | "INFO:tensorflow:loss = 12.238069, step = 701 (7.524 sec)\n",
412 | "INFO:tensorflow:global_step/sec: 13.6131\n",
413 | "INFO:tensorflow:loss = 19.554596, step = 801 (7.345 sec)\n",
414 | "INFO:tensorflow:global_step/sec: 12.5833\n",
415 | "INFO:tensorflow:loss = 4.9210396, step = 901 (7.948 sec)\n",
416 | "INFO:tensorflow:global_step/sec: 13.2139\n",
417 | "INFO:tensorflow:loss = 8.347723, step = 1001 (7.566 sec)\n",
418 | "INFO:tensorflow:global_step/sec: 14.5858\n",
419 | "INFO:tensorflow:loss = 17.034126, step = 1101 (6.856 sec)\n",
420 | "INFO:tensorflow:global_step/sec: 14.5617\n",
421 | "INFO:tensorflow:loss = 21.071743, step = 1201 (6.866 sec)\n",
422 | "INFO:tensorflow:global_step/sec: 14.7257\n",
423 | "INFO:tensorflow:loss = 11.271985, step = 1301 (6.791 sec)\n",
424 | "INFO:tensorflow:global_step/sec: 14.9258\n",
425 | "INFO:tensorflow:loss = 7.7849083, step = 1401 (6.700 sec)\n",
426 | "INFO:tensorflow:global_step/sec: 14.8296\n",
427 | "INFO:tensorflow:loss = 7.3179874, step = 1501 (6.743 sec)\n",
428 | "INFO:tensorflow:global_step/sec: 15.3108\n",
429 | "INFO:tensorflow:loss = 5.9724092, step = 1601 (6.532 sec)\n",
430 | "INFO:tensorflow:global_step/sec: 111.22\n",
431 | "INFO:tensorflow:loss = 23.16468, step = 1701 (0.899 sec)\n",
432 | "INFO:tensorflow:global_step/sec: 165.726\n",
433 | "INFO:tensorflow:loss = 15.113611, step = 1801 (0.603 sec)\n",
434 | "INFO:tensorflow:global_step/sec: 164.038\n",
435 | "INFO:tensorflow:loss = 17.828293, step = 1901 (0.610 sec)\n",
436 | "INFO:tensorflow:global_step/sec: 3.84192\n",
437 | "INFO:tensorflow:loss = 10.36054, step = 2001 (26.032 sec)\n",
438 | "INFO:tensorflow:global_step/sec: 13.1081\n",
439 | "INFO:tensorflow:loss = 10.766257, step = 2101 (7.626 sec)\n",
440 | "INFO:tensorflow:Saving checkpoints for 2200 into ./simple_dnn_dataset/model.ckpt.\n",
441 | "INFO:tensorflow:Loss for final step: 10.364952.\n"
442 | ]
443 | },
444 | {
445 | "data": {
446 | "text/plain": [
447 | ""
448 | ]
449 | },
450 | "execution_count": 14,
451 | "metadata": {},
452 | "output_type": "execute_result"
453 | }
454 | ],
455 | "source": [
456 | "input_fn = lambda: csv_input_fn(\\\n",
457 | " files_name_pattern= TRAIN_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.TRAIN)\n",
458 | "\n",
459 | "model.train(input_fn, steps = 2000)"
460 | ]
461 | },
462 | {
463 | "cell_type": "markdown",
464 | "metadata": {},
465 | "source": [
466 | "### 验证一下呗"
467 | ]
468 | },
469 | {
470 | "cell_type": "code",
471 | "execution_count": 15,
472 | "metadata": {},
473 | "outputs": [
474 | {
475 | "name": "stdout",
476 | "output_type": "stream",
477 | "text": [
478 | "\n",
479 | "* data input_fn:\n",
480 | "================\n",
481 | "Input file(s): data_csv/mnist_val.csv\n",
482 | "Batch size: 128\n",
483 | "Epoch Count: None\n",
484 | "Mode: eval\n",
485 | "Thread Count: 4\n",
486 | "Shuffle: False\n",
487 | "================\n",
488 | "\n",
489 | "DATASET \n",
490 | "DATASET_1 \n",
491 | "DATASET_2 \n",
492 | "DATASET_3 \n",
493 | "DATASET_4 \n",
494 | "INFO:tensorflow:Calling model_fn.\n",
495 | "INFO:tensorflow:Done calling model_fn.\n",
496 | "INFO:tensorflow:Starting evaluation at 2018-10-25-03:38:01\n",
497 | "INFO:tensorflow:Graph was finalized.\n",
498 | "INFO:tensorflow:Restoring parameters from ./simple_dnn_dataset/model.ckpt-2200\n",
499 | "INFO:tensorflow:Running local_init_op.\n",
500 | "INFO:tensorflow:Done running local_init_op.\n",
501 | "INFO:tensorflow:Evaluation [1/1]\n",
502 | "INFO:tensorflow:Finished evaluation at 2018-10-25-03:38:10\n",
503 | "INFO:tensorflow:Saving dict for global step 2200: accuracy = 0.9375, average_loss = 0.14245859, global_step = 2200, loss = 4.558675\n"
504 | ]
505 | },
506 | {
507 | "data": {
508 | "text/plain": [
509 | "{'accuracy': 0.9375,\n",
510 | " 'average_loss': 0.14245859,\n",
511 | " 'global_step': 2200,\n",
512 | " 'loss': 4.558675}"
513 | ]
514 | },
515 | "execution_count": 15,
516 | "metadata": {},
517 | "output_type": "execute_result"
518 | }
519 | ],
520 | "source": [
521 | "input_fn = lambda: csv_input_fn(files_name_pattern= VAL_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.EVAL)\n",
522 | "\n",
523 | "model.evaluate(input_fn,steps=1)"
524 | ]
525 | },
526 | {
527 | "cell_type": "markdown",
528 | "metadata": {},
529 | "source": [
530 | "### 测试测试吧"
531 | ]
532 | },
533 | {
534 | "cell_type": "code",
535 | "execution_count": 19,
536 | "metadata": {},
537 | "outputs": [
538 | {
539 | "name": "stdout",
540 | "output_type": "stream",
541 | "text": [
542 | "\n",
543 | "* data input_fn:\n",
544 | "================\n",
545 | "Input file(s): data_csv/mnist_test.csv\n",
546 | "Batch size: 10\n",
547 | "Epoch Count: None\n",
548 | "Mode: infer\n",
549 | "Thread Count: 4\n",
550 | "Shuffle: False\n",
551 | "================\n",
552 | "\n",
553 | "DATASET \n",
554 | "DATASET_1 \n",
555 | "DATASET_2 \n",
556 | "DATASET_3 \n",
557 | "DATASET_4 \n",
558 | "INFO:tensorflow:Calling model_fn.\n",
559 | "INFO:tensorflow:Done calling model_fn.\n",
560 | "INFO:tensorflow:Graph was finalized.\n",
561 | "INFO:tensorflow:Restoring parameters from ./simple_dnn_dataset/model.ckpt-2200\n",
562 | "INFO:tensorflow:Running local_init_op.\n",
563 | "INFO:tensorflow:Done running local_init_op.\n",
564 | "\n",
565 | "* Predicted Classes: [b'0', b'7', b'6', b'3', b'4', b'6', b'5', b'8', b'5', b'6']\n"
566 | ]
567 | }
568 | ],
569 | "source": [
570 | "import itertools\n",
571 | "\n",
572 | "input_fn = lambda: csv_input_fn(\\\n",
573 | " files_name_pattern= TEST_DATA_FILES_PATTERN,mode=tf.estimator.ModeKeys.PREDICT,batch_size=10)\n",
574 | "\n",
575 | "predictions = list(itertools.islice(model.predict(input_fn=input_fn),10))\n",
576 | "# print('PREDICTIONS',predictions)\n",
577 | "print(\"\")\n",
578 | "print(\"* Predicted Classes: {}\".format(list(map(lambda item: item[\"classes\"][0]\n",
579 | " ,predictions))))\n",
580 | "\n"
581 | ]
582 | },
583 | {
584 | "cell_type": "code",
585 | "execution_count": null,
586 | "metadata": {},
587 | "outputs": [],
588 | "source": []
589 | }
590 | ],
591 | "metadata": {
592 | "kernelspec": {
593 | "display_name": "Python 3",
594 | "language": "python",
595 | "name": "python3"
596 | },
597 | "language_info": {
598 | "codemirror_mode": {
599 | "name": "ipython",
600 | "version": 3
601 | },
602 | "file_extension": ".py",
603 | "mimetype": "text/x-python",
604 | "name": "python",
605 | "nbconvert_exporter": "python",
606 | "pygments_lexer": "ipython3",
607 | "version": "3.5.0"
608 | }
609 | },
610 | "nbformat": 4,
611 | "nbformat_minor": 2
612 | }
613 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/data_csv/mnist_test.csv:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:51c292478d94ec3a01461bdfa82eb0885d262eb09e615679b2d69dedb6ad09e7
3 | size 18289443
4 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/data_csv/mnist_train.csv:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:fb1eb744fad41aefc48109fa1694b8385a541b9ff13c5954a646e37ffd4b87f6
3 | size 100447555
4 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/data_csv/mnist_val.csv:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:46106534c381da5f74701cdddc5cde0834abd3a9634d70f3d5beb0d15c8d1675
3 | size 9128008
4 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/02_convolution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/02_convolution.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/02_network_flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/02_network_flowchart.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/0_TF_HELLO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/0_TF_HELLO.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/dataset_classes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/dataset_classes.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/estimator_types.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/estimator_types.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/feed_tf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/feed_tf.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/feed_tf_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/feed_tf_out.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/inputs_to_model_bridge.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/inputs_to_model_bridge.jpg
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/pt_sum_code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/pt_sum_code.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/pt_sum_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/pt_sum_output.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tensorflow_programming_environment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tensorflow_programming_environment.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tensors_flowing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tensors_flowing.gif
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_feed_out_wrong2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_feed_out_wrong2.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_feed_wrong.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_feed_wrong.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_feed_wrong_out_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_feed_wrong_out_1.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_graph.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sess_code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sess_code.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sess_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sess_output.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sum_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_graph.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sum_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_output.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sum_sess.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_sess.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sum_sess_code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_sess_code.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tf_sum_sess_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tf_sum_sess_out.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tfe_sum_code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tfe_sum_code.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/images/tfe_sum_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lanhongvp/tensorflow_estimator_tutorial/70164f2fee715f10fbd8a013e63ea68f7a83c41b/tensorflow_estimator_learn/images/tfe_sum_output.png
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/tmp/basic_pt.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | two_node_pt = torch.tensor([2])
4 | three_node_pt = torch.tensor([3])
5 | sum_node_pt = two_node_pt + three_node_pt
6 |
7 | print('TWO_NODE_PT',two_node_pt)
8 | print('THREEE_NODE_PT',three_node_pt)
9 | print('SUM_NODE_PT',sum_node_pt)
10 |
11 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/tmp/basic_tf.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | two_node = tf.constant(2)
4 | three_node = tf.constant(3)
5 | sum_node = two_node + three_node
6 |
7 | sess = tf.Session()
8 | two_node,three_node,sum_node = \
9 | sess.run([two_node,three_node,sum_node])
10 |
11 | print('TWO_NODE',two_node)
12 | print('THREE_NODE',three_node)
13 | print('SUM_NODE',sum_node)
14 |
15 | # print('TWO_NODE',two_node)
16 | # print('THREEE_NODE',three_node)
17 | # print('SUM_NODE',sum_node)
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/tmp/basic_tfe.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.eager as tfe
3 | tfe.enable_eager_execution()
4 |
5 | two_node_tfe = tf.constant(2)
6 | three_node_tfe = tf.constant(3)
7 | sum_node_tfe = two_node_tfe + three_node_tfe
8 |
9 | print('TWO_NODE_TFE',two_node_tfe)
10 | print('THREE_NODE_TFE',three_node_tfe)
11 | print('SUM_NODE_TFE',sum_node_tfe)
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/tmp/feed_tf.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | input_placeholder = tf.placeholder(tf.int32)
4 | sess = tf.Session()
5 | input = sess.run(\
6 | input_placeholder, feed_dict={input_placeholder: 2})
7 |
8 | print('INPUT',input)
9 |
10 |
--------------------------------------------------------------------------------
/tensorflow_estimator_learn/tmp/feed_tf_wrong.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | input_placeholder = tf.placeholder(tf.int32)
4 | three_node = tf.constant(3)
5 | sum_node = input_placeholder + three_node
6 | sess = tf.Session()
7 |
8 | print('THREE_NODE',sess.run(three_node))
9 | print('SUM_NODE',sess.run(sum_node))
10 |
11 |
--------------------------------------------------------------------------------