├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── __init__.py ├── example_project ├── __init__.py ├── convert_to_records.py ├── input_data.py ├── my_cifar.py ├── my_cifar_train.py └── tmp │ ├── ckpt │ └── .gitkeep │ ├── log │ └── .gitkeep │ └── val.txt ├── imageflow ├── __init__.py ├── convert_to_records.py ├── imageflow.py ├── playground.py ├── reader.py └── utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | 7 | # Python egg metadata, regenerated from source files by setuptools. 8 | /*.egg-info 9 | .idea 10 | build 11 | docs 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Notice - This version of imageflow is no longer under maintenance and major update is required. 2 | **The tensorflow version is too old and the library is not working as expected. You are welcome to add your use-cases in the Issues as Feature request to be considered in the new versions. Sorry for the inconvenience.** 3 | 4 | 5 | # ImageFlow 6 | A simple wrapper of TensorFlow for Converting, Importing (and Soon, Training) Images in tensorflow. 7 | 8 | Installation: 9 | ``` 10 | pip install imageflow 11 | ``` 12 | 13 | Usage: 14 | 15 | ```python 16 | import imageflow 17 | ``` 18 | 19 | #### Convert a directory of images and their labels to `.tfrecords` 20 | Just calling the following function will make a `filename.tfrecords` file in the directory `converted_data` in your projects root(where you call this method). 21 | 22 | ```python 23 | convert_images(images, labels, filename) 24 | ``` 25 | 26 | The `images` should be an array of shape `[-1, height, width, channel]` and has the same rows as the `labels` 27 | 28 | #### Read distorted and normal data from `.tfrecords` in multi-thread manner: 29 | ```python 30 | # Distorted images for training 31 | images, labels = distorted_inputs(filename='../my_data_raw/train.tfrecords', batch_size=FLAGS.batch_size, 32 | num_epochs=FLAGS.num_epochs, 33 | num_threads=5, imshape=[32, 32, 3], imsize=32) 34 | 35 | # Normal images for validation 36 | val_images, val_labels = inputs(filename='../my_data_raw/validation.tfrecords', batch_size=FLAGS.batch_size, 37 | num_epochs=FLAGS.num_epochs, 38 | num_threads=5, imshape=[32, 32, 3]) 39 | ``` 40 | 41 | 42 | Dependencies: 43 | 44 | * TensorFlow ( => version 0.7.0) 45 | * Numpy 46 | * Pillow -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Hamed MP. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | """ 18 | Simple library to read all PNG and JPG/JPEG images in a directory 19 | with TensorFlow buil-in functions in multi-thread way to boost speed. 20 | 21 | Supported formats by TensorFlow are: PNG, JPG/JPEG 22 | 23 | Hamed MP 24 | Github: @hamedmp 25 | Twitter: @TheHamedMP 26 | """ 27 | 28 | from imageflow import read_and_decode 29 | from imageflow import inputs 30 | from imageflow import distorted_inputs 31 | 32 | 33 | __author__ = 'HANEL' 34 | -------------------------------------------------------------------------------- /example_project/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'HANEL' 2 | -------------------------------------------------------------------------------- /example_project/convert_to_records.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Converts CIFAR10 png to TFRecords file format with Example protos.""" 17 | 18 | import os 19 | 20 | import tensorflow as tf 21 | 22 | import input_data 23 | 24 | 25 | 26 | tf.app.flags.DEFINE_string('directory', 'ckpt', 27 | 'Directory to download ckpt files and write the ' 28 | 'converted result') 29 | tf.app.flags.DEFINE_integer('validation_size', 10000, 30 | 'Number of examples to separate from the training ' 31 | 'ckpt for the validation set.') 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | 35 | def _int64_feature(value): 36 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 37 | 38 | 39 | def _bytes_feature(value): 40 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 41 | 42 | 43 | def convert_to(images, labels, name): 44 | num_examples = labels.shape[0] 45 | print('labels shape is ', labels.shape[0]) 46 | if images.shape[0] != num_examples: 47 | raise ValueError("Images size %d does not match label size %d." % 48 | (images.shape[0], num_examples)) 49 | rows = images.shape[1] 50 | cols = images.shape[2] 51 | depth = images.shape[3] 52 | 53 | filename = os.path.join(FLAGS.directory, name + '.tfrecords') 54 | print('Writing', filename) 55 | writer = tf.python_io.TFRecordWriter(filename) 56 | for index in range(num_examples): 57 | image_raw = images[index].tostring() 58 | example = tf.train.Example(features=tf.train.Features(feature={ 59 | 'height': _int64_feature(rows), 60 | 'width': _int64_feature(cols), 61 | 'depth': _int64_feature(depth), 62 | 'label': _int64_feature(int(labels[index])), 63 | 'image_raw': _bytes_feature(image_raw)})) 64 | writer.write(example.SerializeToString()) 65 | 66 | 67 | def main(argv): 68 | 69 | # Extract it into numpy arrays. 70 | train_images = input_data.read_images_from() 71 | train_labels = input_data.read_labels_from() 72 | # test_images = input_data.extract_images(test_images_filename) 73 | # test_labels = input_data.extract_labels(test_labels_filename) 74 | 75 | print(train_images.shape) 76 | 77 | # Generate a validation set. 78 | validation_images = train_images[:FLAGS.validation_size, :, :, :] 79 | validation_labels = train_labels[:FLAGS.validation_size] 80 | train_images = train_images[FLAGS.validation_size:, :, :, :] 81 | train_labels = train_labels[FLAGS.validation_size:] 82 | # TODO: create test.tfrecords to run tests after training 83 | 84 | # Convert to Examples and write the result to TFRecords. 85 | convert_to(train_images, train_labels, 'train') 86 | convert_to(validation_images, validation_labels, 'validation') 87 | # convert_to(test_images, test_labels, 'test') 88 | 89 | 90 | if __name__ == '__main__': 91 | tf.app.run() 92 | -------------------------------------------------------------------------------- /example_project/input_data.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ============================================================================== 13 | 14 | """All I/O jobs, i.e. reading raw images, reading .tfrecords are done here""" 15 | 16 | __author__ = 'HANEL' 17 | 18 | import csv 19 | import glob 20 | import os 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from PIL import Image 25 | 26 | import my_cifar 27 | 28 | Data_PATH = '../../mcifar_data/' 29 | 30 | # Parameters 31 | num_classes = 10 32 | IMAGE_SIZE = 32 33 | IMAGE_SHAPE = [IMAGE_SIZE, IMAGE_SIZE, 3] 34 | 35 | # Basic model parameters as external flags. 36 | flags = tf.app.flags 37 | FLAGS = flags.FLAGS 38 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 39 | flags.DEFINE_integer('num_epochs', 50000, 'Number of epochs to run trainer.') 40 | flags.DEFINE_integer('batch_size', 128, 'Batch size.') 41 | flags.DEFINE_string('train_dir', '../my_data_raw', 'Directory with the training ckpt.') 42 | 43 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 40000 44 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 45 | 46 | # Constants used for dealing with the files, matches convert_to_records. 47 | TRAIN_FILE = 'train.tfrecords' 48 | VALIDATION_FILE = 'validation.tfrecords' 49 | 50 | 51 | def _dense_to_one_hot(labels_dense, num_classes): 52 | """Convert class labels from scalars to one-hot vectors.""" 53 | num_labels = labels_dense.shape[0] 54 | index_offset = np.arange(num_labels) * num_classes 55 | labels_one_hot = np.zeros((num_labels, num_classes)) 56 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 57 | print(labels_one_hot[0]) 58 | return labels_one_hot 59 | 60 | 61 | def _label_to_int(labels): 62 | categories = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 63 | new_labels = [] 64 | 65 | for label in labels: 66 | new_labels.append(categories.index(label[1])) 67 | return new_labels 68 | 69 | 70 | '''Read Images and Labels normally, with python ''' 71 | 72 | 73 | def read_labels_from(path=Data_PATH): 74 | print('Reading labels') 75 | with open(os.path.join(path, 'trainLabels.csv'), 'r') as dest_f: 76 | data_iter = csv.reader(dest_f) 77 | train_labels = [data for data in data_iter] 78 | 79 | # pre process labels to int 80 | train_labels = _label_to_int(train_labels) 81 | train_labels = np.array(train_labels, dtype=np.uint32) 82 | 83 | return train_labels 84 | 85 | 86 | def read_images_from(path=Data_PATH): 87 | images = [] 88 | png_files_path = glob.glob(os.path.join(path, 'train/', '*.[pP][nN][gG]')) 89 | for filename in png_files_path: 90 | im = Image.open(filename) # .convert("L") # Convert to greyscale 91 | im = np.asarray(im, np.uint8) 92 | # print(type(im)) 93 | # get only images name, not path 94 | image_name = filename.split('/')[-1].split('.')[0] 95 | images.append([int(image_name), im]) 96 | 97 | images = sorted(images, key=lambda image: image[0]) 98 | 99 | images_only = [np.asarray(image[1], np.uint8) for image in images] # Use unint8 or you will be !!! 100 | images_only = np.array(images_only) 101 | 102 | print(images_only.shape) 103 | return images_only 104 | 105 | 106 | ''' Decode TFRecords ''' 107 | 108 | 109 | def read_and_decode(filename_queue): 110 | reader = tf.TFRecordReader() 111 | _, serialized_example = reader.read(filename_queue) 112 | features = tf.parse_single_example(serialized_example, 113 | dense_keys=['image_raw', 'label'], 114 | # Defaults are not specified since both keys are required. 115 | dense_types=[tf.string, tf.int64]) 116 | 117 | # Convert from a scalar string tensor (whose single string has 118 | # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 119 | # [mnist.IMAGE_PIXELS]. 120 | image = tf.decode_raw(features['image_raw'], tf.uint8) 121 | 122 | image = tf.reshape(image, [my_cifar.n_input]) 123 | image.set_shape([my_cifar.n_input]) 124 | 125 | 126 | # # Convert label from a scalar uint8 tensor to an int32 scalar. 127 | label = tf.cast(features['label'], tf.int32) 128 | 129 | return image, label 130 | 131 | 132 | def inputs(train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs): 133 | """Reads input.py ckpt num_epochs times. 134 | Args: 135 | train: Selects between the training (True) and validation (False) ckpt. 136 | batch_size: Number of examples per returned batch. 137 | num_epochs: Number of times to read the input.py ckpt, or 0/None to 138 | train forever. 139 | Returns: 140 | A tuple (images, labels), where: 141 | * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] 142 | in the range [-0.5, 0.5]. 143 | * labels is an int32 tensor with shape [batch_size] with the true label, 144 | a number in the range [0, mnist.NUM_CLASSES). 145 | Note that an tf.train.QueueRunner is added to the graph, which 146 | must be run using e.g. tf.train.start_queue_runners(). 147 | """ 148 | if not num_epochs: 149 | num_epochs = None 150 | filename = os.path.join(FLAGS.train_dir, 151 | TRAIN_FILE if train else VALIDATION_FILE) 152 | 153 | with tf.name_scope('input.py'): 154 | filename_queue = tf.train.string_input_producer( 155 | [filename], num_epochs=num_epochs, name='string_input_producer') 156 | 157 | # Even when reading in multiple threads, share the filename 158 | # queue. 159 | image, label = read_and_decode(filename_queue) 160 | 161 | # Convert from [0, 255] -> [-0.5, 0.5] floats. 162 | image = tf.cast(image, tf.float32) 163 | image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 164 | 165 | print('1- image shape is ', image.get_shape()) 166 | 167 | 168 | # Shuffle the examples and collect them into batch_size batches. 169 | # (Internally uses a RandomShuffleQueue.) 170 | # We run this in two threads to avoid being a bottleneck. 171 | # Ensure that the random shuffling has good mixing properties. 172 | min_fraction_of_examples_in_queue = 0.4 173 | min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 174 | min_fraction_of_examples_in_queue) 175 | 176 | images, sparse_labels = tf.train.shuffle_batch( 177 | [image, label], batch_size=batch_size, num_threads=5, 178 | capacity=min_queue_examples + 3 * batch_size, enqueue_many=False, 179 | # Ensures a minimum amount of shuffling of examples. 180 | min_after_dequeue=min_queue_examples, name='batching_shuffling') 181 | print('1.1- label batch shape is ', sparse_labels.get_shape()) 182 | 183 | return images, sparse_labels 184 | 185 | 186 | def distorted_inputs(batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs): 187 | """Construct distorted input.py for CIFAR training using the Reader ops. 188 | 189 | Raises: 190 | ValueError: if no data_dir 191 | 192 | Returns: 193 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 194 | labels: Labels. 1D tensor of [batch_size] size. 195 | """ 196 | if not num_epochs: 197 | num_epochs = None 198 | filename = os.path.join(FLAGS.train_dir, 199 | TRAIN_FILE) 200 | 201 | with tf.name_scope('input.py'): 202 | filename_queue = tf.train.string_input_producer( 203 | [filename], num_epochs=num_epochs, name='string_DISTORTED_input_producer') 204 | 205 | # Even when reading in multiple threads, share the filename 206 | # queue. 207 | image, label = read_and_decode(filename_queue) 208 | 209 | # Reshape to [32, 32, 3] as distortion methods need this shape 210 | image = tf.reshape(image, IMAGE_SHAPE) 211 | # image.set_shape(IMAGE_SHAPE) 212 | 213 | height = IMAGE_SIZE 214 | width = IMAGE_SIZE 215 | 216 | # Image processing for training the network. Note the many random 217 | # distortions applied to the image. 218 | 219 | # Randomly crop a [height, width] section of the image. 220 | distorted_image = tf.image.random_crop(image, [height, width]) 221 | # 222 | # Randomly flip the image horizontally. 223 | distorted_image = tf.image.random_flip_left_right(distorted_image) 224 | # 225 | # Because these operations are not commutative, consider randomizing 226 | # randomize the order their operation. 227 | distorted_image = tf.image.random_brightness(distorted_image, 228 | max_delta=63) 229 | distorted_image = tf.image.random_contrast(distorted_image, 230 | lower=0.2, upper=1.8) 231 | 232 | # # Subtract off the mean and divide by the variance of the pixels. 233 | float_image = tf.image.per_image_whitening(distorted_image) 234 | 235 | # Reshape back to original placeholder shape and other architecture 236 | image = tf.reshape(float_image, [my_cifar.n_input]) 237 | # image = tf.reshape(image, [my_cifar.n_input]) 238 | # image.set_shape([my_cifar.n_input]) 239 | 240 | # Ensure that the random shuffling has good mixing properties. 241 | min_fraction_of_examples_in_queue = 0.4 242 | min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 243 | min_fraction_of_examples_in_queue) 244 | images, sparse_labels = tf.train.shuffle_batch([image, label], 245 | batch_size=batch_size, 246 | num_threads=5, 247 | capacity=min_queue_examples + 3 * batch_size, 248 | enqueue_many=False, 249 | # Ensures a minimum amount of shuffling of examples. 250 | min_after_dequeue=min_queue_examples, 251 | name='batching_shuffling_distortion') 252 | 253 | return images, sparse_labels 254 | 255 | 256 | def main(argv=None): 257 | return 0 258 | 259 | 260 | if __name__ == '__main__': 261 | tf.app.run() 262 | -------------------------------------------------------------------------------- /example_project/my_cifar.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ============================================================================== 13 | 14 | """The network model""" 15 | 16 | __author__ = 'HANEL' 17 | 18 | import tensorflow as tf 19 | 20 | # Data 21 | Data_PATH = '../../mcifar_data/' 22 | 23 | 24 | # Network Parameters 25 | n_input = 32 * 32 * 3 # Cifar ckpt input.py (img shape: 32*32) 26 | 27 | out_conv_1 = 64 28 | out_conv_2 = 64 29 | 30 | n_hidden_1 = 384 31 | n_hidden_2 = 192 32 | 33 | dropout = 0.90 # Dropout, probability to keep units 34 | 35 | # Global constants describing the CIFAR-10 36 | NUM_CLASSES = 10 # Cifar10 total classes (0-9 digits) 37 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 40000 38 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 39 | 40 | # Constants describing the training process. 41 | NUM_EPOCHS_PER_DECAY = 10.0 # Epochs after which learning rate decays. 42 | LEARNING_RATE_DECAY_FACTOR = 0.60 # Learning rate decay factor. 43 | INITIAL_LEARNING_RATE = 0.001 # Initial learning rate. 44 | 45 | FLAGS = tf.app.flags.FLAGS 46 | 47 | 48 | # Create model 49 | def conv2d(img, w, b): 50 | return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(img, w, strides=[1, 1, 1, 1], padding='SAME'), b)) 51 | 52 | 53 | def max_pool(img, k): 54 | return tf.nn.max_pool(img, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME') 55 | 56 | 57 | def inference(images): 58 | """Build the CIFAR model up to where it may be used for inference. 59 | Args: 60 | 61 | Returns: 62 | logits: Output tensor with the computed logits. 63 | """ 64 | 65 | # Reshape input.py picture 66 | print('In Inference ', images.get_shape(), type(images)) 67 | 68 | images = tf.reshape(images, shape=[-1, 32, 32, 3]) 69 | 70 | _dropout = tf.Variable(dropout) # dropout (keep probability) 71 | 72 | # Store layers weight & bias 73 | _weights = { 74 | 'wc1': tf.Variable(tf.random_normal([5, 5, 3, out_conv_1], stddev=1e-3)), # 5x5 conv, 3 input.py, 64 outputs 75 | 'wc2': tf.Variable(tf.random_normal([5, 5, out_conv_1, out_conv_2], stddev=1e-3)), 76 | # 5x5 conv, 64 inputs, 64 outputs 77 | 'wd1': tf.Variable(tf.random_normal([out_conv_2 * 8 * 8, n_hidden_1], stddev=1e-3)), 78 | 'wd2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2], stddev=1e-3)), 79 | 'out': tf.Variable(tf.random_normal([n_hidden_2, NUM_CLASSES], stddev=1e-3)) 80 | } 81 | 82 | _biases = { 83 | 'bc1': tf.Variable(tf.random_normal([out_conv_1])), 84 | 'bc2': tf.Variable(tf.random_normal([out_conv_2])), 85 | 'bd1': tf.Variable(tf.random_normal([n_hidden_1])), 86 | 'bd2': tf.Variable(tf.random_normal([n_hidden_2])), 87 | 'out': tf.Variable(tf.random_normal([NUM_CLASSES])) 88 | } 89 | 90 | # Convolution Layer 1 91 | with tf.name_scope('Conv1'): 92 | conv1 = conv2d(images, _weights['wc1'], _biases['bc1']) 93 | # Max Pooling (down-sampling) 94 | conv1 = max_pool(conv1, k=2) 95 | # norm1 96 | conv1 = tf.nn.lrn(conv1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 97 | name='norm1') 98 | # Apply Dropout 99 | conv1 = tf.nn.dropout(conv1, _dropout) 100 | 101 | # Convolution Layer 2 102 | with tf.name_scope('Conv2'): 103 | conv2 = conv2d(conv1, _weights['wc2'], _biases['bc2']) 104 | # norm2 105 | conv2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 106 | name='norm2') 107 | # # Max Pooling (down-sampling) 108 | conv2 = max_pool(conv2, k=2) 109 | # Apply Dropout 110 | conv2 = tf.nn.dropout(conv2, _dropout) 111 | 112 | # Fully connected layer 1 113 | with tf.name_scope('Dense1'): 114 | dense1 = tf.reshape(conv2, 115 | [-1, _weights['wd1'].get_shape().as_list()[0]]) # Reshape conv2 output to fit dense layer input.py 116 | dense1 = tf.nn.relu_layer(dense1, _weights['wd1'], _biases['bd1']) # Relu activation 117 | dense1 = tf.nn.dropout(dense1, _dropout) # Apply Dropout 118 | 119 | # Fully connected layer 2 120 | with tf.name_scope('Dense2'): 121 | dense2 = tf.nn.relu_layer(dense1, _weights['wd2'], _biases['bd2']) # Relu activation 122 | 123 | # Output, class prediction 124 | logits = tf.add(tf.matmul(dense2, _weights['out']), _biases['out']) 125 | 126 | return logits 127 | 128 | 129 | def loss(logits, labels): 130 | """Add L2Loss to all the trainable variables. 131 | 132 | Add summary for for "Loss" and "Loss/avg". 133 | 134 | Args: 135 | logits: Logits from inference(). 136 | labels: Labels from distorted_inputs or inputs(). 1-D tensor 137 | of shape [batch_size] 138 | 139 | Returns: 140 | Loss tensor of type float. 141 | """ 142 | # Reshape the labels into a dense Tensor of 143 | # shape [batch_size, NUM_CLASSES]. 144 | sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1]) 145 | indices = tf.reshape(tf.range(0, FLAGS.batch_size), [FLAGS.batch_size, 1]) 146 | concated = tf.concat(1, [indices, sparse_labels]) 147 | dense_labels = tf.sparse_to_dense(concated, 148 | [FLAGS.batch_size, NUM_CLASSES], 149 | 1.0, 0.0) 150 | 151 | # Calculate the average cross entropy loss across the batch. 152 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 153 | logits, dense_labels, name='cross_entropy_per_example') 154 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 155 | tf.add_to_collection('losses', cross_entropy_mean) 156 | 157 | # The total loss is defined as the cross entropy loss plus all of the weight 158 | # decay terms (L2 loss). 159 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 160 | 161 | 162 | def training(loss, global_step): 163 | """Sets up the training Ops. 164 | Creates a summarizer to track the loss over time in TensorBoard. 165 | Creates an optimizer and applies the gradients to all trainable variables. 166 | The Op returned by this function is what must be passed to the 167 | `sess.run()` call to cause the model to train. 168 | Args: 169 | loss: Loss tensor, from loss(). 170 | learning_rate: The learning rate to use for gradient descent. 171 | Returns: 172 | train_op: The Op for training. 173 | """ 174 | # Variables that affect learning rate. 175 | num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size 176 | decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) 177 | 178 | print('Decay steps is: ', decay_steps) 179 | # Decay the learning rate exponentially based on the number of steps. 180 | lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, 181 | global_step, 182 | decay_steps, 183 | LEARNING_RATE_DECAY_FACTOR, 184 | staircase=True) 185 | tf.scalar_summary('learning_rate', lr) 186 | # Add a scalar summary for the snapshot loss. 187 | tf.scalar_summary(loss.op.name, loss) 188 | 189 | # Create the adam or gradient descent optimizer with the given learning rate. 190 | optimizer = tf.train.AdamOptimizer(lr) 191 | # optimizer = tf.train.GradientDescentOptimizer(lr) 192 | 193 | # Use the optimizer to apply the gradients that minimize the loss 194 | # (and also increment the global step counter) as a single training step. 195 | train_op = optimizer.minimize(loss, global_step=global_step) 196 | 197 | return train_op 198 | 199 | 200 | def evaluation(logits, labels): 201 | """Evaluate the quality of the logits at predicting the label. 202 | Args: 203 | logits: Logits tensor, float - [batch_size, NUM_CLASSES]. 204 | labels: Labels tensor, int32 - [batch_size], with values in the 205 | range [0, NUM_CLASSES). 206 | Returns: 207 | A scalar int32 tensor with the number of examples (out of batch_size) 208 | that were predicted correctly. 209 | """ 210 | print('Evaluation..') 211 | # For a classifier model, we can use the in_top_k Op. 212 | # It returns a bool tensor with shape [batch_size] that is true for 213 | # the examples where the label's is was in the top k (here k=1) 214 | # of all logits for that example. 215 | correct = tf.nn.in_top_k(logits, labels, 1) 216 | num_correct = tf.reduce_sum(tf.cast(correct, tf.float32)) 217 | 218 | acc_percent = num_correct / FLAGS.batch_size 219 | 220 | # Return the number of true entries. 221 | return acc_percent * 100.0, num_correct 222 | 223 | 224 | def main(argv=None): 225 | return 0 226 | 227 | 228 | if __name__ == '__main__': 229 | tf.app.run() 230 | -------------------------------------------------------------------------------- /example_project/my_cifar_train.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # ============================================================================== 13 | 14 | """Training the model""" 15 | 16 | __author__ = 'HANEL' 17 | 18 | import os.path 19 | import time 20 | from datetime import datetime 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from six.moves import xrange 25 | 26 | import imageflow 27 | # from imageflow import inputs 28 | # from imageflow import distorted_inputs 29 | import my_cifar 30 | 31 | 32 | 33 | # Basic model parameters as external flags. 34 | flags = tf.app.flags 35 | FLAGS = flags.FLAGS 36 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 37 | flags.DEFINE_integer('num_epochs', 50000, 'Number of epochs to run trainer.') 38 | flags.DEFINE_integer('batch_size', 128, 'Batch size.') 39 | 40 | 41 | FLAGS = tf.app.flags.FLAGS 42 | 43 | tf.app.flags.DEFINE_string('model_dir', 'tmp/my-model', 44 | """Directory where to write model proto """ 45 | """ to import in c++""") 46 | tf.app.flags.DEFINE_string('train_dirr', 'tmp/log', 47 | """Directory where to write event logs """ 48 | """and checkpoint.""") 49 | tf.app.flags.DEFINE_integer('max_steps', 100000, 50 | """Number of batches to run.""") 51 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 52 | """Whether to log device placement.""") 53 | tf.app.flags.DEFINE_string('eval_dir', 'tmp/log_eval', 54 | """Directory where to write event logs.""") 55 | tf.app.flags.DEFINE_string('eval_data', 'test', 56 | """Either 'test' or 'train_eval'.""") 57 | tf.app.flags.DEFINE_string('checkpoint_dir', 'tmp/ckpt', 58 | """Directory where to read model checkpoints.""") 59 | 60 | # Parameters 61 | display_step = 1 62 | val_step = 5 63 | save_step = 50 64 | IMAGE_PIXELS = 32 * 32 * 3 65 | NEW_LINE = '\n' 66 | 67 | 68 | def placeholder_inputs(batch_size): 69 | """Generate placeholder variables to represent the the input.py tensors. 70 | These placeholders are used as inputs by the rest of the model building 71 | code and will be fed from the downloaded ckpt in the .run() loop, below. 72 | Args: 73 | batch_size: The batch size will be baked into both placeholders. 74 | Returns: 75 | images_placeholder: Images placeholder. 76 | labels_placeholder: Labels placeholder. 77 | """ 78 | # Note that the shapes of the placeholders match the shapes of the full 79 | # image and label tensors, except the first dimension is now batch_size 80 | # rather than the full size of the train or test ckpt sets. 81 | # batch_size = -1 82 | images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 83 | IMAGE_PIXELS)) 84 | # 32, 32, 3)) 85 | labels_placeholder = tf.placeholder(tf.int32, shape=batch_size) 86 | 87 | return images_placeholder, labels_placeholder 88 | 89 | 90 | def train(re_train=True): 91 | """Train CIFAR-10 for a number of steps.""" 92 | with tf.Graph().as_default(): 93 | global_step = tf.Variable(0, trainable=False) 94 | 95 | images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size) 96 | 97 | # Get images and labels for CIFAR-10. 98 | # images, labels = my_input.inputs() 99 | images, labels = imageflow.distorted_inputs(filename='../my_data_raw/train.tfrecords', batch_size=FLAGS.batch_size, 100 | num_epochs=FLAGS.num_epochs, 101 | num_threads=5, imshape=[32, 32, 3], imsize=32) 102 | val_images, val_labels = imageflow.inputs(filename='../my_data_raw/validation.tfrecords', batch_size=FLAGS.batch_size, 103 | num_epochs=FLAGS.num_epochs, 104 | num_threads=5, imshape=[32, 32, 3]) 105 | 106 | print (images.get_shape(), val_images.get_shape()) 107 | # Build a Graph that computes the logits predictions from the inference model. 108 | logits = my_cifar.inference(images_placeholder) 109 | 110 | # Calculate loss. 111 | loss = my_cifar.loss(logits, labels_placeholder) 112 | 113 | # Build a Graph that trains the model with one batch of examples and 114 | # updates the model parameters. 115 | train_op = my_cifar.training(loss, global_step) 116 | 117 | # Calculate accuracy # 118 | acc, n_correct = my_cifar.evaluation(logits, labels_placeholder) 119 | 120 | # Create a saver. 121 | saver = tf.train.Saver() 122 | 123 | tf.scalar_summary('Acc', acc) 124 | # tf.scalar_summary('Val Acc', acc_val) 125 | tf.scalar_summary('Loss', loss) 126 | tf.image_summary('Images', tf.reshape(images, shape=[-1, 32, 32, 3]), max_images=10) 127 | tf.image_summary('Val Images', tf.reshape(val_images, shape=[-1, 32, 32, 3]), max_images=10) 128 | 129 | # Build the summary operation based on the TF collection of Summaries. 130 | summary_op = tf.merge_all_summaries() 131 | 132 | # Build an initialization operation to run below. 133 | init = tf.initialize_all_variables() 134 | 135 | # Start running operations on the Graph. 136 | # NUM_CORES = 2 # Choose how many cores to use. 137 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement, )) 138 | # inter_op_parallelism_threads=NUM_CORES, 139 | # intra_op_parallelism_threads=NUM_CORES)) 140 | sess.run(init) 141 | 142 | # Write all terminal output results here 143 | val_f = open("tmp/val.txt", "ab") 144 | 145 | # Start the queue runners. 146 | coord = tf.train.Coordinator() 147 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 148 | 149 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dirr, 150 | graph=sess.graph) 151 | 152 | if re_train: 153 | 154 | # Export graph to import it later in c++ 155 | # tf.train.write_graph(sess.graph, FLAGS.model_dir, 'train.pbtxt') # TODO: uncomment to get graph and use in c++ 156 | 157 | continue_from_pre = False 158 | 159 | if continue_from_pre: 160 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir) 161 | print ckpt.model_checkpoint_path 162 | if ckpt and ckpt.model_checkpoint_path: 163 | saver.restore(sess, ckpt.model_checkpoint_path) 164 | print('Session Restored!') 165 | 166 | try: 167 | while not coord.should_stop(): 168 | 169 | for step in xrange(FLAGS.max_steps): 170 | 171 | images_r, labels_r = sess.run([images, labels]) 172 | images_val_r, labels_val_r = sess.run([val_images, val_labels]) 173 | 174 | train_feed = {images_placeholder: images_r, 175 | labels_placeholder: labels_r} 176 | 177 | val_feed = {images_placeholder: images_val_r, 178 | labels_placeholder: labels_val_r} 179 | 180 | start_time = time.time() 181 | 182 | _, loss_value = sess.run([train_op, loss], feed_dict=train_feed) 183 | duration = time.time() - start_time 184 | 185 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN' 186 | 187 | if step % display_step == 0: 188 | num_examples_per_step = FLAGS.batch_size 189 | examples_per_sec = num_examples_per_step / duration 190 | sec_per_batch = float(duration) 191 | 192 | format_str = ('%s: step %d, loss = %.6f (%.1f examples/sec; %.3f ' 193 | 'sec/batch)') 194 | print_str_loss = format_str % (datetime.now(), step, loss_value, 195 | examples_per_sec, sec_per_batch) 196 | print (print_str_loss) 197 | val_f.write(print_str_loss + NEW_LINE) 198 | summary_str = sess.run([summary_op], feed_dict=train_feed) 199 | summary_writer.add_summary(summary_str[0], step) 200 | 201 | if step % val_step == 0: 202 | acc_value, num_corroect = sess.run([acc, n_correct], feed_dict=train_feed) 203 | 204 | format_str = '%s: step %d, train acc = %.2f, n_correct= %d' 205 | print_str_train = format_str % (datetime.now(), step, acc_value, num_corroect) 206 | val_f.write(print_str_train + NEW_LINE) 207 | print (print_str_train) 208 | 209 | # Save the model checkpoint periodically. 210 | if step % save_step == 0 or (step + 1) == FLAGS.max_steps: 211 | val_acc_r, val_n_correct_r = sess.run([acc, n_correct], feed_dict=val_feed) 212 | 213 | frmt_str = ' step %d, Val Acc = %.2f, num correct = %d' 214 | print_str_val = frmt_str % (step, val_acc_r, val_n_correct_r) 215 | val_f.write(print_str_val + NEW_LINE) 216 | print(print_str_val) 217 | 218 | checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt') 219 | saver.save(sess, checkpoint_path, global_step=step) 220 | 221 | 222 | except tf.errors.OutOfRangeError: 223 | print ('Done training -- epoch limit reached') 224 | 225 | finally: 226 | # When done, ask the threads to stop. 227 | val_f.write(NEW_LINE + 228 | NEW_LINE + 229 | '############################ FINISHED ############################' + 230 | NEW_LINE) 231 | val_f.close() 232 | coord.request_stop() 233 | 234 | # Wait for threads to finish. 235 | coord.join(threads) 236 | sess.close() 237 | 238 | else: 239 | 240 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir=FLAGS.checkpoint_dir) 241 | print ckpt.model_checkpoint_path 242 | if ckpt and ckpt.model_checkpoint_path: 243 | saver.restore(sess, ckpt.model_checkpoint_path) 244 | print('Restored!') 245 | 246 | for i in range(100): 247 | images_val_r, labels_val_r = sess.run([val_images, val_labels]) 248 | val_feed = {images_placeholder: images_val_r, 249 | labels_placeholder: labels_val_r} 250 | 251 | tf.scalar_summary('Acc', acc) 252 | 253 | print('Calculating Acc: ') 254 | 255 | acc_r = sess.run(acc, feed_dict=val_feed) 256 | print(acc_r) 257 | 258 | coord.join(threads) 259 | sess.close() 260 | 261 | 262 | def main(argv=None): 263 | train() 264 | 265 | 266 | if __name__ == '__main__': 267 | tf.app.run() 268 | -------------------------------------------------------------------------------- /example_project/tmp/ckpt/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamedMP/ImageFlow/e33ff64e296aef302cfa06d2179dc213ade68a80/example_project/tmp/ckpt/.gitkeep -------------------------------------------------------------------------------- /example_project/tmp/log/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamedMP/ImageFlow/e33ff64e296aef302cfa06d2179dc213ade68a80/example_project/tmp/log/.gitkeep -------------------------------------------------------------------------------- /example_project/tmp/val.txt: -------------------------------------------------------------------------------- 1 | 2016-01-13 03:13:31.320568: step 0, loss = 2.897762 (68.1 examples/sec; 1.879 sec/batch) 2 | 2016-01-13 03:13:32.302118: step 0, train acc = 6.25, n_correct= 8 3 | step 0, Val Acc = 13.28, num correct = 17 4 | 2016-01-13 03:13:34.340002: step 1, loss = 2.564887 (89.6 examples/sec; 1.428 sec/batch) 5 | 2016-01-13 03:13:36.276101: step 2, loss = 2.609227 (89.9 examples/sec; 1.424 sec/batch) 6 | 2016-01-13 03:13:38.182333: step 3, loss = 2.492731 (92.1 examples/sec; 1.390 sec/batch) 7 | 2016-01-13 03:13:40.291217: step 4, loss = 2.373059 (80.4 examples/sec; 1.592 sec/batch) 8 | 2016-01-13 03:13:42.706800: step 5, loss = 2.417678 (67.5 examples/sec; 1.897 sec/batch) 9 | 2016-01-13 03:13:43.957534: step 5, train acc = 14.06, n_correct= 18 10 | 2016-01-13 03:13:45.600259: step 6, loss = 2.297081 (78.3 examples/sec; 1.635 sec/batch) 11 | 2016-01-13 03:13:48.936141: step 7, loss = 2.341561 (47.4 examples/sec; 2.702 sec/batch) 12 | 2016-01-13 03:13:51.198055: step 8, loss = 2.330050 (76.0 examples/sec; 1.685 sec/batch) 13 | 2016-01-13 03:13:53.674682: step 9, loss = 2.336012 (68.6 examples/sec; 1.866 sec/batch) 14 | 2016-01-13 03:13:55.846825: step 10, loss = 2.379518 (78.7 examples/sec; 1.626 sec/batch) 15 | 2016-01-13 03:13:56.974894: step 10, train acc = 5.47, n_correct= 7 16 | 2016-01-13 03:13:58.653115: step 11, loss = 2.343262 (77.0 examples/sec; 1.663 sec/batch) 17 | 2016-01-13 03:14:01.238556: step 12, loss = 2.385437 (73.6 examples/sec; 1.740 sec/batch) 18 | 2016-01-13 03:14:03.738443: step 13, loss = 2.319719 (69.7 examples/sec; 1.836 sec/batch) 19 | 2016-01-13 03:14:06.141940: step 14, loss = 2.319307 (71.5 examples/sec; 1.789 sec/batch) 20 | 2016-01-13 03:14:08.708325: step 15, loss = 2.351548 (64.6 examples/sec; 1.981 sec/batch) 21 | 2016-01-13 03:14:09.885258: step 15, train acc = 11.72, n_correct= 15 22 | 2016-01-13 03:14:11.686827: step 16, loss = 2.328614 (71.5 examples/sec; 1.789 sec/batch) 23 | 2016-01-13 03:14:13.773844: step 17, loss = 2.311491 (83.4 examples/sec; 1.535 sec/batch) 24 | 2016-01-13 03:14:15.958159: step 18, loss = 2.313294 (78.6 examples/sec; 1.628 sec/batch) 25 | 2016-01-13 03:14:18.423966: step 19, loss = 2.339075 (67.6 examples/sec; 1.893 sec/batch) 26 | 2016-01-13 03:14:20.944807: step 20, loss = 2.301256 (64.7 examples/sec; 1.977 sec/batch) 27 | 2016-01-13 03:14:22.103483: step 20, train acc = 12.50, n_correct= 16 28 | 2016-01-13 03:14:24.086866: step 21, loss = 2.315367 (65.0 examples/sec; 1.968 sec/batch) 29 | 2016-01-13 03:14:26.239897: step 22, loss = 2.325019 (81.6 examples/sec; 1.569 sec/batch) 30 | 31 | 32 | ############################ FINISHED ############################ 33 | 34 | 35 | ############################ FINISHED ############################ 36 | 37 | 38 | ############################ FINISHED ############################ 39 | 40 | 41 | ############################ FINISHED ############################ 42 | 2016-01-13 03:44:32.499830: step 0, loss = 2.787704 (77.0 examples/sec; 1.663 sec/batch) 43 | 2016-01-13 03:44:33.491026: step 0, train acc = 10.16, n_correct= 13 44 | step 0, Val Acc = 8.59, num correct = 11 45 | 46 | 47 | ############################ FINISHED ############################ 48 | 49 | 50 | ############################ FINISHED ############################ 51 | 2016-01-13 03:58:36.801882: step 0, loss = 2.536760 (77.1 examples/sec; 1.659 sec/batch) 52 | 2016-01-13 03:58:37.769191: step 0, train acc = 12.50, n_correct= 16 53 | step 0, Val Acc = 11.72, num correct = 15 54 | 2016-01-13 03:58:39.844276: step 1, loss = 2.540612 (88.4 examples/sec; 1.449 sec/batch) 55 | 2016-01-13 03:58:41.737896: step 2, loss = 2.530640 (92.6 examples/sec; 1.382 sec/batch) 56 | 2016-01-13 03:58:43.897213: step 3, loss = 2.387611 (81.7 examples/sec; 1.567 sec/batch) 57 | 2016-01-13 03:58:45.880616: step 4, loss = 2.480456 (86.1 examples/sec; 1.486 sec/batch) 58 | 2016-01-13 03:58:47.904688: step 5, loss = 2.353584 (84.6 examples/sec; 1.513 sec/batch) 59 | 2016-01-13 03:58:48.856693: step 5, train acc = 10.16, n_correct= 13 60 | 2016-01-13 03:58:50.312882: step 6, loss = 2.349627 (88.4 examples/sec; 1.448 sec/batch) 61 | 2016-01-13 03:58:52.333935: step 7, loss = 2.336494 (85.3 examples/sec; 1.500 sec/batch) 62 | 2016-01-13 03:58:54.352809: step 8, loss = 2.372593 (85.6 examples/sec; 1.496 sec/batch) 63 | 2016-01-13 03:58:56.336635: step 9, loss = 2.314384 (87.2 examples/sec; 1.468 sec/batch) 64 | 2016-01-13 03:58:58.330619: step 10, loss = 2.354839 (87.6 examples/sec; 1.462 sec/batch) 65 | 2016-01-13 03:58:59.735143: step 10, train acc = 13.28, n_correct= 17 66 | 67 | 68 | ############################ FINISHED ############################ 69 | 2016-03-17 16:01:21.182128: step 0, loss = 2.462488 (88.0 examples/sec; 1.454 sec/batch) 70 | 2016-03-17 16:01:22.074818: step 0, train acc = 11.72, n_correct= 15 71 | step 0, Val Acc = 14.84, num correct = 19 72 | 2016-03-17 16:01:24.438499: step 1, loss = 2.473764 (86.8 examples/sec; 1.475 sec/batch) 73 | 2016-03-17 16:01:26.298836: step 2, loss = 2.405185 (96.0 examples/sec; 1.333 sec/batch) 74 | 2016-03-17 16:01:28.069750: step 3, loss = 2.346606 (97.6 examples/sec; 1.311 sec/batch) 75 | 76 | 77 | ############################ FINISHED ############################ 78 | -------------------------------------------------------------------------------- /imageflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Hamed MP. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | """ 18 | Simple library to read all PNG and JPG/JPEG images in a directory 19 | with TensorFlow buil-in functions in multi-thread way to boost speed. 20 | 21 | Supported formats by TensorFlow are: PNG, JPG/JPEG 22 | 23 | Hamed MP 24 | Github: @hamedmp 25 | Twitter: @TheHamedMP 26 | """ 27 | 28 | from .imageflow import * 29 | from .reader import * 30 | 31 | __author__ = 'HANEL' 32 | -------------------------------------------------------------------------------- /imageflow/convert_to_records.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Hamed MP. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Converts images to TFRecords file format with Example protos.""" 17 | 18 | import os 19 | import tensorflow as tf 20 | 21 | 22 | tf.app.flags.DEFINE_string('directory', 'converted_data', 23 | 'Directory to write the ' 24 | 'converted result') 25 | # tf.app.flags.DEFINE_integer('validation_size', 10000, 26 | # 'Number of examples to separate from the training ' 27 | # 'ckpt for the validation set.') 28 | FLAGS = tf.app.flags.FLAGS 29 | 30 | 31 | def _int64_feature(value): 32 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 33 | 34 | 35 | def _bytes_feature(value): 36 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 37 | 38 | 39 | def convert_to(images, labels, name): 40 | num_examples = labels.shape[0] 41 | print('labels shape is ', labels.shape[0]) 42 | if images.shape[0] != num_examples: 43 | raise ValueError("Images size %d does not match label size %d." % 44 | (images.shape[0], num_examples)) 45 | rows = images.shape[1] 46 | cols = images.shape[2] 47 | depth = images.shape[3] 48 | 49 | filename = os.path.join(FLAGS.directory, name + '.tfrecords') 50 | print('Writing', filename) 51 | writer = tf.python_io.TFRecordWriter(filename) 52 | for index in range(num_examples): 53 | image_raw = images[index].tostring() 54 | example = tf.train.Example(features=tf.train.Features(feature={ 55 | 'height': _int64_feature(rows), 56 | 'width': _int64_feature(cols), 57 | 'depth': _int64_feature(depth), 58 | 'label': _int64_feature(int(labels[index])), 59 | 'image_raw': _bytes_feature(image_raw)})) 60 | writer.write(example.SerializeToString()) 61 | 62 | if __name__ == '__main__': 63 | tf.app.run() 64 | -------------------------------------------------------------------------------- /imageflow/imageflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Hamed MP. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from .convert_to_records import convert_to 17 | from .reader import read_and_decode 18 | 19 | __author__ = 'HANEL' 20 | 21 | ''' 22 | Simple library to read all PNG and JPG/JPEG images in a directory 23 | with TensorFlow buil-in functions in multi-thread way to boost speed. 24 | Supported formats by TensorFlow are: PNG, JPG/JPEG 25 | Hamed MP 26 | Github: @hamedmp 27 | Twitter: @TheHamedMP 28 | ''' 29 | 30 | from numpy.random import shuffle 31 | import tensorflow as tf 32 | 33 | 34 | def inputs(filename, batch_size, num_epochs, num_threads, 35 | imshape, num_examples_per_epoch=128): 36 | """Reads input tfrecord file num_epochs times. Use it for validation. 37 | Args: 38 | filename: The path to the .tfrecords file to be read 39 | batch_size: Number of examples per returned batch. 40 | num_epochs: Number of times to read the input ckpt, or 0/None to 41 | train forever. 42 | num_threads: Number of reader workers to enqueue 43 | imshape: The shape of image in the format 44 | num_examples_per_epoch: Number of images to use per epoch 45 | Returns: 46 | A tuple (images, labels), where: 47 | * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] 48 | in the range [-0.5, 0.5]. 49 | * labels is an int32 tensor with shape [batch_size] with the true label, 50 | a number in the range [0, mnist.NUM_CLASSES). 51 | Note that an tf.train.QueueRunner is added to the graph, which 52 | must be run using e.g. tf.train.start_queue_runners(). 53 | """ 54 | if not num_epochs: 55 | num_epochs = None 56 | 57 | with tf.name_scope('input'): 58 | filename_queue = tf.train.string_input_producer( 59 | [filename], num_epochs=num_epochs, name='string_input_producer') 60 | 61 | # Even when reading in multiple threads, share the filename 62 | # queue. 63 | image, label = read_and_decode(filename_queue, imshape, normalize=True) 64 | 65 | # Convert from [0, 255] -> [-0.5, 0.5] floats. The normalize param in read_and_decode will do the same job. 66 | # image = tf.cast(image, tf.float32) 67 | # image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 68 | 69 | # Shuffle the examples and collect them into batch_size batches. 70 | # (Internally uses a RandomShuffleQueue.) 71 | # We run this in two threads to avoid being a bottleneck. 72 | # Ensure that the random shuffling has good mixing properties. 73 | min_fraction_of_examples_in_queue = 0.4 74 | min_queue_examples = int(num_examples_per_epoch * 75 | min_fraction_of_examples_in_queue) 76 | 77 | images, sparse_labels = tf.train.shuffle_batch( 78 | [image, label], batch_size=batch_size, num_threads=num_threads, 79 | capacity=min_queue_examples + 3 * batch_size, enqueue_many=False, 80 | # Ensures a minimum amount of shuffling of examples. 81 | min_after_dequeue=min_queue_examples, name='batching_shuffling') 82 | 83 | return images, sparse_labels 84 | 85 | def _random_brightness_helper(image): 86 | return tf.image.random_brightness(image, max_delta=63) 87 | 88 | def _random_contrast_helper(image): 89 | return tf.image.random_contrast(image, lower=0.2, upper=1.8) 90 | 91 | def distorted_inputs(filename, batch_size, num_epochs, num_threads, 92 | imshape, num_examples_per_epoch=128, flatten=True): 93 | """Construct distorted input for training using the Reader ops. 94 | Raises: 95 | ValueError: if no data_dir 96 | Args: 97 | filename: The name of the file containing the images 98 | batch_size: The number of images per batch 99 | num_epochs: The number of epochs passed to string_input_producer 100 | num_threads: The number of threads passed to shuffle_batch 101 | imshape: Shape of image in [height, width, n_channels] format 102 | num_examples_per_epoch: Number of images to use per epoch 103 | flatten: Whether to flatten image after image transformations 104 | Returns: 105 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 106 | labels: Labels. 1D tensor of [batch_size] size. 107 | """ 108 | 109 | if not num_epochs: 110 | num_epochs = None 111 | 112 | with tf.name_scope('input'): 113 | filename_queue = tf.train.string_input_producer( 114 | [filename], num_epochs=num_epochs, name='string_DISTORTED_input_producer') 115 | 116 | # Even when reading in multiple threads, share the filename 117 | # queue. 118 | image, label = read_and_decode(filename_queue, imshape) 119 | 120 | # Reshape to imshape as distortion methods need this shape 121 | image = tf.reshape(image, imshape) 122 | 123 | # Image processing for training the network. Note the many random 124 | # distortions applied to the image. 125 | 126 | # Removed random_crop in new TensorFlow release. 127 | # Randomly crop a [height, width] section of the image. 128 | # distorted_image = tf.image.random_crop(image, [height, width]) 129 | # 130 | # Randomly flip the image horizontally. 131 | distorted_image = tf.image.random_flip_left_right(image) 132 | # 133 | # Randomly apply image transformations in random_functions list 134 | random_functions = [_random_brightness_helper, _random_contrast_helper] 135 | shuffle(random_functions) 136 | for fcn in random_functions: 137 | distorted_image = fcn(distorted_image) 138 | 139 | # # Subtract off the mean and divide by the variance of the pixels. 140 | float_image = tf.image.per_image_standardization(distorted_image) 141 | 142 | if flatten: 143 | num_elements = 1 144 | for i in imshape: num_elements = num_elements * i 145 | image = tf.reshape(float_image, [num_elements]) 146 | else: 147 | image = float_image 148 | 149 | # Ensure that the random shuffling has good mixing properties. 150 | min_fraction_of_examples_in_queue = 0.4 151 | min_queue_examples = int(num_examples_per_epoch * 152 | min_fraction_of_examples_in_queue) 153 | images, sparse_labels = tf.train.shuffle_batch([image, label], 154 | batch_size=batch_size, 155 | num_threads=num_threads, 156 | capacity=min_queue_examples + 3 * batch_size, 157 | enqueue_many=False, 158 | # Ensures a minimum amount of shuffling of examples. 159 | min_after_dequeue=min_queue_examples, 160 | name='batching_shuffling_distortion') 161 | 162 | return images, sparse_labels 163 | 164 | 165 | def convert_split_images(images, labels, train_validation_split=10): 166 | """Construct distorted input for CIFAR training using the Reader ops. 167 | Raises: 168 | ValueError: if labels and images count doesn't match. 169 | Args: 170 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 171 | labels: Labels. 1D tensor of [batch_size] size. 172 | train_validation_split: 173 | Returns: 174 | """ 175 | # TODO complete doc 176 | assert images.shape[0] == labels.shape[0], "Number of images, %d should be equal to number of labels %d" % \ 177 | (images.shape[0], labels.shape[0]) 178 | 179 | validation_size = images.shape[0] * train_validation_split // 100 # default 10% 180 | 181 | # Generate a validation set. 182 | validation_images = images[:validation_size, :, :, :] 183 | validation_labels = labels[:validation_size] 184 | train_images = images[validation_size:, :, :, :] 185 | train_labels = labels[validation_size:] 186 | 187 | # Convert to Examples and write the result to TFRecords. 188 | convert_to(train_images, train_labels, 'train') 189 | convert_to(validation_images, validation_labels, 'validation') 190 | 191 | def convert_images(images, labels, filename): 192 | """Construct distorted input for CIFAR training using the Reader ops. 193 | Raises: 194 | ValueError: if labels and images count doesn't match. 195 | Args: 196 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 197 | labels: Labels. 1D tensor of [batch_size] size. 198 | filename: 199 | Returns: 200 | """ 201 | # TODO complete doc 202 | assert images.shape[0] == labels.shape[0], "Number of images, %d should be equal to number of labels %d" % \ 203 | (images.shape[0], labels.shape[0]) 204 | 205 | # Convert to Examples and write the result to TFRecords. 206 | convert_to(images, labels, filename) -------------------------------------------------------------------------------- /imageflow/playground.py: -------------------------------------------------------------------------------- 1 | __author__ = 'HANEL' 2 | 3 | import os 4 | import glob 5 | import tensorflow as tf 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import imageflow 10 | 11 | imageflow.read_images() 12 | 13 | ''' 14 | Simple library to read all PNG and JPG/JPEG images in a directory 15 | with TensorFlow buil-in functions to boost speed. 16 | 17 | Hamed MP 18 | Github: @hamedmp 19 | Twitter: @hamedpc2002 20 | 21 | ''' 22 | 23 | 24 | # def _read_png(filename_queue, num): 25 | # 26 | # images = [] 27 | # # filename_queue = tf.train.string_input_producer(filename_queue_list) 28 | # reader = tf.WholeFileReader() 29 | # key, value = reader.read(filename_queue) 30 | # 31 | # _img = tf.image.decode_png(value) 32 | # 33 | # init_op = tf.initialize_all_variables() 34 | # with tf.Session() as sess: 35 | # sess.run(init_op) 36 | # 37 | # # Start populating the filename queue. 38 | # coord = tf.train.Coordinator() 39 | # threads = tf.train.start_queue_runners(coord=coord) 40 | # 41 | # for i in range(1): 42 | # png = _img.eval() 43 | # images.append(png) 44 | # Image._showxv(Image.fromarray(np.asarray(png))) 45 | # 46 | # coord.request_stop() 47 | # coord.join(threads) 48 | # 49 | # 50 | # return images 51 | # 52 | # 53 | # def _read_jpg(filename_queue, num): 54 | # 55 | # images = [] 56 | # # filename_queue = tf.train.string_input_producer(filename_queue_list) 57 | # reader = tf.WholeFileReader() 58 | # key, value = reader.read(filename_queue) 59 | # 60 | # _img = tf.image.decode_jpeg(value) 61 | # 62 | # init_op = tf.initialize_all_variables() 63 | # with tf.Session() as sess: 64 | # sess.run(init_op) 65 | # 66 | # # Start populating the filename queue. 67 | # coord = tf.train.Coordinator() 68 | # threads = tf.train.start_queue_runners(coord=coord) 69 | # 70 | # for i in range(1): 71 | # jpeg = _img.eval() 72 | # images.append(jpeg) 73 | # Image._showxv(Image.fromarray(np.asarray(jpeg))) 74 | # 75 | # coord.request_stop() 76 | # coord.join(threads) 77 | # 78 | # 79 | # return images 80 | # 81 | # print(jpeg.shape) 82 | 83 | 84 | def read_images(path, is_directory=True): 85 | 86 | images = [] 87 | png_files = [] 88 | jpeg_files = [] 89 | 90 | reader = tf.WholeFileReader() 91 | 92 | png_files_path = glob.glob(os.path.join(path, '*.[pP][nN][gG]')) 93 | jpeg_files_path = glob.glob(os.path.join(path, '*.[jJ][pP][eE][gG]')) #, 94 | jpg_files_path = glob.glob(os.path.join(path, '*.[jJ][pP][gG]')) # glob.glob(os.path.join(path + '*.[jJ][pP][nN][gG]')) 95 | 96 | print(png_files_path) 97 | # jpeg_files_path = [glob.glob(path + '*.jpg'), glob.glob(path + '*.jpeg')] 98 | 99 | if is_directory: 100 | for filename in png_files_path: 101 | png_files.append(filename) 102 | for filename in jpeg_files_path: 103 | jpeg_files.append(filename) 104 | for filename in jpg_files_path: 105 | jpeg_files.append(filename) 106 | else: 107 | _, extension = os.path.splitext(path) 108 | print(extension) 109 | if extension.lower() == '.png': 110 | key, value = reader.read(tf.train.string_input_producer(path)) 111 | img = tf.image.decode_png(value) 112 | print(img) 113 | Image._show(Image.fromarray(np.asarray(img))) 114 | return img 115 | 116 | 117 | # Decode if there is a PNG file: 118 | if len(png_files) > 0: 119 | png_file_queue = tf.train.string_input_producer(png_files) 120 | pkey, pvalue = reader.read(png_file_queue) 121 | p_img = tf.image.decode_png(pvalue) 122 | 123 | if len(jpeg_files) > 0: 124 | jpeg_file_queue = tf.train.string_input_producer(jpeg_files) 125 | jkey, jvalue = reader.read(jpeg_file_queue) 126 | j_img = tf.image.decode_jpeg(jvalue) 127 | 128 | 129 | with tf.Session() as sess: 130 | 131 | # Start populating the filename queue. 132 | coord = tf.train.Coordinator() 133 | threads = tf.train.start_queue_runners(coord=coord) 134 | 135 | if len(png_files) > 0: 136 | for i in range(len(png_files)): 137 | png = p_img.eval() 138 | images.append(png) 139 | 140 | # Image._showxv(Image.fromarray(np.asarray(png))) 141 | 142 | if len(jpeg_files) > 0: 143 | for i in range(len(jpeg_files)): 144 | jpeg = j_img.eval() 145 | images.append(jpeg) 146 | 147 | coord.request_stop() 148 | coord.join(threads) 149 | 150 | return images 151 | 152 | # all_images = [_read_png(png_file_queue, num=len(png_files)), 153 | # _read_jpg(jpeg_file_queue, num=len(jpeg_files))] 154 | 155 | # return all_images 156 | 157 | 158 | read_images('/Users/HANEL/Desktop/') 159 | 160 | 161 | # TODO: Remove 162 | 163 | def _read_jpg(): 164 | 165 | 166 | dumm = glob.glob('/Users/HANEL/Desktop/' + '*.png') 167 | print(len(dumm)) 168 | filename_queue = tf.train.string_input_producer(dumm) 169 | # filename_queue = tf.train.string_input_producer(['/Users/HANEL/Desktop/tf.png', '/Users/HANEL/Desktop/ft.png']) 170 | 171 | reader = tf.WholeFileReader() 172 | key, value = reader.read(filename_queue) 173 | 174 | my_img = tf.image.decode_png(value) 175 | # my_img_flip = tf.image.flip_up_down(my_img) 176 | 177 | init_op = tf.initialize_all_variables() 178 | with tf.Session() as sess: 179 | sess.run(init_op) 180 | 181 | # Start populating the filename queue. 182 | coord = tf.train.Coordinator() 183 | threads = tf.train.start_queue_runners(coord=coord) 184 | 185 | for i in range(1): 186 | gunel = my_img.eval() 187 | 188 | print(gunel.shape) 189 | 190 | Image._showxv(Image.fromarray(np.asarray(gunel))) 191 | coord.request_stop() 192 | coord.join(threads) 193 | 194 | # 195 | # _read_jpg() 196 | -------------------------------------------------------------------------------- /imageflow/reader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Hamed MP. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import print_function 17 | 18 | __author__ = 'HANEL' 19 | 20 | import os 21 | import glob 22 | import csv 23 | import tensorflow as tf 24 | import numpy as np 25 | from PIL import Image 26 | from .utils import dense_to_one_hot 27 | 28 | 29 | def _read_raw_images(path, is_directory=True): 30 | """Reads directory of images in tensorflow 31 | Args: 32 | path: 33 | is_directory: 34 | 35 | Returns: 36 | 37 | """ 38 | images = [] 39 | png_files = [] 40 | jpeg_files = [] 41 | 42 | reader = tf.WholeFileReader() 43 | 44 | png_files_path = glob.glob(os.path.join(path, '*.[pP][nN][gG]')) 45 | jpeg_files_path = glob.glob(os.path.join(path, '*.[jJ][pP][eE][gG]')) 46 | jpg_files_path = glob.glob(os.path.join(path, '*.[jJ][pP][gG]')) 47 | 48 | if is_directory: 49 | for filename in png_files_path: 50 | png_files.append(filename) 51 | for filename in jpeg_files_path: 52 | jpeg_files.append(filename) 53 | for filename in jpg_files_path: 54 | jpeg_files.append(filename) 55 | else: 56 | raise ValueError('Currently only batch read from directory supported') 57 | 58 | # Decode if there is a PNG file: 59 | if len(png_files) > 0: 60 | png_file_queue = tf.train.string_input_producer(png_files) 61 | pkey, pvalue = reader.read(png_file_queue) 62 | p_img = tf.image.decode_png(pvalue) 63 | 64 | if len(jpeg_files) > 0: 65 | jpeg_file_queue = tf.train.string_input_producer(jpeg_files) 66 | jkey, jvalue = reader.read(jpeg_file_queue) 67 | j_img = tf.image.decode_jpeg(jvalue) 68 | 69 | return # TODO: return normal thing 70 | 71 | 72 | def read_and_decode(filename_queue, imshape, normalize=False, flatten=True): 73 | """Reads 74 | Args: 75 | filename_queue: 76 | imshape: 77 | normalize: 78 | flatten: 79 | 80 | Returns: 81 | 82 | """ 83 | reader = tf.TFRecordReader() 84 | _, serialized_example = reader.read(filename_queue) 85 | features = tf.parse_single_example( 86 | serialized_example, 87 | features={ 88 | 'image_raw': tf.FixedLenFeature([], tf.string), 89 | 'label': tf.FixedLenFeature([], tf.int64) 90 | }) 91 | 92 | # Convert from a scalar string tensor (whose single string has 93 | # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 94 | # [mnist.IMAGE_PIXELS]. 95 | image = tf.decode_raw(features['image_raw'], tf.uint8) 96 | 97 | if flatten: 98 | num_elements = 1 99 | for i in imshape: num_elements = num_elements * i 100 | print(num_elements) 101 | image = tf.reshape(image, [num_elements]) 102 | image.set_shape(num_elements) 103 | else: 104 | image = tf.reshape(image, imshape) 105 | image.set_shape(imshape) 106 | 107 | if normalize: 108 | # Convert from [0, 255] -> [-0.5, 0.5] floats. 109 | image = tf.cast(image, tf.float32) 110 | image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 111 | 112 | # Convert label from a scalar uint8 tensor to an int32 scalar. 113 | label = tf.cast(features['label'], tf.int32) 114 | 115 | return image, label 116 | 117 | 118 | # Helper, Examples 119 | def _read_labels_csv_from(path, num_classes, one_hot=False): 120 | """Reads 121 | Args: 122 | 123 | Returns: 124 | 125 | """ 126 | print('Reading labels') 127 | with open(os.path.join(path), 'r') as dest_f: 128 | data_iter = csv.reader(dest_f) 129 | train_labels = [data for data in data_iter] 130 | 131 | train_labels = np.array(train_labels, dtype=np.uint32) 132 | 133 | if one_hot: 134 | labels_one_hot = dense_to_one_hot(train_labels, num_classes) 135 | labels_one_hot = np.asarray(labels_one_hot) 136 | return labels_one_hot 137 | 138 | return train_labels 139 | 140 | 141 | def _read_pngs_from(path): 142 | """Reads directory of images. 143 | Args: 144 | path: path to the directory 145 | 146 | Returns: 147 | A list of all images in the directory in the TF format (You need to call sess.run() or .eval() to get the value). 148 | """ 149 | images = [] 150 | png_files_path = glob.glob(os.path.join(path, '*.[pP][nN][gG]')) 151 | for filename in png_files_path: 152 | im = Image.open(filename) 153 | im = np.asarray(im, np.uint8) 154 | 155 | # get only images name, not path 156 | image_name = filename.split('/')[-1].split('.')[0] 157 | images.append([int(image_name), im]) 158 | 159 | images = sorted(images, key=lambda image: image[0]) 160 | 161 | images_only = [np.asarray(image[1], np.uint8) for image in images] # Use unint8 or you will be !!! 162 | images_only = np.array(images_only) 163 | 164 | print(images_only.shape) 165 | return images_only 166 | -------------------------------------------------------------------------------- /imageflow/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Hamed MP. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | __author__ = 'HANEL' 17 | 18 | import numpy as np 19 | 20 | 21 | def dense_to_one_hot(labels_dense, num_classes): 22 | """ 23 | Convert class labels from scalars to one-hot vectors. 24 | 25 | 26 | """ 27 | num_labels = labels_dense.shape[0] 28 | index_offset = np.arange(num_labels) * num_classes 29 | labels_one_hot = np.zeros((num_labels, num_classes)) 30 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 31 | print(labels_one_hot[0]) 32 | 33 | return labels_one_hot 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | __author__ = 'HANEL' 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | __author__ = 'HANEL' 2 | 3 | from setuptools import setup 4 | 5 | setup(name='imageflow', 6 | version='0.0.2', 7 | 8 | description='Import, Convert (and Soon Train) images with TensorFlow', 9 | 10 | classifiers=[ 11 | 'License :: OSI Approved :: MIT License', 12 | 'Programming Language :: Python :: 2.7' 13 | ], 14 | keywords='tensorflow image cnn', 15 | 16 | url='http://hamedmp.github.io/ImageFlow/', 17 | 18 | author='Hamed Mohammadpour', 19 | author_email='hamedmp@my.com', 20 | 21 | license='MIT', 22 | 23 | packages=['imageflow'], 24 | zip_safe=False, 25 | 26 | install_requires=['numpy', 'Pillow'], 27 | 28 | include_package_data=True, 29 | 30 | dependency_links=[''], 31 | 32 | ) 33 | --------------------------------------------------------------------------------