├── .idea
├── Emotion_recognition_system.iml
├── Face_Sentiment_analysis.iml
├── Tic_tac_toe.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── Emotion_recognition.gif
├── Face_crop.py
├── LICENSE
├── README.md
├── Sample
├── Face_crop.py
└── haarcascade_frontalface_alt.xml
├── android_recognition.py
├── avbin64.dll
├── haarcascade_frontalface_alt.xml
├── haarcascade_frontalface_default.xml
├── label_image.py
├── recognition_webcam.py
├── retrain.py
└── training code.txt
/.idea/Emotion_recognition_system.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/Face_Sentiment_analysis.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/Tic_tac_toe.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Emotion_recognition.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spidy20/Emotion_recognition_system/41241626a2c6cd24afcfb40f3ad5e0e469fff662/Emotion_recognition.gif
--------------------------------------------------------------------------------
/Face_crop.py:
--------------------------------------------------------------------------------
1 | import cv2,glob
2 |
3 | images=glob.glob("*.jpg")
4 |
5 | for image in images:
6 | facedata = "haarcascade_frontalface_alt.xml"
7 | cascade = cv2.CascadeClassifier(facedata)
8 | img=cv2.imread(image,0)
9 |
10 | re=cv2.resize(img,(int(img.shape[1]),int(img.shape[0])))
11 | faces = cascade.detectMultiScale(re)
12 |
13 | for f in faces:
14 | x, y, w, h = [v for v in f]
15 | Rect=cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2)
16 |
17 | sub_face = img[y:y + h, x:x + w]
18 |
19 | f_name = image.split('/')
20 | f_name = f_name[-1]
21 | cv2.imshow("checking",sub_face)
22 | cv2.waitKey(500)
23 | cv2.destroyAllWindows()
24 |
25 |
26 |
27 | cv2.imwrite("resized_"+image,sub_face)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Kushal Bhavsar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Emotion😂😥😡😱 Recognition system [](https://github.com/Spidy20/Emotion_recognition_system/blob/master/LICENSE)
2 |
3 | ### Sourcerer
4 |
5 |
6 | ### Code Requirements
7 | - Tensorflow
8 | Installation code:- `pip install tensorflow`
9 | - Download my repository
10 | - Own Expression dataset(Note: You can downlaod expression images from google, or you can record your video make diffrent expression ,and
11 | converts into Grayscale images(For more accurate prediction))
12 |
13 |
14 | ### What steps you have to follow??
15 | - Download my repository
16 | - Make `Images` folder in your project ,make subfolder for emotions like Happy,sad,Angry.
17 | - Put `Face_crop.py` & `haarcascade_frontalface_alt.xml` in every type of image folder,ex : put this program in "happy' image folder and
18 | run this program it will detect faces from images and convert it into grayscale and make a new images in same folder.
19 | - After that you have to create model, for that copy code from code.txt file and open CMD in your project folder and paste it & enter
20 | - It will take training aaround 20-25 minutes so keep patience.
21 | - After training it will create two files `retrained_graph.pb` & `retrained_labels.txt`
22 | - Now run `recognition_webcam.py`.
23 | - If you want to fetch video from your mobile cam than use `android_recognition.py`,but you have to install IPWebcam app in your system
24 | and replace your server URL with my URL
25 | - That's all
26 |
27 | ### How it works? See:)
28 |
29 |
30 |
31 | ### Notes
32 | - It will require high processing power(I have 8 GB RAM & 2 GB GC)
33 | - If you think it will recognise expression just like humans,than leave it ,its not possible.
34 | - Download 300 Images for every expression(you can use batch downloader)
35 | - Noisy image can reduce your accuracy so quality of images matter.
36 |
37 |
38 | ## Just follow☝️ me and Star⭐ my repository
39 |
--------------------------------------------------------------------------------
/Sample/Face_crop.py:
--------------------------------------------------------------------------------
1 | import cv2,glob
2 |
3 | images=glob.glob("*.jpg")
4 |
5 | for image in images:
6 |
7 | img=cv2.imread(image,0)
8 |
9 | re=cv2.resize(img,(int(img.shape[1]),int(img.shape[0])))
10 |
11 |
12 | for f in re:
13 | f_name = image.split('/')
14 | f_name = f_name[-1]
15 | cv2.imshow("checking",sub_face)
16 | cv2.waitKey(500)
17 | cv2.destroyAllWindows()
18 |
19 |
20 |
21 | cv2.imwrite("resized_"+image,sub_face)
--------------------------------------------------------------------------------
/android_recognition.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import label_image
3 | import os
4 | import numpy as np
5 | from urllib.request import urlopen
6 | import time
7 | from playsound import playsound
8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9 |
10 | size = 4
11 |
12 | # We load the xml file
13 | classifier = cv2.CascadeClassifier('haarcascade_frontalface_alt.xml')
14 | mobile_video="http://192.168.0.101:8080/shot.jpg"
15 | # Using default WebCam connected to the PC.
16 |
17 | while True:
18 |
19 | img_resp = urlopen(mobile_video)
20 | img_arr = np.array(bytearray(img_resp.read()), dtype=np.uint8)
21 |
22 | cap = cv2.imdecode(img_arr, -1)
23 | im=cv2.flip(cap,1,0)
24 | mini = cv2.resize(im, (int(im.shape[1] / size), int(im.shape[0] / size)))
25 |
26 | # detect MultiScale / faces
27 | faces = classifier.detectMultiScale(mini)
28 |
29 | # Draw rectangles around each face
30 | for f in faces:
31 | (x, y, w, h) = [v * size for v in f] # Scale the shapesize backup
32 | sub_face = im[y:y + h, x:x + w]
33 | FaceFileName = "test.jpg" # Saving the current image from the webcam for testing.
34 | cv2.imwrite(FaceFileName, sub_face)
35 | text = label_image.main(FaceFileName) # Getting the Result from the label_image file, i.e., Classification Result.
36 | text = text.title() # Title Case looks Stunning.
37 | font = cv2.FONT_HERSHEY_TRIPLEX
38 |
39 | if text == 'Angry':
40 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 25, 255), 7)
41 | cv2.putText(im, text, (x + h, y), font, 1, (0, 25,255), 2)
42 |
43 | if text == 'Smile':
44 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,260,0), 7)
45 | cv2.putText(im, text, (x + h, y), font, 1, (0,260,0), 2)
46 |
47 | if text == 'Fear':
48 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 255), 7)
49 | cv2.putText(im, text, (x + h, y), font, 1, (0, 255, 255), 2)
50 |
51 | if text == 'Sad':
52 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,191,255), 7)
53 | cv2.putText(im, text, (x + h, y), font, 1, (0,191,255), 2)
54 |
55 | # Show the image/
56 | cv2.imshow('Emotion recognition from Android screen', im)
57 | key = cv2.waitKey(30)& 0xff
58 | if key == 27: # The Esc key
59 |
60 | break
61 |
62 |
63 |
--------------------------------------------------------------------------------
/avbin64.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Spidy20/Emotion_recognition_system/41241626a2c6cd24afcfb40f3ad5e0e469fff662/avbin64.dll
--------------------------------------------------------------------------------
/label_image.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | import argparse
7 | import sys
8 | import time
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | def load_graph(model_file):
14 | graph = tf.Graph()
15 | graph_def = tf.GraphDef()
16 |
17 | with open(model_file, "rb") as f:
18 | graph_def.ParseFromString(f.read())
19 | with graph.as_default():
20 | tf.import_graph_def(graph_def)
21 |
22 | return graph
23 |
24 | def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
25 | input_mean=0, input_std=255):
26 | input_name = "file_reader"
27 | output_name = "normalized"
28 | file_reader = tf.read_file(file_name, input_name)
29 | if file_name.endswith(".png"):
30 | image_reader = tf.image.decode_png(file_reader, channels = 3,
31 | name='png_reader')
32 | elif file_name.endswith(".gif"):
33 | image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
34 | name='gif_reader'))
35 | elif file_name.endswith(".bmp"):
36 | image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
37 | else:
38 | image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
39 | name='jpeg_reader')
40 | float_caster = tf.cast(image_reader, tf.float32)
41 | dims_expander = tf.expand_dims(float_caster, 0);
42 | resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
43 | normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
44 | sess = tf.Session()
45 | result = sess.run(normalized)
46 |
47 | return result
48 |
49 | def load_labels(label_file):
50 | label = []
51 |
52 | proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
53 | for l in proto_as_ascii_lines:
54 | label.append(l.rstrip())
55 |
56 | return label
57 |
58 |
59 |
60 | def main(img):
61 | file_name = img
62 | model_file = "retrained_graph.pb"
63 | label_file = "retrained_labels.txt"
64 | input_height = 224
65 | input_width = 224
66 | input_mean = 128
67 | input_std = 128
68 | input_layer = "input"
69 | output_layer = "final_result"
70 |
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument("--image", help="image to be processed")
73 | parser.add_argument("--graph", help="graph/model to be executed")
74 | parser.add_argument("--labels", help="name of file containing labels")
75 | parser.add_argument("--input_height", type=int, help="input height")
76 | parser.add_argument("--input_width", type=int, help="input width")
77 | parser.add_argument("--input_mean", type=int, help="input mean")
78 | parser.add_argument("--input_std", type=int, help="input std")
79 | parser.add_argument("--input_layer", help="name of input layer")
80 | parser.add_argument("--output_layer", help="name of output layer")
81 | args = parser.parse_args()
82 |
83 | if args.graph:
84 | model_file = args.graph
85 | if args.image:
86 | file_name = args.image
87 | if args.labels:
88 | label_file = args.labels
89 | if args.input_height:
90 | input_height = args.input_height
91 | if args.input_width:
92 | input_width = args.input_width
93 | if args.input_mean:
94 | input_mean = args.input_mean
95 | if args.input_std:
96 | input_std = args.input_std
97 | if args.input_layer:
98 | input_layer = args.input_layer
99 | if args.output_layer:
100 | output_layer = args.output_layer
101 |
102 | graph = load_graph(model_file)
103 | t = read_tensor_from_image_file(file_name,
104 | input_height=input_height,
105 | input_width=input_width,
106 | input_mean=input_mean,
107 | input_std=input_std)
108 |
109 | input_name = "import/" + input_layer
110 | output_name = "import/" + output_layer
111 | input_operation = graph.get_operation_by_name(input_name);
112 | output_operation = graph.get_operation_by_name(output_name);
113 |
114 |
115 | with tf.Session(graph=graph) as sess:
116 | start = time.time()
117 | results = sess.run(output_operation.outputs[0],
118 | {input_operation.outputs[0]: t})
119 | end=time.time()
120 | results = np.squeeze(results)
121 |
122 | top_k = results.argsort()[-5:][::-1]
123 | labels = load_labels(label_file)
124 |
125 |
126 |
127 |
128 |
129 |
130 | for i in top_k:
131 | return labels[i]
--------------------------------------------------------------------------------
/recognition_webcam.py:
--------------------------------------------------------------------------------
1 | #Coded by:- Kushal Bhavsra
2 | #From:- Techmicra IT solution
3 | import time
4 | import cv2
5 | import label_image
6 | import os,random
7 | import subprocess
8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9 |
10 | size = 4
11 | # We load the xml file
12 | classifier = cv2.CascadeClassifier('haarcascade_frontalface_alt.xml')
13 | global text
14 | webcam = cv2.VideoCapture(0) # Using default WebCam connected to the PC.
15 | while True:
16 | (rval, im) = webcam.read()
17 | im = cv2.flip(im, 1, 0) # Flip to act as a mirror
18 | # Resize the image to speed up detection
19 | mini = cv2.resize(im, (int(im.shape[1] / size), int(im.shape[0] / size)))
20 | # detect MultiScale / faces
21 | faces = classifier.detectMultiScale(mini)
22 | # Draw rectangles around each face
23 | for f in faces:
24 | (x, y, w, h) = [v * size for v in f] # Scale the shapesize backup
25 | sub_face = im[y:y + h, x:x + w]
26 | FaceFileName = "test.jpg" # Saving the current image from the webcam for testing.
27 | cv2.imwrite(FaceFileName, sub_face)
28 | text = label_image.main(FaceFileName) # Getting the Result from the label_image file, i.e., Classification Result.
29 | text = text.title() # Title Case looks Stunning.
30 | font = cv2.FONT_HERSHEY_TRIPLEX
31 |
32 | if text == 'Angry':
33 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 25, 255), 7)
34 | cv2.putText(im, text, (x + h, y), font, 1, (0, 25,255), 2)
35 |
36 | if text == 'Smile':
37 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,260,0), 7)
38 | cv2.putText(im, text, (x + h, y), font, 1, (0,260,0), 2)
39 |
40 | if text == 'Fear':
41 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 255), 7)
42 | cv2.putText(im, text, (x + h, y), font, 1, (0, 255, 255), 2)
43 |
44 | if text == 'Sad':
45 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,191,255), 7)
46 | cv2.putText(im, text, (x + h, y), font, 1, (0,191,255), 2)
47 |
48 | # Show the image/
49 | cv2.imshow('Emotion recognition', im)
50 | key = cv2.waitKey(30)& 0xff
51 |
52 | if key == 27: # The Esc key
53 | break
--------------------------------------------------------------------------------
/retrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | r"""Simple transfer learning with Inception v3 or Mobilenet models.
16 |
17 | With support for TensorBoard.
18 |
19 | This example shows how to take a Inception v3 or Mobilenet model trained on
20 | ImageNet images, and train a new top layer that can recognize other classes of
21 | images.
22 |
23 | The top layer receives as input a 2048-dimensional vector (1001-dimensional for
24 | Mobilenet) for each image. We train a softmax layer on top of this
25 | representation. Assuming the softmax layer contains N labels, this corresponds
26 | to learning N + 2048*N (or 1001*N) model parameters corresponding to the
27 | learned biases and weights.
28 |
29 | Here's an example, which assumes you have a folder containing class-named
30 | subfolders, each full of images for each label. The example folder flower_photos
31 | should have a structure like this:
32 |
33 | ~/flower_photos/daisy/photo1.jpg
34 | ~/flower_photos/daisy/photo2.jpg
35 | ...
36 | ~/flower_photos/rose/anotherphoto77.jpg
37 | ...
38 | ~/flower_photos/sunflower/somepicture.jpg
39 |
40 | The subfolder names are important, since they define what label is applied to
41 | each image, but the filenames themselves don't matter. Once your images are
42 | prepared, you can run the training with a command like this:
43 |
44 |
45 | ```bash
46 | bazel build tensorflow/examples/image_retraining:retrain && \
47 | bazel-bin/tensorflow/examples/image_retraining/retrain \
48 | --image_dir ~/flower_photos
49 | ```
50 |
51 | Or, if you have a pip installation of tensorflow, `retrain.py` can be run
52 | without bazel:
53 |
54 | ```bash
55 | python tensorflow/examples/image_retraining/retrain.py \
56 | --image_dir ~/flower_photos
57 | ```
58 |
59 | You can replace the image_dir argument with any folder containing subfolders of
60 | images. The label for each image is taken from the name of the subfolder it's
61 | in.
62 |
63 | This produces a new model file that can be loaded and run by any TensorFlow
64 | program, for example the label_image sample code.
65 |
66 | By default this script will use the high accuracy, but comparatively large and
67 | slow Inception v3 model architecture. It's recommended that you start with this
68 | to validate that you have gathered good training data, but if you want to deploy
69 | on resource-limited platforms, you can try the `--architecture` flag with a
70 | Mobilenet model. For example:
71 |
72 | ```bash
73 | python tensorflow/examples/image_retraining/retrain.py \
74 | --image_dir ~/flower_photos --architecture mobilenet_1.0_224
75 | ```
76 |
77 | There are 32 different Mobilenet models to choose from, with a variety of file
78 | size and latency options. The first number can be '1.0', '0.75', '0.50', or
79 | '0.25' to control the size, and the second controls the input image size, either
80 | '224', '192', '160', or '128', with smaller sizes running faster. See
81 | https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
82 | for more information on Mobilenet.
83 |
84 | To use with TensorBoard:
85 |
86 | By default, this script will log summaries to /tmp/retrain_logs directory
87 |
88 | Visualize the summaries with this command:
89 |
90 | tensorboard --logdir /tmp/retrain_logs
91 |
92 | """
93 | from __future__ import absolute_import
94 | from __future__ import division
95 | from __future__ import print_function
96 |
97 | import argparse
98 | import collections
99 | from datetime import datetime
100 | import hashlib
101 | import os.path
102 | import random
103 | import re
104 | import sys
105 | import tarfile
106 |
107 | import numpy as np
108 | from six.moves import urllib
109 | import tensorflow as tf
110 |
111 |
112 | from tensorflow.python.framework import graph_util
113 | from tensorflow.python.framework import tensor_shape
114 | from tensorflow.python.platform import gfile
115 | from tensorflow.python.util import compat
116 |
117 | FLAGS = None
118 |
119 | # These are all parameters that are tied to the particular model architecture
120 | # we're using for Inception v3. These include things like tensor names and their
121 | # sizes. If you want to adapt this script to work with another model, you will
122 | # need to update these to reflect the values in the network you're using.
123 | MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M
124 |
125 |
126 | def create_image_lists(image_dir, testing_percentage, validation_percentage):
127 | """Builds a list of training images from the file system.
128 |
129 | Analyzes the sub folders in the image directory, splits them into stable
130 | training, testing, and validation sets, and returns a data structure
131 | describing the lists of images for each label and their paths.
132 |
133 | Args:
134 | image_dir: String path to a folder containing subfolders of images.
135 | testing_percentage: Integer percentage of the images to reserve for tests.
136 | validation_percentage: Integer percentage of images reserved for validation.
137 |
138 | Returns:
139 | A dictionary containing an entry for each label subfolder, with images split
140 | into training, testing, and validation sets within each label.
141 | """
142 | if not gfile.Exists(image_dir):
143 | tf.logging.error("Image directory '" + image_dir + "' not found.")
144 | return None
145 | result = collections.OrderedDict()
146 | sub_dirs = [
147 | os.path.join(image_dir,item)
148 | for item in gfile.ListDirectory(image_dir)]
149 | sub_dirs = sorted(item for item in sub_dirs
150 | if gfile.IsDirectory(item))
151 | for sub_dir in sub_dirs:
152 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
153 | file_list = []
154 | dir_name = os.path.basename(sub_dir)
155 | if dir_name == image_dir:
156 | continue
157 | tf.logging.info("Looking for images in '" + dir_name + "'")
158 | for extension in extensions:
159 | file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
160 | file_list.extend(gfile.Glob(file_glob))
161 | if not file_list:
162 | tf.logging.warning('No files found')
163 | continue
164 | if len(file_list) < 20:
165 | tf.logging.warning(
166 | 'WARNING: Folder has less than 20 images, which may cause issues.')
167 | elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
168 | tf.logging.warning(
169 | 'WARNING: Folder {} has more than {} images. Some images will '
170 | 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
171 | label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
172 | training_images = []
173 | testing_images = []
174 | validation_images = []
175 | for file_name in file_list:
176 | base_name = os.path.basename(file_name)
177 | # We want to ignore anything after '_nohash_' in the file name when
178 | # deciding which set to put an image in, the data set creator has a way of
179 | # grouping photos that are close variations of each other. For example
180 | # this is used in the plant disease data set to group multiple pictures of
181 | # the same leaf.
182 | hash_name = re.sub(r'_nohash_.*$', '', file_name)
183 | # This looks a bit magical, but we need to decide whether this file should
184 | # go into the training, testing, or validation sets, and we want to keep
185 | # existing files in the same set even if more files are subsequently
186 | # added.
187 | # To do that, we need a stable way of deciding based on just the file name
188 | # itself, so we do a hash of that and then use that to generate a
189 | # probability value that we use to assign it.
190 | hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
191 | percentage_hash = ((int(hash_name_hashed, 16) %
192 | (MAX_NUM_IMAGES_PER_CLASS + 1)) *
193 | (100.0 / MAX_NUM_IMAGES_PER_CLASS))
194 | if percentage_hash < validation_percentage:
195 | validation_images.append(base_name)
196 | elif percentage_hash < (testing_percentage + validation_percentage):
197 | testing_images.append(base_name)
198 | else:
199 | training_images.append(base_name)
200 | result[label_name] = {
201 | 'dir': dir_name,
202 | 'training': training_images,
203 | 'testing': testing_images,
204 | 'validation': validation_images,
205 | }
206 | return result
207 |
208 |
209 | def get_image_path(image_lists, label_name, index, image_dir, category):
210 | """"Returns a path to an image for a label at the given index.
211 |
212 | Args:
213 | image_lists: Dictionary of training images for each label.
214 | label_name: Label string we want to get an image for.
215 | index: Int offset of the image we want. This will be moduloed by the
216 | available number of images for the label, so it can be arbitrarily large.
217 | image_dir: Root folder string of the subfolders containing the training
218 | images.
219 | category: Name string of set to pull images from - training, testing, or
220 | validation.
221 |
222 | Returns:
223 | File system path string to an image that meets the requested parameters.
224 |
225 | """
226 | if label_name not in image_lists:
227 | tf.logging.fatal('Label does not exist %s.', label_name)
228 | label_lists = image_lists[label_name]
229 | if category not in label_lists:
230 | tf.logging.fatal('Category does not exist %s.', category)
231 | category_list = label_lists[category]
232 | if not category_list:
233 | tf.logging.fatal('Label %s has no images in the category %s.',
234 | label_name, category)
235 | mod_index = index % len(category_list)
236 | base_name = category_list[mod_index]
237 | sub_dir = label_lists['dir']
238 | full_path = os.path.join(image_dir, sub_dir, base_name)
239 | return full_path
240 |
241 |
242 | def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
243 | category, architecture):
244 | """"Returns a path to a bottleneck file for a label at the given index.
245 |
246 | Args:
247 | image_lists: Dictionary of training images for each label.
248 | label_name: Label string we want to get an image for.
249 | index: Integer offset of the image we want. This will be moduloed by the
250 | available number of images for the label, so it can be arbitrarily large.
251 | bottleneck_dir: Folder string holding cached files of bottleneck values.
252 | category: Name string of set to pull images from - training, testing, or
253 | validation.
254 | architecture: The name of the model architecture.
255 |
256 | Returns:
257 | File system path string to an image that meets the requested parameters.
258 | """
259 | return get_image_path(image_lists, label_name, index, bottleneck_dir,
260 | category) + '_' + architecture + '.txt'
261 |
262 |
263 | def create_model_graph(model_info):
264 | """"Creates a graph from saved GraphDef file and returns a Graph object.
265 |
266 | Args:
267 | model_info: Dictionary containing information about the model architecture.
268 |
269 | Returns:
270 | Graph holding the trained Inception network, and various tensors we'll be
271 | manipulating.
272 | """
273 | with tf.Graph().as_default() as graph:
274 | model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
275 | with gfile.FastGFile(model_path, 'rb') as f:
276 | graph_def = tf.GraphDef()
277 | graph_def.ParseFromString(f.read())
278 | bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
279 | graph_def,
280 | name='',
281 | return_elements=[
282 | model_info['bottleneck_tensor_name'],
283 | model_info['resized_input_tensor_name'],
284 | ]))
285 | return graph, bottleneck_tensor, resized_input_tensor
286 |
287 |
288 | def run_bottleneck_on_image(sess, image_data, image_data_tensor,
289 | decoded_image_tensor, resized_input_tensor,
290 | bottleneck_tensor):
291 | """Runs inference on an image to extract the 'bottleneck' summary layer.
292 |
293 | Args:
294 | sess: Current active TensorFlow Session.
295 | image_data: String of raw JPEG data.
296 | image_data_tensor: Input data layer in the graph.
297 | decoded_image_tensor: Output of initial image resizing and preprocessing.
298 | resized_input_tensor: The input node of the recognition graph.
299 | bottleneck_tensor: Layer before the final softmax.
300 |
301 | Returns:
302 | Numpy array of bottleneck values.
303 | """
304 | # First decode the JPEG image, resize it, and rescale the pixel values.
305 | resized_input_values = sess.run(decoded_image_tensor,
306 | {image_data_tensor: image_data})
307 | # Then run it through the recognition network.
308 | bottleneck_values = sess.run(bottleneck_tensor,
309 | {resized_input_tensor: resized_input_values})
310 | bottleneck_values = np.squeeze(bottleneck_values)
311 | return bottleneck_values
312 |
313 |
314 | def maybe_download_and_extract(data_url):
315 | """Download and extract model tar file.
316 |
317 | If the pretrained model we're using doesn't already exist, this function
318 | downloads it from the TensorFlow.org website and unpacks it into a directory.
319 |
320 | Args:
321 | data_url: Web location of the tar file containing the pretrained model.
322 | """
323 | dest_directory = FLAGS.model_dir
324 | if not os.path.exists(dest_directory):
325 | os.makedirs(dest_directory)
326 | filename = data_url.split('/')[-1]
327 | filepath = os.path.join(dest_directory, filename)
328 | if not os.path.exists(filepath):
329 |
330 | def _progress(count, block_size, total_size):
331 | sys.stdout.write('\r>> Downloading %s %.1f%%' %
332 | (filename,
333 | float(count * block_size) / float(total_size) * 100.0))
334 | sys.stdout.flush()
335 |
336 | filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress)
337 | print()
338 | statinfo = os.stat(filepath)
339 | tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
340 | 'bytes.')
341 | tarfile.open(filepath, 'r:gz').extractall(dest_directory)
342 |
343 |
344 | def ensure_dir_exists(dir_name):
345 | """Makes sure the folder exists on disk.
346 |
347 | Args:
348 | dir_name: Path string to the folder we want to create.
349 | """
350 | if not os.path.exists(dir_name):
351 | os.makedirs(dir_name)
352 |
353 |
354 | bottleneck_path_2_bottleneck_values = {}
355 |
356 |
357 | def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
358 | image_dir, category, sess, jpeg_data_tensor,
359 | decoded_image_tensor, resized_input_tensor,
360 | bottleneck_tensor):
361 | """Create a single bottleneck file."""
362 | tf.logging.info('Creating bottleneck at ' + bottleneck_path)
363 | image_path = get_image_path(image_lists, label_name, index,
364 | image_dir, category)
365 | if not gfile.Exists(image_path):
366 | tf.logging.fatal('File does not exist %s', image_path)
367 | image_data = gfile.FastGFile(image_path, 'rb').read()
368 | try:
369 | bottleneck_values = run_bottleneck_on_image(
370 | sess, image_data, jpeg_data_tensor, decoded_image_tensor,
371 | resized_input_tensor, bottleneck_tensor)
372 | except Exception as e:
373 | raise RuntimeError('Error during processing file %s (%s)' % (image_path,
374 | str(e)))
375 | bottleneck_string = ','.join(str(x) for x in bottleneck_values)
376 | with open(bottleneck_path, 'w') as bottleneck_file:
377 | bottleneck_file.write(bottleneck_string)
378 |
379 |
380 | def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
381 | category, bottleneck_dir, jpeg_data_tensor,
382 | decoded_image_tensor, resized_input_tensor,
383 | bottleneck_tensor, architecture):
384 | """Retrieves or calculates bottleneck values for an image.
385 |
386 | If a cached version of the bottleneck data exists on-disk, return that,
387 | otherwise calculate the data and save it to disk for future use.
388 |
389 | Args:
390 | sess: The current active TensorFlow Session.
391 | image_lists: Dictionary of training images for each label.
392 | label_name: Label string we want to get an image for.
393 | index: Integer offset of the image we want. This will be modulo-ed by the
394 | available number of images for the label, so it can be arbitrarily large.
395 | image_dir: Root folder string of the subfolders containing the training
396 | images.
397 | category: Name string of which set to pull images from - training, testing,
398 | or validation.
399 | bottleneck_dir: Folder string holding cached files of bottleneck values.
400 | jpeg_data_tensor: The tensor to feed loaded jpeg data into.
401 | decoded_image_tensor: The output of decoding and resizing the image.
402 | resized_input_tensor: The input node of the recognition graph.
403 | bottleneck_tensor: The output tensor for the bottleneck values.
404 | architecture: The name of the model architecture.
405 |
406 | Returns:
407 | Numpy array of values produced by the bottleneck layer for the image.
408 | """
409 | label_lists = image_lists[label_name]
410 | sub_dir = label_lists['dir']
411 | sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
412 | ensure_dir_exists(sub_dir_path)
413 | bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
414 | bottleneck_dir, category, architecture)
415 | if not os.path.exists(bottleneck_path):
416 | create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
417 | image_dir, category, sess, jpeg_data_tensor,
418 | decoded_image_tensor, resized_input_tensor,
419 | bottleneck_tensor)
420 | with open(bottleneck_path, 'r') as bottleneck_file:
421 | bottleneck_string = bottleneck_file.read()
422 | did_hit_error = False
423 | try:
424 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
425 | except ValueError:
426 | tf.logging.warning('Invalid float found, recreating bottleneck')
427 | did_hit_error = True
428 | if did_hit_error:
429 | create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
430 | image_dir, category, sess, jpeg_data_tensor,
431 | decoded_image_tensor, resized_input_tensor,
432 | bottleneck_tensor)
433 | with open(bottleneck_path, 'r') as bottleneck_file:
434 | bottleneck_string = bottleneck_file.read()
435 | # Allow exceptions to propagate here, since they shouldn't happen after a
436 | # fresh creation
437 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
438 | return bottleneck_values
439 |
440 |
441 | def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
442 | jpeg_data_tensor, decoded_image_tensor,
443 | resized_input_tensor, bottleneck_tensor, architecture):
444 | """Ensures all the training, testing, and validation bottlenecks are cached.
445 |
446 | Because we're likely to read the same image multiple times (if there are no
447 | distortions applied during training) it can speed things up a lot if we
448 | calculate the bottleneck layer values once for each image during
449 | preprocessing, and then just read those cached values repeatedly during
450 | training. Here we go through all the images we've found, calculate those
451 | values, and save them off.
452 |
453 | Args:
454 | sess: The current active TensorFlow Session.
455 | image_lists: Dictionary of training images for each label.
456 | image_dir: Root folder string of the subfolders containing the training
457 | images.
458 | bottleneck_dir: Folder string holding cached files of bottleneck values.
459 | jpeg_data_tensor: Input tensor for jpeg data from file.
460 | decoded_image_tensor: The output of decoding and resizing the image.
461 | resized_input_tensor: The input node of the recognition graph.
462 | bottleneck_tensor: The penultimate output layer of the graph.
463 | architecture: The name of the model architecture.
464 |
465 | Returns:
466 | Nothing.
467 | """
468 | how_many_bottlenecks = 0
469 | ensure_dir_exists(bottleneck_dir)
470 | for label_name, label_lists in image_lists.items():
471 | for category in ['training', 'testing', 'validation']:
472 | category_list = label_lists[category]
473 | for index, unused_base_name in enumerate(category_list):
474 | get_or_create_bottleneck(
475 | sess, image_lists, label_name, index, image_dir, category,
476 | bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
477 | resized_input_tensor, bottleneck_tensor, architecture)
478 |
479 | how_many_bottlenecks += 1
480 | if how_many_bottlenecks % 100 == 0:
481 | tf.logging.info(
482 | str(how_many_bottlenecks) + ' bottleneck files created.')
483 |
484 |
485 | def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
486 | bottleneck_dir, image_dir, jpeg_data_tensor,
487 | decoded_image_tensor, resized_input_tensor,
488 | bottleneck_tensor, architecture):
489 | """Retrieves bottleneck values for cached images.
490 |
491 | If no distortions are being applied, this function can retrieve the cached
492 | bottleneck values directly from disk for images. It picks a random set of
493 | images from the specified category.
494 |
495 | Args:
496 | sess: Current TensorFlow Session.
497 | image_lists: Dictionary of training images for each label.
498 | how_many: If positive, a random sample of this size will be chosen.
499 | If negative, all bottlenecks will be retrieved.
500 | category: Name string of which set to pull from - training, testing, or
501 | validation.
502 | bottleneck_dir: Folder string holding cached files of bottleneck values.
503 | image_dir: Root folder string of the subfolders containing the training
504 | images.
505 | jpeg_data_tensor: The layer to feed jpeg image data into.
506 | decoded_image_tensor: The output of decoding and resizing the image.
507 | resized_input_tensor: The input node of the recognition graph.
508 | bottleneck_tensor: The bottleneck output layer of the CNN graph.
509 | architecture: The name of the model architecture.
510 |
511 | Returns:
512 | List of bottleneck arrays, their corresponding ground truths, and the
513 | relevant filenames.
514 | """
515 | class_count = len(image_lists.keys())
516 | bottlenecks = []
517 | ground_truths = []
518 | filenames = []
519 | if how_many >= 0:
520 | # Retrieve a random sample of bottlenecks.
521 | for unused_i in range(how_many):
522 | label_index = random.randrange(class_count)
523 | label_name = list(image_lists.keys())[label_index]
524 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
525 | image_name = get_image_path(image_lists, label_name, image_index,
526 | image_dir, category)
527 | bottleneck = get_or_create_bottleneck(
528 | sess, image_lists, label_name, image_index, image_dir, category,
529 | bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
530 | resized_input_tensor, bottleneck_tensor, architecture)
531 | ground_truth = np.zeros(class_count, dtype=np.float32)
532 | ground_truth[label_index] = 1.0
533 | bottlenecks.append(bottleneck)
534 | ground_truths.append(ground_truth)
535 | filenames.append(image_name)
536 | else:
537 | # Retrieve all bottlenecks.
538 | for label_index, label_name in enumerate(image_lists.keys()):
539 | for image_index, image_name in enumerate(
540 | image_lists[label_name][category]):
541 | image_name = get_image_path(image_lists, label_name, image_index,
542 | image_dir, category)
543 | bottleneck = get_or_create_bottleneck(
544 | sess, image_lists, label_name, image_index, image_dir, category,
545 | bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
546 | resized_input_tensor, bottleneck_tensor, architecture)
547 | ground_truth = np.zeros(class_count, dtype=np.float32)
548 | ground_truth[label_index] = 1.0
549 | bottlenecks.append(bottleneck)
550 | ground_truths.append(ground_truth)
551 | filenames.append(image_name)
552 | return bottlenecks, ground_truths, filenames
553 |
554 |
555 | def get_random_distorted_bottlenecks(
556 | sess, image_lists, how_many, category, image_dir, input_jpeg_tensor,
557 | distorted_image, resized_input_tensor, bottleneck_tensor):
558 | """Retrieves bottleneck values for training images, after distortions.
559 |
560 | If we're training with distortions like crops, scales, or flips, we have to
561 | recalculate the full model for every image, and so we can't use cached
562 | bottleneck values. Instead we find random images for the requested category,
563 | run them through the distortion graph, and then the full graph to get the
564 | bottleneck results for each.
565 |
566 | Args:
567 | sess: Current TensorFlow Session.
568 | image_lists: Dictionary of training images for each label.
569 | how_many: The integer number of bottleneck values to return.
570 | category: Name string of which set of images to fetch - training, testing,
571 | or validation.
572 | image_dir: Root folder string of the subfolders containing the training
573 | images.
574 | input_jpeg_tensor: The input layer we feed the image data to.
575 | distorted_image: The output node of the distortion graph.
576 | resized_input_tensor: The input node of the recognition graph.
577 | bottleneck_tensor: The bottleneck output layer of the CNN graph.
578 |
579 | Returns:
580 | List of bottleneck arrays and their corresponding ground truths.
581 | """
582 | class_count = len(image_lists.keys())
583 | bottlenecks = []
584 | ground_truths = []
585 | for unused_i in range(how_many):
586 | label_index = random.randrange(class_count)
587 | label_name = list(image_lists.keys())[label_index]
588 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
589 | image_path = get_image_path(image_lists, label_name, image_index, image_dir,
590 | category)
591 | if not gfile.Exists(image_path):
592 | tf.logging.fatal('File does not exist %s', image_path)
593 | jpeg_data = gfile.FastGFile(image_path, 'rb').read()
594 | # Note that we materialize the distorted_image_data as a numpy array before
595 | # sending running inference on the image. This involves 2 memory copies and
596 | # might be optimized in other implementations.
597 | distorted_image_data = sess.run(distorted_image,
598 | {input_jpeg_tensor: jpeg_data})
599 | bottleneck_values = sess.run(bottleneck_tensor,
600 | {resized_input_tensor: distorted_image_data})
601 | bottleneck_values = np.squeeze(bottleneck_values)
602 | ground_truth = np.zeros(class_count, dtype=np.float32)
603 | ground_truth[label_index] = 1.0
604 | bottlenecks.append(bottleneck_values)
605 | ground_truths.append(ground_truth)
606 | return bottlenecks, ground_truths
607 |
608 |
609 | def should_distort_images(flip_left_right, random_crop, random_scale,
610 | random_brightness):
611 | """Whether any distortions are enabled, from the input flags.
612 |
613 | Args:
614 | flip_left_right: Boolean whether to randomly mirror images horizontally.
615 | random_crop: Integer percentage setting the total margin used around the
616 | crop box.
617 | random_scale: Integer percentage of how much to vary the scale by.
618 | random_brightness: Integer range to randomly multiply the pixel values by.
619 |
620 | Returns:
621 | Boolean value indicating whether any distortions should be applied.
622 | """
623 | return (flip_left_right or (random_crop != 0) or (random_scale != 0) or
624 | (random_brightness != 0))
625 |
626 |
627 | def add_input_distortions(flip_left_right, random_crop, random_scale,
628 | random_brightness, input_width, input_height,
629 | input_depth, input_mean, input_std):
630 | """Creates the operations to apply the specified distortions.
631 |
632 | During training it can help to improve the results if we run the images
633 | through simple distortions like crops, scales, and flips. These reflect the
634 | kind of variations we expect in the real world, and so can help train the
635 | model to cope with natural data more effectively. Here we take the supplied
636 | parameters and construct a network of operations to apply them to an image.
637 |
638 | Cropping
639 | ~~~~~~~~
640 |
641 | Cropping is done by placing a bounding box at a random position in the full
642 | image. The cropping parameter controls the size of that box relative to the
643 | input image. If it's zero, then the box is the same size as the input and no
644 | cropping is performed. If the value is 50%, then the crop box will be half the
645 | width and height of the input. In a diagram it looks like this:
646 |
647 | < width >
648 | +---------------------+
649 | | |
650 | | width - crop% |
651 | | < > |
652 | | +------+ |
653 | | | | |
654 | | | | |
655 | | | | |
656 | | +------+ |
657 | | |
658 | | |
659 | +---------------------+
660 |
661 | Scaling
662 | ~~~~~~~
663 |
664 | Scaling is a lot like cropping, except that the bounding box is always
665 | centered and its size varies randomly within the given range. For example if
666 | the scale percentage is zero, then the bounding box is the same size as the
667 | input and no scaling is applied. If it's 50%, then the bounding box will be in
668 | a random range between half the width and height and full size.
669 |
670 | Args:
671 | flip_left_right: Boolean whether to randomly mirror images horizontally.
672 | random_crop: Integer percentage setting the total margin used around the
673 | crop box.
674 | random_scale: Integer percentage of how much to vary the scale by.
675 | random_brightness: Integer range to randomly multiply the pixel values by.
676 | graph.
677 | input_width: Horizontal size of expected input image to model.
678 | input_height: Vertical size of expected input image to model.
679 | input_depth: How many channels the expected input image should have.
680 | input_mean: Pixel value that should be zero in the image for the graph.
681 | input_std: How much to divide the pixel values by before recognition.
682 |
683 | Returns:
684 | The jpeg input layer and the distorted result tensor.
685 | """
686 |
687 | jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
688 | decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
689 | decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
690 | decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
691 | margin_scale = 1.0 + (random_crop / 100.0)
692 | resize_scale = 1.0 + (random_scale / 100.0)
693 | margin_scale_value = tf.constant(margin_scale)
694 | resize_scale_value = tf.random_uniform(tensor_shape.scalar(),
695 | minval=1.0,
696 | maxval=resize_scale)
697 | scale_value = tf.multiply(margin_scale_value, resize_scale_value)
698 | precrop_width = tf.multiply(scale_value, input_width)
699 | precrop_height = tf.multiply(scale_value, input_height)
700 | precrop_shape = tf.stack([precrop_height, precrop_width])
701 | precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
702 | precropped_image = tf.image.resize_bilinear(decoded_image_4d,
703 | precrop_shape_as_int)
704 | precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
705 | cropped_image = tf.random_crop(precropped_image_3d,
706 | [input_height, input_width, input_depth])
707 | if flip_left_right:
708 | flipped_image = tf.image.random_flip_left_right(cropped_image)
709 | else:
710 | flipped_image = cropped_image
711 | brightness_min = 1.0 - (random_brightness / 100.0)
712 | brightness_max = 1.0 + (random_brightness / 100.0)
713 | brightness_value = tf.random_uniform(tensor_shape.scalar(),
714 | minval=brightness_min,
715 | maxval=brightness_max)
716 | brightened_image = tf.multiply(flipped_image, brightness_value)
717 | offset_image = tf.subtract(brightened_image, input_mean)
718 | mul_image = tf.multiply(offset_image, 1.0 / input_std)
719 | distort_result = tf.expand_dims(mul_image, 0, name='DistortResult')
720 | return jpeg_data, distort_result
721 |
722 |
723 | def variable_summaries(var):
724 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
725 | with tf.name_scope('summaries'):
726 | mean = tf.reduce_mean(var)
727 | tf.summary.scalar('mean', mean)
728 | with tf.name_scope('stddev'):
729 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
730 | tf.summary.scalar('stddev', stddev)
731 | tf.summary.scalar('max', tf.reduce_max(var))
732 | tf.summary.scalar('min', tf.reduce_min(var))
733 | tf.summary.histogram('histogram', var)
734 |
735 |
736 | def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
737 | bottleneck_tensor_size):
738 | """Adds a new softmax and fully-connected layer for training.
739 |
740 | We need to retrain the top layer to identify our new classes, so this function
741 | adds the right operations to the graph, along with some variables to hold the
742 | weights, and then sets up all the gradients for the backward pass.
743 |
744 | The set up for the softmax and fully-connected layers is based on:
745 | https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
746 |
747 | Args:
748 | class_count: Integer of how many categories of things we're trying to
749 | recognize.
750 | final_tensor_name: Name string for the new final node that produces results.
751 | bottleneck_tensor: The output of the main CNN graph.
752 | bottleneck_tensor_size: How many entries in the bottleneck vector.
753 |
754 | Returns:
755 | The tensors for the training and cross entropy results, and tensors for the
756 | bottleneck input and ground truth input.
757 | """
758 | with tf.name_scope('input'):
759 | bottleneck_input = tf.placeholder_with_default(
760 | bottleneck_tensor,
761 | shape=[None, bottleneck_tensor_size],
762 | name='BottleneckInputPlaceholder')
763 |
764 | ground_truth_input = tf.placeholder(tf.float32,
765 | [None, class_count],
766 | name='GroundTruthInput')
767 |
768 | # Organizing the following ops as `final_training_ops` so they're easier
769 | # to see in TensorBoard
770 | layer_name = 'final_training_ops'
771 | with tf.name_scope(layer_name):
772 | with tf.name_scope('weights'):
773 | initial_value = tf.truncated_normal(
774 | [bottleneck_tensor_size, class_count], stddev=0.001)
775 |
776 | layer_weights = tf.Variable(initial_value, name='final_weights')
777 |
778 | variable_summaries(layer_weights)
779 | with tf.name_scope('biases'):
780 | layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
781 | variable_summaries(layer_biases)
782 | with tf.name_scope('Wx_plus_b'):
783 | logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
784 | tf.summary.histogram('pre_activations', logits)
785 |
786 | final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
787 | tf.summary.histogram('activations', final_tensor)
788 |
789 | with tf.name_scope('cross_entropy'):
790 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
791 | labels=ground_truth_input, logits=logits)
792 | with tf.name_scope('total'):
793 | cross_entropy_mean = tf.reduce_mean(cross_entropy)
794 | tf.summary.scalar('cross_entropy', cross_entropy_mean)
795 |
796 | with tf.name_scope('train'):
797 | optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
798 | train_step = optimizer.minimize(cross_entropy_mean)
799 |
800 | return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
801 | final_tensor)
802 |
803 |
804 | def add_evaluation_step(result_tensor, ground_truth_tensor):
805 | """Inserts the operations we need to evaluate the accuracy of our results.
806 |
807 | Args:
808 | result_tensor: The new final node that produces results.
809 | ground_truth_tensor: The node we feed ground truth data
810 | into.
811 |
812 | Returns:
813 | Tuple of (evaluation step, prediction).
814 | """
815 | with tf.name_scope('accuracy'):
816 | with tf.name_scope('correct_prediction'):
817 | prediction = tf.argmax(result_tensor, 1)
818 | correct_prediction = tf.equal(
819 | prediction, tf.argmax(ground_truth_tensor, 1))
820 | with tf.name_scope('accuracy'):
821 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
822 | tf.summary.scalar('accuracy', evaluation_step)
823 | return evaluation_step, prediction
824 |
825 |
826 | def save_graph_to_file(sess, graph, graph_file_name):
827 | output_graph_def = graph_util.convert_variables_to_constants(
828 | sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
829 | with gfile.FastGFile(graph_file_name, 'wb') as f:
830 | f.write(output_graph_def.SerializeToString())
831 | return
832 |
833 |
834 | def prepare_file_system():
835 | # Setup the directory we'll write summaries to for TensorBoard
836 | if tf.gfile.Exists(FLAGS.summaries_dir):
837 | tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
838 | tf.gfile.MakeDirs(FLAGS.summaries_dir)
839 | if FLAGS.intermediate_store_frequency > 0:
840 | ensure_dir_exists(FLAGS.intermediate_output_graphs_dir)
841 | return
842 |
843 |
844 | def create_model_info(architecture):
845 | """Given the name of a model architecture, returns information about it.
846 |
847 | There are different base image recognition pretrained models that can be
848 | retrained using transfer learning, and this function translates from the name
849 | of a model to the attributes that are needed to download and train with it.
850 |
851 | Args:
852 | architecture: Name of a model architecture.
853 |
854 | Returns:
855 | Dictionary of information about the model, or None if the name isn't
856 | recognized
857 |
858 | Raises:
859 | ValueError: If architecture name is unknown.
860 | """
861 | architecture = architecture.lower()
862 | if architecture == 'inception_v3':
863 | # pylint: disable=line-too-long
864 | data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
865 | # pylint: enable=line-too-long
866 | bottleneck_tensor_name = 'pool_3/_reshape:0'
867 | bottleneck_tensor_size = 2048
868 | input_width = 299
869 | input_height = 299
870 | input_depth = 3
871 | resized_input_tensor_name = 'Mul:0'
872 | model_file_name = 'classify_image_graph_def.pb'
873 | input_mean = 128
874 | input_std = 128
875 | elif architecture.startswith('mobilenet_'):
876 | parts = architecture.split('_')
877 | if len(parts) != 3 and len(parts) != 4:
878 | tf.logging.error("Couldn't understand architecture name '%s'",
879 | architecture)
880 | return None
881 | version_string = parts[1]
882 | if (version_string != '1.0' and version_string != '0.75' and
883 | version_string != '0.50' and version_string != '0.25'):
884 | tf.logging.error(
885 | """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
886 | but found '%s' for architecture '%s'""",
887 | version_string, architecture)
888 | return None
889 | size_string = parts[2]
890 | if (size_string != '224' and size_string != '192' and
891 | size_string != '160' and size_string != '128'):
892 | tf.logging.error(
893 | """The Mobilenet input size should be '224', '192', '160', or '128',
894 | but found '%s' for architecture '%s'""",
895 | size_string, architecture)
896 | return None
897 | if len(parts) == 3:
898 | is_quantized = False
899 | else:
900 | if parts[3] != 'quantized':
901 | tf.logging.error(
902 | "Couldn't understand architecture suffix '%s' for '%s'", parts[3],
903 | architecture)
904 | return None
905 | is_quantized = True
906 | data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
907 | data_url += version_string + '_' + size_string + '_frozen.tgz'
908 | bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
909 | bottleneck_tensor_size = 1001
910 | input_width = int(size_string)
911 | input_height = int(size_string)
912 | input_depth = 3
913 | resized_input_tensor_name = 'input:0'
914 | if is_quantized:
915 | model_base_name = 'quantized_graph.pb'
916 | else:
917 | model_base_name = 'frozen_graph.pb'
918 | model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
919 | model_file_name = os.path.join(model_dir_name, model_base_name)
920 | input_mean = 127.5
921 | input_std = 127.5
922 | else:
923 | tf.logging.error("Couldn't understand architecture name '%s'", architecture)
924 | raise ValueError('Unknown architecture', architecture)
925 |
926 | return {
927 | 'data_url': data_url,
928 | 'bottleneck_tensor_name': bottleneck_tensor_name,
929 | 'bottleneck_tensor_size': bottleneck_tensor_size,
930 | 'input_width': input_width,
931 | 'input_height': input_height,
932 | 'input_depth': input_depth,
933 | 'resized_input_tensor_name': resized_input_tensor_name,
934 | 'model_file_name': model_file_name,
935 | 'input_mean': input_mean,
936 | 'input_std': input_std,
937 | }
938 |
939 |
940 | def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
941 | input_std):
942 | """Adds operations that perform JPEG decoding and resizing to the graph..
943 |
944 | Args:
945 | input_width: Desired width of the image fed into the recognizer graph.
946 | input_height: Desired width of the image fed into the recognizer graph.
947 | input_depth: Desired channels of the image fed into the recognizer graph.
948 | input_mean: Pixel value that should be zero in the image for the graph.
949 | input_std: How much to divide the pixel values by before recognition.
950 |
951 | Returns:
952 | Tensors for the node to feed JPEG data into, and the output of the
953 | preprocessing steps.
954 | """
955 | jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput')
956 | decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth)
957 | decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
958 | decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
959 | resize_shape = tf.stack([input_height, input_width])
960 | resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
961 | resized_image = tf.image.resize_bilinear(decoded_image_4d,
962 | resize_shape_as_int)
963 | offset_image = tf.subtract(resized_image, input_mean)
964 | mul_image = tf.multiply(offset_image, 1.0 / input_std)
965 | return jpeg_data, mul_image
966 |
967 |
968 | def main(_):
969 | # Needed to make sure the logging output is visible.
970 | # See https://github.com/tensorflow/tensorflow/issues/3047
971 | tf.logging.set_verbosity(tf.logging.INFO)
972 |
973 | # Prepare necessary directories that can be used during training
974 | prepare_file_system()
975 |
976 | # Gather information about the model architecture we'll be using.
977 | model_info = create_model_info(FLAGS.architecture)
978 | if not model_info:
979 | tf.logging.error('Did not recognize architecture flag')
980 | return -1
981 |
982 | # Set up the pre-trained graph.
983 | maybe_download_and_extract(model_info['data_url'])
984 | graph, bottleneck_tensor, resized_image_tensor = (
985 | create_model_graph(model_info))
986 |
987 | # Look at the folder structure, and create lists of all the images.
988 | image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
989 | FLAGS.validation_percentage)
990 | class_count = len(image_lists.keys())
991 | if class_count == 0:
992 | tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir)
993 | return -1
994 | if class_count == 1:
995 | tf.logging.error('Only one valid folder of images found at ' +
996 | FLAGS.image_dir +
997 | ' - multiple classes are needed for classification.')
998 | return -1
999 |
1000 | # See if the command-line flags mean we're applying any distortions.
1001 | do_distort_images = should_distort_images(
1002 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
1003 | FLAGS.random_brightness)
1004 |
1005 | with tf.Session(graph=graph) as sess:
1006 | # Set up the image decoding sub-graph.
1007 | jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(
1008 | model_info['input_width'], model_info['input_height'],
1009 | model_info['input_depth'], model_info['input_mean'],
1010 | model_info['input_std'])
1011 |
1012 | if do_distort_images:
1013 | # We will be applying distortions, so setup the operations we'll need.
1014 | (distorted_jpeg_data_tensor,
1015 | distorted_image_tensor) = add_input_distortions(
1016 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
1017 | FLAGS.random_brightness, model_info['input_width'],
1018 | model_info['input_height'], model_info['input_depth'],
1019 | model_info['input_mean'], model_info['input_std'])
1020 | else:
1021 | # We'll make sure we've calculated the 'bottleneck' image summaries and
1022 | # cached them on disk.
1023 | cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
1024 | FLAGS.bottleneck_dir, jpeg_data_tensor,
1025 | decoded_image_tensor, resized_image_tensor,
1026 | bottleneck_tensor, FLAGS.architecture)
1027 |
1028 | # Add the new layer that we'll be training.
1029 | (train_step, cross_entropy, bottleneck_input, ground_truth_input,
1030 | final_tensor) = add_final_training_ops(
1031 | len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
1032 | model_info['bottleneck_tensor_size'])
1033 |
1034 | # Create the operations we need to evaluate the accuracy of our new layer.
1035 | evaluation_step, prediction = add_evaluation_step(
1036 | final_tensor, ground_truth_input)
1037 |
1038 | # Merge all the summaries and write them out to the summaries_dir
1039 | merged = tf.summary.merge_all()
1040 | train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
1041 | sess.graph)
1042 |
1043 | validation_writer = tf.summary.FileWriter(
1044 | FLAGS.summaries_dir + '/validation')
1045 |
1046 | # Set up all our weights to their initial default values.
1047 | init = tf.global_variables_initializer()
1048 | sess.run(init)
1049 |
1050 | # Run the training for as many cycles as requested on the command line.
1051 | for i in range(FLAGS.how_many_training_steps):
1052 | # Get a batch of input bottleneck values, either calculated fresh every
1053 | # time with distortions applied, or from the cache stored on disk.
1054 | if do_distort_images:
1055 | (train_bottlenecks,
1056 | train_ground_truth) = get_random_distorted_bottlenecks(
1057 | sess, image_lists, FLAGS.train_batch_size, 'training',
1058 | FLAGS.image_dir, distorted_jpeg_data_tensor,
1059 | distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
1060 | else:
1061 | (train_bottlenecks,
1062 | train_ground_truth, _) = get_random_cached_bottlenecks(
1063 | sess, image_lists, FLAGS.train_batch_size, 'training',
1064 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1065 | decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1066 | FLAGS.architecture)
1067 | # Feed the bottlenecks and ground truth into the graph, and run a training
1068 | # step. Capture training summaries for TensorBoard with the `merged` op.
1069 | train_summary, _ = sess.run(
1070 | [merged, train_step],
1071 | feed_dict={bottleneck_input: train_bottlenecks,
1072 | ground_truth_input: train_ground_truth})
1073 | train_writer.add_summary(train_summary, i)
1074 |
1075 | # Every so often, print out how well the graph is training.
1076 | is_last_step = (i + 1 == FLAGS.how_many_training_steps)
1077 | if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
1078 | train_accuracy, cross_entropy_value = sess.run(
1079 | [evaluation_step, cross_entropy],
1080 | feed_dict={bottleneck_input: train_bottlenecks,
1081 | ground_truth_input: train_ground_truth})
1082 | tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %
1083 | (datetime.now(), i, train_accuracy * 100))
1084 | tf.logging.info('%s: Step %d: Cross entropy = %f' %
1085 | (datetime.now(), i, cross_entropy_value))
1086 | validation_bottlenecks, validation_ground_truth, _ = (
1087 | get_random_cached_bottlenecks(
1088 | sess, image_lists, FLAGS.validation_batch_size, 'validation',
1089 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1090 | decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1091 | FLAGS.architecture))
1092 | # Run a validation step and capture training summaries for TensorBoard
1093 | # with the `merged` op.
1094 | validation_summary, validation_accuracy = sess.run(
1095 | [merged, evaluation_step],
1096 | feed_dict={bottleneck_input: validation_bottlenecks,
1097 | ground_truth_input: validation_ground_truth})
1098 | validation_writer.add_summary(validation_summary, i)
1099 | tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
1100 | (datetime.now(), i, validation_accuracy * 100,
1101 | len(validation_bottlenecks)))
1102 |
1103 | # Store intermediate results
1104 | intermediate_frequency = FLAGS.intermediate_store_frequency
1105 |
1106 | if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
1107 | and i > 0):
1108 | intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
1109 | 'intermediate_' + str(i) + '.pb')
1110 | tf.logging.info('Save intermediate result to : ' +
1111 | intermediate_file_name)
1112 | save_graph_to_file(sess, graph, intermediate_file_name)
1113 |
1114 | # We've completed all our training, so run a final test evaluation on
1115 | # some new images we haven't used before.
1116 | test_bottlenecks, test_ground_truth, test_filenames = (
1117 | get_random_cached_bottlenecks(
1118 | sess, image_lists, FLAGS.test_batch_size, 'testing',
1119 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
1120 | decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
1121 | FLAGS.architecture))
1122 | test_accuracy, predictions = sess.run(
1123 | [evaluation_step, prediction],
1124 | feed_dict={bottleneck_input: test_bottlenecks,
1125 | ground_truth_input: test_ground_truth})
1126 | tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
1127 | (test_accuracy * 100, len(test_bottlenecks)))
1128 |
1129 | if FLAGS.print_misclassified_test_images:
1130 | tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===')
1131 | for i, test_filename in enumerate(test_filenames):
1132 | if predictions[i] != test_ground_truth[i].argmax():
1133 | tf.logging.info('%70s %s' %
1134 | (test_filename,
1135 | list(image_lists.keys())[predictions[i]]))
1136 |
1137 | # Write out the trained graph and labels with the weights stored as
1138 | # constants.
1139 | save_graph_to_file(sess, graph, FLAGS.output_graph)
1140 | with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
1141 | f.write('\n'.join(image_lists.keys()) + '\n')
1142 |
1143 |
1144 | if __name__ == '__main__':
1145 | parser = argparse.ArgumentParser()
1146 | parser.add_argument(
1147 | '--image_dir',
1148 | type=str,
1149 | default='',
1150 | help='Path to folders of labeled images.'
1151 | )
1152 | parser.add_argument(
1153 | '--output_graph',
1154 | type=str,
1155 | default='/tmp/output_graph.pb',
1156 | help='Where to save the trained graph.'
1157 | )
1158 | parser.add_argument(
1159 | '--intermediate_output_graphs_dir',
1160 | type=str,
1161 | default='/tmp/intermediate_graph/',
1162 | help='Where to save the intermediate graphs.'
1163 | )
1164 | parser.add_argument(
1165 | '--intermediate_store_frequency',
1166 | type=int,
1167 | default=0,
1168 | help="""\
1169 | How many steps to store intermediate graph. If "0" then will not
1170 | store.\
1171 | """
1172 | )
1173 | parser.add_argument(
1174 | '--output_labels',
1175 | type=str,
1176 | default='/tmp/output_labels.txt',
1177 | help='Where to save the trained graph\'s labels.'
1178 | )
1179 | parser.add_argument(
1180 | '--summaries_dir',
1181 | type=str,
1182 | default='/tmp/retrain_logs',
1183 | help='Where to save summary logs for TensorBoard.'
1184 | )
1185 | parser.add_argument(
1186 | '--how_many_training_steps',
1187 | type=int,
1188 | default=6000,
1189 | help='How many training steps to run before ending.'
1190 | )
1191 | parser.add_argument(
1192 | '--learning_rate',
1193 | type=float,
1194 | default=0.01,
1195 | help='How large a learning rate to use when training.'
1196 | )
1197 | parser.add_argument(
1198 | '--testing_percentage',
1199 | type=int,
1200 | default=10,
1201 | help='What percentage of images to use as a test set.'
1202 | )
1203 | parser.add_argument(
1204 | '--validation_percentage',
1205 | type=int,
1206 | default=10,
1207 | help='What percentage of images to use as a validation set.'
1208 | )
1209 | parser.add_argument(
1210 | '--eval_step_interval',
1211 | type=int,
1212 | default=10,
1213 | help='How often to evaluate the training results.'
1214 | )
1215 | parser.add_argument(
1216 | '--train_batch_size',
1217 | type=int,
1218 | default=100,
1219 | help='How many images to train on at a time.'
1220 | )
1221 | parser.add_argument(
1222 | '--test_batch_size',
1223 | type=int,
1224 | default=-1,
1225 | help="""\
1226 | How many images to test on. This test set is only used once, to evaluate
1227 | the final accuracy of the model after training completes.
1228 | A value of -1 causes the entire test set to be used, which leads to more
1229 | stable results across runs.\
1230 | """
1231 | )
1232 | parser.add_argument(
1233 | '--validation_batch_size',
1234 | type=int,
1235 | default=100,
1236 | help="""\
1237 | How many images to use in an evaluation batch. This validation set is
1238 | used much more often than the test set, and is an early indicator of how
1239 | accurate the model is during training.
1240 | A value of -1 causes the entire validation set to be used, which leads to
1241 | more stable results across training iterations, but may be slower on large
1242 | training sets.\
1243 | """
1244 | )
1245 | parser.add_argument(
1246 | '--print_misclassified_test_images',
1247 | default=False,
1248 | help="""\
1249 | Whether to print out a list of all misclassified test images.\
1250 | """,
1251 | action='store_true'
1252 | )
1253 | parser.add_argument(
1254 | '--model_dir',
1255 | type=str,
1256 | default='/tmp/imagenet',
1257 | help="""\
1258 | Path to classify_image_graph_def.pb,
1259 | imagenet_synset_to_human_label_map.txt, and
1260 | imagenet_2012_challenge_label_map_proto.pbtxt.\
1261 | """
1262 | )
1263 | parser.add_argument(
1264 | '--bottleneck_dir',
1265 | type=str,
1266 | default='/tmp/bottleneck',
1267 | help='Path to cache bottleneck layer values as files.'
1268 | )
1269 | parser.add_argument(
1270 | '--final_tensor_name',
1271 | type=str,
1272 | default='final_result',
1273 | help="""\
1274 | The name of the output classification layer in the retrained graph.\
1275 | """
1276 | )
1277 | parser.add_argument(
1278 | '--flip_left_right',
1279 | default=False,
1280 | help="""\
1281 | Whether to randomly flip half of the training images horizontally.\
1282 | """,
1283 | action='store_true'
1284 | )
1285 | parser.add_argument(
1286 | '--random_crop',
1287 | type=int,
1288 | default=0,
1289 | help="""\
1290 | A percentage determining how much of a margin to randomly crop off the
1291 | training images.\
1292 | """
1293 | )
1294 | parser.add_argument(
1295 | '--random_scale',
1296 | type=int,
1297 | default=0,
1298 | help="""\
1299 | A percentage determining how much to randomly scale up the size of the
1300 | training images by.\
1301 | """
1302 | )
1303 | parser.add_argument(
1304 | '--random_brightness',
1305 | type=int,
1306 | default=0,
1307 | help="""\
1308 | A percentage determining how much to randomly multiply the training image
1309 | input pixels up or down by.\
1310 | """
1311 | )
1312 | parser.add_argument(
1313 | '--architecture',
1314 | type=str,
1315 | default='inception_v3',
1316 | help="""\
1317 | Which model architecture to use. 'inception_v3' is the most accurate, but
1318 | also the slowest. For faster or smaller models, chose a MobileNet with the
1319 | form 'mobilenet__[_quantized]'. For example,
1320 | 'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224
1321 | pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much
1322 | less accurate, but smaller and faster network that's 920 KB on disk and
1323 | takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
1324 | for more information on Mobilenet.\
1325 | """)
1326 | FLAGS, unparsed = parser.parse_known_args()
1327 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
1328 |
--------------------------------------------------------------------------------
/training code.txt:
--------------------------------------------------------------------------------
1 | python retrain.py --output_graph=retrained_graph.pb --output_labels=retrained_labels.txt --architecture=MobileNet_1.0_224 --image_dir=images
2 |
--------------------------------------------------------------------------------