├── .gitignore ├── LICENSE ├── README.md ├── build_image_data.py ├── config.py ├── data_provider ├── __init__.py └── data_provider.py ├── net ├── __init__.py └── densenet.py ├── res ├── graph.png └── test_result.png ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # os 104 | *.DS_Store 105 | tfrecord/ 106 | log/ 107 | models/ 108 | *.jpeg 109 | *.jpg 110 | view_image.py 111 | 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DenseNet-Tensorflow 2 | An implementation of densenet in low level tensorflow API. Focus on performance, scalability and stability. 3 | 4 | No tf.slim 5 | 6 | No tf.layers 7 | 8 | No tf.contrib 9 | 10 | No opencv 11 | 12 | Current graph is a variation of DenseNet-BC-121 with 224 input image size, the difference is that there's 52 convolutional layers in this implementation. 13 | 14 | With this repo., you are able to train and test the architecture from scratch. 15 | 16 | More sophisticated features are still under construction. 17 | 18 | ## Feature 19 | Support tfrecord 20 | 21 | With minimum dependencies 22 | 23 | 24 | ## Dependencies 25 | numpy 26 | 27 | ## Usage 28 | 1. Clone the repo: 29 | ```bash 30 | git clone https://github.com/yeephycho/densenet-tensorflow.git 31 | ``` 32 | 33 | 2. Download example tfrecord data: 34 | Click [here](https://drive.google.com/drive/folders/0BwTYOWiLy2btX2RiZHlDYVdiWVE?usp=sharing) to download. 35 | About how to generate tfrecord, please see [repo.](https://github.com/yeephycho/tensorflow_input_image_by_tfrecord) or see the script from tensorflow [build image data](https://github.com/tensorflow/models/blob/master/inception/inception/data/build_image_data.py). 36 | 37 | Data comes from tensorflow [inception retraining example](https://github.com/tensorflow/models/tree/master/inception) which contains 5 kinds of flowers, click [here](http://download.tensorflow.org/models/image/imagenet/inception-v3-2016-03-01.tar.gz) to download original data. 38 | 39 | 3. Train example data: 40 | ```bash 41 | cd densenet-tensorflow 42 | ``` 43 | ```python 44 | python train.py 45 | ``` 46 | 47 | 4. Visualize training loss: 48 | ```bash 49 | tensorboard --logdir=./log 50 | ``` 51 | 52 | 5. Test model: 53 | A pre-trained model can be download [here](https://drive.google.com/drive/folders/0BwTYOWiLy2btUmRoT0RvWWJyOWM?usp=sharing). Put the models folder under this project folder. Then 54 | ```python 55 | python test.py 56 | ``` 57 | Hopefully the pre-trained model should give you a precision of 80.3%. 58 | 59 | Expected accuracy should be around 80%. 60 | ![Result](https://github.com/yeephycho/densenet-tensorflow/blob/master/res/test_result.png?raw=true "Show result") 61 | 62 | 63 | ## Reference 64 | [Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) 65 | 66 | -------------------------------------------------------------------------------- /build_image_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converts image data to TFRecords file format with Example protos. 16 | 17 | The image data set is expected to reside in JPEG files located in the 18 | following directory structure. 19 | 20 | data_dir/label_0/image0.jpeg 21 | data_dir/label_0/image1.jpg 22 | ... 23 | data_dir/label_1/weird-image.jpeg 24 | data_dir/label_1/my-image.jpeg 25 | ... 26 | 27 | where the sub-directory is the unique label associated with these images. 28 | 29 | This TensorFlow script converts the training and evaluation data into 30 | a sharded data set consisting of TFRecord files 31 | 32 | train_directory/train-00000-of-01024 33 | train_directory/train-00001-of-01024 34 | ... 35 | train_directory/train-00127-of-01024 36 | 37 | and 38 | 39 | validation_directory/validation-00000-of-00128 40 | validation_directory/validation-00001-of-00128 41 | ... 42 | validation_directory/validation-00127-of-00128 43 | 44 | where we have selected 1024 and 128 shards for each data set. Each record 45 | within the TFRecord file is a serialized Example proto. The Example proto 46 | contains the following fields: 47 | 48 | image/encoded: string containing JPEG encoded image in RGB colorspace 49 | image/height: integer, image height in pixels 50 | image/width: integer, image width in pixels 51 | image/colorspace: string, specifying the colorspace, always 'RGB' 52 | image/channels: integer, specifying the number of channels, always 3 53 | image/format: string, specifying the format, always'JPEG' 54 | 55 | image/filename: string containing the basename of the image file 56 | e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG' 57 | image/class/label: integer specifying the index in a classification layer. 58 | The label ranges from [0, num_labels] where 0 is unused and left as 59 | the background class. 60 | image/class/text: string specifying the human-readable version of the label 61 | e.g. 'dog' 62 | 63 | If you data set involves bounding boxes, please look at build_imagenet_data.py. 64 | """ 65 | from __future__ import absolute_import 66 | from __future__ import division 67 | from __future__ import print_function 68 | 69 | from datetime import datetime 70 | import os 71 | import random 72 | import sys 73 | import threading 74 | 75 | 76 | import numpy as np 77 | import tensorflow as tf 78 | 79 | tf.app.flags.DEFINE_string('train_directory', './', 80 | 'Training data directory') 81 | tf.app.flags.DEFINE_string('validation_directory', './', 82 | 'Validation data directory') 83 | tf.app.flags.DEFINE_string('output_directory', './', 84 | 'Output data directory') 85 | 86 | tf.app.flags.DEFINE_integer('train_shards', 2, 87 | 'Number of shards in training TFRecord files.') 88 | tf.app.flags.DEFINE_integer('validation_shards', 0, 89 | 'Number of shards in validation TFRecord files.') 90 | 91 | tf.app.flags.DEFINE_integer('num_threads', 2, 92 | 'Number of threads to preprocess the images.') 93 | 94 | # The labels file contains a list of valid labels are held in this file. 95 | # Assumes that the file contains entries as such: 96 | # dog 97 | # cat 98 | # flower 99 | # where each line corresponds to a label. We map each label contained in 100 | # the file to an integer corresponding to the line number starting from 0. 101 | tf.app.flags.DEFINE_string('labels_file', './label.txt', 'Labels file') 102 | 103 | 104 | FLAGS = tf.app.flags.FLAGS 105 | 106 | 107 | def _int64_feature(value): 108 | """Wrapper for inserting int64 features into Example proto.""" 109 | if not isinstance(value, list): 110 | value = [value] 111 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 112 | 113 | 114 | def _bytes_feature(value): 115 | """Wrapper for inserting bytes features into Example proto.""" 116 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 117 | 118 | 119 | def _convert_to_example(filename, image_buffer, label, text, height, width): 120 | """Build an Example proto for an example. 121 | 122 | Args: 123 | filename: string, path to an image file, e.g., '/path/to/example.JPG' 124 | image_buffer: string, JPEG encoding of RGB image 125 | label: integer, identifier for the ground truth for the network 126 | text: string, unique human-readable, e.g. 'dog' 127 | height: integer, image height in pixels 128 | width: integer, image width in pixels 129 | Returns: 130 | Example proto 131 | """ 132 | 133 | colorspace = 'RGB' 134 | channels = 3 135 | image_format = 'JPEG' 136 | 137 | example = tf.train.Example(features=tf.train.Features(feature={ 138 | 'image/height': _int64_feature(height), 139 | 'image/width': _int64_feature(width), 140 | 'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)), 141 | 'image/channels': _int64_feature(channels), 142 | 'image/class/label': _int64_feature(label), 143 | 'image/class/text': _bytes_feature(tf.compat.as_bytes(text)), 144 | 'image/format': _bytes_feature(tf.compat.as_bytes(image_format)), 145 | 'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))), 146 | 'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))})) 147 | return example 148 | 149 | 150 | class ImageCoder(object): 151 | """Helper class that provides TensorFlow image coding utilities.""" 152 | 153 | def __init__(self): 154 | # Create a single Session to run all image coding calls. 155 | self._sess = tf.Session() 156 | 157 | # Initializes function that converts PNG to JPEG data. 158 | self._png_data = tf.placeholder(dtype=tf.string) 159 | image = tf.image.decode_png(self._png_data, channels=3) 160 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) 161 | 162 | # Initializes function that decodes RGB JPEG data. 163 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 164 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 165 | 166 | def png_to_jpeg(self, image_data): 167 | return self._sess.run(self._png_to_jpeg, 168 | feed_dict={self._png_data: image_data}) 169 | 170 | def decode_jpeg(self, image_data): 171 | image = self._sess.run(self._decode_jpeg, 172 | feed_dict={self._decode_jpeg_data: image_data}) 173 | assert len(image.shape) == 3 174 | assert image.shape[2] == 3 175 | return image 176 | 177 | 178 | def _is_png(filename): 179 | """Determine if a file contains a PNG format image. 180 | 181 | Args: 182 | filename: string, path of the image file. 183 | 184 | Returns: 185 | boolean indicating if the image is a PNG. 186 | """ 187 | return '.png' in filename 188 | 189 | 190 | def _process_image(filename, coder): 191 | """Process a single image file. 192 | 193 | Args: 194 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 195 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 196 | Returns: 197 | image_buffer: string, JPEG encoding of RGB image. 198 | height: integer, image height in pixels. 199 | width: integer, image width in pixels. 200 | """ 201 | # Read the image file. 202 | with tf.gfile.FastGFile(filename, 'r') as f: 203 | image_data = f.read() 204 | 205 | # Convert any PNG to JPEG's for consistency. 206 | if _is_png(filename): 207 | print('Converting PNG to JPEG for %s' % filename) 208 | image_data = coder.png_to_jpeg(image_data) 209 | 210 | # Decode the RGB JPEG. 211 | image = coder.decode_jpeg(image_data) 212 | 213 | # Check that image converted to RGB 214 | assert len(image.shape) == 3 215 | height = image.shape[0] 216 | width = image.shape[1] 217 | assert image.shape[2] == 3 218 | 219 | return image_data, height, width 220 | 221 | 222 | def _process_image_files_batch(coder, thread_index, ranges, name, filenames, 223 | texts, labels, num_shards): 224 | """Processes and saves list of images as TFRecord in 1 thread. 225 | 226 | Args: 227 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 228 | thread_index: integer, unique batch to run index is within [0, len(ranges)). 229 | ranges: list of pairs of integers specifying ranges of each batches to 230 | analyze in parallel. 231 | name: string, unique identifier specifying the data set 232 | filenames: list of strings; each string is a path to an image file 233 | texts: list of strings; each string is human readable, e.g. 'dog' 234 | labels: list of integer; each integer identifies the ground truth 235 | num_shards: integer number of shards for this data set. 236 | """ 237 | # Each thread produces N shards where N = int(num_shards / num_threads). 238 | # For instance, if num_shards = 128, and the num_threads = 2, then the first 239 | # thread would produce shards [0, 64). 240 | num_threads = len(ranges) 241 | assert not num_shards % num_threads 242 | num_shards_per_batch = int(num_shards / num_threads) 243 | 244 | shard_ranges = np.linspace(ranges[thread_index][0], 245 | ranges[thread_index][1], 246 | num_shards_per_batch + 1).astype(int) 247 | num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0] 248 | 249 | counter = 0 250 | for s in range(num_shards_per_batch): 251 | # Generate a sharded version of the file name, e.g. 'train-00002-of-00010' 252 | shard = thread_index * num_shards_per_batch + s 253 | output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards) 254 | output_file = os.path.join(FLAGS.output_directory, output_filename) 255 | writer = tf.python_io.TFRecordWriter(output_file) 256 | 257 | shard_counter = 0 258 | files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int) 259 | for i in files_in_shard: 260 | filename = filenames[i] 261 | label = labels[i] 262 | text = texts[i] 263 | 264 | image_buffer, height, width = _process_image(filename, coder) 265 | 266 | example = _convert_to_example(filename, image_buffer, label, 267 | text, height, width) 268 | writer.write(example.SerializeToString()) 269 | shard_counter += 1 270 | counter += 1 271 | 272 | if not counter % 1000: 273 | print('%s [thread %d]: Processed %d of %d images in thread batch.' % 274 | (datetime.now(), thread_index, counter, num_files_in_thread)) 275 | sys.stdout.flush() 276 | 277 | writer.close() 278 | print('%s [thread %d]: Wrote %d images to %s' % 279 | (datetime.now(), thread_index, shard_counter, output_file)) 280 | sys.stdout.flush() 281 | shard_counter = 0 282 | print('%s [thread %d]: Wrote %d images to %d shards.' % 283 | (datetime.now(), thread_index, counter, num_files_in_thread)) 284 | sys.stdout.flush() 285 | 286 | 287 | def _process_image_files(name, filenames, texts, labels, num_shards): 288 | """Process and save list of images as TFRecord of Example protos. 289 | 290 | Args: 291 | name: string, unique identifier specifying the data set 292 | filenames: list of strings; each string is a path to an image file 293 | texts: list of strings; each string is human readable, e.g. 'dog' 294 | labels: list of integer; each integer identifies the ground truth 295 | num_shards: integer number of shards for this data set. 296 | """ 297 | assert len(filenames) == len(texts) 298 | assert len(filenames) == len(labels) 299 | 300 | # Break all images into batches with a [ranges[i][0], ranges[i][1]]. 301 | spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int) 302 | ranges = [] 303 | for i in range(len(spacing) - 1): 304 | ranges.append([spacing[i], spacing[i+1]]) 305 | 306 | # Launch a thread for each batch. 307 | print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges)) 308 | sys.stdout.flush() 309 | 310 | # Create a mechanism for monitoring when all threads are finished. 311 | coord = tf.train.Coordinator() 312 | 313 | # Create a generic TensorFlow-based utility for converting all image codings. 314 | coder = ImageCoder() 315 | 316 | threads = [] 317 | for thread_index in range(len(ranges)): 318 | args = (coder, thread_index, ranges, name, filenames, 319 | texts, labels, num_shards) 320 | t = threading.Thread(target=_process_image_files_batch, args=args) 321 | t.start() 322 | threads.append(t) 323 | 324 | # Wait for all the threads to terminate. 325 | coord.join(threads) 326 | print('%s: Finished writing all %d images in data set.' % 327 | (datetime.now(), len(filenames))) 328 | sys.stdout.flush() 329 | 330 | 331 | def _find_image_files(data_dir, labels_file): 332 | """Build a list of all images files and labels in the data set. 333 | 334 | Args: 335 | data_dir: string, path to the root directory of images. 336 | 337 | Assumes that the image data set resides in JPEG files located in 338 | the following directory structure. 339 | 340 | data_dir/dog/another-image.JPEG 341 | data_dir/dog/my-image.jpg 342 | 343 | where 'dog' is the label associated with these images. 344 | 345 | labels_file: string, path to the labels file. 346 | 347 | The list of valid labels are held in this file. Assumes that the file 348 | contains entries as such: 349 | dog 350 | cat 351 | flower 352 | where each line corresponds to a label. We map each label contained in 353 | the file to an integer starting with the integer 0 corresponding to the 354 | label contained in the first line. 355 | 356 | Returns: 357 | filenames: list of strings; each string is a path to an image file. 358 | texts: list of strings; each string is the class, e.g. 'dog' 359 | labels: list of integer; each integer identifies the ground truth. 360 | """ 361 | print('Determining list of input files and labels from %s.' % data_dir) 362 | unique_labels = [l.strip() for l in tf.gfile.FastGFile( 363 | labels_file, 'r').readlines()] 364 | 365 | labels = [] 366 | filenames = [] 367 | texts = [] 368 | 369 | # Leave label index 0 empty as a background class. 370 | label_index = 1 371 | 372 | # Construct the list of JPEG files and labels. 373 | for text in unique_labels: 374 | jpeg_file_path = '%s/%s/*' % (data_dir, text) 375 | matching_files = tf.gfile.Glob(jpeg_file_path) 376 | 377 | labels.extend([label_index] * len(matching_files)) 378 | texts.extend([text] * len(matching_files)) 379 | filenames.extend(matching_files) 380 | 381 | if not label_index % 100: 382 | print('Finished finding files in %d of %d classes.' % ( 383 | label_index, len(labels))) 384 | label_index += 1 385 | 386 | # Shuffle the ordering of all image files in order to guarantee 387 | # random ordering of the images with respect to label in the 388 | # saved TFRecord files. Make the randomization repeatable. 389 | shuffled_index = list(range(len(filenames))) 390 | random.seed(12345) 391 | random.shuffle(shuffled_index) 392 | 393 | filenames = [filenames[i] for i in shuffled_index] 394 | texts = [texts[i] for i in shuffled_index] 395 | labels = [labels[i] for i in shuffled_index] 396 | 397 | print('Found %d JPEG files across %d labels inside %s.' % 398 | (len(filenames), len(unique_labels), data_dir)) 399 | return filenames, texts, labels 400 | 401 | 402 | def _process_dataset(name, directory, num_shards, labels_file): 403 | """Process a complete data set and save it as a TFRecord. 404 | 405 | Args: 406 | name: string, unique identifier specifying the data set. 407 | directory: string, root path to the data set. 408 | num_shards: integer number of shards for this data set. 409 | labels_file: string, path to the labels file. 410 | """ 411 | filenames, texts, labels = _find_image_files(directory, labels_file) 412 | _process_image_files(name, filenames, texts, labels, num_shards) 413 | 414 | 415 | def main(unused_argv): 416 | assert not FLAGS.train_shards % FLAGS.num_threads, ( 417 | 'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards') 418 | assert not FLAGS.validation_shards % FLAGS.num_threads, ( 419 | 'Please make the FLAGS.num_threads commensurate with ' 420 | 'FLAGS.validation_shards') 421 | print('Saving results to %s' % FLAGS.output_directory) 422 | 423 | # Run it! 424 | _process_dataset('validation', FLAGS.validation_directory, 425 | FLAGS.validation_shards, FLAGS.labels_file) 426 | _process_dataset('train', FLAGS.train_directory, 427 | FLAGS.train_shards, FLAGS.labels_file) 428 | 429 | 430 | if __name__ == '__main__': 431 | tf.app.run() 432 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Brief: Test the densenet for image classification 2 | # Data: 24/Aug./2017 3 | # E-mail: huyixuanhyx@gmail.com 4 | # License: Apache 2.0 5 | # By: Yeephycho @ Hong Kong 6 | 7 | import tensorflow as tf 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | tf.app.flags.DEFINE_string("train_data_path", "./tfrecord", "training data dir") 12 | tf.app.flags.DEFINE_string("log_dir", "./log", " the log dir") 13 | 14 | tf.app.flags.DEFINE_integer("TRAINING_SET_SIZE", 2512, "total image number of training set") 15 | tf.app.flags.DEFINE_integer("TESTING_SET_SIZE", 908, "total image number of training set") 16 | 17 | tf.app.flags.DEFINE_integer("BATCH_SIZE", 16, "batch size") 18 | tf.app.flags.DEFINE_integer("IMAGE_SIZE", 224, "image width and height") 19 | 20 | tf.app.flags.DEFINE_float("INIT_LEARNING_RATE", 0.005, "initial learning rate") 21 | tf.app.flags.DEFINE_float("DECAY_RATE", 0.5, "learning rate decay rate") 22 | tf.app.flags.DEFINE_integer("DECAY_STEPS", 2000, "learning rate decay step") 23 | 24 | tf.app.flags.DEFINE_float("weights_decay", 0.0001, "weights decay serve as l2 regularizer") 25 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeephycho/densenet-tensorflow/7fd51c7479c33f29b708a241215f7f8ea50f05ce/data_provider/__init__.py -------------------------------------------------------------------------------- /data_provider/data_provider.py: -------------------------------------------------------------------------------- 1 | # Brief: Data provdier for image classification using tfrecord 2 | # Data: 28/Aug./2017 3 | # E-mail: huyixuanhyx@gmail.com 4 | # License: Apache 2.0 5 | # By: Yeephycho @ Hong Kong 6 | 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import tensorflow as tf 13 | import numpy as np 14 | import os 15 | 16 | import sys 17 | import config as config 18 | 19 | 20 | 21 | FLAGS = tf.app.flags.FLAGS 22 | DATA_DIR = FLAGS.train_data_path 23 | TRAINING_SET_SIZE = FLAGS.TRAINING_SET_SIZE 24 | BATCH_SIZE = FLAGS.BATCH_SIZE 25 | IMAGE_SIZE = FLAGS.IMAGE_SIZE 26 | 27 | 28 | 29 | def _int64_feature(value): 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 31 | 32 | def _bytes_feature(value): 33 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 34 | 35 | 36 | # image object from tfrecord 37 | class _image_object: 38 | def __init__(self): 39 | self.image = tf.Variable([], dtype = tf.string, trainable=False) 40 | self.height = tf.Variable([], dtype = tf.int64, trainable=False) 41 | self.width = tf.Variable([], dtype = tf.int64, trainable=False) 42 | self.filename = tf.Variable([], dtype = tf.string, trainable=False) 43 | self.label = tf.Variable([], dtype = tf.int32, trainable=False) 44 | 45 | def read_and_decode(filename_queue): 46 | with tf.name_scope('data_provider'): 47 | reader = tf.TFRecordReader() 48 | _, serialized_example = reader.read(filename_queue) 49 | features = tf.parse_single_example(serialized_example, features = { 50 | "image/encoded": tf.FixedLenFeature([], tf.string), 51 | "image/height": tf.FixedLenFeature([], tf.int64), 52 | "image/width": tf.FixedLenFeature([], tf.int64), 53 | "image/filename": tf.FixedLenFeature([], tf.string), 54 | "image/class/label": tf.FixedLenFeature([], tf.int64),}) 55 | image_encoded = features["image/encoded"] 56 | image_raw = tf.image.decode_jpeg(image_encoded, channels=3) 57 | image_object = _image_object() 58 | # image_object.image = tf.image.resize_image_with_crop_or_pad(image_raw, IMAGE_SIZE, IMAGE_SIZE) 59 | image_object.image = tf.image.resize_images(image_raw, [IMAGE_SIZE, IMAGE_SIZE], method=0, align_corners=True) 60 | image_object.height = features["image/height"] 61 | image_object.width = features["image/width"] 62 | image_object.filename = features["image/filename"] 63 | image_object.label = tf.cast(features["image/class/label"], tf.int64) 64 | return image_object 65 | 66 | 67 | 68 | def feed_data(if_random = True, if_training = True): 69 | with tf.name_scope('image_reader_and_preprocessor') as scope: 70 | if(if_training): 71 | filenames = [os.path.join(DATA_DIR, "train.tfrecord")] 72 | else: 73 | filenames = [os.path.join(DATA_DIR, "test.tfrecord")] 74 | 75 | for f in filenames: 76 | if not tf.gfile.Exists(f): 77 | raise ValueError("Failed to find file: " + f) 78 | filename_queue = tf.train.string_input_producer(filenames) 79 | image_object = read_and_decode(filename_queue) 80 | 81 | if(if_training): 82 | image = tf.cast(tf.image.random_flip_left_right(image_object.image), tf.float32) 83 | # image = tf.image.adjust_gamma(tf.cast(image_object.image, tf.float32), gamma=1, gain=1) # Scale image to (0, 1) 84 | # image = tf.image.per_image_standardization(image) 85 | else: 86 | image = tf.cast(image_object.image, tf.float32) 87 | # image = tf.image.per_image_standardization(image_object.image) 88 | 89 | label = image_object.label 90 | filename = image_object.filename 91 | 92 | if(if_training): 93 | num_preprocess_threads = 2 94 | else: 95 | num_preprocess_threads = 1 96 | 97 | if(if_random): 98 | min_fraction_of_examples_in_queue = 0.4 99 | min_queue_examples = int(TRAINING_SET_SIZE * min_fraction_of_examples_in_queue) 100 | print("Filling queue with %d images before starting to train. " "This will take some time." % min_queue_examples) 101 | image_batch, label_batch, filename_batch = tf.train.shuffle_batch( 102 | [image, label, filename], 103 | batch_size = BATCH_SIZE, 104 | num_threads = num_preprocess_threads, 105 | capacity = min_queue_examples + 3 * BATCH_SIZE, 106 | min_after_dequeue = min_queue_examples) 107 | image_batch = tf.reshape(image_batch, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3)) 108 | label_offset = -tf.ones([BATCH_SIZE], dtype=tf.int64, name="label_batch_offset") 109 | label_batch = tf.one_hot(tf.add(label_batch, label_offset), depth=5, on_value=1.0, off_value=0.0) 110 | else: 111 | image_batch, label_batch, filename_batch = tf.train.batch( 112 | [image, label, filename], 113 | batch_size = BATCH_SIZE, 114 | num_threads = num_preprocess_threads) 115 | image_batch = tf.reshape(image_batch, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3)) 116 | label_offset = -tf.ones([BATCH_SIZE], dtype=tf.int64, name="label_batch_offset") 117 | label_batch = tf.one_hot(tf.add(label_batch, label_offset), depth=5, on_value=1.0, off_value=0.0) 118 | return image_batch, label_batch, filename_batch 119 | -------------------------------------------------------------------------------- /net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeephycho/densenet-tensorflow/7fd51c7479c33f29b708a241215f7f8ea50f05ce/net/__init__.py -------------------------------------------------------------------------------- /net/densenet.py: -------------------------------------------------------------------------------- 1 | # Brief: Build densnet graph 2 | # Data: 28/Aug./2017 3 | # E-mail: huyixuanhyx@gmail.com 4 | # License: Apache 2.0 5 | # By: Yeephycho @ Hong Kong 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import config as config 14 | 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | weights_decay = FLAGS.weights_decay 18 | 19 | 20 | def _weight_variable_truncated_norm(shape): 21 | return tf.Variable(tf.truncated_normal(shape, stddev=0.05)) 22 | 23 | def _weight_variable_constant(shape): 24 | return tf.Variable(tf.constant(0.02, shape=shape)) 25 | 26 | def _weight_variable_with_decay(shape): 27 | var = tf.Variable(tf.truncated_normal(shape, stddev=0.05)) 28 | weight_decay = tf.multiply(tf.nn.l2_loss(var), weights_decay, name='weight_loss') 29 | tf.add_to_collection('regularzation_loss', weight_decay) 30 | return var 31 | 32 | 33 | 34 | def batch_norm(input_tensor, if_training): 35 | """ 36 | Batch normalization on convolutional feature maps. 37 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow 38 | Args: 39 | input_tensor: Tensor, 4D NHWC input feature maps 40 | depth: Integer, depth of input feature maps 41 | if_training: Boolean tf.Varialbe, true indicates training phase 42 | scope: String, variable scope 43 | Return: 44 | normed_tensor: Batch-normalized feature maps 45 | """ 46 | with tf.variable_scope('batch_normalization'): 47 | depth = int(input_tensor.get_shape()[-1]) 48 | beta = tf.Variable(tf.constant(0.0, shape=[depth]), 49 | name='beta', trainable=True) 50 | gamma = tf.Variable(tf.constant(1.0, shape=[depth]), 51 | name='gamma', trainable=True) 52 | batch_mean, batch_var = tf.nn.moments(input_tensor, [0,1,2], name='moments') 53 | ema = tf.train.ExponentialMovingAverage(decay=0.99) 54 | 55 | def mean_var_with_update(): 56 | ema_apply_op = ema.apply([batch_mean, batch_var]) 57 | with tf.control_dependencies([ema_apply_op]): 58 | return tf.identity(batch_mean), tf.identity(batch_var) 59 | 60 | mean, var = tf.cond(if_training, 61 | mean_var_with_update, 62 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 63 | normed_tensor = tf.nn.batch_normalization(input_tensor, mean, var, beta, gamma, 1e-3) 64 | return normed_tensor 65 | 66 | 67 | 68 | def composite_function(__input_tensor, growth_rate, if_training): 69 | __input_tensor_depth = int(__input_tensor.get_shape()[-1]) 70 | __conv_weights = _weight_variable_with_decay([3, 3, __input_tensor_depth, growth_rate]) 71 | __output_tensor = batch_norm(__input_tensor, if_training) 72 | __output_tensor = tf.nn.relu(__output_tensor) 73 | __output_tensor = tf.nn.conv2d(input=__output_tensor, filter=__conv_weights, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC', name='composite_3x3_s1') 74 | 75 | return __output_tensor 76 | 77 | 78 | def transition_layer(__input_tensor, theta, if_training): 79 | __input_tensor_depth = int(__input_tensor.get_shape()[-1]) 80 | __conv_weights = _weight_variable_with_decay([1, 1, __input_tensor_depth, int(theta * __input_tensor_depth)]) 81 | __output_tensor = batch_norm(__input_tensor, if_training) 82 | __output_tensor = tf.nn.conv2d(input=__output_tensor, filter=__conv_weights, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC', name='transition_1x1_s1') 83 | __output_tensor = tf.nn.avg_pool(value=__output_tensor, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', data_format='NHWC', name='avg_pool_2x2') 84 | 85 | return __output_tensor 86 | 87 | 88 | def bottleneck_layer(_input_tensor, growth_rate, if_training): 89 | __input_tensor_depth = int(_input_tensor.get_shape()[-1]) 90 | __conv_weights = _weight_variable_with_decay([1, 1, __input_tensor_depth, 4 * growth_rate]) #NOTE: output_tensor should be 4k, 4 times of the growth_rate 91 | __output_tensor = batch_norm(_input_tensor, if_training) 92 | __output_tensor = tf.nn.relu(__output_tensor) 93 | __output_tensor = tf.nn.conv2d(input=__output_tensor, filter=__conv_weights, strides=[1, 1, 1, 1], padding='SAME', data_format='NHWC', name='bottleneck_1x1_s1') 94 | __output_tensor = composite_function(__output_tensor, growth_rate, if_training) #NOTE: output/input_tensor_depth is different from input 95 | 96 | return __output_tensor 97 | 98 | 99 | 100 | def dense_block(_input_tensor, growth_rate, if_training): 101 | _input_tensor_depth = int(_input_tensor.get_shape()[-1]) 102 | with tf.name_scope("bottleneck_0") as scope: 103 | _bottleneck_output_0 = bottleneck_layer(_input_tensor, growth_rate, if_training)#NOTE:64 is 2k, here k = 32, 128 is 4k, output is k = 32 104 | _bottleneck_input_0 = tf.concat(values=[_input_tensor, _bottleneck_output_0], axis=3, name='stack0')# 96 105 | 106 | with tf.name_scope("bottleneck_1") as scope: 107 | _bottlenect_output_1 = bottleneck_layer(_bottleneck_input_0, growth_rate, if_training)#NOTE:96 = 64 + 32 108 | _bottleneck_input_1 = tf.concat(values=[_bottleneck_input_0, _bottlenect_output_1], axis=3, name='stack1')# 128 109 | 110 | with tf.name_scope("bottleneck_2") as scope: 111 | _bottlenect_output_2 = bottleneck_layer(_bottleneck_input_1, growth_rate, if_training)#NOTE:128 = 96 + 32 112 | _bottleneck_input_2 = tf.concat(values=[_bottleneck_input_1, _bottlenect_output_2], axis=3, name='stack2')# 160 113 | 114 | with tf.name_scope("bottleneck_3") as scope: 115 | _bottlenect_output_3 = bottleneck_layer(_bottleneck_input_2, growth_rate, if_training)#NOTE:160 = 128 + 32 116 | _bottleneck_input_3 = tf.concat(values=[_bottleneck_input_2, _bottlenect_output_3], axis=3, name='stack3')# 192 117 | 118 | with tf.name_scope("bottleneck_4") as scope: 119 | _bottlenect_output_4 = bottleneck_layer(_bottleneck_input_3, growth_rate, if_training)#NOTE:192 = 160 + 32 120 | _bottleneck_input_4 = tf.concat(values=[_bottleneck_input_3, _bottlenect_output_4], axis=3, name='stack4')# 224 121 | 122 | with tf.name_scope("bottleneck_5") as scope: 123 | _bottlenect_output_5 = bottleneck_layer(_bottleneck_input_4, growth_rate, if_training)#NOTE:192 = 160 + 32 124 | output_tensor = tf.concat(values=[_bottleneck_input_4, _bottlenect_output_5], axis=3, name='stack5') 125 | 126 | return output_tensor 127 | 128 | 129 | 130 | def densenet_inference(image_batch, if_training, dropout_prob): 131 | with tf.name_scope('DenseNet-BC-121'): 132 | _image_batch = tf.reshape(image_batch, [-1, 224, 224, 3]) 133 | 134 | with tf.name_scope('conv2d_7x7_s2') as scope: 135 | _conv_weights = _weight_variable_with_decay([7, 7, 3, 64]) 136 | _output_tensor = tf.nn.conv2d(input=_image_batch, filter=_conv_weights, strides=[1, 2, 2, 1], padding='SAME', data_format='NHWC', name='conv2d_7x7_s2') 137 | _output_tensor = tf.nn.max_pool(value=_output_tensor, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', data_format='NHWC', name='max_pool_3x3_s2') 138 | _output_tensor = tf.nn.relu(_output_tensor) 139 | 140 | 141 | with tf.name_scope('dense_block_0') as scope: 142 | _output_tensor = dense_block(_output_tensor, 32, if_training) 143 | 144 | with tf.name_scope('transition_layer_0') as scope: 145 | _output_tensor = transition_layer(_output_tensor, 0.5, if_training) 146 | 147 | 148 | with tf.name_scope('dense_block_1') as scope: 149 | _output_tensor = dense_block(_output_tensor, 32, if_training) 150 | 151 | with tf.name_scope('transition_layer_1') as scope: 152 | _output_tensor = transition_layer(_output_tensor, 0.5, if_training) 153 | 154 | 155 | with tf.name_scope('dense_block_2') as scope: 156 | _output_tensor = dense_block(_output_tensor, 32, if_training) 157 | 158 | with tf.name_scope('transition_layer_2') as scope: 159 | _output_tensor = transition_layer(_output_tensor, 0.5, if_training) 160 | 161 | 162 | with tf.name_scope('dense_block_3') as scope: 163 | _output_tensor = dense_block(_output_tensor, 32, if_training) 164 | 165 | 166 | with tf.name_scope('avg_pool_7x7') as scope: 167 | _output_tensor = tf.nn.avg_pool(value=_output_tensor, ksize=[1, 7, 7, 1], strides=[1, 7, 7, 1], padding='SAME', data_format='NHWC', name='avg_pool_7x7') 168 | 169 | 170 | with tf.name_scope('fc'): 171 | W_fc0 = _weight_variable_with_decay([368, 128]) 172 | b_fc0 = _weight_variable_constant([128]) 173 | _output_tensor = tf.reshape(_output_tensor, [-1, 368]) 174 | _output_tensor = tf.nn.relu(tf.matmul(_output_tensor, W_fc0) + b_fc0) 175 | 176 | _output_tensor = tf.nn.dropout(_output_tensor, dropout_prob) 177 | 178 | W_fc1 = _weight_variable_with_decay([128, 5]) 179 | b_fc1 = _weight_variable_constant([5]) 180 | _output_tensor = tf.nn.relu(tf.matmul(_output_tensor, W_fc1) + b_fc1) 181 | 182 | return _output_tensor 183 | -------------------------------------------------------------------------------- /res/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeephycho/densenet-tensorflow/7fd51c7479c33f29b708a241215f7f8ea50f05ce/res/graph.png -------------------------------------------------------------------------------- /res/test_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeephycho/densenet-tensorflow/7fd51c7479c33f29b708a241215f7f8ea50f05ce/res/test_result.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Brief: Test the densenet for image classification 2 | # Data: 24/Aug./2017 3 | # E-mail: huyixuanhyx@gmail.com 4 | # License: Apache 2.0 5 | # By: Yeephycho @ Hong Kong 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import os 14 | 15 | import net.densenet as densenet 16 | import config as config 17 | import data_provider.data_provider as data_provider 18 | 19 | 20 | 21 | FLAGS = tf.app.flags.FLAGS 22 | TEST_SET_SIZE = FLAGS.TESTING_SET_SIZE 23 | BATCH_SIZE = FLAGS.BATCH_SIZE 24 | 25 | 26 | 27 | def densenet_test(): 28 | image_batch_placeholder = tf.placeholder(tf.float32, shape=[None, 224, 224, 3]) 29 | label_batch_placeholder = tf.placeholder(tf.int64, shape=[BATCH_SIZE]) 30 | if_training_placeholder = tf.placeholder(tf.bool, shape=[]) 31 | 32 | image_batch, label_batch, filename_batch = data_provider.feed_data(if_random = False, if_training = False) 33 | label_batch_dense = tf.arg_max(label_batch, dimension = 1) 34 | 35 | if_training = tf.Variable(False, name='if_training', trainable=False) 36 | 37 | logits = tf.reshape(densenet.densenet_inference(image_batch_placeholder, if_training_placeholder, 1.0), [BATCH_SIZE, 5]) 38 | logits_batch = tf.to_int64(tf.arg_max(logits, dimension = 1)) 39 | 40 | correct_prediction = tf.equal(logits_batch, label_batch_placeholder) 41 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 42 | 43 | checkpoint = tf.train.get_checkpoint_state("./models") 44 | saver = tf.train.Saver() 45 | 46 | config = tf.ConfigProto() 47 | config.gpu_options.allow_growth=True 48 | with tf.Session(config=config) as sess: 49 | sess.run(tf.global_variables_initializer()) 50 | tf.logging.info("Restoring full model from checkpoint file %s",checkpoint.model_checkpoint_path) 51 | saver.restore(sess, checkpoint.model_checkpoint_path) 52 | 53 | accuracy_accu = 0 54 | 55 | coord = tf.train.Coordinator() 56 | threads = tf.train.start_queue_runners(coord=coord, sess = sess) 57 | 58 | for i in range(int(TEST_SET_SIZE / BATCH_SIZE)): 59 | image_out, label_batch_dense_out, filename_out = sess.run([image_batch, label_batch_dense, filename_batch]) 60 | print("label: ", label_batch_dense_out) 61 | accuracy_out, infer_out = sess.run([accuracy, logits_batch], feed_dict={image_batch_placeholder: image_out, 62 | label_batch_placeholder: label_batch_dense_out, 63 | if_training_placeholder: if_training}) 64 | accuracy_out = np.asarray(accuracy_out) 65 | print("infer: ", infer_out) 66 | print(' ') 67 | accuracy_accu = accuracy_out + accuracy_accu 68 | 69 | print(accuracy_accu / TEST_SET_SIZE * BATCH_SIZE) 70 | 71 | tf.train.write_graph(sess.graph_def, 'graph/', 'my_graph.pb', as_text=False) 72 | 73 | coord.request_stop() 74 | coord.join(threads) 75 | sess.close() 76 | return 0 77 | 78 | 79 | 80 | def main(): 81 | tf.reset_default_graph() 82 | 83 | densenet_test() 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Brief: Train a densenet for image classification 2 | # Data: 24/Aug./2017 3 | # E-mail: huyixuanhyx@gmail.com 4 | # License: Apache 2.0 5 | # By: Yeephycho @ Hong Kong 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import os 14 | 15 | import net.densenet as densenet 16 | import config as config 17 | import data_provider.data_provider as data_provider 18 | 19 | 20 | 21 | FLAGS = tf.app.flags.FLAGS 22 | TRAINING_SET_SIZE = FLAGS.TRAINING_SET_SIZE 23 | BATCH_SIZE = FLAGS.BATCH_SIZE 24 | starter_learning_rate = FLAGS.INIT_LEARNING_RATE 25 | exp_decay_steps = FLAGS.DECAY_STEPS 26 | exp_decay_rate = FLAGS.DECAY_RATE 27 | 28 | 29 | 30 | def densenet_train(): 31 | image_batch_placeholder = tf.placeholder(tf.float32, shape=[None, 224, 224, 3]) 32 | label_batch_placeholder = tf.placeholder(tf.float32, shape=[None, 5]) 33 | if_training_placeholder = tf.placeholder(tf.bool, shape=[]) 34 | 35 | image_batch, label_batch, filename_batch = data_provider.feed_data(if_random = True, if_training = True) 36 | 37 | if_training = tf.Variable(True, name='if_training', trainable=False) 38 | 39 | logits = densenet.densenet_inference(image_batch_placeholder, if_training_placeholder, dropout_prob=0.7) 40 | 41 | loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=label_batch_placeholder, logits=logits)) 42 | #loss = tf.losses.mean_squared_error(labels=label_batch_placeholder, predictions=logits) 43 | tf.summary.scalar('loss', loss) # create a summary for training loss 44 | 45 | regularzation_loss = sum(tf.get_collection("regularzation_loss")) 46 | tf.summary.scalar('regularzation_loss', regularzation_loss) 47 | 48 | total_loss = regularzation_loss + loss 49 | tf.summary.scalar('total_loss', total_loss) 50 | 51 | global_step = tf.Variable(0, name='global_step', trainable=False) 52 | 53 | learning_rate = tf.train.exponential_decay(learning_rate=starter_learning_rate, 54 | global_step=global_step, 55 | decay_steps=exp_decay_steps, 56 | decay_rate=exp_decay_rate, 57 | staircase=True) 58 | tf.summary.scalar('learning_rate', learning_rate) 59 | 60 | train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss=total_loss, global_step=global_step) 61 | 62 | summary_op = tf.summary.merge_all() # merge all summaries into a single "operation" which we can execute in a session 63 | 64 | saver = tf.train.Saver() 65 | 66 | config = tf.ConfigProto() 67 | config.gpu_options.allow_growth=True 68 | sess = tf.Session(config=config) 69 | 70 | summary_writer = tf.summary.FileWriter("./log", sess.graph) 71 | 72 | sess.run(tf.global_variables_initializer()) 73 | 74 | checkpoint = tf.train.get_checkpoint_state("./models") 75 | if(checkpoint != None): 76 | tf.logging.info("Restoring full model from checkpoint file %s",checkpoint.model_checkpoint_path) 77 | saver.restore(sess, checkpoint.model_checkpoint_path) 78 | 79 | coord = tf.train.Coordinator() 80 | threads = tf.train.start_queue_runners(coord=coord, sess = sess) 81 | 82 | check_points = int(TRAINING_SET_SIZE/BATCH_SIZE) 83 | for epoch in range(250): 84 | for check_point in range(check_points): 85 | image_batch_train, label_batch_train, filename_train = sess.run([image_batch, label_batch, filename_batch]) 86 | 87 | _, training_loss, _global_step, summary = sess.run([train_step, loss, global_step, summary_op], 88 | feed_dict={image_batch_placeholder: image_batch_train, 89 | label_batch_placeholder: label_batch_train, 90 | if_training_placeholder: if_training}) 91 | 92 | if(bool(check_point%50 == 0) & bool(check_point != 0)): 93 | print(_) 94 | print("batch: ", check_point + epoch * check_points) 95 | print("training loss: ", training_loss) 96 | summary_writer.add_summary(summary, _global_step) 97 | 98 | saver.save(sess, "./models/densenet.ckpt", _global_step) 99 | 100 | coord.request_stop() 101 | coord.join(threads) 102 | sess.close() 103 | return 0 104 | 105 | 106 | 107 | def main(): 108 | tf.reset_default_graph() 109 | densenet_train() 110 | 111 | 112 | 113 | if __name__ == '__main__': 114 | main() 115 | 116 | 117 | 118 | # weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 119 | # print("") 120 | # for w in weights: 121 | # shp = w.get_shape().as_list() 122 | # print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp))) 123 | # print("") 124 | --------------------------------------------------------------------------------