├── .gitattributes ├── LICENSE ├── README.md ├── data ├── train.csv ├── train_output.tfrecords ├── val.csv └── val_output.tfrecords └── tf_dataset_learn.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | *.csv filter=lfs diff=lfs merge=lfs -text 2 | *.tfrecords filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hong Lan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow_dataset_learn 2 | Learn tensorflow tf.data API 3 | -------------------------------------------------------------------------------- /data/train.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b35cc952c291f03d995f48721373e69454f689e299d724bd9a5927bdabddfeed 3 | size 76775041 4 | -------------------------------------------------------------------------------- /data/train_output.tfrecords: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2fb651cf308c149364f3c82905678d48fe488e5e7659b1274894e7de01005291 3 | size 47300000 4 | -------------------------------------------------------------------------------- /data/val.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b35cc952c291f03d995f48721373e69454f689e299d724bd9a5927bdabddfeed 3 | size 76775041 4 | -------------------------------------------------------------------------------- /data/val_output.tfrecords: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2fb651cf308c149364f3c82905678d48fe488e5e7659b1274894e7de01005291 3 | size 47300000 4 | -------------------------------------------------------------------------------- /tf_dataset_learn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import numpy as np" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 3, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "DATASET \n", 23 | "IMAGES Tensor(\"Reshape:0\", shape=(28, 28, 1), dtype=uint8)\n", 24 | "LABELS Tensor(\"one_hot:0\", shape=(10,), dtype=float32)\n", 25 | "DATASET_1 \n", 26 | "DATASET_2 \n", 27 | "DATASET_3 \n", 28 | "DATASET_4 \n", 29 | "FEATURES {'image_raw': }\n", 30 | "LABELS Tensor(\"IteratorGetNext_1:1\", shape=(?, 10), dtype=float32)\n", 31 | "SESS_RUN_LABELS \n", 32 | " [[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 33 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 34 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 35 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 36 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 37 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 38 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 39 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 40 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 41 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 42 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 43 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 44 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 45 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 46 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 47 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 48 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 49 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 50 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 51 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 52 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 53 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 54 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 55 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 56 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 57 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 58 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 59 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 60 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 61 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 62 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 63 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# Validate tf.data.TFRecordDataset() using make_one_shot_iterator()\n", 69 | "import tensorflow as tf\n", 70 | "\n", 71 | "num_epochs = 2\n", 72 | "num_class = 10\n", 73 | "sess = tf.Session()\n", 74 | "\n", 75 | "# Use `tf.parse_single_example()` to extract data from a `tf.Example`\n", 76 | "# protocol buffer, and perform any additional per-record preprocessing.\n", 77 | "def parser(record):\n", 78 | " keys_to_features = {\n", 79 | " \"image_raw\": tf.FixedLenFeature((), tf.string, default_value=\"\"),\n", 80 | " \"pixels\": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),\n", 81 | " \"label\": tf.FixedLenFeature((), tf.int64,\n", 82 | " default_value=tf.zeros([], dtype=tf.int64)),\n", 83 | " }\n", 84 | " parsed = tf.parse_single_example(record, keys_to_features)\n", 85 | "\n", 86 | " # Parse the string into an array of pixels corresponding to the image\n", 87 | " images = tf.decode_raw(parsed[\"image_raw\"],tf.uint8)\n", 88 | " images = tf.reshape(images,[28,28,1])\n", 89 | " labels = tf.cast(parsed['label'], tf.int32)\n", 90 | " labels = tf.one_hot(labels,num_class)\n", 91 | " pixels = tf.cast(parsed['pixels'], tf.int32)\n", 92 | " print(\"IMAGES\",images)\n", 93 | " print(\"LABELS\",labels)\n", 94 | " \n", 95 | " return {\"image_raw\": images}, labels\n", 96 | "\n", 97 | "\n", 98 | "filenames = [\"/Users/honglan/Desktop/train_output.tfrecords\"] \n", 99 | "# replace the filenames with your own path\n", 100 | "dataset = tf.data.TFRecordDataset(filenames)\n", 101 | "print(\"DATASET\",dataset)\n", 102 | "\n", 103 | "# Use `Dataset.map()` to build a pair of a feature dictionary and a label\n", 104 | "# tensor for each example.\n", 105 | "dataset = dataset.map(parser)\n", 106 | "print(\"DATASET_1\",dataset)\n", 107 | "dataset = dataset.shuffle(buffer_size=10000)\n", 108 | "print(\"DATASET_2\",dataset)\n", 109 | "dataset = dataset.batch(32)\n", 110 | "print(\"DATASET_3\",dataset)\n", 111 | "dataset = dataset.repeat(num_epochs)\n", 112 | "print(\"DATASET_4\",dataset)\n", 113 | "iterator = dataset.make_one_shot_iterator()\n", 114 | "\n", 115 | "# `features` is a dictionary in which each value is a batch of values for\n", 116 | "# that feature; `labels` is a batch of labels.\n", 117 | "features, labels = iterator.get_next()\n", 118 | "\n", 119 | "print(\"FEATURES\",features)\n", 120 | "print(\"LABELS\",labels)\n", 121 | "print(\"SESS_RUN_LABELS \\n\",sess.run(labels))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "IMAGES Tensor(\"Reshape:0\", shape=(28, 28, 1), dtype=uint8)\n", 134 | "LABELS Tensor(\"one_hot:0\", shape=(10,), dtype=float32)\n", 135 | "DATASET \n", 136 | "ITERATOR \n", 137 | "FEATURES {'image_raw': }\n", 138 | "LABELS Tensor(\"IteratorGetNext_2:1\", shape=(?, 10), dtype=float32)\n", 139 | "TRAIN\n", 140 | " [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 141 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 142 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 143 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 144 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 145 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 146 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 147 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 148 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 149 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 150 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 151 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 152 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 153 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 154 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 155 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 156 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 157 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 158 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 159 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 160 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 161 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 162 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 163 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 164 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 165 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 166 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 167 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 168 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 169 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 170 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 171 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n", 172 | "VAL\n", 173 | " [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 174 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 175 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 176 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 177 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 178 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 179 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 180 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 181 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 182 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 183 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 184 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 185 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 186 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 187 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 188 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 189 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 190 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 191 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 192 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 193 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 194 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 195 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 196 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 197 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 198 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 199 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 200 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 201 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 202 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 203 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 204 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "# Validate tf.data.TFRecordDataset() using make_initializable_iterator()\n", 210 | "# In order to switch between train and validation data\n", 211 | "num_epochs = 2\n", 212 | "num_class = 10\n", 213 | "\n", 214 | "def parser(record):\n", 215 | " keys_to_features = {\n", 216 | " \"image_raw\": tf.FixedLenFeature((), tf.string, default_value=\"\"),\n", 217 | " \"pixels\": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),\n", 218 | " \"label\": tf.FixedLenFeature((), tf.int64,\n", 219 | " default_value=tf.zeros([], dtype=tf.int64)),\n", 220 | " }\n", 221 | " parsed = tf.parse_single_example(record, keys_to_features)\n", 222 | " \n", 223 | " # Parse the string into an array of pixels corresponding to the image\n", 224 | " images = tf.decode_raw(parsed[\"image_raw\"],tf.uint8)\n", 225 | " images = tf.reshape(images,[28,28,1])\n", 226 | " labels = tf.cast(parsed['label'], tf.int32)\n", 227 | " labels = tf.one_hot(labels,10)\n", 228 | " pixels = tf.cast(parsed['pixels'], tf.int32)\n", 229 | " print(\"IMAGES\",images)\n", 230 | " print(\"LABELS\",labels)\n", 231 | " \n", 232 | " return {\"image_raw\": images}, labels\n", 233 | "\n", 234 | "\n", 235 | "filenames = tf.placeholder(tf.string, shape=[None])\n", 236 | "dataset = tf.data.TFRecordDataset(filenames)\n", 237 | "dataset = dataset.map(parser) # Parse the record into tensors\n", 238 | "# print(\"DATASET\",dataset)\n", 239 | "dataset = dataset.shuffle(buffer_size=10000)\n", 240 | "dataset = dataset.batch(32)\n", 241 | "dataset = dataset.repeat(num_epochs)\n", 242 | "print(\"DATASET\",dataset)\n", 243 | "iterator = dataset.make_initializable_iterator()\n", 244 | "features, labels = iterator.get_next()\n", 245 | "print(\"ITERATOR\",iterator)\n", 246 | "print(\"FEATURES\",features)\n", 247 | "print(\"LABELS\",labels)\n", 248 | "\n", 249 | "\n", 250 | "# Initialize `iterator` with training data.\n", 251 | "training_filenames = [\"/Users/honglan/Desktop/train_output.tfrecords\"] \n", 252 | "# replace the filenames with your own path\n", 253 | "sess.run(iterator.initializer,feed_dict={filenames: training_filenames})\n", 254 | "print(\"TRAIN\\n\",sess.run(labels))\n", 255 | "# print(sess.run(features))\n", 256 | "\n", 257 | "# Initialize `iterator` with validation data.\n", 258 | "validation_filenames = [\"/Users/honglan/Desktop/val_output.tfrecords\"] \n", 259 | "# replace the filenames with your own path\n", 260 | "sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})\n", 261 | "print(\"VAL\\n\",sess.run(labels))\n" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 5, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "DATASET \n", 274 | "DATASET_1 \n", 275 | "DATASET_2 \n", 276 | "DATASET_3 \n", 277 | "DATASET_4 \n", 278 | "FEATURES Tensor(\"IteratorGetNext_3:0\", shape=(?, 28, 28, 1), dtype=float32)\n", 279 | "LABELS Tensor(\"IteratorGetNext_3:1\", shape=(?, 10), dtype=float32)\n", 280 | "SESS_RUN_LABELS\n", 281 | " [[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 282 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 283 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 284 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 285 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 286 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 287 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 288 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 289 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 290 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 291 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 292 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 293 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 294 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 295 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 296 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 297 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 298 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 299 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 300 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 301 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 302 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 303 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 304 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 305 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 306 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 307 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 308 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 309 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 310 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 311 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 312 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]\n" 313 | ] 314 | } 315 | ], 316 | "source": [ 317 | "# validate tf.data.TextLineDataset() using make_one_shot_iterator()\n", 318 | "\n", 319 | "def decode_line(line):\n", 320 | " # Decode the csv_line to tensor.\n", 321 | " record_defaults = [[1.0] for col in range(785)]\n", 322 | " items = tf.decode_csv(line, record_defaults)\n", 323 | " features = items[1:785]\n", 324 | " label = items[0]\n", 325 | "\n", 326 | " features = tf.cast(features, tf.float32)\n", 327 | " features = tf.reshape(features,[28,28,1])\n", 328 | " label = tf.cast(label, tf.int64)\n", 329 | " label = tf.one_hot(label,num_class)\n", 330 | " return features,label\n", 331 | "\n", 332 | "\n", 333 | "filenames = [\"/Users/honglan/Desktop/train.csv\"] \n", 334 | "# replace the filenames with your own path\n", 335 | "dataset = tf.data.TextLineDataset(filenames).skip(1)\n", 336 | "print(\"DATASET\",dataset)\n", 337 | "\n", 338 | "# Use `Dataset.map()` to build a pair of a feature dictionary and a label\n", 339 | "# tensor for each example.\n", 340 | "dataset = dataset.map(decode_line)\n", 341 | "print(\"DATASET_1\",dataset)\n", 342 | "dataset = dataset.shuffle(buffer_size=10000)\n", 343 | "print(\"DATASET_2\",dataset)\n", 344 | "dataset = dataset.batch(32)\n", 345 | "print(\"DATASET_3\",dataset)\n", 346 | "dataset = dataset.repeat(num_epochs)\n", 347 | "print(\"DATASET_4\",dataset)\n", 348 | "iterator = dataset.make_one_shot_iterator()\n", 349 | "\n", 350 | "# `features` is a dictionary in which each value is a batch of values for\n", 351 | "# that feature; `labels` is a batch of labels.\n", 352 | "features, labels = iterator.get_next()\n", 353 | "\n", 354 | "print(\"FEATURES\",features)\n", 355 | "print(\"LABELS\",labels)\n", 356 | "print(\"SESS_RUN_LABELS\\n\",sess.run(labels))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 6, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "DATASET \n", 369 | "ITERATOR \n", 370 | "FEATURES Tensor(\"IteratorGetNext_4:0\", shape=(?, 28, 28, 1), dtype=float32)\n", 371 | "LABELS Tensor(\"IteratorGetNext_4:1\", shape=(?, 10), dtype=float32)\n", 372 | "TRAIN\n", 373 | " [[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 374 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 375 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 376 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 377 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 378 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 379 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 380 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 381 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 382 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 383 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 384 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 385 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 386 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 387 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 388 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 389 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 390 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 391 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 392 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 393 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 394 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 395 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 396 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 397 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 398 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 399 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 400 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 401 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 402 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 403 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 404 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]\n", 405 | "VAL\n", 406 | " [[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 407 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 408 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 409 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 410 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 411 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 412 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 413 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 414 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 415 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 416 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 417 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 418 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 419 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 420 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 421 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 422 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 423 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 424 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 425 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 426 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 427 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 428 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 429 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 430 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 431 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 432 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 433 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 434 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 435 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 436 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 437 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]\n" 438 | ] 439 | } 440 | ], 441 | "source": [ 442 | "# validate tf.data.TextLineDataset() using make_initializable_iterator()\n", 443 | "# In order to switch between train and validation data\n", 444 | "\n", 445 | "def decode_line(line):\n", 446 | " # Decode the csv_line to tensor.\n", 447 | " record_defaults = [[1.0] for col in range(785)]\n", 448 | " items = tf.decode_csv(line, record_defaults)\n", 449 | " features = items[1:785]\n", 450 | " label = items[0]\n", 451 | "\n", 452 | " features = tf.cast(features, tf.float32)\n", 453 | " features = tf.reshape(features,[28,28,1])\n", 454 | " label = tf.cast(label, tf.int64)\n", 455 | " label = tf.one_hot(label,num_class)\n", 456 | " return features,label\n", 457 | "\n", 458 | "\n", 459 | "filenames = tf.placeholder(tf.string, shape=[None])\n", 460 | "dataset = tf.data.TextLineDataset(filenames).skip(1)\n", 461 | "dataset = dataset.map(decode_line) # Parse the record into tensors\n", 462 | "# print(\"DATASET\",dataset)\n", 463 | "dataset = dataset.shuffle(buffer_size=10000)\n", 464 | "dataset = dataset.batch(32)\n", 465 | "dataset = dataset.repeat()\n", 466 | "print(\"DATASET\",dataset)\n", 467 | "iterator = dataset.make_initializable_iterator()\n", 468 | "features, labels = iterator.get_next()\n", 469 | "print(\"ITERATOR\",iterator)\n", 470 | "print(\"FEATURES\",features)\n", 471 | "print(\"LABELS\",labels)\n", 472 | "\n", 473 | "\n", 474 | "# Initialize `iterator` with training data.\n", 475 | "training_filenames = [\"/Users/honglan/Desktop/train.csv\"]\n", 476 | "sess.run(iterator.initializer,feed_dict={filenames: training_filenames})\n", 477 | "print(\"TRAIN\\n\",sess.run(labels))\n", 478 | "# print(sess.run(features))\n", 479 | "\n", 480 | "# Initialize `iterator` with validation data.\n", 481 | "validation_filenames = [\"/Users/honglan/Desktop/val.csv\"]\n", 482 | "sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})\n", 483 | "print(\"VAL\\n\",sess.run(labels))\n" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 10, 489 | "metadata": {}, 490 | "outputs": [ 491 | { 492 | "name": "stdout", 493 | "output_type": "stream", 494 | "text": [ 495 | "TRAIN\n", 496 | " [[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 497 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 498 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 499 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 500 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 501 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 502 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 503 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 504 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 505 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 506 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 507 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 508 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 509 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 510 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 511 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 512 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 513 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 514 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 515 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 516 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 517 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 518 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 519 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 520 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 521 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 522 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 523 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 524 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 525 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 526 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 527 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]\n", 528 | "VAL\n", 529 | " [[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 530 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 531 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 532 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 533 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 534 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 535 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 536 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 537 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 538 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 539 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 540 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 541 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 542 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 543 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 544 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 545 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 546 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 547 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 548 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 549 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 550 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 551 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 552 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 553 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 554 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 555 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 556 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 557 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 558 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 559 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 560 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]\n" 561 | ] 562 | } 563 | ], 564 | "source": [ 565 | "# validate tf.data.TextLineDataset() using Reinitializable iterator\n", 566 | "# In order to switch between train and validation data\n", 567 | "\n", 568 | "def decode_line(line):\n", 569 | " # Decode the csv_line to tensor.\n", 570 | " record_defaults = [[1.0] for col in range(785)]\n", 571 | " items = tf.decode_csv(line, record_defaults)\n", 572 | " features = items[1:785]\n", 573 | " label = items[0]\n", 574 | "\n", 575 | " features = tf.cast(features, tf.float32)\n", 576 | " features = tf.reshape(features,[28,28,1])\n", 577 | " label = tf.cast(label, tf.int64)\n", 578 | " label = tf.one_hot(label,num_class)\n", 579 | " return features,label\n", 580 | "\n", 581 | "\n", 582 | "def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):\n", 583 | " \"\"\"create dataset for train and validation dataset\"\"\"\n", 584 | " dataset = tf.data.TextLineDataset(filename).skip(1)\n", 585 | " if n_repeats > 0:\n", 586 | " dataset = dataset.repeat(n_repeats) # for train\n", 587 | " # dataset = dataset.map(decode_line).map(normalize) \n", 588 | " dataset = dataset.map(decode_line) \n", 589 | " # decode and normalize\n", 590 | " if is_shuffle:\n", 591 | " dataset = dataset.shuffle(10000) # shuffle\n", 592 | " dataset = dataset.batch(batch_size)\n", 593 | " return dataset\n", 594 | "\n", 595 | "\n", 596 | "training_filenames = [\"/Users/honglan/Desktop/train.csv\"] \n", 597 | "# replace the filenames with your own path\n", 598 | "validation_filenames = [\"/Users/honglan/Desktop/val.csv\"] \n", 599 | "# replace the filenames with your own path\n", 600 | "\n", 601 | "# Create different datasets\n", 602 | "training_dataset = create_dataset(training_filenames, batch_size=32, \\\n", 603 | " is_shuffle=True, n_repeats=num_epochs) # train_filename\n", 604 | "validation_dataset = create_dataset(validation_filenames, batch_size=32, \\\n", 605 | " is_shuffle=True, n_repeats=num_epochs) # val_filename\n", 606 | "\n", 607 | "# A reinitializable iterator is defined by its structure. We could use the\n", 608 | "# `output_types` and `output_shapes` properties of either `training_dataset`\n", 609 | "# or `validation_dataset` here, because they are compatible.\n", 610 | "iterator = tf.data.Iterator.from_structure(training_dataset.output_types,\n", 611 | " training_dataset.output_shapes)\n", 612 | "features, labels = iterator.get_next()\n", 613 | "\n", 614 | "training_init_op = iterator.make_initializer(training_dataset)\n", 615 | "validation_init_op = iterator.make_initializer(validation_dataset)\n", 616 | "\n", 617 | "# Using reinitializable iterator to alternate between training and validation.\n", 618 | "sess.run(training_init_op)\n", 619 | "print(\"TRAIN\\n\",sess.run(labels))\n", 620 | "# print(sess.run(features))\n", 621 | "\n", 622 | "# Reinitialize `iterator` with validation data.\n", 623 | "sess.run(validation_init_op)\n", 624 | "print(\"VAL\\n\",sess.run(labels))\n" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 7, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "name": "stdout", 634 | "output_type": "stream", 635 | "text": [ 636 | "TRAIN\n", 637 | " [[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 638 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 639 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 640 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 641 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 642 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 643 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 644 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 645 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 646 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 647 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 648 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 649 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 650 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 651 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 652 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 653 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 654 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 655 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 656 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 657 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 658 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 659 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 660 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 661 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 662 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 663 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 664 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 665 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 666 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 667 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 668 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]\n", 669 | "VAL\n", 670 | " [[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 671 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 672 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 673 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 674 | " [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 675 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 676 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 677 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 678 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 679 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 680 | " [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]\n", 681 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 682 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 683 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 684 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 685 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 686 | " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n", 687 | " [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]\n", 688 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 689 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 690 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 691 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 692 | " [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", 693 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 694 | " [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]\n", 695 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 696 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 697 | " [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]\n", 698 | " [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]\n", 699 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]\n", 700 | " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]\n", 701 | " [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]\n" 702 | ] 703 | } 704 | ], 705 | "source": [ 706 | "# validate tf.data.TextLineDataset() using two different iterator\n", 707 | "# In order to switch between train and validation data\n", 708 | "\n", 709 | "def decode_line(line):\n", 710 | " # Decode the csv_line to tensor.\n", 711 | " record_defaults = [[1.0] for col in range(785)]\n", 712 | " items = tf.decode_csv(line, record_defaults)\n", 713 | " features = items[1:785]\n", 714 | " label = items[0]\n", 715 | "\n", 716 | " features = tf.cast(features, tf.float32)\n", 717 | " features = tf.reshape(features,[28,28])\n", 718 | " label = tf.cast(label, tf.int64)\n", 719 | " label = tf.one_hot(label,num_class)\n", 720 | " return features,label\n", 721 | "\n", 722 | "\n", 723 | "def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):\n", 724 | " \"\"\"create dataset for train and validation dataset\"\"\"\n", 725 | " dataset = tf.data.TextLineDataset(filename).skip(1)\n", 726 | " if n_repeats > 0:\n", 727 | " dataset = dataset.repeat(n_repeats) # for train\n", 728 | " # dataset = dataset.map(decode_line).map(normalize) \n", 729 | " dataset = dataset.map(decode_line) \n", 730 | " # decode and normalize\n", 731 | " if is_shuffle:\n", 732 | " dataset = dataset.shuffle(10000) # shuffle\n", 733 | " dataset = dataset.batch(batch_size)\n", 734 | " return dataset\n", 735 | "\n", 736 | "\n", 737 | "training_filenames = [\"/Users/honglan/Desktop/train.csv\"] \n", 738 | "# replace the filenames with your own path\n", 739 | "validation_filenames = [\"/Users/honglan/Desktop/val.csv\"] \n", 740 | "# replace the filenames with your own path\n", 741 | "\n", 742 | "# Create different datasets\n", 743 | "training_dataset = create_dataset(training_filenames, batch_size=32, \\\n", 744 | " is_shuffle=True, n_repeats=num_epochs) # train_filename\n", 745 | "validation_dataset = create_dataset(validation_filenames, batch_size=32, \\\n", 746 | " is_shuffle=True, n_repeats=num_epochs) # val_filename\n", 747 | "\n", 748 | "# A feedable iterator is defined by a handle placeholder and its structure. We\n", 749 | "# could use the `output_types` and `output_shapes` properties of either\n", 750 | "# `training_dataset` or `validation_dataset` here, because they have\n", 751 | "# identical structure.\n", 752 | "handle = tf.placeholder(tf.string, shape=[])\n", 753 | "iterator = tf.data.Iterator.from_string_handle(\n", 754 | " handle, training_dataset.output_types, training_dataset.output_shapes)\n", 755 | "features, labels = iterator.get_next()\n", 756 | "\n", 757 | "# You can use feedable iterators with a variety of different kinds of iterator\n", 758 | "# (such as one-shot and initializable iterators).\n", 759 | "training_iterator = training_dataset.make_one_shot_iterator()\n", 760 | "validation_iterator = validation_dataset.make_initializable_iterator()\n", 761 | "\n", 762 | "# The `Iterator.string_handle()` method returns a tensor that can be evaluated\n", 763 | "# and used to feed the `handle` placeholder.\n", 764 | "training_handle = sess.run(training_iterator.string_handle())\n", 765 | "validation_handle = sess.run(validation_iterator.string_handle())\n", 766 | "\n", 767 | "# Using different handle to alternate between training and validation.\n", 768 | "print(\"TRAIN\\n\",sess.run(labels, feed_dict={handle: training_handle}))\n", 769 | "# print(sess.run(features))\n", 770 | "\n", 771 | "# Initialize `iterator` with validation data.\n", 772 | "sess.run(validation_iterator.initializer)\n", 773 | "print(\"VAL\\n\",sess.run(labels, feed_dict={handle: validation_handle}))\n" 774 | ] 775 | } 776 | ], 777 | "metadata": { 778 | "kernelspec": { 779 | "display_name": "Python 3", 780 | "language": "python", 781 | "name": "python3" 782 | }, 783 | "language_info": { 784 | "codemirror_mode": { 785 | "name": "ipython", 786 | "version": 3 787 | }, 788 | "file_extension": ".py", 789 | "mimetype": "text/x-python", 790 | "name": "python", 791 | "nbconvert_exporter": "python", 792 | "pygments_lexer": "ipython3", 793 | "version": "3.7.0" 794 | } 795 | }, 796 | "nbformat": 4, 797 | "nbformat_minor": 2 798 | } 799 | --------------------------------------------------------------------------------