├── test
├── WechatIMG383.jpg
├── WechatIMG389.jpg
└── WechatIMG390.jpg
├── label_image.py
├── README.md
├── LICENSE
└── retrain.py
/test/WechatIMG383.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anymake/tensorflow_flow_demo/HEAD/test/WechatIMG383.jpg
--------------------------------------------------------------------------------
/test/WechatIMG389.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anymake/tensorflow_flow_demo/HEAD/test/WechatIMG389.jpg
--------------------------------------------------------------------------------
/test/WechatIMG390.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anymake/tensorflow_flow_demo/HEAD/test/WechatIMG390.jpg
--------------------------------------------------------------------------------
/label_image.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | import tensorflow as tf
4 |
5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
6 |
7 | # change this as you see fit
8 | image_path = sys.argv[1]
9 |
10 | # Read in the image_data
11 | image_data = tf.gfile.FastGFile(image_path, 'rb').read()
12 |
13 | # Loads label file, strips off carriage return
14 | label_lines = [line.rstrip() for line
15 | in tf.gfile.GFile("retrained_labels.txt")]
16 |
17 | # Unpersists graph from file
18 | with tf.gfile.FastGFile("retrained_graph.pb", 'rb') as f:
19 | graph_def = tf.GraphDef()
20 | graph_def.ParseFromString(f.read())
21 | tf.import_graph_def(graph_def, name='')
22 |
23 | with tf.Session() as sess:
24 | # Feed the image_data as input to the graph and get first prediction
25 | softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
26 |
27 | predictions = sess.run(softmax_tensor, \
28 | {'DecodeJpeg/contents:0': image_data})
29 |
30 | # Sort to show labels of first prediction in order of confidence
31 | top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
32 |
33 | for node_id in top_k:
34 | human_string = label_lines[node_id]
35 | score = predictions[0][node_id]
36 | print('%s (score = %.5f)' % (human_string, score))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # 基于TensorFlow训练花朵识别模型的源码和Demo
3 | 下面就通过对现有的 Google Inception-V3 模型进行 retrain ,对 5 种花朵样本数据的进行训练,来完成一个可以识别五种花朵的模型,并将新训练的模型进行测试部属,让大家体验一下完整的流程。
4 |
5 | 
6 |
7 |
8 | ### 安装 TensorFlow (Mac 为例)
9 |
10 | 其他平台可以直接参考官网说明:[Installing TensorFlow](https://www.tensorflow.org/install/)
11 |
12 | #### 首先检查系统是否安装了 Python
13 |
14 | 要安装 `TensorFlow` ,你的系统必须依据安装了以下任一 `Python` 版本:
15 |
16 | * **Python 2.7**
17 | * **Python 3.3+**
18 |
19 | 如果做数据处理较多的话,建议安装Anaconda, **Anaconda** 是一种Python语言的免费增值开源发行版 ,用于进行大规模数据处理, 预测分析, 和科学计算, 致力于简化包的管理和部署。Anaconda使用软件包管理系统Conda进行包管理。安装完成后输入shell下输入`python`即可查看Anaconda对应的Python 版本,我使用的是Python 2.7.14:
20 | ```
21 | ➜ ~ python
22 | Python 2.7.14 |Anaconda, Inc.| (default, Dec 7 2017, 11:07:58)
23 | [GCC 4.2.1 Compatible Clang 4.0.1 (tags/RELEASE_401/final)] on darwin
24 | Type "help", "copyright", "credits" or "license" for more information.
25 |
26 | ```
27 | 如果你的系统还没有安装符合以上版本的 Python,现在安装。
28 |
29 |
30 |
31 | #### 通过 pip 安装 TensorFlow
32 |
33 | ```
34 | # Python 2
35 | ➜ pip install tensorflow
36 | # Python 3
37 | ➜ pip3 install tensorflow
38 |
39 | ```
40 |
41 | #### 通过官方样例测试 TensorFlow 是否正常安装
42 |
43 | 进入 Python 环境后输入以下代码,当出现 `“Hello, TensorFlow!”` 时表明已经安装成功,可正常使用 TensorFlow 了。
44 |
45 | ```
46 | ➜ python
47 | import tensorflow as tf
48 | hello = tf.constant('Hello, TensorFlow!')
49 | sess = tf.Session()
50 | print(sess.run(hello))
51 | Hello, TensorFlow!
52 |
53 | ```
54 |
55 | ### 准备训练样本
56 |
57 | 现在我们要训练花朵的识别模型,这是 Google 在TensorFlow里面提供的一个例子,其中包含了5类花朵的训练图片。可以新建个flower_demo文件夹,用于存放数据和训练的模型。
58 |
59 | **下载并解压得到训练样本**
60 |
61 | ```
62 | cd flower_demo
63 | # 下载和解压花朵训练数据
64 | curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
65 | tar xzf flower_photos.tgz
66 |
67 | ```
68 |
69 | 打开训练样本文件夹 flower_photos ,里面有 5 种类别的花:`daisy(雏菊), dandelion(蒲公英), roses(玫瑰), sunflowers(向日葵) , tulips(郁金香)`,总共3672张,每个类别的大概有 600-900 张训练样本图片,具体如下:
70 |
71 | ```
72 | cd flower_photos
73 | for dir in `find ./ -maxdepth 1 -type d`;do echo -n -e "$dir\t";find $dir -type f|wc -l ;done;
74 | ./ 3672
75 | .//roses 641
76 | .//sunflowers 699
77 | .//daisy 633
78 | .//dandelion 898
79 | .//tulips 799
80 |
81 | ```
82 | ### 开始训练
83 |
84 | **下载训练模型使用的 retrain 脚本**
85 | 该脚本会自动下载 google Inception v3 模型相关文件,`retrain.py` 是 Google 提供的以ImageNet图片分类模型为基础模型,利用flower_photos数据迁移训练花朵识别模型的脚本。
86 |
87 | ```
88 | cd flower_demo
89 | curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py
90 |
91 | ```
92 | **启动训练脚本,开始训练模型**
93 |
94 | 在运行 `retrain.py` 脚本时,需要配置一些运行命令参数,比如指定模型输入输出相关名称和其他训练要求的配置。其中`--how_many_training_steps=4000`配置代表训练迭代次数,默认值为4000,如果机器较差,可以适当减少这个值。
95 |
96 | ```
97 | ➜ cd flower_demo
98 | ➜ python3 retrain.py \
99 | --bottleneck_dir=bottlenecks \
100 | --how_many_training_steps=4000 \
101 | --model_dir=inception \
102 | --summaries_dir=training_summaries/basic \
103 | --output_graph=retrained_graph.pb \
104 | --output_labels=retrained_labels.txt \
105 | --image_dir=flower_photos
106 |
107 | ```
108 | 这里我们训练4000steps,时间不是很久,我在配备4.2 GHz Intel Core i7处理器的iMac上,不适用GPU大概就5分钟就能训练完成。模型训练完成后,可以看到测试集上`Final test accuracy = 92.1%`,也就是说我们训练的5类花朵识别模型,在测试集上已经有92%的识别准确率了。其中生成的 `retrained_labels.txt` 和 `retrained_graph.pb` 这两个是模型相关文件。
109 | ```
110 | 2018-06-02 15:47:00.266119: Step 3950: Train accuracy = 94.0%
111 | 2018-06-02 15:47:00.266159: Step 3950: Cross entropy = 0.135385
112 | 2018-06-02 15:47:00.327843: Step 3950: Validation accuracy = 93.0% (N=100)
113 | 2018-06-02 15:47:00.976543: Step 3960: Train accuracy = 94.0%
114 | 2018-06-02 15:47:00.976591: Step 3960: Cross entropy = 0.234760
115 | 2018-06-02 15:47:01.038559: Step 3960: Validation accuracy = 91.0% (N=100)
116 | 2018-06-02 15:47:01.667255: Step 3970: Train accuracy = 97.0%
117 | 2018-06-02 15:47:01.667372: Step 3970: Cross entropy = 0.167394
118 | 2018-06-02 15:47:01.731935: Step 3970: Validation accuracy = 87.0% (N=100)
119 | 2018-06-02 15:47:02.355780: Step 3980: Train accuracy = 96.0%
120 | 2018-06-02 15:47:02.355818: Step 3980: Cross entropy = 0.151201
121 | 2018-06-02 15:47:02.418314: Step 3980: Validation accuracy = 91.0% (N=100)
122 | 2018-06-02 15:47:03.042364: Step 3990: Train accuracy = 99.0%
123 | 2018-06-02 15:47:03.042402: Step 3990: Cross entropy = 0.094383
124 | 2018-06-02 15:47:03.103718: Step 3990: Validation accuracy = 91.0% (N=100)
125 | 2018-06-02 15:47:03.667861: Step 3999: Train accuracy = 99.0%
126 | 2018-06-02 15:47:03.667899: Step 3999: Cross entropy = 0.106797
127 | 2018-06-02 15:47:03.729215: Step 3999: Validation accuracy = 94.0% (N=100)
128 | Final test accuracy = 92.1% (N=353)
129 | ```
130 | ### 测试训练完成后的模型
131 |
132 | 同样的,我们先下载测试模型的脚本 `label_image.py`,然后从flower_photos/daisy/文件夹下选择图片488202750_c420cbce61.jpg,测试我们训练后的模型的识别准确率,当然你也可以百度搜索一张5类花朵的任意一张图测试识别效果,从下图可以看出,我们训练的算法模型认为这张图属于`daisy`的概率高达98.9%.
133 |
134 | ```
135 | ➜ cd flower_demo
136 | ➜ curl -L https://goo.gl/3lTKZs > label_image.py
137 | ➜ python label_image.py flower_photos/daisy/488202750_c420cbce61.jpg
138 |
139 | daisy (score = 0.98921)
140 | sunflowers (score = 0.00948)
141 | dandelion (score = 0.00088)
142 | tulips (score = 0.00038)
143 | roses (score = 0.00005)
144 | ```
145 | 
146 | 有人说`label_image.py`无法下载,代码如下:
147 | ```
148 | import os, sys
149 | import tensorflow as tf
150 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
151 |
152 | # change this as you see fit
153 | image_path = sys.argv[1]
154 |
155 | # Read in the image_data
156 | image_data = tf.gfile.FastGFile(image_path, 'rb').read()
157 |
158 | # Loads label file, strips off carriage return
159 | label_lines = [line.rstrip() for line in tf.gfile.GFile("retrained_labels.txt")]
160 |
161 | # Unpersists graph from file
162 | with tf.gfile.FastGFile("retrained_graph.pb", 'rb') as f:
163 | graph_def = tf.GraphDef()
164 | graph_def.ParseFromString(f.read())
165 | tf.import_graph_def(graph_def, name='')
166 |
167 | with tf.Session() as sess:
168 | # Feed the image_data as input to the graph and get first prediction
169 | softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
170 |
171 | predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
172 |
173 | # Sort to show labels of first prediction in order of confidence
174 | top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
175 |
176 | for node_id in top_k:
177 | human_string = label_lines[node_id]
178 | score = predictions[0][node_id]
179 | print('%s (score = %.5f)' % (human_string, score))
180 | ```
181 | 我们随便从百度搜索一张蒲公英(dandelion)的图,保存到`test/WechatIMG383.jpg`,测试结果显示属于蒲公英的概率为99.59%.
182 |
183 | ```
184 | python label_image.py test/WechatIMG383.jpg
185 |
186 | dandelion (score = 0.99592)
187 | sunflowers (score = 0.00359)
188 | daisy (score = 0.00042)
189 | tulips (score = 0.00005)
190 | roses (score = 0.00001)
191 | ```
192 | 以上基本是模型训练和测试的全部过程,希望能让大家对深度学习的完整项目有个大致的了解。
193 |
194 | **启动 TensorBoard**
195 | TensorBoard 是 TensorFlow 自带的训练效果可视化的分析工具,我们可以利用此工具检测和分析模型的收敛情况,比如查看loss的下降、acc的提升和查看可视化的网络结构图等。在我们建的工程目录下,启动tensorboard的具体命令如下:
196 |
197 | ```
198 | ➜ cd flower_demo
199 | ➜ tensorboard --logdir training_summaries
200 |
201 | ```
202 |
203 | 启动 TensorBoard 会占用系统 `6006` 端口 ,再启动一个新的 TensorBoard 之前,必须要 kill 已在运行的 TensorBoard 任务。
204 |
205 | ```
206 | ➜ pkill -f "tensorboard
207 |
208 | ```
209 | **启动浏览器查看 TensorBoard**
210 |
211 | 启动TensorBoard后,可以启动浏览器,在地址栏中输入 `localhost:6006` 来查看训练进度以及loss和准确度的变化,分析模型等。
212 |
213 | 
214 |
215 |
216 | 
217 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/retrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Simple transfer learning with an Inception v3 architecture model which
16 | displays summaries in TensorBoard.
17 |
18 | This example shows how to take a Inception v3 architecture model trained on
19 | ImageNet images, and train a new top layer that can recognize other classes of
20 | images.
21 |
22 | The top layer receives as input a 2048-dimensional vector for each image. We
23 | train a softmax layer on top of this representation. Assuming the softmax layer
24 | contains N labels, this corresponds to learning N + 2048*N model parameters
25 | corresponding to the learned biases and weights.
26 |
27 | Here's an example, which assumes you have a folder containing class-named
28 | subfolders, each full of images for each label. The example folder flower_photos
29 | should have a structure like this:
30 |
31 | ~/flower_photos/daisy/photo1.jpg
32 | ~/flower_photos/daisy/photo2.jpg
33 | ...
34 | ~/flower_photos/rose/anotherphoto77.jpg
35 | ...
36 | ~/flower_photos/sunflower/somepicture.jpg
37 |
38 | The subfolder names are important, since they define what label is applied to
39 | each image, but the filenames themselves don't matter. Once your images are
40 | prepared, you can run the training with a command like this:
41 |
42 | bazel build tensorflow/examples/image_retraining:retrain && \
43 | bazel-bin/tensorflow/examples/image_retraining/retrain \
44 | --image_dir ~/flower_photos
45 |
46 | You can replace the image_dir argument with any folder containing subfolders of
47 | images. The label for each image is taken from the name of the subfolder it's
48 | in.
49 |
50 | This produces a new model file that can be loaded and run by any TensorFlow
51 | program, for example the label_image sample code.
52 |
53 |
54 | To use with TensorBoard:
55 |
56 | By default, this script will log summaries to /tmp/retrain_logs directory
57 |
58 | Visualize the summaries with this command:
59 |
60 | tensorboard --logdir /tmp/retrain_logs
61 |
62 | """
63 | from __future__ import absolute_import
64 | from __future__ import division
65 | from __future__ import print_function
66 |
67 | import argparse
68 | from datetime import datetime
69 | import hashlib
70 | import os.path
71 | import random
72 | import re
73 | import struct
74 | import sys
75 | import tarfile
76 |
77 | import numpy as np
78 | from six.moves import urllib
79 | import tensorflow as tf
80 |
81 | from tensorflow.python.framework import graph_util
82 | from tensorflow.python.framework import tensor_shape
83 | from tensorflow.python.platform import gfile
84 | from tensorflow.python.util import compat
85 |
86 | FLAGS = None
87 |
88 | # These are all parameters that are tied to the particular model architecture
89 | # we're using for Inception v3. These include things like tensor names and their
90 | # sizes. If you want to adapt this script to work with another model, you will
91 | # need to update these to reflect the values in the network you're using.
92 | # pylint: disable=line-too-long
93 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
94 | # pylint: enable=line-too-long
95 | BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
96 | BOTTLENECK_TENSOR_SIZE = 2048
97 | MODEL_INPUT_WIDTH = 299
98 | MODEL_INPUT_HEIGHT = 299
99 | MODEL_INPUT_DEPTH = 3
100 | JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
101 | RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0'
102 | MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
103 |
104 |
105 | def create_image_lists(image_dir, testing_percentage, validation_percentage):
106 | """Builds a list of training images from the file system.
107 |
108 | Analyzes the sub folders in the image directory, splits them into stable
109 | training, testing, and validation sets, and returns a data structure
110 | describing the lists of images for each label and their paths.
111 |
112 | Args:
113 | image_dir: String path to a folder containing subfolders of images.
114 | testing_percentage: Integer percentage of the images to reserve for tests.
115 | validation_percentage: Integer percentage of images reserved for validation.
116 |
117 | Returns:
118 | A dictionary containing an entry for each label subfolder, with images split
119 | into training, testing, and validation sets within each label.
120 | """
121 | if not gfile.Exists(image_dir):
122 | print("Image directory '" + image_dir + "' not found.")
123 | return None
124 | result = {}
125 | sub_dirs = [x[0] for x in gfile.Walk(image_dir)]
126 | # The root directory comes first, so skip it.
127 | is_root_dir = True
128 | for sub_dir in sub_dirs:
129 | if is_root_dir:
130 | is_root_dir = False
131 | continue
132 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
133 | file_list = []
134 | dir_name = os.path.basename(sub_dir)
135 | if dir_name == image_dir:
136 | continue
137 | print("Looking for images in '" + dir_name + "'")
138 | for extension in extensions:
139 | file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
140 | file_list.extend(gfile.Glob(file_glob))
141 | if not file_list:
142 | print('No files found')
143 | continue
144 | if len(file_list) < 20:
145 | print('WARNING: Folder has less than 20 images, which may cause issues.')
146 | elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
147 | print('WARNING: Folder {} has more than {} images. Some images will '
148 | 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
149 | label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
150 | training_images = []
151 | testing_images = []
152 | validation_images = []
153 | for file_name in file_list:
154 | base_name = os.path.basename(file_name)
155 | # We want to ignore anything after '_nohash_' in the file name when
156 | # deciding which set to put an image in, the data set creator has a way of
157 | # grouping photos that are close variations of each other. For example
158 | # this is used in the plant disease data set to group multiple pictures of
159 | # the same leaf.
160 | hash_name = re.sub(r'_nohash_.*$', '', file_name)
161 | # This looks a bit magical, but we need to decide whether this file should
162 | # go into the training, testing, or validation sets, and we want to keep
163 | # existing files in the same set even if more files are subsequently
164 | # added.
165 | # To do that, we need a stable way of deciding based on just the file name
166 | # itself, so we do a hash of that and then use that to generate a
167 | # probability value that we use to assign it.
168 | hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
169 | percentage_hash = ((int(hash_name_hashed, 16) %
170 | (MAX_NUM_IMAGES_PER_CLASS + 1)) *
171 | (100.0 / MAX_NUM_IMAGES_PER_CLASS))
172 | if percentage_hash < validation_percentage:
173 | validation_images.append(base_name)
174 | elif percentage_hash < (testing_percentage + validation_percentage):
175 | testing_images.append(base_name)
176 | else:
177 | training_images.append(base_name)
178 | result[label_name] = {
179 | 'dir': dir_name,
180 | 'training': training_images,
181 | 'testing': testing_images,
182 | 'validation': validation_images,
183 | }
184 | return result
185 |
186 |
187 | def get_image_path(image_lists, label_name, index, image_dir, category):
188 | """"Returns a path to an image for a label at the given index.
189 |
190 | Args:
191 | image_lists: Dictionary of training images for each label.
192 | label_name: Label string we want to get an image for.
193 | index: Int offset of the image we want. This will be moduloed by the
194 | available number of images for the label, so it can be arbitrarily large.
195 | image_dir: Root folder string of the subfolders containing the training
196 | images.
197 | category: Name string of set to pull images from - training, testing, or
198 | validation.
199 |
200 | Returns:
201 | File system path string to an image that meets the requested parameters.
202 |
203 | """
204 | if label_name not in image_lists:
205 | tf.logging.fatal('Label does not exist %s.', label_name)
206 | label_lists = image_lists[label_name]
207 | if category not in label_lists:
208 | tf.logging.fatal('Category does not exist %s.', category)
209 | category_list = label_lists[category]
210 | if not category_list:
211 | tf.logging.fatal('Label %s has no images in the category %s.',
212 | label_name, category)
213 | mod_index = index % len(category_list)
214 | base_name = category_list[mod_index]
215 | sub_dir = label_lists['dir']
216 | full_path = os.path.join(image_dir, sub_dir, base_name)
217 | return full_path
218 |
219 |
220 | def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
221 | category):
222 | """"Returns a path to a bottleneck file for a label at the given index.
223 |
224 | Args:
225 | image_lists: Dictionary of training images for each label.
226 | label_name: Label string we want to get an image for.
227 | index: Integer offset of the image we want. This will be moduloed by the
228 | available number of images for the label, so it can be arbitrarily large.
229 | bottleneck_dir: Folder string holding cached files of bottleneck values.
230 | category: Name string of set to pull images from - training, testing, or
231 | validation.
232 |
233 | Returns:
234 | File system path string to an image that meets the requested parameters.
235 | """
236 | return get_image_path(image_lists, label_name, index, bottleneck_dir,
237 | category) + '.txt'
238 |
239 |
240 | def create_inception_graph():
241 | """"Creates a graph from saved GraphDef file and returns a Graph object.
242 |
243 | Returns:
244 | Graph holding the trained Inception network, and various tensors we'll be
245 | manipulating.
246 | """
247 | with tf.Session() as sess:
248 | model_filename = os.path.join(
249 | FLAGS.model_dir, 'classify_image_graph_def.pb')
250 | with gfile.FastGFile(model_filename, 'rb') as f:
251 | graph_def = tf.GraphDef()
252 | graph_def.ParseFromString(f.read())
253 | bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
254 | tf.import_graph_def(graph_def, name='', return_elements=[
255 | BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
256 | RESIZED_INPUT_TENSOR_NAME]))
257 | return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
258 |
259 |
260 | def run_bottleneck_on_image(sess, image_data, image_data_tensor,
261 | bottleneck_tensor):
262 | """Runs inference on an image to extract the 'bottleneck' summary layer.
263 |
264 | Args:
265 | sess: Current active TensorFlow Session.
266 | image_data: String of raw JPEG data.
267 | image_data_tensor: Input data layer in the graph.
268 | bottleneck_tensor: Layer before the final softmax.
269 |
270 | Returns:
271 | Numpy array of bottleneck values.
272 | """
273 | bottleneck_values = sess.run(
274 | bottleneck_tensor,
275 | {image_data_tensor: image_data})
276 | bottleneck_values = np.squeeze(bottleneck_values)
277 | return bottleneck_values
278 |
279 |
280 | def maybe_download_and_extract():
281 | """Download and extract model tar file.
282 |
283 | If the pretrained model we're using doesn't already exist, this function
284 | downloads it from the TensorFlow.org website and unpacks it into a directory.
285 | """
286 | dest_directory = FLAGS.model_dir
287 | if not os.path.exists(dest_directory):
288 | os.makedirs(dest_directory)
289 | filename = DATA_URL.split('/')[-1]
290 | filepath = os.path.join(dest_directory, filename)
291 | if not os.path.exists(filepath):
292 |
293 | def _progress(count, block_size, total_size):
294 | sys.stdout.write('\r>> Downloading %s %.1f%%' %
295 | (filename,
296 | float(count * block_size) / float(total_size) * 100.0))
297 | sys.stdout.flush()
298 |
299 | filepath, _ = urllib.request.urlretrieve(DATA_URL,
300 | filepath,
301 | _progress)
302 | print()
303 | statinfo = os.stat(filepath)
304 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
305 | tarfile.open(filepath, 'r:gz').extractall(dest_directory)
306 |
307 |
308 | def ensure_dir_exists(dir_name):
309 | """Makes sure the folder exists on disk.
310 |
311 | Args:
312 | dir_name: Path string to the folder we want to create.
313 | """
314 | if not os.path.exists(dir_name):
315 | os.makedirs(dir_name)
316 |
317 |
318 | def write_list_of_floats_to_file(list_of_floats , file_path):
319 | """Writes a given list of floats to a binary file.
320 |
321 | Args:
322 | list_of_floats: List of floats we want to write to a file.
323 | file_path: Path to a file where list of floats will be stored.
324 |
325 | """
326 |
327 | s = struct.pack('d' * BOTTLENECK_TENSOR_SIZE, *list_of_floats)
328 | with open(file_path, 'wb') as f:
329 | f.write(s)
330 |
331 |
332 | def read_list_of_floats_from_file(file_path):
333 | """Reads list of floats from a given file.
334 |
335 | Args:
336 | file_path: Path to a file where list of floats was stored.
337 | Returns:
338 | Array of bottleneck values (list of floats).
339 |
340 | """
341 |
342 | with open(file_path, 'rb') as f:
343 | s = struct.unpack('d' * BOTTLENECK_TENSOR_SIZE, f.read())
344 | return list(s)
345 |
346 |
347 | bottleneck_path_2_bottleneck_values = {}
348 |
349 | def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
350 | image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor):
351 | print('Creating bottleneck at ' + bottleneck_path)
352 | image_path = get_image_path(image_lists, label_name, index, image_dir, category)
353 | if not gfile.Exists(image_path):
354 | tf.logging.fatal('File does not exist %s', image_path)
355 | image_data = gfile.FastGFile(image_path, 'rb').read()
356 | bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
357 | bottleneck_string = ','.join(str(x) for x in bottleneck_values)
358 | with open(bottleneck_path, 'w') as bottleneck_file:
359 | bottleneck_file.write(bottleneck_string)
360 |
361 | def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
362 | category, bottleneck_dir, jpeg_data_tensor,
363 | bottleneck_tensor):
364 | """Retrieves or calculates bottleneck values for an image.
365 |
366 | If a cached version of the bottleneck data exists on-disk, return that,
367 | otherwise calculate the data and save it to disk for future use.
368 |
369 | Args:
370 | sess: The current active TensorFlow Session.
371 | image_lists: Dictionary of training images for each label.
372 | label_name: Label string we want to get an image for.
373 | index: Integer offset of the image we want. This will be modulo-ed by the
374 | available number of images for the label, so it can be arbitrarily large.
375 | image_dir: Root folder string of the subfolders containing the training
376 | images.
377 | category: Name string of which set to pull images from - training, testing,
378 | or validation.
379 | bottleneck_dir: Folder string holding cached files of bottleneck values.
380 | jpeg_data_tensor: The tensor to feed loaded jpeg data into.
381 | bottleneck_tensor: The output tensor for the bottleneck values.
382 |
383 | Returns:
384 | Numpy array of values produced by the bottleneck layer for the image.
385 | """
386 | label_lists = image_lists[label_name]
387 | sub_dir = label_lists['dir']
388 | sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
389 | ensure_dir_exists(sub_dir_path)
390 | bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category)
391 | if not os.path.exists(bottleneck_path):
392 | create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
393 | with open(bottleneck_path, 'r') as bottleneck_file:
394 | bottleneck_string = bottleneck_file.read()
395 | did_hit_error = False
396 | try:
397 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
398 | except:
399 | print("Invalid float found, recreating bottleneck")
400 | did_hit_error = True
401 | if did_hit_error:
402 | create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
403 | with open(bottleneck_path, 'r') as bottleneck_file:
404 | bottleneck_string = bottleneck_file.read()
405 | # Allow exceptions to propagate here, since they shouldn't happen after a fresh creation
406 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
407 | return bottleneck_values
408 |
409 | def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
410 | jpeg_data_tensor, bottleneck_tensor):
411 | """Ensures all the training, testing, and validation bottlenecks are cached.
412 |
413 | Because we're likely to read the same image multiple times (if there are no
414 | distortions applied during training) it can speed things up a lot if we
415 | calculate the bottleneck layer values once for each image during
416 | preprocessing, and then just read those cached values repeatedly during
417 | training. Here we go through all the images we've found, calculate those
418 | values, and save them off.
419 |
420 | Args:
421 | sess: The current active TensorFlow Session.
422 | image_lists: Dictionary of training images for each label.
423 | image_dir: Root folder string of the subfolders containing the training
424 | images.
425 | bottleneck_dir: Folder string holding cached files of bottleneck values.
426 | jpeg_data_tensor: Input tensor for jpeg data from file.
427 | bottleneck_tensor: The penultimate output layer of the graph.
428 |
429 | Returns:
430 | Nothing.
431 | """
432 | how_many_bottlenecks = 0
433 | ensure_dir_exists(bottleneck_dir)
434 | for label_name, label_lists in image_lists.items():
435 | for category in ['training', 'testing', 'validation']:
436 | category_list = label_lists[category]
437 | for index, unused_base_name in enumerate(category_list):
438 | get_or_create_bottleneck(sess, image_lists, label_name, index,
439 | image_dir, category, bottleneck_dir,
440 | jpeg_data_tensor, bottleneck_tensor)
441 |
442 | how_many_bottlenecks += 1
443 | if how_many_bottlenecks % 100 == 0:
444 | print(str(how_many_bottlenecks) + ' bottleneck files created.')
445 |
446 |
447 | def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
448 | bottleneck_dir, image_dir, jpeg_data_tensor,
449 | bottleneck_tensor):
450 | """Retrieves bottleneck values for cached images.
451 |
452 | If no distortions are being applied, this function can retrieve the cached
453 | bottleneck values directly from disk for images. It picks a random set of
454 | images from the specified category.
455 |
456 | Args:
457 | sess: Current TensorFlow Session.
458 | image_lists: Dictionary of training images for each label.
459 | how_many: If positive, a random sample of this size will be chosen.
460 | If negative, all bottlenecks will be retrieved.
461 | category: Name string of which set to pull from - training, testing, or
462 | validation.
463 | bottleneck_dir: Folder string holding cached files of bottleneck values.
464 | image_dir: Root folder string of the subfolders containing the training
465 | images.
466 | jpeg_data_tensor: The layer to feed jpeg image data into.
467 | bottleneck_tensor: The bottleneck output layer of the CNN graph.
468 |
469 | Returns:
470 | List of bottleneck arrays, their corresponding ground truths, and the
471 | relevant filenames.
472 | """
473 | class_count = len(image_lists.keys())
474 | bottlenecks = []
475 | ground_truths = []
476 | filenames = []
477 | if how_many >= 0:
478 | # Retrieve a random sample of bottlenecks.
479 | for unused_i in range(how_many):
480 | label_index = random.randrange(class_count)
481 | label_name = list(image_lists.keys())[label_index]
482 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
483 | image_name = get_image_path(image_lists, label_name, image_index,
484 | image_dir, category)
485 | bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
486 | image_index, image_dir, category,
487 | bottleneck_dir, jpeg_data_tensor,
488 | bottleneck_tensor)
489 | ground_truth = np.zeros(class_count, dtype=np.float32)
490 | ground_truth[label_index] = 1.0
491 | bottlenecks.append(bottleneck)
492 | ground_truths.append(ground_truth)
493 | filenames.append(image_name)
494 | else:
495 | # Retrieve all bottlenecks.
496 | for label_index, label_name in enumerate(image_lists.keys()):
497 | for image_index, image_name in enumerate(
498 | image_lists[label_name][category]):
499 | image_name = get_image_path(image_lists, label_name, image_index,
500 | image_dir, category)
501 | bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
502 | image_index, image_dir, category,
503 | bottleneck_dir, jpeg_data_tensor,
504 | bottleneck_tensor)
505 | ground_truth = np.zeros(class_count, dtype=np.float32)
506 | ground_truth[label_index] = 1.0
507 | bottlenecks.append(bottleneck)
508 | ground_truths.append(ground_truth)
509 | filenames.append(image_name)
510 | return bottlenecks, ground_truths, filenames
511 |
512 |
513 | def get_random_distorted_bottlenecks(
514 | sess, image_lists, how_many, category, image_dir, input_jpeg_tensor,
515 | distorted_image, resized_input_tensor, bottleneck_tensor):
516 | """Retrieves bottleneck values for training images, after distortions.
517 |
518 | If we're training with distortions like crops, scales, or flips, we have to
519 | recalculate the full model for every image, and so we can't use cached
520 | bottleneck values. Instead we find random images for the requested category,
521 | run them through the distortion graph, and then the full graph to get the
522 | bottleneck results for each.
523 |
524 | Args:
525 | sess: Current TensorFlow Session.
526 | image_lists: Dictionary of training images for each label.
527 | how_many: The integer number of bottleneck values to return.
528 | category: Name string of which set of images to fetch - training, testing,
529 | or validation.
530 | image_dir: Root folder string of the subfolders containing the training
531 | images.
532 | input_jpeg_tensor: The input layer we feed the image data to.
533 | distorted_image: The output node of the distortion graph.
534 | resized_input_tensor: The input node of the recognition graph.
535 | bottleneck_tensor: The bottleneck output layer of the CNN graph.
536 |
537 | Returns:
538 | List of bottleneck arrays and their corresponding ground truths.
539 | """
540 | class_count = len(image_lists.keys())
541 | bottlenecks = []
542 | ground_truths = []
543 | for unused_i in range(how_many):
544 | label_index = random.randrange(class_count)
545 | label_name = list(image_lists.keys())[label_index]
546 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
547 | image_path = get_image_path(image_lists, label_name, image_index, image_dir,
548 | category)
549 | if not gfile.Exists(image_path):
550 | tf.logging.fatal('File does not exist %s', image_path)
551 | jpeg_data = gfile.FastGFile(image_path, 'rb').read()
552 | # Note that we materialize the distorted_image_data as a numpy array before
553 | # sending running inference on the image. This involves 2 memory copies and
554 | # might be optimized in other implementations.
555 | distorted_image_data = sess.run(distorted_image,
556 | {input_jpeg_tensor: jpeg_data})
557 | bottleneck = run_bottleneck_on_image(sess, distorted_image_data,
558 | resized_input_tensor,
559 | bottleneck_tensor)
560 | ground_truth = np.zeros(class_count, dtype=np.float32)
561 | ground_truth[label_index] = 1.0
562 | bottlenecks.append(bottleneck)
563 | ground_truths.append(ground_truth)
564 | return bottlenecks, ground_truths
565 |
566 |
567 | def should_distort_images(flip_left_right, random_crop, random_scale,
568 | random_brightness):
569 | """Whether any distortions are enabled, from the input flags.
570 |
571 | Args:
572 | flip_left_right: Boolean whether to randomly mirror images horizontally.
573 | random_crop: Integer percentage setting the total margin used around the
574 | crop box.
575 | random_scale: Integer percentage of how much to vary the scale by.
576 | random_brightness: Integer range to randomly multiply the pixel values by.
577 |
578 | Returns:
579 | Boolean value indicating whether any distortions should be applied.
580 | """
581 | return (flip_left_right or (random_crop != 0) or (random_scale != 0) or
582 | (random_brightness != 0))
583 |
584 |
585 | def add_input_distortions(flip_left_right, random_crop, random_scale,
586 | random_brightness):
587 | """Creates the operations to apply the specified distortions.
588 |
589 | During training it can help to improve the results if we run the images
590 | through simple distortions like crops, scales, and flips. These reflect the
591 | kind of variations we expect in the real world, and so can help train the
592 | model to cope with natural data more effectively. Here we take the supplied
593 | parameters and construct a network of operations to apply them to an image.
594 |
595 | Cropping
596 | ~~~~~~~~
597 |
598 | Cropping is done by placing a bounding box at a random position in the full
599 | image. The cropping parameter controls the size of that box relative to the
600 | input image. If it's zero, then the box is the same size as the input and no
601 | cropping is performed. If the value is 50%, then the crop box will be half the
602 | width and height of the input. In a diagram it looks like this:
603 |
604 | < width >
605 | +---------------------+
606 | | |
607 | | width - crop% |
608 | | < > |
609 | | +------+ |
610 | | | | |
611 | | | | |
612 | | | | |
613 | | +------+ |
614 | | |
615 | | |
616 | +---------------------+
617 |
618 | Scaling
619 | ~~~~~~~
620 |
621 | Scaling is a lot like cropping, except that the bounding box is always
622 | centered and its size varies randomly within the given range. For example if
623 | the scale percentage is zero, then the bounding box is the same size as the
624 | input and no scaling is applied. If it's 50%, then the bounding box will be in
625 | a random range between half the width and height and full size.
626 |
627 | Args:
628 | flip_left_right: Boolean whether to randomly mirror images horizontally.
629 | random_crop: Integer percentage setting the total margin used around the
630 | crop box.
631 | random_scale: Integer percentage of how much to vary the scale by.
632 | random_brightness: Integer range to randomly multiply the pixel values by.
633 | graph.
634 |
635 | Returns:
636 | The jpeg input layer and the distorted result tensor.
637 | """
638 |
639 | jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
640 | decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH)
641 | decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
642 | decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
643 | margin_scale = 1.0 + (random_crop / 100.0)
644 | resize_scale = 1.0 + (random_scale / 100.0)
645 | margin_scale_value = tf.constant(margin_scale)
646 | resize_scale_value = tf.random_uniform(tensor_shape.scalar(),
647 | minval=1.0,
648 | maxval=resize_scale)
649 | scale_value = tf.multiply(margin_scale_value, resize_scale_value)
650 | precrop_width = tf.multiply(scale_value, MODEL_INPUT_WIDTH)
651 | precrop_height = tf.multiply(scale_value, MODEL_INPUT_HEIGHT)
652 | precrop_shape = tf.stack([precrop_height, precrop_width])
653 | precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
654 | precropped_image = tf.image.resize_bilinear(decoded_image_4d,
655 | precrop_shape_as_int)
656 | precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
657 | cropped_image = tf.random_crop(precropped_image_3d,
658 | [MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH,
659 | MODEL_INPUT_DEPTH])
660 | if flip_left_right:
661 | flipped_image = tf.image.random_flip_left_right(cropped_image)
662 | else:
663 | flipped_image = cropped_image
664 | brightness_min = 1.0 - (random_brightness / 100.0)
665 | brightness_max = 1.0 + (random_brightness / 100.0)
666 | brightness_value = tf.random_uniform(tensor_shape.scalar(),
667 | minval=brightness_min,
668 | maxval=brightness_max)
669 | brightened_image = tf.multiply(flipped_image, brightness_value)
670 | distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult')
671 | return jpeg_data, distort_result
672 |
673 |
674 | def variable_summaries(var):
675 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
676 | with tf.name_scope('summaries'):
677 | mean = tf.reduce_mean(var)
678 | tf.summary.scalar('mean', mean)
679 | with tf.name_scope('stddev'):
680 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
681 | tf.summary.scalar('stddev', stddev)
682 | tf.summary.scalar('max', tf.reduce_max(var))
683 | tf.summary.scalar('min', tf.reduce_min(var))
684 | tf.summary.histogram('histogram', var)
685 |
686 |
687 | def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
688 | """Adds a new softmax and fully-connected layer for training.
689 |
690 | We need to retrain the top layer to identify our new classes, so this function
691 | adds the right operations to the graph, along with some variables to hold the
692 | weights, and then sets up all the gradients for the backward pass.
693 |
694 | The set up for the softmax and fully-connected layers is based on:
695 | https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
696 |
697 | Args:
698 | class_count: Integer of how many categories of things we're trying to
699 | recognize.
700 | final_tensor_name: Name string for the new final node that produces results.
701 | bottleneck_tensor: The output of the main CNN graph.
702 |
703 | Returns:
704 | The tensors for the training and cross entropy results, and tensors for the
705 | bottleneck input and ground truth input.
706 | """
707 | with tf.name_scope('input'):
708 | bottleneck_input = tf.placeholder_with_default(
709 | bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
710 | name='BottleneckInputPlaceholder')
711 |
712 | ground_truth_input = tf.placeholder(tf.float32,
713 | [None, class_count],
714 | name='GroundTruthInput')
715 |
716 | # Organizing the following ops as `final_training_ops` so they're easier
717 | # to see in TensorBoard
718 | layer_name = 'final_training_ops'
719 | with tf.name_scope(layer_name):
720 | with tf.name_scope('weights'):
721 | layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights')
722 | variable_summaries(layer_weights)
723 | with tf.name_scope('biases'):
724 | layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
725 | variable_summaries(layer_biases)
726 | with tf.name_scope('Wx_plus_b'):
727 | logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
728 | tf.summary.histogram('pre_activations', logits)
729 |
730 | final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
731 | tf.summary.histogram('activations', final_tensor)
732 |
733 | with tf.name_scope('cross_entropy'):
734 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
735 | labels=ground_truth_input, logits=logits)
736 | with tf.name_scope('total'):
737 | cross_entropy_mean = tf.reduce_mean(cross_entropy)
738 | tf.summary.scalar('cross_entropy', cross_entropy_mean)
739 |
740 | with tf.name_scope('train'):
741 | train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
742 | cross_entropy_mean)
743 |
744 | return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
745 | final_tensor)
746 |
747 |
748 | def add_evaluation_step(result_tensor, ground_truth_tensor):
749 | """Inserts the operations we need to evaluate the accuracy of our results.
750 |
751 | Args:
752 | result_tensor: The new final node that produces results.
753 | ground_truth_tensor: The node we feed ground truth data
754 | into.
755 |
756 | Returns:
757 | Tuple of (evaluation step, prediction).
758 | """
759 | with tf.name_scope('accuracy'):
760 | with tf.name_scope('correct_prediction'):
761 | prediction = tf.argmax(result_tensor, 1)
762 | correct_prediction = tf.equal(
763 | prediction, tf.argmax(ground_truth_tensor, 1))
764 | with tf.name_scope('accuracy'):
765 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
766 | tf.summary.scalar('accuracy', evaluation_step)
767 | return evaluation_step, prediction
768 |
769 |
770 | def main(_):
771 | # Setup the directory we'll write summaries to for TensorBoard
772 | if tf.gfile.Exists(FLAGS.summaries_dir):
773 | tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
774 | tf.gfile.MakeDirs(FLAGS.summaries_dir)
775 |
776 | # Set up the pre-trained graph.
777 | maybe_download_and_extract()
778 | graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
779 | create_inception_graph())
780 |
781 | # Look at the folder structure, and create lists of all the images.
782 | image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
783 | FLAGS.validation_percentage)
784 | class_count = len(image_lists.keys())
785 | if class_count == 0:
786 | print('No valid folders of images found at ' + FLAGS.image_dir)
787 | return -1
788 | if class_count == 1:
789 | print('Only one valid folder of images found at ' + FLAGS.image_dir +
790 | ' - multiple classes are needed for classification.')
791 | return -1
792 |
793 | # See if the command-line flags mean we're applying any distortions.
794 | do_distort_images = should_distort_images(
795 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
796 | FLAGS.random_brightness)
797 | sess = tf.Session()
798 |
799 | if do_distort_images:
800 | # We will be applying distortions, so setup the operations we'll need.
801 | distorted_jpeg_data_tensor, distorted_image_tensor = add_input_distortions(
802 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
803 | FLAGS.random_brightness)
804 | else:
805 | # We'll make sure we've calculated the 'bottleneck' image summaries and
806 | # cached them on disk.
807 | cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir,
808 | jpeg_data_tensor, bottleneck_tensor)
809 |
810 | # Add the new layer that we'll be training.
811 | (train_step, cross_entropy, bottleneck_input, ground_truth_input,
812 | final_tensor) = add_final_training_ops(len(image_lists.keys()),
813 | FLAGS.final_tensor_name,
814 | bottleneck_tensor)
815 |
816 | # Create the operations we need to evaluate the accuracy of our new layer.
817 | evaluation_step, prediction = add_evaluation_step(
818 | final_tensor, ground_truth_input)
819 |
820 | # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
821 | merged = tf.summary.merge_all()
822 | train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
823 | sess.graph)
824 | validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
825 |
826 | # Set up all our weights to their initial default values.
827 | init = tf.global_variables_initializer()
828 | sess.run(init)
829 |
830 | # Run the training for as many cycles as requested on the command line.
831 | for i in range(FLAGS.how_many_training_steps):
832 | # Get a batch of input bottleneck values, either calculated fresh every time
833 | # with distortions applied, or from the cache stored on disk.
834 | if do_distort_images:
835 | train_bottlenecks, train_ground_truth = get_random_distorted_bottlenecks(
836 | sess, image_lists, FLAGS.train_batch_size, 'training',
837 | FLAGS.image_dir, distorted_jpeg_data_tensor,
838 | distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
839 | else:
840 | train_bottlenecks, train_ground_truth, _ = get_random_cached_bottlenecks(
841 | sess, image_lists, FLAGS.train_batch_size, 'training',
842 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
843 | bottleneck_tensor)
844 | # Feed the bottlenecks and ground truth into the graph, and run a training
845 | # step. Capture training summaries for TensorBoard with the `merged` op.
846 | train_summary, _ = sess.run([merged, train_step],
847 | feed_dict={bottleneck_input: train_bottlenecks,
848 | ground_truth_input: train_ground_truth})
849 | train_writer.add_summary(train_summary, i)
850 |
851 | # Every so often, print out how well the graph is training.
852 | is_last_step = (i + 1 == FLAGS.how_many_training_steps)
853 | if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
854 | train_accuracy, cross_entropy_value = sess.run(
855 | [evaluation_step, cross_entropy],
856 | feed_dict={bottleneck_input: train_bottlenecks,
857 | ground_truth_input: train_ground_truth})
858 | print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
859 | train_accuracy * 100))
860 | print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
861 | cross_entropy_value))
862 | validation_bottlenecks, validation_ground_truth, _ = (
863 | get_random_cached_bottlenecks(
864 | sess, image_lists, FLAGS.validation_batch_size, 'validation',
865 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
866 | bottleneck_tensor))
867 | # Run a validation step and capture training summaries for TensorBoard
868 | # with the `merged` op.
869 | validation_summary, validation_accuracy = sess.run(
870 | [merged, evaluation_step],
871 | feed_dict={bottleneck_input: validation_bottlenecks,
872 | ground_truth_input: validation_ground_truth})
873 | validation_writer.add_summary(validation_summary, i)
874 | print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
875 | (datetime.now(), i, validation_accuracy * 100,
876 | len(validation_bottlenecks)))
877 |
878 | # We've completed all our training, so run a final test evaluation on
879 | # some new images we haven't used before.
880 | test_bottlenecks, test_ground_truth, test_filenames = (
881 | get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
882 | 'testing', FLAGS.bottleneck_dir,
883 | FLAGS.image_dir, jpeg_data_tensor,
884 | bottleneck_tensor))
885 | test_accuracy, predictions = sess.run(
886 | [evaluation_step, prediction],
887 | feed_dict={bottleneck_input: test_bottlenecks,
888 | ground_truth_input: test_ground_truth})
889 | print('Final test accuracy = %.1f%% (N=%d)' % (
890 | test_accuracy * 100, len(test_bottlenecks)))
891 |
892 | if FLAGS.print_misclassified_test_images:
893 | print('=== MISCLASSIFIED TEST IMAGES ===')
894 | for i, test_filename in enumerate(test_filenames):
895 | if predictions[i] != test_ground_truth[i].argmax():
896 | print('%70s %s' % (test_filename,
897 | list(image_lists.keys())[predictions[i]]))
898 |
899 | # Write out the trained graph and labels with the weights stored as constants.
900 | output_graph_def = graph_util.convert_variables_to_constants(
901 | sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
902 | with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
903 | f.write(output_graph_def.SerializeToString())
904 | with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
905 | f.write('\n'.join(image_lists.keys()) + '\n')
906 |
907 |
908 | if __name__ == '__main__':
909 | parser = argparse.ArgumentParser()
910 | parser.add_argument(
911 | '--image_dir',
912 | type=str,
913 | default='',
914 | help='Path to folders of labeled images.'
915 | )
916 | parser.add_argument(
917 | '--output_graph',
918 | type=str,
919 | default='/tmp/output_graph.pb',
920 | help='Where to save the trained graph.'
921 | )
922 | parser.add_argument(
923 | '--output_labels',
924 | type=str,
925 | default='/tmp/output_labels.txt',
926 | help='Where to save the trained graph\'s labels.'
927 | )
928 | parser.add_argument(
929 | '--summaries_dir',
930 | type=str,
931 | default='/tmp/retrain_logs',
932 | help='Where to save summary logs for TensorBoard.'
933 | )
934 | parser.add_argument(
935 | '--how_many_training_steps',
936 | type=int,
937 | default=4000,
938 | help='How many training steps to run before ending.'
939 | )
940 | parser.add_argument(
941 | '--learning_rate',
942 | type=float,
943 | default=0.01,
944 | help='How large a learning rate to use when training.'
945 | )
946 | parser.add_argument(
947 | '--testing_percentage',
948 | type=int,
949 | default=10,
950 | help='What percentage of images to use as a test set.'
951 | )
952 | parser.add_argument(
953 | '--validation_percentage',
954 | type=int,
955 | default=10,
956 | help='What percentage of images to use as a validation set.'
957 | )
958 | parser.add_argument(
959 | '--eval_step_interval',
960 | type=int,
961 | default=10,
962 | help='How often to evaluate the training results.'
963 | )
964 | parser.add_argument(
965 | '--train_batch_size',
966 | type=int,
967 | default=100,
968 | help='How many images to train on at a time.'
969 | )
970 | parser.add_argument(
971 | '--test_batch_size',
972 | type=int,
973 | default=-1,
974 | help="""\
975 | How many images to test on. This test set is only used once, to evaluate
976 | the final accuracy of the model after training completes.
977 | A value of -1 causes the entire test set to be used, which leads to more
978 | stable results across runs.\
979 | """
980 | )
981 | parser.add_argument(
982 | '--validation_batch_size',
983 | type=int,
984 | default=100,
985 | help="""\
986 | How many images to use in an evaluation batch. This validation set is
987 | used much more often than the test set, and is an early indicator of how
988 | accurate the model is during training.
989 | A value of -1 causes the entire validation set to be used, which leads to
990 | more stable results across training iterations, but may be slower on large
991 | training sets.\
992 | """
993 | )
994 | parser.add_argument(
995 | '--print_misclassified_test_images',
996 | default=False,
997 | help="""\
998 | Whether to print out a list of all misclassified test images.\
999 | """,
1000 | action='store_true'
1001 | )
1002 | parser.add_argument(
1003 | '--model_dir',
1004 | type=str,
1005 | default='/tmp/imagenet',
1006 | help="""\
1007 | Path to classify_image_graph_def.pb,
1008 | imagenet_synset_to_human_label_map.txt, and
1009 | imagenet_2012_challenge_label_map_proto.pbtxt.\
1010 | """
1011 | )
1012 | parser.add_argument(
1013 | '--bottleneck_dir',
1014 | type=str,
1015 | default='/tmp/bottleneck',
1016 | help='Path to cache bottleneck layer values as files.'
1017 | )
1018 | parser.add_argument(
1019 | '--final_tensor_name',
1020 | type=str,
1021 | default='final_result',
1022 | help="""\
1023 | The name of the output classification layer in the retrained graph.\
1024 | """
1025 | )
1026 | parser.add_argument(
1027 | '--flip_left_right',
1028 | default=False,
1029 | help="""\
1030 | Whether to randomly flip half of the training images horizontally.\
1031 | """,
1032 | action='store_true'
1033 | )
1034 | parser.add_argument(
1035 | '--random_crop',
1036 | type=int,
1037 | default=0,
1038 | help="""\
1039 | A percentage determining how much of a margin to randomly crop off the
1040 | training images.\
1041 | """
1042 | )
1043 | parser.add_argument(
1044 | '--random_scale',
1045 | type=int,
1046 | default=0,
1047 | help="""\
1048 | A percentage determining how much to randomly scale up the size of the
1049 | training images by.\
1050 | """
1051 | )
1052 | parser.add_argument(
1053 | '--random_brightness',
1054 | type=int,
1055 | default=0,
1056 | help="""\
1057 | A percentage determining how much to randomly multiply the training image
1058 | input pixels up or down by.\
1059 | """
1060 | )
1061 | FLAGS, unparsed = parser.parse_known_args()
1062 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
1063 |
--------------------------------------------------------------------------------