├── .gitignore ├── README.md ├── create_sprite.py ├── embedding.npy ├── embedding_viz.py ├── misc ├── label_valid.tsv └── sprite_valid.jpg ├── plain_cnn_completed.py ├── plain_cnn_exercise.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | .idea* 92 | 93 | # data folders 94 | _*/ 95 | data/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-image-workshop 2 | A code for "Deep Learning with TensorFlow Workshop 2 - Image Recognition" 3 | -------------------------------------------------------------------------------- /create_sprite.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.image_ops_impl import ResizeMethod 3 | import pandas as pd 4 | from skimage import io 5 | from utils import * 6 | 7 | W = 72 8 | H = 72 9 | 10 | sess = tf.InteractiveSession() 11 | 12 | # Reads pfathes of images together with their labels 13 | image_list, label_list = read_labeled_image_list('data/test') 14 | 15 | images = tf.convert_to_tensor(image_list, dtype=tf.string) 16 | labels = tf.convert_to_tensor(label_list, dtype=tf.int32) 17 | 18 | # Makes an input queue 19 | input_queue = tf.train.slice_input_producer([images, labels], 20 | shuffle=False, 21 | num_epochs=1, 22 | capacity=len(image_list), 23 | name="file_input_queue") 24 | 25 | image, label = read_images_from_disk(input_queue) 26 | 27 | # resize image 28 | image = tf.image.resize_images(image, (H, W), tf.image.ResizeMethod.NEAREST_NEIGHBOR) 29 | 30 | # Image and Label Batching 31 | image_batch, label_batch = tf.train.batch([image, label], batch_size=len(image_list), capacity=len(image_list), 32 | num_threads=1, name="batch_queue", 33 | allow_smaller_final_batch=True) 34 | 35 | sess.run(tf.local_variables_initializer()) 36 | sess.run(tf.global_variables_initializer()) 37 | coord = tf.train.Coordinator() 38 | threads = tf.train.start_queue_runners(coord=coord) 39 | 40 | image_pix, label_val = sess.run([image_batch, label_batch]) 41 | 42 | import matplotlib.pyplot as plt 43 | import numpy as np 44 | 45 | d = np.ceil(len(image_pix) ** 0.5).astype("int") 46 | result = np.zeros((d * H, d * W, 3), dtype="uint8") 47 | 48 | for idx, p in enumerate(image_pix): 49 | col = (idx % d) * W 50 | row = idx // d * H 51 | result[row:row + H, col:col + W, :] = p 52 | 53 | plt.imshow(result) 54 | 55 | classes_map = {0: 'env', 1: 'food', 2: 'front', 3: 'menu', 4: 'profile'} 56 | 57 | io.imsave(os.getcwd() + '/misc/sprite_valid.jpg', result) 58 | label_df = pd.DataFrame(list(map(lambda x: classes_map[x], label_val)), columns=['label']) 59 | label_df.to_csv(os.getcwd() + '/misc/label_valid.tsv', sep='\t', index=False, header=False) -------------------------------------------------------------------------------- /embedding.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tf-dl-workshop/tf-image-workshop/0acf49dde5e9b94b28fb8bc7fe12ea79f21283ff/embedding.npy -------------------------------------------------------------------------------- /embedding_viz.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib.tensorboard.plugins import projector 2 | import tensorflow as tf 3 | import os 4 | import numpy as np 5 | 6 | features = np.load(os.getcwd() + '/embedding.npy') 7 | 8 | features_tensor = tf.Variable(features, name="features") 9 | 10 | writer = tf.summary.FileWriter(os.getcwd() + '/_emb') 11 | config = projector.ProjectorConfig() 12 | 13 | emb_conf = config.embeddings.add() 14 | emb_conf.tensor_name = features_tensor.name 15 | emb_conf.sprite.image_path = os.getcwd() + '/misc/sprite_valid.jpg' 16 | emb_conf.metadata_path = os.getcwd() + '/misc/label_valid.tsv' 17 | emb_conf.sprite.single_image_dim.extend((72, 72)) 18 | projector.visualize_embeddings(writer, config) 19 | 20 | sess = tf.InteractiveSession() 21 | sess.run(tf.global_variables_initializer()) 22 | writer.add_graph(sess.graph) 23 | saver = tf.train.Saver() 24 | saver.save(sess, os.path.join(os.getcwd() + '/_emb', "_model.ckpt"), 1) 25 | -------------------------------------------------------------------------------- /misc/label_valid.tsv: -------------------------------------------------------------------------------- 1 | env 2 | env 3 | env 4 | env 5 | env 6 | env 7 | env 8 | env 9 | env 10 | env 11 | env 12 | env 13 | env 14 | env 15 | env 16 | env 17 | env 18 | env 19 | env 20 | env 21 | env 22 | env 23 | env 24 | env 25 | env 26 | env 27 | env 28 | env 29 | env 30 | env 31 | env 32 | env 33 | env 34 | env 35 | env 36 | env 37 | env 38 | env 39 | env 40 | env 41 | env 42 | env 43 | env 44 | env 45 | env 46 | env 47 | env 48 | env 49 | env 50 | env 51 | env 52 | env 53 | env 54 | env 55 | env 56 | env 57 | env 58 | env 59 | env 60 | env 61 | env 62 | env 63 | env 64 | env 65 | env 66 | env 67 | env 68 | env 69 | env 70 | env 71 | env 72 | env 73 | env 74 | env 75 | env 76 | env 77 | env 78 | env 79 | env 80 | env 81 | env 82 | env 83 | env 84 | env 85 | env 86 | env 87 | env 88 | env 89 | env 90 | env 91 | env 92 | env 93 | env 94 | env 95 | env 96 | env 97 | env 98 | env 99 | env 100 | env 101 | env 102 | food 103 | food 104 | food 105 | food 106 | food 107 | food 108 | food 109 | food 110 | food 111 | food 112 | food 113 | food 114 | food 115 | food 116 | food 117 | food 118 | food 119 | food 120 | food 121 | food 122 | food 123 | food 124 | food 125 | food 126 | food 127 | food 128 | food 129 | food 130 | food 131 | food 132 | food 133 | food 134 | food 135 | food 136 | food 137 | food 138 | food 139 | food 140 | food 141 | food 142 | food 143 | food 144 | food 145 | food 146 | food 147 | food 148 | food 149 | food 150 | food 151 | food 152 | food 153 | food 154 | food 155 | food 156 | food 157 | food 158 | food 159 | food 160 | food 161 | food 162 | food 163 | food 164 | food 165 | food 166 | food 167 | food 168 | food 169 | food 170 | food 171 | food 172 | food 173 | food 174 | food 175 | food 176 | food 177 | food 178 | food 179 | food 180 | food 181 | food 182 | food 183 | food 184 | food 185 | food 186 | food 187 | food 188 | food 189 | food 190 | food 191 | food 192 | food 193 | food 194 | food 195 | food 196 | food 197 | food 198 | food 199 | food 200 | food 201 | food 202 | food 203 | food 204 | food 205 | food 206 | front 207 | front 208 | front 209 | front 210 | front 211 | front 212 | front 213 | front 214 | front 215 | front 216 | front 217 | front 218 | front 219 | front 220 | front 221 | front 222 | front 223 | front 224 | front 225 | front 226 | front 227 | front 228 | front 229 | front 230 | front 231 | front 232 | front 233 | front 234 | front 235 | front 236 | front 237 | front 238 | front 239 | front 240 | front 241 | front 242 | front 243 | front 244 | front 245 | front 246 | front 247 | front 248 | front 249 | front 250 | front 251 | front 252 | front 253 | front 254 | front 255 | front 256 | front 257 | front 258 | front 259 | front 260 | front 261 | front 262 | front 263 | front 264 | front 265 | front 266 | front 267 | front 268 | front 269 | front 270 | front 271 | front 272 | front 273 | front 274 | front 275 | front 276 | front 277 | front 278 | front 279 | front 280 | front 281 | front 282 | front 283 | front 284 | front 285 | front 286 | front 287 | front 288 | front 289 | front 290 | front 291 | front 292 | front 293 | front 294 | front 295 | front 296 | front 297 | front 298 | front 299 | front 300 | front 301 | front 302 | front 303 | front 304 | front 305 | front 306 | front 307 | menu 308 | menu 309 | menu 310 | menu 311 | menu 312 | menu 313 | menu 314 | menu 315 | menu 316 | menu 317 | menu 318 | menu 319 | menu 320 | menu 321 | menu 322 | menu 323 | menu 324 | menu 325 | menu 326 | menu 327 | menu 328 | menu 329 | menu 330 | menu 331 | menu 332 | menu 333 | menu 334 | menu 335 | menu 336 | menu 337 | menu 338 | menu 339 | menu 340 | menu 341 | menu 342 | menu 343 | menu 344 | menu 345 | menu 346 | menu 347 | menu 348 | menu 349 | menu 350 | menu 351 | menu 352 | menu 353 | menu 354 | menu 355 | menu 356 | menu 357 | menu 358 | menu 359 | menu 360 | menu 361 | menu 362 | menu 363 | menu 364 | menu 365 | menu 366 | menu 367 | menu 368 | menu 369 | menu 370 | menu 371 | menu 372 | menu 373 | menu 374 | menu 375 | menu 376 | menu 377 | menu 378 | menu 379 | menu 380 | menu 381 | menu 382 | menu 383 | menu 384 | menu 385 | menu 386 | menu 387 | menu 388 | menu 389 | menu 390 | menu 391 | menu 392 | menu 393 | menu 394 | menu 395 | menu 396 | menu 397 | menu 398 | menu 399 | menu 400 | menu 401 | menu 402 | menu 403 | menu 404 | menu 405 | menu 406 | menu 407 | menu 408 | menu 409 | profile 410 | profile 411 | profile 412 | profile 413 | profile 414 | profile 415 | profile 416 | profile 417 | profile 418 | profile 419 | profile 420 | profile 421 | profile 422 | profile 423 | profile 424 | profile 425 | profile 426 | profile 427 | profile 428 | profile 429 | profile 430 | profile 431 | profile 432 | profile 433 | profile 434 | profile 435 | profile 436 | profile 437 | profile 438 | profile 439 | profile 440 | profile 441 | profile 442 | profile 443 | profile 444 | profile 445 | profile 446 | profile 447 | profile 448 | profile 449 | profile 450 | profile 451 | profile 452 | profile 453 | profile 454 | profile 455 | profile 456 | profile 457 | profile 458 | profile 459 | profile 460 | profile 461 | profile 462 | profile 463 | profile 464 | profile 465 | profile 466 | profile 467 | profile 468 | profile 469 | profile 470 | profile 471 | profile 472 | profile 473 | profile 474 | profile 475 | profile 476 | profile 477 | profile 478 | profile 479 | profile 480 | profile 481 | profile 482 | profile 483 | profile 484 | profile 485 | profile 486 | profile 487 | profile 488 | profile 489 | profile 490 | profile 491 | profile 492 | profile 493 | profile 494 | profile 495 | profile 496 | profile 497 | profile 498 | profile 499 | profile 500 | profile 501 | profile 502 | profile 503 | profile 504 | profile 505 | profile 506 | profile 507 | profile 508 | profile 509 | profile 510 | profile 511 | profile 512 | profile 513 | -------------------------------------------------------------------------------- /misc/sprite_valid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tf-dl-workshop/tf-image-workshop/0acf49dde5e9b94b28fb8bc7fe12ea79f21283ff/misc/sprite_valid.jpg -------------------------------------------------------------------------------- /plain_cnn_completed.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from tensorflow.contrib import learn 3 | from tensorflow.contrib.learn import * 4 | from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 5 | 6 | tf.logging.set_verbosity(tf.logging.INFO) 7 | 8 | 9 | def cnn_model_fn(features, labels, mode, params): 10 | """ 11 | Model function for CNN 12 | :param features: images features with shape (batch_size, height, width, channels) 13 | :param labels: images category with shape (batch_size) 14 | :param mode: Specifies if this training, evaluation or 15 | prediction. See `model_fn_lib.ModeKey` 16 | :param params: dict of hyperparameters 17 | :return: predictions, loss, train_op, Optional(eval_op). See `model_fn_lib.ModelFnOps` 18 | """ 19 | 20 | # Convolutional Layer #1 21 | conv1 = tf.layers.conv2d( 22 | inputs=features, 23 | filters=32, 24 | kernel_size=[3, 3], 25 | padding="same", 26 | activation=tf.nn.relu) 27 | 28 | # Pooling Layer #1 29 | pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) 30 | 31 | # Convolutional Layer #2 and Pooling Layer #2 32 | conv2 = tf.layers.conv2d( 33 | inputs=pool1, 34 | filters=64, 35 | kernel_size=[3, 3], 36 | padding="same", 37 | activation=tf.nn.relu) 38 | pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) 39 | 40 | # Convolutional Layer #3 and Pooling Layer #3 41 | conv3 = tf.layers.conv2d( 42 | inputs=pool2, 43 | filters=64, 44 | kernel_size=[3, 3], 45 | padding="same", 46 | activation=tf.nn.relu) 47 | pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2) 48 | 49 | # Dense Layer 50 | pool_flat = tf.reshape(pool3, [-1, 32 * 32 * 64]) 51 | dense = tf.layers.dense(inputs=pool_flat, units=512, activation=tf.nn.relu) 52 | dropout = tf.layers.dropout(inputs=dense, rate=params['drop_out_rate'], training=mode == learn.ModeKeys.TRAIN) 53 | 54 | # Logits Layer 55 | logits = tf.layers.dense(inputs=dropout, units=5) 56 | 57 | loss = None 58 | train_op = None 59 | 60 | # Calculate Loss (for both TRAIN and EVAL modes) 61 | if mode != learn.ModeKeys.INFER: 62 | onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=5, name="onehot") 63 | loss = tf.losses.softmax_cross_entropy( 64 | onehot_labels=onehot_labels, logits=logits) 65 | 66 | # Configure the Training Op (for TRAIN mode) 67 | if mode == learn.ModeKeys.TRAIN: 68 | train_op = tf.contrib.layers.optimize_loss( 69 | loss=loss, 70 | global_step=tf.contrib.framework.get_global_step(), 71 | optimizer=tf.train.AdamOptimizer, 72 | learning_rate=params['learning_rate'], 73 | summaries=[ 74 | "learning_rate", 75 | "loss", 76 | "gradients", 77 | "gradient_norm", 78 | ]) 79 | # Generate Predictions 80 | predictions = { 81 | "classes": tf.argmax( 82 | input=logits, axis=1), 83 | "probabilities": tf.nn.softmax( 84 | logits, name="softmax_tensor") 85 | } 86 | 87 | # Return a ModelFnOps object 88 | return model_fn_lib.ModelFnOps( 89 | mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops={'dense': dense}) 90 | 91 | 92 | def feature_engineering_fn(features, labels): 93 | """ 94 | feature_engineering_fn: Feature engineering function. Takes features and 95 | labels which are the output of `input_fn` and 96 | returns features and labels which will be fed 97 | into `model_fn` 98 | """ 99 | 100 | features = tf.to_float(features) 101 | 102 | # Preprocessing or Data Augmentation 103 | # tf.image implements most of the standard image augmentation 104 | 105 | # Example 106 | # Subtract off the mean and divide by the variance of the pixels. 107 | features = tf.map_fn(tf.image.per_image_standardization, features) 108 | 109 | return features, labels 110 | 111 | 112 | if __name__ == '__main__': 113 | params = {'drop_out_rate': 0.2, 'learning_rate': 0.0001} 114 | cnn_classifier = learn.Estimator( 115 | model_fn=cnn_model_fn, model_dir="_model/plain_cnn", 116 | config=RunConfig(save_summary_steps=10, keep_checkpoint_max=2, save_checkpoints_secs=30), 117 | feature_engineering_fn=feature_engineering_fn, params=params) 118 | 119 | # Configure the accuracy metric for evaluation 120 | metrics = { 121 | "accuracy": 122 | learn.MetricSpec( 123 | metric_fn=tf.metrics.accuracy, prediction_key="classes"), 124 | } 125 | 126 | train_input_fn = read_img(data_dir='data/train', batch_size=32, shuffle=True) 127 | monitor_input_fn = read_img(data_dir='data/validate', batch_size=128, shuffle=True) 128 | test_input_fn = read_img(data_dir='data/test', batch_size=512, shuffle=False) 129 | 130 | validation_monitor = monitors.ValidationMonitor(input_fn=monitor_input_fn, 131 | eval_steps=10, 132 | every_n_steps=50, 133 | metrics=metrics, 134 | name='validation') 135 | 136 | cnn_classifier.fit(input_fn=train_input_fn, steps=300, monitors=[validation_monitor]) 137 | 138 | # Evaluate the _model and print results 139 | eval_results = cnn_classifier.evaluate(input_fn=test_input_fn, metrics=metrics, steps=1) 140 | np.save(os.getcwd() + '/embedding.npy', eval_results['dense']) 141 | -------------------------------------------------------------------------------- /plain_cnn_exercise.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import tensorflow as tf\n", 13 | "from utils import *\n", 14 | "from tensorflow.contrib import learn\n", 15 | "from tensorflow.contrib import layers\n", 16 | "from tensorflow.contrib.learn import *\n", 17 | "from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib\n", 18 | "\n", 19 | "tf.logging.set_verbosity(tf.logging.INFO)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## Input function for training\n", 27 | "\n", 28 | "- Get image data from the given directory \n", 29 | "- Put the data into TensorFlow Queue\n", 30 | "- Return (features, label)\n", 31 | " - features: a Tensor with shape (batch_size, height, width, channels) \n", 32 | " - label: a Tensor with shape (batch_size)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": { 39 | "collapsed": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "train_input_fn = read_img(data_dir='data/train', batch_size=32, shuffle=True)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "data": { 53 | "text/plain": [ 54 | "(,\n", 55 | " )" 56 | ] 57 | }, 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | } 62 | ], 63 | "source": [ 64 | "train_input_fn()" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "### Define CNN model function" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": { 78 | "collapsed": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "def cnn_model_fn(features, labels, mode, params):\n", 83 | " \"\"\"\n", 84 | " Model function for CNN\n", 85 | " :param features: images features with shape (batch_size, height, width, channels)\n", 86 | " :param labels: images category with shape (batch_size)\n", 87 | " :param mode: Specifies if this training, evaluation or\n", 88 | " prediction. See `model_fn_lib.ModeKey`\n", 89 | " :param params: dict of hyperparameters\n", 90 | " :return: predictions, loss, train_op, Optional(eval_op). See `model_fn_lib.ModelFnOps`\n", 91 | " \"\"\"\n", 92 | " \n", 93 | " # Convolutional Layer #1\n", 94 | " conv1 = tf.layers.conv2d(\n", 95 | " inputs=features,\n", 96 | " filters=32,\n", 97 | " kernel_size=[3, 3],\n", 98 | " padding=\"same\",\n", 99 | " activation=tf.nn.relu)\n", 100 | "\n", 101 | " # Pooling Layer #1\n", 102 | " pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)\n", 103 | " \n", 104 | " # Convolutional Layer #2 and Pooling Layer #2\n", 105 | " \"\"\"\n", 106 | " Your code here\n", 107 | " \"\"\"\n", 108 | "\n", 109 | " # Convolutional Layer #3 and Pooling Layer #3\n", 110 | " \"\"\"\n", 111 | " Your code here\n", 112 | " \"\"\"\n", 113 | " pool3 = ???\n", 114 | " \n", 115 | " # Dense Layer\n", 116 | " pool_flat = tf.reshape(pool3, [-1, ??? * ??? * ???])\n", 117 | " dense = tf.layers.dense(...)\n", 118 | " dropout = tf.layers.dropout(inputs=dense, rate=params['drop_out_rate']\n", 119 | " , training=mode == learn.ModeKeys.TRAIN)\n", 120 | " \n", 121 | " # Logits Layer, a final layer before applying softmax\n", 122 | " logits = ???\n", 123 | " \n", 124 | " loss = None\n", 125 | " train_op = None\n", 126 | " \n", 127 | " # Calculate Loss (for both TRAIN and EVAL modes)\n", 128 | " if mode != learn.ModeKeys.INFER:\n", 129 | " onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=5, name=\"onehot\")\n", 130 | " #cross entropy loss\n", 131 | " loss = ???\n", 132 | " \n", 133 | " # Configure the Training Op (for TRAIN mode)\n", 134 | " if mode == learn.ModeKeys.TRAIN:\n", 135 | " train_op = tf.contrib.layers.optimize_loss(...)\n", 136 | " \n", 137 | " # Generate Predictions\n", 138 | " predictions = {\n", 139 | " \"classes\": ???,\n", 140 | " \"probabilities\": ???\n", 141 | " }\n", 142 | " \n", 143 | " # Return a ModelFnOps object\n", 144 | " return model_fn_lib.ModelFnOps(mode=mode, \n", 145 | " predictions=predictions, \n", 146 | " loss=loss, \n", 147 | " train_op=train_op)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Define Data Preprocessing" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "collapsed": true 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "def feature_engineering_fn(features, labels):\n", 166 | " \"\"\"\n", 167 | " feature_engineering_fn: Feature engineering function. Takes features and\n", 168 | " labels which are the output of `input_fn` and\n", 169 | " returns features and labels which will be fed\n", 170 | " into `model_fn`\n", 171 | " \"\"\"\n", 172 | " \n", 173 | " features = tf.to_float(features)\n", 174 | " \n", 175 | " # Preprocessing or Data Augmentation\n", 176 | " # tf.image implements most of the standard image augmentation\n", 177 | "\n", 178 | " # Example\n", 179 | " # Subtract off the mean and divide by the variance of the pixels.\n", 180 | " features = tf.map_fn(tf.image.per_image_standardization, features)\n", 181 | " \n", 182 | " return features, labels" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "### Instantiate an Estimator" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": { 196 | "collapsed": true 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "run_config = RunConfig(save_summary_steps=10, keep_checkpoint_max=2, save_checkpoints_secs=30)\n", 201 | "#drop_out_rate = 0.2, learning_rate = 0.0001\n", 202 | "params = ???\n", 203 | "#use \"model/plain_cnn\" as model_dir\n", 204 | "cnn_classifier = learn.Estimator(....)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "collapsed": true 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "# Configure the accuracy metric for evaluation\n", 216 | "metrics = {\n", 217 | " \"accuracy\":\n", 218 | " learn.MetricSpec(\n", 219 | " metric_fn=tf.metrics.accuracy, prediction_key=\"classes\")\n", 220 | "}" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "### Define input function for validation monitor and Instantiate a Validation Monitor" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "collapsed": true 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "#validation data is in 'data/validate' folder, batch size = 128\n", 239 | "validate_input_fn = ???\n", 240 | "validation_monitor = monitors.ValidationMonitor(input_fn=validate_input_fn,\n", 241 | " eval_steps=10,\n", 242 | " every_n_steps=50,\n", 243 | " metrics=metrics,\n", 244 | " name='validation')" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": {}, 250 | "source": [ 251 | "### Start training the model" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": { 258 | "collapsed": true 259 | }, 260 | "outputs": [], 261 | "source": [ 262 | "#use validation monitor defined above to evaluate model every 50 steps\n", 263 | "cnn_classifier.fit(...)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "### Final evaluation on unseen test data set" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "collapsed": true 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "#test data is in 'data/test' folder, batch size = 512\n", 282 | "test_input_fn = ???\n", 283 | "#steps = 1\n", 284 | "cnn_classifier.evaluate(...)" 285 | ] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3.0 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.6.1" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 0 309 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import numpy as np 4 | 5 | 6 | def resnetpreprocessing(x): 7 | x = tf.to_float(x) 8 | # RGB -> BGR and subtract mean image, see 9 | # keras.applications.resnet50 10 | # https://groups.google.com/forum/#!topic/caffe-users/wmOnKLSKfpI 11 | # https://arxiv.org/pdf/1512.03385v1.pdf 12 | # http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf 13 | return x[:, :, :, ::-1] - np.array([103.939, 116.779, 123.68]).reshape(1, 1, 1, -1) 14 | 15 | 16 | def read_labeled_image_list(image_dir): 17 | """ 18 | Returns: 19 | List with all filenames in file image_list_file 20 | """ 21 | dir_list = [x for x in os.walk(os.path.join(os.getcwd(), image_dir))][1:] 22 | 23 | filenames = [] 24 | labels = [] 25 | 26 | for i, d in enumerate(dir_list): 27 | for fname in d[2]: 28 | filename = os.path.join(d[0], fname) 29 | label = i 30 | filenames.append(filename) 31 | labels.append(int(label)) 32 | 33 | return filenames, labels 34 | 35 | 36 | def read_images_from_disk(input_queue): 37 | """Consumes a single filename and label as a ' '-delimited string. 38 | Args: 39 | filename_and_label_tensor: A scalar string tensor. 40 | Returns: 41 | Two tensors: the decoded image, and the string label. 42 | """ 43 | filename = input_queue[0] 44 | label = input_queue[1] 45 | file_contents = tf.read_file(filename) 46 | example = tf.image.decode_image(file_contents, channels=3) 47 | example.set_shape([None, None, 3]) 48 | 49 | return example, label 50 | 51 | 52 | def read_img(data_dir, batch_size, shuffle): 53 | # Reads pfathes of images together with their labels 54 | def input_fn(): 55 | image_list, label_list = read_labeled_image_list(data_dir) 56 | 57 | images = tf.convert_to_tensor(image_list, dtype=tf.string) 58 | labels = tf.convert_to_tensor(label_list, dtype=tf.int32) 59 | 60 | # Makes an input queue 61 | input_queue = tf.train.slice_input_producer([images, labels], 62 | shuffle=shuffle, 63 | capacity=batch_size * 5, 64 | name="file_input_queue") 65 | 66 | image, label = read_images_from_disk(input_queue) 67 | 68 | # resize image 69 | image = tf.image.resize_images(image, (256, 256), tf.image.ResizeMethod.NEAREST_NEIGHBOR) 70 | 71 | # Image and Label Batching 72 | image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, capacity=batch_size * 10, 73 | num_threads=1, name="batch_queue", 74 | allow_smaller_final_batch=True) 75 | 76 | return tf.identity(image_batch, name="features"), tf.identity(label_batch, name="label") 77 | 78 | return input_fn 79 | --------------------------------------------------------------------------------