├── .idea ├── Emotion_recognition_system.iml ├── Face_Sentiment_analysis.iml ├── Tic_tac_toe.iml ├── misc.xml ├── modules.xml └── vcs.xml ├── Emotion_recognition.gif ├── Face_crop.py ├── LICENSE ├── README.md ├── Sample ├── Face_crop.py └── haarcascade_frontalface_alt.xml ├── android_recognition.py ├── avbin64.dll ├── haarcascade_frontalface_alt.xml ├── haarcascade_frontalface_default.xml ├── label_image.py ├── recognition_webcam.py ├── retrain.py └── training code.txt /.idea/Emotion_recognition_system.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/Face_Sentiment_analysis.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/Tic_tac_toe.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Emotion_recognition.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spidy20/Emotion_recognition_system/41241626a2c6cd24afcfb40f3ad5e0e469fff662/Emotion_recognition.gif -------------------------------------------------------------------------------- /Face_crop.py: -------------------------------------------------------------------------------- 1 | import cv2,glob 2 | 3 | images=glob.glob("*.jpg") 4 | 5 | for image in images: 6 | facedata = "haarcascade_frontalface_alt.xml" 7 | cascade = cv2.CascadeClassifier(facedata) 8 | img=cv2.imread(image,0) 9 | 10 | re=cv2.resize(img,(int(img.shape[1]),int(img.shape[0]))) 11 | faces = cascade.detectMultiScale(re) 12 | 13 | for f in faces: 14 | x, y, w, h = [v for v in f] 15 | Rect=cv2.rectangle(img, (x, y), (x + w, y + h), (0, 255, 0), 2) 16 | 17 | sub_face = img[y:y + h, x:x + w] 18 | 19 | f_name = image.split('/') 20 | f_name = f_name[-1] 21 | cv2.imshow("checking",sub_face) 22 | cv2.waitKey(500) 23 | cv2.destroyAllWindows() 24 | 25 | 26 | 27 | cv2.imwrite("resized_"+image,sub_face) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kushal Bhavsar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Emotion😂😥😡😱 Recognition system [![](https://img.shields.io/github/license/sourcerer-io/hall-of-fame.svg)](https://github.com/Spidy20/Emotion_recognition_system/blob/master/LICENSE) 2 | 3 | ### Sourcerer 4 | 5 | 6 | ### Code Requirements 7 | - Tensorflow 8 | Installation code:- `pip install tensorflow` 9 | - Download my repository 10 | - Own Expression dataset(Note: You can downlaod expression images from google, or you can record your video make diffrent expression ,and 11 | converts into Grayscale images(For more accurate prediction)) 12 | 13 | 14 | ### What steps you have to follow?? 15 | - Download my repository 16 | - Make `Images` folder in your project ,make subfolder for emotions like Happy,sad,Angry. 17 | - Put `Face_crop.py` & `haarcascade_frontalface_alt.xml` in every type of image folder,ex : put this program in "happy' image folder and 18 | run this program it will detect faces from images and convert it into grayscale and make a new images in same folder. 19 | - After that you have to create model, for that copy code from code.txt file and open CMD in your project folder and paste it & enter 20 | - It will take training aaround 20-25 minutes so keep patience. 21 | - After training it will create two files `retrained_graph.pb` & `retrained_labels.txt` 22 | - Now run `recognition_webcam.py`. 23 | - If you want to fetch video from your mobile cam than use `android_recognition.py`,but you have to install IPWebcam app in your system 24 | and replace your server URL with my URL 25 | - That's all 26 | 27 | ### How it works? See:) 28 | 29 | 30 | 31 | ### Notes 32 | - It will require high processing power(I have 8 GB RAM & 2 GB GC) 33 | - If you think it will recognise expression just like humans,than leave it ,its not possible. 34 | - Download 300 Images for every expression(you can use batch downloader) 35 | - Noisy image can reduce your accuracy so quality of images matter. 36 | 37 | 38 | ## Just follow☝️ me and Star⭐ my repository 39 | -------------------------------------------------------------------------------- /Sample/Face_crop.py: -------------------------------------------------------------------------------- 1 | import cv2,glob 2 | 3 | images=glob.glob("*.jpg") 4 | 5 | for image in images: 6 | 7 | img=cv2.imread(image,0) 8 | 9 | re=cv2.resize(img,(int(img.shape[1]),int(img.shape[0]))) 10 | 11 | 12 | for f in re: 13 | f_name = image.split('/') 14 | f_name = f_name[-1] 15 | cv2.imshow("checking",sub_face) 16 | cv2.waitKey(500) 17 | cv2.destroyAllWindows() 18 | 19 | 20 | 21 | cv2.imwrite("resized_"+image,sub_face) -------------------------------------------------------------------------------- /android_recognition.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import label_image 3 | import os 4 | import numpy as np 5 | from urllib.request import urlopen 6 | import time 7 | from playsound import playsound 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 9 | 10 | size = 4 11 | 12 | # We load the xml file 13 | classifier = cv2.CascadeClassifier('haarcascade_frontalface_alt.xml') 14 | mobile_video="http://192.168.0.101:8080/shot.jpg" 15 | # Using default WebCam connected to the PC. 16 | 17 | while True: 18 | 19 | img_resp = urlopen(mobile_video) 20 | img_arr = np.array(bytearray(img_resp.read()), dtype=np.uint8) 21 | 22 | cap = cv2.imdecode(img_arr, -1) 23 | im=cv2.flip(cap,1,0) 24 | mini = cv2.resize(im, (int(im.shape[1] / size), int(im.shape[0] / size))) 25 | 26 | # detect MultiScale / faces 27 | faces = classifier.detectMultiScale(mini) 28 | 29 | # Draw rectangles around each face 30 | for f in faces: 31 | (x, y, w, h) = [v * size for v in f] # Scale the shapesize backup 32 | sub_face = im[y:y + h, x:x + w] 33 | FaceFileName = "test.jpg" # Saving the current image from the webcam for testing. 34 | cv2.imwrite(FaceFileName, sub_face) 35 | text = label_image.main(FaceFileName) # Getting the Result from the label_image file, i.e., Classification Result. 36 | text = text.title() # Title Case looks Stunning. 37 | font = cv2.FONT_HERSHEY_TRIPLEX 38 | 39 | if text == 'Angry': 40 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 25, 255), 7) 41 | cv2.putText(im, text, (x + h, y), font, 1, (0, 25,255), 2) 42 | 43 | if text == 'Smile': 44 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,260,0), 7) 45 | cv2.putText(im, text, (x + h, y), font, 1, (0,260,0), 2) 46 | 47 | if text == 'Fear': 48 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 255), 7) 49 | cv2.putText(im, text, (x + h, y), font, 1, (0, 255, 255), 2) 50 | 51 | if text == 'Sad': 52 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,191,255), 7) 53 | cv2.putText(im, text, (x + h, y), font, 1, (0,191,255), 2) 54 | 55 | # Show the image/ 56 | cv2.imshow('Emotion recognition from Android screen', im) 57 | key = cv2.waitKey(30)& 0xff 58 | if key == 27: # The Esc key 59 | 60 | break 61 | 62 | 63 | -------------------------------------------------------------------------------- /avbin64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Spidy20/Emotion_recognition_system/41241626a2c6cd24afcfb40f3ad5e0e469fff662/avbin64.dll -------------------------------------------------------------------------------- /label_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import argparse 7 | import sys 8 | import time 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | def load_graph(model_file): 14 | graph = tf.Graph() 15 | graph_def = tf.GraphDef() 16 | 17 | with open(model_file, "rb") as f: 18 | graph_def.ParseFromString(f.read()) 19 | with graph.as_default(): 20 | tf.import_graph_def(graph_def) 21 | 22 | return graph 23 | 24 | def read_tensor_from_image_file(file_name, input_height=299, input_width=299, 25 | input_mean=0, input_std=255): 26 | input_name = "file_reader" 27 | output_name = "normalized" 28 | file_reader = tf.read_file(file_name, input_name) 29 | if file_name.endswith(".png"): 30 | image_reader = tf.image.decode_png(file_reader, channels = 3, 31 | name='png_reader') 32 | elif file_name.endswith(".gif"): 33 | image_reader = tf.squeeze(tf.image.decode_gif(file_reader, 34 | name='gif_reader')) 35 | elif file_name.endswith(".bmp"): 36 | image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader') 37 | else: 38 | image_reader = tf.image.decode_jpeg(file_reader, channels = 3, 39 | name='jpeg_reader') 40 | float_caster = tf.cast(image_reader, tf.float32) 41 | dims_expander = tf.expand_dims(float_caster, 0); 42 | resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width]) 43 | normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std]) 44 | sess = tf.Session() 45 | result = sess.run(normalized) 46 | 47 | return result 48 | 49 | def load_labels(label_file): 50 | label = [] 51 | 52 | proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() 53 | for l in proto_as_ascii_lines: 54 | label.append(l.rstrip()) 55 | 56 | return label 57 | 58 | 59 | 60 | def main(img): 61 | file_name = img 62 | model_file = "retrained_graph.pb" 63 | label_file = "retrained_labels.txt" 64 | input_height = 224 65 | input_width = 224 66 | input_mean = 128 67 | input_std = 128 68 | input_layer = "input" 69 | output_layer = "final_result" 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument("--image", help="image to be processed") 73 | parser.add_argument("--graph", help="graph/model to be executed") 74 | parser.add_argument("--labels", help="name of file containing labels") 75 | parser.add_argument("--input_height", type=int, help="input height") 76 | parser.add_argument("--input_width", type=int, help="input width") 77 | parser.add_argument("--input_mean", type=int, help="input mean") 78 | parser.add_argument("--input_std", type=int, help="input std") 79 | parser.add_argument("--input_layer", help="name of input layer") 80 | parser.add_argument("--output_layer", help="name of output layer") 81 | args = parser.parse_args() 82 | 83 | if args.graph: 84 | model_file = args.graph 85 | if args.image: 86 | file_name = args.image 87 | if args.labels: 88 | label_file = args.labels 89 | if args.input_height: 90 | input_height = args.input_height 91 | if args.input_width: 92 | input_width = args.input_width 93 | if args.input_mean: 94 | input_mean = args.input_mean 95 | if args.input_std: 96 | input_std = args.input_std 97 | if args.input_layer: 98 | input_layer = args.input_layer 99 | if args.output_layer: 100 | output_layer = args.output_layer 101 | 102 | graph = load_graph(model_file) 103 | t = read_tensor_from_image_file(file_name, 104 | input_height=input_height, 105 | input_width=input_width, 106 | input_mean=input_mean, 107 | input_std=input_std) 108 | 109 | input_name = "import/" + input_layer 110 | output_name = "import/" + output_layer 111 | input_operation = graph.get_operation_by_name(input_name); 112 | output_operation = graph.get_operation_by_name(output_name); 113 | 114 | 115 | with tf.Session(graph=graph) as sess: 116 | start = time.time() 117 | results = sess.run(output_operation.outputs[0], 118 | {input_operation.outputs[0]: t}) 119 | end=time.time() 120 | results = np.squeeze(results) 121 | 122 | top_k = results.argsort()[-5:][::-1] 123 | labels = load_labels(label_file) 124 | 125 | 126 | 127 | 128 | 129 | 130 | for i in top_k: 131 | return labels[i] -------------------------------------------------------------------------------- /recognition_webcam.py: -------------------------------------------------------------------------------- 1 | #Coded by:- Kushal Bhavsra 2 | #From:- Techmicra IT solution 3 | import time 4 | import cv2 5 | import label_image 6 | import os,random 7 | import subprocess 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 9 | 10 | size = 4 11 | # We load the xml file 12 | classifier = cv2.CascadeClassifier('haarcascade_frontalface_alt.xml') 13 | global text 14 | webcam = cv2.VideoCapture(0) # Using default WebCam connected to the PC. 15 | while True: 16 | (rval, im) = webcam.read() 17 | im = cv2.flip(im, 1, 0) # Flip to act as a mirror 18 | # Resize the image to speed up detection 19 | mini = cv2.resize(im, (int(im.shape[1] / size), int(im.shape[0] / size))) 20 | # detect MultiScale / faces 21 | faces = classifier.detectMultiScale(mini) 22 | # Draw rectangles around each face 23 | for f in faces: 24 | (x, y, w, h) = [v * size for v in f] # Scale the shapesize backup 25 | sub_face = im[y:y + h, x:x + w] 26 | FaceFileName = "test.jpg" # Saving the current image from the webcam for testing. 27 | cv2.imwrite(FaceFileName, sub_face) 28 | text = label_image.main(FaceFileName) # Getting the Result from the label_image file, i.e., Classification Result. 29 | text = text.title() # Title Case looks Stunning. 30 | font = cv2.FONT_HERSHEY_TRIPLEX 31 | 32 | if text == 'Angry': 33 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 25, 255), 7) 34 | cv2.putText(im, text, (x + h, y), font, 1, (0, 25,255), 2) 35 | 36 | if text == 'Smile': 37 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,260,0), 7) 38 | cv2.putText(im, text, (x + h, y), font, 1, (0,260,0), 2) 39 | 40 | if text == 'Fear': 41 | cv2.rectangle(im, (x, y), (x + w, y + h), (0, 255, 255), 7) 42 | cv2.putText(im, text, (x + h, y), font, 1, (0, 255, 255), 2) 43 | 44 | if text == 'Sad': 45 | cv2.rectangle(im, (x, y), (x + w, y + h), (0,191,255), 7) 46 | cv2.putText(im, text, (x + h, y), font, 1, (0,191,255), 2) 47 | 48 | # Show the image/ 49 | cv2.imshow('Emotion recognition', im) 50 | key = cv2.waitKey(30)& 0xff 51 | 52 | if key == 27: # The Esc key 53 | break -------------------------------------------------------------------------------- /retrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Simple transfer learning with Inception v3 or Mobilenet models. 16 | 17 | With support for TensorBoard. 18 | 19 | This example shows how to take a Inception v3 or Mobilenet model trained on 20 | ImageNet images, and train a new top layer that can recognize other classes of 21 | images. 22 | 23 | The top layer receives as input a 2048-dimensional vector (1001-dimensional for 24 | Mobilenet) for each image. We train a softmax layer on top of this 25 | representation. Assuming the softmax layer contains N labels, this corresponds 26 | to learning N + 2048*N (or 1001*N) model parameters corresponding to the 27 | learned biases and weights. 28 | 29 | Here's an example, which assumes you have a folder containing class-named 30 | subfolders, each full of images for each label. The example folder flower_photos 31 | should have a structure like this: 32 | 33 | ~/flower_photos/daisy/photo1.jpg 34 | ~/flower_photos/daisy/photo2.jpg 35 | ... 36 | ~/flower_photos/rose/anotherphoto77.jpg 37 | ... 38 | ~/flower_photos/sunflower/somepicture.jpg 39 | 40 | The subfolder names are important, since they define what label is applied to 41 | each image, but the filenames themselves don't matter. Once your images are 42 | prepared, you can run the training with a command like this: 43 | 44 | 45 | ```bash 46 | bazel build tensorflow/examples/image_retraining:retrain && \ 47 | bazel-bin/tensorflow/examples/image_retraining/retrain \ 48 | --image_dir ~/flower_photos 49 | ``` 50 | 51 | Or, if you have a pip installation of tensorflow, `retrain.py` can be run 52 | without bazel: 53 | 54 | ```bash 55 | python tensorflow/examples/image_retraining/retrain.py \ 56 | --image_dir ~/flower_photos 57 | ``` 58 | 59 | You can replace the image_dir argument with any folder containing subfolders of 60 | images. The label for each image is taken from the name of the subfolder it's 61 | in. 62 | 63 | This produces a new model file that can be loaded and run by any TensorFlow 64 | program, for example the label_image sample code. 65 | 66 | By default this script will use the high accuracy, but comparatively large and 67 | slow Inception v3 model architecture. It's recommended that you start with this 68 | to validate that you have gathered good training data, but if you want to deploy 69 | on resource-limited platforms, you can try the `--architecture` flag with a 70 | Mobilenet model. For example: 71 | 72 | ```bash 73 | python tensorflow/examples/image_retraining/retrain.py \ 74 | --image_dir ~/flower_photos --architecture mobilenet_1.0_224 75 | ``` 76 | 77 | There are 32 different Mobilenet models to choose from, with a variety of file 78 | size and latency options. The first number can be '1.0', '0.75', '0.50', or 79 | '0.25' to control the size, and the second controls the input image size, either 80 | '224', '192', '160', or '128', with smaller sizes running faster. See 81 | https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html 82 | for more information on Mobilenet. 83 | 84 | To use with TensorBoard: 85 | 86 | By default, this script will log summaries to /tmp/retrain_logs directory 87 | 88 | Visualize the summaries with this command: 89 | 90 | tensorboard --logdir /tmp/retrain_logs 91 | 92 | """ 93 | from __future__ import absolute_import 94 | from __future__ import division 95 | from __future__ import print_function 96 | 97 | import argparse 98 | import collections 99 | from datetime import datetime 100 | import hashlib 101 | import os.path 102 | import random 103 | import re 104 | import sys 105 | import tarfile 106 | 107 | import numpy as np 108 | from six.moves import urllib 109 | import tensorflow as tf 110 | 111 | 112 | from tensorflow.python.framework import graph_util 113 | from tensorflow.python.framework import tensor_shape 114 | from tensorflow.python.platform import gfile 115 | from tensorflow.python.util import compat 116 | 117 | FLAGS = None 118 | 119 | # These are all parameters that are tied to the particular model architecture 120 | # we're using for Inception v3. These include things like tensor names and their 121 | # sizes. If you want to adapt this script to work with another model, you will 122 | # need to update these to reflect the values in the network you're using. 123 | MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M 124 | 125 | 126 | def create_image_lists(image_dir, testing_percentage, validation_percentage): 127 | """Builds a list of training images from the file system. 128 | 129 | Analyzes the sub folders in the image directory, splits them into stable 130 | training, testing, and validation sets, and returns a data structure 131 | describing the lists of images for each label and their paths. 132 | 133 | Args: 134 | image_dir: String path to a folder containing subfolders of images. 135 | testing_percentage: Integer percentage of the images to reserve for tests. 136 | validation_percentage: Integer percentage of images reserved for validation. 137 | 138 | Returns: 139 | A dictionary containing an entry for each label subfolder, with images split 140 | into training, testing, and validation sets within each label. 141 | """ 142 | if not gfile.Exists(image_dir): 143 | tf.logging.error("Image directory '" + image_dir + "' not found.") 144 | return None 145 | result = collections.OrderedDict() 146 | sub_dirs = [ 147 | os.path.join(image_dir,item) 148 | for item in gfile.ListDirectory(image_dir)] 149 | sub_dirs = sorted(item for item in sub_dirs 150 | if gfile.IsDirectory(item)) 151 | for sub_dir in sub_dirs: 152 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 153 | file_list = [] 154 | dir_name = os.path.basename(sub_dir) 155 | if dir_name == image_dir: 156 | continue 157 | tf.logging.info("Looking for images in '" + dir_name + "'") 158 | for extension in extensions: 159 | file_glob = os.path.join(image_dir, dir_name, '*.' + extension) 160 | file_list.extend(gfile.Glob(file_glob)) 161 | if not file_list: 162 | tf.logging.warning('No files found') 163 | continue 164 | if len(file_list) < 20: 165 | tf.logging.warning( 166 | 'WARNING: Folder has less than 20 images, which may cause issues.') 167 | elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS: 168 | tf.logging.warning( 169 | 'WARNING: Folder {} has more than {} images. Some images will ' 170 | 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS)) 171 | label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) 172 | training_images = [] 173 | testing_images = [] 174 | validation_images = [] 175 | for file_name in file_list: 176 | base_name = os.path.basename(file_name) 177 | # We want to ignore anything after '_nohash_' in the file name when 178 | # deciding which set to put an image in, the data set creator has a way of 179 | # grouping photos that are close variations of each other. For example 180 | # this is used in the plant disease data set to group multiple pictures of 181 | # the same leaf. 182 | hash_name = re.sub(r'_nohash_.*$', '', file_name) 183 | # This looks a bit magical, but we need to decide whether this file should 184 | # go into the training, testing, or validation sets, and we want to keep 185 | # existing files in the same set even if more files are subsequently 186 | # added. 187 | # To do that, we need a stable way of deciding based on just the file name 188 | # itself, so we do a hash of that and then use that to generate a 189 | # probability value that we use to assign it. 190 | hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() 191 | percentage_hash = ((int(hash_name_hashed, 16) % 192 | (MAX_NUM_IMAGES_PER_CLASS + 1)) * 193 | (100.0 / MAX_NUM_IMAGES_PER_CLASS)) 194 | if percentage_hash < validation_percentage: 195 | validation_images.append(base_name) 196 | elif percentage_hash < (testing_percentage + validation_percentage): 197 | testing_images.append(base_name) 198 | else: 199 | training_images.append(base_name) 200 | result[label_name] = { 201 | 'dir': dir_name, 202 | 'training': training_images, 203 | 'testing': testing_images, 204 | 'validation': validation_images, 205 | } 206 | return result 207 | 208 | 209 | def get_image_path(image_lists, label_name, index, image_dir, category): 210 | """"Returns a path to an image for a label at the given index. 211 | 212 | Args: 213 | image_lists: Dictionary of training images for each label. 214 | label_name: Label string we want to get an image for. 215 | index: Int offset of the image we want. This will be moduloed by the 216 | available number of images for the label, so it can be arbitrarily large. 217 | image_dir: Root folder string of the subfolders containing the training 218 | images. 219 | category: Name string of set to pull images from - training, testing, or 220 | validation. 221 | 222 | Returns: 223 | File system path string to an image that meets the requested parameters. 224 | 225 | """ 226 | if label_name not in image_lists: 227 | tf.logging.fatal('Label does not exist %s.', label_name) 228 | label_lists = image_lists[label_name] 229 | if category not in label_lists: 230 | tf.logging.fatal('Category does not exist %s.', category) 231 | category_list = label_lists[category] 232 | if not category_list: 233 | tf.logging.fatal('Label %s has no images in the category %s.', 234 | label_name, category) 235 | mod_index = index % len(category_list) 236 | base_name = category_list[mod_index] 237 | sub_dir = label_lists['dir'] 238 | full_path = os.path.join(image_dir, sub_dir, base_name) 239 | return full_path 240 | 241 | 242 | def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, 243 | category, architecture): 244 | """"Returns a path to a bottleneck file for a label at the given index. 245 | 246 | Args: 247 | image_lists: Dictionary of training images for each label. 248 | label_name: Label string we want to get an image for. 249 | index: Integer offset of the image we want. This will be moduloed by the 250 | available number of images for the label, so it can be arbitrarily large. 251 | bottleneck_dir: Folder string holding cached files of bottleneck values. 252 | category: Name string of set to pull images from - training, testing, or 253 | validation. 254 | architecture: The name of the model architecture. 255 | 256 | Returns: 257 | File system path string to an image that meets the requested parameters. 258 | """ 259 | return get_image_path(image_lists, label_name, index, bottleneck_dir, 260 | category) + '_' + architecture + '.txt' 261 | 262 | 263 | def create_model_graph(model_info): 264 | """"Creates a graph from saved GraphDef file and returns a Graph object. 265 | 266 | Args: 267 | model_info: Dictionary containing information about the model architecture. 268 | 269 | Returns: 270 | Graph holding the trained Inception network, and various tensors we'll be 271 | manipulating. 272 | """ 273 | with tf.Graph().as_default() as graph: 274 | model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name']) 275 | with gfile.FastGFile(model_path, 'rb') as f: 276 | graph_def = tf.GraphDef() 277 | graph_def.ParseFromString(f.read()) 278 | bottleneck_tensor, resized_input_tensor = (tf.import_graph_def( 279 | graph_def, 280 | name='', 281 | return_elements=[ 282 | model_info['bottleneck_tensor_name'], 283 | model_info['resized_input_tensor_name'], 284 | ])) 285 | return graph, bottleneck_tensor, resized_input_tensor 286 | 287 | 288 | def run_bottleneck_on_image(sess, image_data, image_data_tensor, 289 | decoded_image_tensor, resized_input_tensor, 290 | bottleneck_tensor): 291 | """Runs inference on an image to extract the 'bottleneck' summary layer. 292 | 293 | Args: 294 | sess: Current active TensorFlow Session. 295 | image_data: String of raw JPEG data. 296 | image_data_tensor: Input data layer in the graph. 297 | decoded_image_tensor: Output of initial image resizing and preprocessing. 298 | resized_input_tensor: The input node of the recognition graph. 299 | bottleneck_tensor: Layer before the final softmax. 300 | 301 | Returns: 302 | Numpy array of bottleneck values. 303 | """ 304 | # First decode the JPEG image, resize it, and rescale the pixel values. 305 | resized_input_values = sess.run(decoded_image_tensor, 306 | {image_data_tensor: image_data}) 307 | # Then run it through the recognition network. 308 | bottleneck_values = sess.run(bottleneck_tensor, 309 | {resized_input_tensor: resized_input_values}) 310 | bottleneck_values = np.squeeze(bottleneck_values) 311 | return bottleneck_values 312 | 313 | 314 | def maybe_download_and_extract(data_url): 315 | """Download and extract model tar file. 316 | 317 | If the pretrained model we're using doesn't already exist, this function 318 | downloads it from the TensorFlow.org website and unpacks it into a directory. 319 | 320 | Args: 321 | data_url: Web location of the tar file containing the pretrained model. 322 | """ 323 | dest_directory = FLAGS.model_dir 324 | if not os.path.exists(dest_directory): 325 | os.makedirs(dest_directory) 326 | filename = data_url.split('/')[-1] 327 | filepath = os.path.join(dest_directory, filename) 328 | if not os.path.exists(filepath): 329 | 330 | def _progress(count, block_size, total_size): 331 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 332 | (filename, 333 | float(count * block_size) / float(total_size) * 100.0)) 334 | sys.stdout.flush() 335 | 336 | filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) 337 | print() 338 | statinfo = os.stat(filepath) 339 | tf.logging.info('Successfully downloaded', filename, statinfo.st_size, 340 | 'bytes.') 341 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 342 | 343 | 344 | def ensure_dir_exists(dir_name): 345 | """Makes sure the folder exists on disk. 346 | 347 | Args: 348 | dir_name: Path string to the folder we want to create. 349 | """ 350 | if not os.path.exists(dir_name): 351 | os.makedirs(dir_name) 352 | 353 | 354 | bottleneck_path_2_bottleneck_values = {} 355 | 356 | 357 | def create_bottleneck_file(bottleneck_path, image_lists, label_name, index, 358 | image_dir, category, sess, jpeg_data_tensor, 359 | decoded_image_tensor, resized_input_tensor, 360 | bottleneck_tensor): 361 | """Create a single bottleneck file.""" 362 | tf.logging.info('Creating bottleneck at ' + bottleneck_path) 363 | image_path = get_image_path(image_lists, label_name, index, 364 | image_dir, category) 365 | if not gfile.Exists(image_path): 366 | tf.logging.fatal('File does not exist %s', image_path) 367 | image_data = gfile.FastGFile(image_path, 'rb').read() 368 | try: 369 | bottleneck_values = run_bottleneck_on_image( 370 | sess, image_data, jpeg_data_tensor, decoded_image_tensor, 371 | resized_input_tensor, bottleneck_tensor) 372 | except Exception as e: 373 | raise RuntimeError('Error during processing file %s (%s)' % (image_path, 374 | str(e))) 375 | bottleneck_string = ','.join(str(x) for x in bottleneck_values) 376 | with open(bottleneck_path, 'w') as bottleneck_file: 377 | bottleneck_file.write(bottleneck_string) 378 | 379 | 380 | def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, 381 | category, bottleneck_dir, jpeg_data_tensor, 382 | decoded_image_tensor, resized_input_tensor, 383 | bottleneck_tensor, architecture): 384 | """Retrieves or calculates bottleneck values for an image. 385 | 386 | If a cached version of the bottleneck data exists on-disk, return that, 387 | otherwise calculate the data and save it to disk for future use. 388 | 389 | Args: 390 | sess: The current active TensorFlow Session. 391 | image_lists: Dictionary of training images for each label. 392 | label_name: Label string we want to get an image for. 393 | index: Integer offset of the image we want. This will be modulo-ed by the 394 | available number of images for the label, so it can be arbitrarily large. 395 | image_dir: Root folder string of the subfolders containing the training 396 | images. 397 | category: Name string of which set to pull images from - training, testing, 398 | or validation. 399 | bottleneck_dir: Folder string holding cached files of bottleneck values. 400 | jpeg_data_tensor: The tensor to feed loaded jpeg data into. 401 | decoded_image_tensor: The output of decoding and resizing the image. 402 | resized_input_tensor: The input node of the recognition graph. 403 | bottleneck_tensor: The output tensor for the bottleneck values. 404 | architecture: The name of the model architecture. 405 | 406 | Returns: 407 | Numpy array of values produced by the bottleneck layer for the image. 408 | """ 409 | label_lists = image_lists[label_name] 410 | sub_dir = label_lists['dir'] 411 | sub_dir_path = os.path.join(bottleneck_dir, sub_dir) 412 | ensure_dir_exists(sub_dir_path) 413 | bottleneck_path = get_bottleneck_path(image_lists, label_name, index, 414 | bottleneck_dir, category, architecture) 415 | if not os.path.exists(bottleneck_path): 416 | create_bottleneck_file(bottleneck_path, image_lists, label_name, index, 417 | image_dir, category, sess, jpeg_data_tensor, 418 | decoded_image_tensor, resized_input_tensor, 419 | bottleneck_tensor) 420 | with open(bottleneck_path, 'r') as bottleneck_file: 421 | bottleneck_string = bottleneck_file.read() 422 | did_hit_error = False 423 | try: 424 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')] 425 | except ValueError: 426 | tf.logging.warning('Invalid float found, recreating bottleneck') 427 | did_hit_error = True 428 | if did_hit_error: 429 | create_bottleneck_file(bottleneck_path, image_lists, label_name, index, 430 | image_dir, category, sess, jpeg_data_tensor, 431 | decoded_image_tensor, resized_input_tensor, 432 | bottleneck_tensor) 433 | with open(bottleneck_path, 'r') as bottleneck_file: 434 | bottleneck_string = bottleneck_file.read() 435 | # Allow exceptions to propagate here, since they shouldn't happen after a 436 | # fresh creation 437 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')] 438 | return bottleneck_values 439 | 440 | 441 | def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir, 442 | jpeg_data_tensor, decoded_image_tensor, 443 | resized_input_tensor, bottleneck_tensor, architecture): 444 | """Ensures all the training, testing, and validation bottlenecks are cached. 445 | 446 | Because we're likely to read the same image multiple times (if there are no 447 | distortions applied during training) it can speed things up a lot if we 448 | calculate the bottleneck layer values once for each image during 449 | preprocessing, and then just read those cached values repeatedly during 450 | training. Here we go through all the images we've found, calculate those 451 | values, and save them off. 452 | 453 | Args: 454 | sess: The current active TensorFlow Session. 455 | image_lists: Dictionary of training images for each label. 456 | image_dir: Root folder string of the subfolders containing the training 457 | images. 458 | bottleneck_dir: Folder string holding cached files of bottleneck values. 459 | jpeg_data_tensor: Input tensor for jpeg data from file. 460 | decoded_image_tensor: The output of decoding and resizing the image. 461 | resized_input_tensor: The input node of the recognition graph. 462 | bottleneck_tensor: The penultimate output layer of the graph. 463 | architecture: The name of the model architecture. 464 | 465 | Returns: 466 | Nothing. 467 | """ 468 | how_many_bottlenecks = 0 469 | ensure_dir_exists(bottleneck_dir) 470 | for label_name, label_lists in image_lists.items(): 471 | for category in ['training', 'testing', 'validation']: 472 | category_list = label_lists[category] 473 | for index, unused_base_name in enumerate(category_list): 474 | get_or_create_bottleneck( 475 | sess, image_lists, label_name, index, image_dir, category, 476 | bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, 477 | resized_input_tensor, bottleneck_tensor, architecture) 478 | 479 | how_many_bottlenecks += 1 480 | if how_many_bottlenecks % 100 == 0: 481 | tf.logging.info( 482 | str(how_many_bottlenecks) + ' bottleneck files created.') 483 | 484 | 485 | def get_random_cached_bottlenecks(sess, image_lists, how_many, category, 486 | bottleneck_dir, image_dir, jpeg_data_tensor, 487 | decoded_image_tensor, resized_input_tensor, 488 | bottleneck_tensor, architecture): 489 | """Retrieves bottleneck values for cached images. 490 | 491 | If no distortions are being applied, this function can retrieve the cached 492 | bottleneck values directly from disk for images. It picks a random set of 493 | images from the specified category. 494 | 495 | Args: 496 | sess: Current TensorFlow Session. 497 | image_lists: Dictionary of training images for each label. 498 | how_many: If positive, a random sample of this size will be chosen. 499 | If negative, all bottlenecks will be retrieved. 500 | category: Name string of which set to pull from - training, testing, or 501 | validation. 502 | bottleneck_dir: Folder string holding cached files of bottleneck values. 503 | image_dir: Root folder string of the subfolders containing the training 504 | images. 505 | jpeg_data_tensor: The layer to feed jpeg image data into. 506 | decoded_image_tensor: The output of decoding and resizing the image. 507 | resized_input_tensor: The input node of the recognition graph. 508 | bottleneck_tensor: The bottleneck output layer of the CNN graph. 509 | architecture: The name of the model architecture. 510 | 511 | Returns: 512 | List of bottleneck arrays, their corresponding ground truths, and the 513 | relevant filenames. 514 | """ 515 | class_count = len(image_lists.keys()) 516 | bottlenecks = [] 517 | ground_truths = [] 518 | filenames = [] 519 | if how_many >= 0: 520 | # Retrieve a random sample of bottlenecks. 521 | for unused_i in range(how_many): 522 | label_index = random.randrange(class_count) 523 | label_name = list(image_lists.keys())[label_index] 524 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) 525 | image_name = get_image_path(image_lists, label_name, image_index, 526 | image_dir, category) 527 | bottleneck = get_or_create_bottleneck( 528 | sess, image_lists, label_name, image_index, image_dir, category, 529 | bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, 530 | resized_input_tensor, bottleneck_tensor, architecture) 531 | ground_truth = np.zeros(class_count, dtype=np.float32) 532 | ground_truth[label_index] = 1.0 533 | bottlenecks.append(bottleneck) 534 | ground_truths.append(ground_truth) 535 | filenames.append(image_name) 536 | else: 537 | # Retrieve all bottlenecks. 538 | for label_index, label_name in enumerate(image_lists.keys()): 539 | for image_index, image_name in enumerate( 540 | image_lists[label_name][category]): 541 | image_name = get_image_path(image_lists, label_name, image_index, 542 | image_dir, category) 543 | bottleneck = get_or_create_bottleneck( 544 | sess, image_lists, label_name, image_index, image_dir, category, 545 | bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, 546 | resized_input_tensor, bottleneck_tensor, architecture) 547 | ground_truth = np.zeros(class_count, dtype=np.float32) 548 | ground_truth[label_index] = 1.0 549 | bottlenecks.append(bottleneck) 550 | ground_truths.append(ground_truth) 551 | filenames.append(image_name) 552 | return bottlenecks, ground_truths, filenames 553 | 554 | 555 | def get_random_distorted_bottlenecks( 556 | sess, image_lists, how_many, category, image_dir, input_jpeg_tensor, 557 | distorted_image, resized_input_tensor, bottleneck_tensor): 558 | """Retrieves bottleneck values for training images, after distortions. 559 | 560 | If we're training with distortions like crops, scales, or flips, we have to 561 | recalculate the full model for every image, and so we can't use cached 562 | bottleneck values. Instead we find random images for the requested category, 563 | run them through the distortion graph, and then the full graph to get the 564 | bottleneck results for each. 565 | 566 | Args: 567 | sess: Current TensorFlow Session. 568 | image_lists: Dictionary of training images for each label. 569 | how_many: The integer number of bottleneck values to return. 570 | category: Name string of which set of images to fetch - training, testing, 571 | or validation. 572 | image_dir: Root folder string of the subfolders containing the training 573 | images. 574 | input_jpeg_tensor: The input layer we feed the image data to. 575 | distorted_image: The output node of the distortion graph. 576 | resized_input_tensor: The input node of the recognition graph. 577 | bottleneck_tensor: The bottleneck output layer of the CNN graph. 578 | 579 | Returns: 580 | List of bottleneck arrays and their corresponding ground truths. 581 | """ 582 | class_count = len(image_lists.keys()) 583 | bottlenecks = [] 584 | ground_truths = [] 585 | for unused_i in range(how_many): 586 | label_index = random.randrange(class_count) 587 | label_name = list(image_lists.keys())[label_index] 588 | image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1) 589 | image_path = get_image_path(image_lists, label_name, image_index, image_dir, 590 | category) 591 | if not gfile.Exists(image_path): 592 | tf.logging.fatal('File does not exist %s', image_path) 593 | jpeg_data = gfile.FastGFile(image_path, 'rb').read() 594 | # Note that we materialize the distorted_image_data as a numpy array before 595 | # sending running inference on the image. This involves 2 memory copies and 596 | # might be optimized in other implementations. 597 | distorted_image_data = sess.run(distorted_image, 598 | {input_jpeg_tensor: jpeg_data}) 599 | bottleneck_values = sess.run(bottleneck_tensor, 600 | {resized_input_tensor: distorted_image_data}) 601 | bottleneck_values = np.squeeze(bottleneck_values) 602 | ground_truth = np.zeros(class_count, dtype=np.float32) 603 | ground_truth[label_index] = 1.0 604 | bottlenecks.append(bottleneck_values) 605 | ground_truths.append(ground_truth) 606 | return bottlenecks, ground_truths 607 | 608 | 609 | def should_distort_images(flip_left_right, random_crop, random_scale, 610 | random_brightness): 611 | """Whether any distortions are enabled, from the input flags. 612 | 613 | Args: 614 | flip_left_right: Boolean whether to randomly mirror images horizontally. 615 | random_crop: Integer percentage setting the total margin used around the 616 | crop box. 617 | random_scale: Integer percentage of how much to vary the scale by. 618 | random_brightness: Integer range to randomly multiply the pixel values by. 619 | 620 | Returns: 621 | Boolean value indicating whether any distortions should be applied. 622 | """ 623 | return (flip_left_right or (random_crop != 0) or (random_scale != 0) or 624 | (random_brightness != 0)) 625 | 626 | 627 | def add_input_distortions(flip_left_right, random_crop, random_scale, 628 | random_brightness, input_width, input_height, 629 | input_depth, input_mean, input_std): 630 | """Creates the operations to apply the specified distortions. 631 | 632 | During training it can help to improve the results if we run the images 633 | through simple distortions like crops, scales, and flips. These reflect the 634 | kind of variations we expect in the real world, and so can help train the 635 | model to cope with natural data more effectively. Here we take the supplied 636 | parameters and construct a network of operations to apply them to an image. 637 | 638 | Cropping 639 | ~~~~~~~~ 640 | 641 | Cropping is done by placing a bounding box at a random position in the full 642 | image. The cropping parameter controls the size of that box relative to the 643 | input image. If it's zero, then the box is the same size as the input and no 644 | cropping is performed. If the value is 50%, then the crop box will be half the 645 | width and height of the input. In a diagram it looks like this: 646 | 647 | < width > 648 | +---------------------+ 649 | | | 650 | | width - crop% | 651 | | < > | 652 | | +------+ | 653 | | | | | 654 | | | | | 655 | | | | | 656 | | +------+ | 657 | | | 658 | | | 659 | +---------------------+ 660 | 661 | Scaling 662 | ~~~~~~~ 663 | 664 | Scaling is a lot like cropping, except that the bounding box is always 665 | centered and its size varies randomly within the given range. For example if 666 | the scale percentage is zero, then the bounding box is the same size as the 667 | input and no scaling is applied. If it's 50%, then the bounding box will be in 668 | a random range between half the width and height and full size. 669 | 670 | Args: 671 | flip_left_right: Boolean whether to randomly mirror images horizontally. 672 | random_crop: Integer percentage setting the total margin used around the 673 | crop box. 674 | random_scale: Integer percentage of how much to vary the scale by. 675 | random_brightness: Integer range to randomly multiply the pixel values by. 676 | graph. 677 | input_width: Horizontal size of expected input image to model. 678 | input_height: Vertical size of expected input image to model. 679 | input_depth: How many channels the expected input image should have. 680 | input_mean: Pixel value that should be zero in the image for the graph. 681 | input_std: How much to divide the pixel values by before recognition. 682 | 683 | Returns: 684 | The jpeg input layer and the distorted result tensor. 685 | """ 686 | 687 | jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput') 688 | decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth) 689 | decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32) 690 | decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) 691 | margin_scale = 1.0 + (random_crop / 100.0) 692 | resize_scale = 1.0 + (random_scale / 100.0) 693 | margin_scale_value = tf.constant(margin_scale) 694 | resize_scale_value = tf.random_uniform(tensor_shape.scalar(), 695 | minval=1.0, 696 | maxval=resize_scale) 697 | scale_value = tf.multiply(margin_scale_value, resize_scale_value) 698 | precrop_width = tf.multiply(scale_value, input_width) 699 | precrop_height = tf.multiply(scale_value, input_height) 700 | precrop_shape = tf.stack([precrop_height, precrop_width]) 701 | precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32) 702 | precropped_image = tf.image.resize_bilinear(decoded_image_4d, 703 | precrop_shape_as_int) 704 | precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0]) 705 | cropped_image = tf.random_crop(precropped_image_3d, 706 | [input_height, input_width, input_depth]) 707 | if flip_left_right: 708 | flipped_image = tf.image.random_flip_left_right(cropped_image) 709 | else: 710 | flipped_image = cropped_image 711 | brightness_min = 1.0 - (random_brightness / 100.0) 712 | brightness_max = 1.0 + (random_brightness / 100.0) 713 | brightness_value = tf.random_uniform(tensor_shape.scalar(), 714 | minval=brightness_min, 715 | maxval=brightness_max) 716 | brightened_image = tf.multiply(flipped_image, brightness_value) 717 | offset_image = tf.subtract(brightened_image, input_mean) 718 | mul_image = tf.multiply(offset_image, 1.0 / input_std) 719 | distort_result = tf.expand_dims(mul_image, 0, name='DistortResult') 720 | return jpeg_data, distort_result 721 | 722 | 723 | def variable_summaries(var): 724 | """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" 725 | with tf.name_scope('summaries'): 726 | mean = tf.reduce_mean(var) 727 | tf.summary.scalar('mean', mean) 728 | with tf.name_scope('stddev'): 729 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 730 | tf.summary.scalar('stddev', stddev) 731 | tf.summary.scalar('max', tf.reduce_max(var)) 732 | tf.summary.scalar('min', tf.reduce_min(var)) 733 | tf.summary.histogram('histogram', var) 734 | 735 | 736 | def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, 737 | bottleneck_tensor_size): 738 | """Adds a new softmax and fully-connected layer for training. 739 | 740 | We need to retrain the top layer to identify our new classes, so this function 741 | adds the right operations to the graph, along with some variables to hold the 742 | weights, and then sets up all the gradients for the backward pass. 743 | 744 | The set up for the softmax and fully-connected layers is based on: 745 | https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html 746 | 747 | Args: 748 | class_count: Integer of how many categories of things we're trying to 749 | recognize. 750 | final_tensor_name: Name string for the new final node that produces results. 751 | bottleneck_tensor: The output of the main CNN graph. 752 | bottleneck_tensor_size: How many entries in the bottleneck vector. 753 | 754 | Returns: 755 | The tensors for the training and cross entropy results, and tensors for the 756 | bottleneck input and ground truth input. 757 | """ 758 | with tf.name_scope('input'): 759 | bottleneck_input = tf.placeholder_with_default( 760 | bottleneck_tensor, 761 | shape=[None, bottleneck_tensor_size], 762 | name='BottleneckInputPlaceholder') 763 | 764 | ground_truth_input = tf.placeholder(tf.float32, 765 | [None, class_count], 766 | name='GroundTruthInput') 767 | 768 | # Organizing the following ops as `final_training_ops` so they're easier 769 | # to see in TensorBoard 770 | layer_name = 'final_training_ops' 771 | with tf.name_scope(layer_name): 772 | with tf.name_scope('weights'): 773 | initial_value = tf.truncated_normal( 774 | [bottleneck_tensor_size, class_count], stddev=0.001) 775 | 776 | layer_weights = tf.Variable(initial_value, name='final_weights') 777 | 778 | variable_summaries(layer_weights) 779 | with tf.name_scope('biases'): 780 | layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') 781 | variable_summaries(layer_biases) 782 | with tf.name_scope('Wx_plus_b'): 783 | logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases 784 | tf.summary.histogram('pre_activations', logits) 785 | 786 | final_tensor = tf.nn.softmax(logits, name=final_tensor_name) 787 | tf.summary.histogram('activations', final_tensor) 788 | 789 | with tf.name_scope('cross_entropy'): 790 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 791 | labels=ground_truth_input, logits=logits) 792 | with tf.name_scope('total'): 793 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 794 | tf.summary.scalar('cross_entropy', cross_entropy_mean) 795 | 796 | with tf.name_scope('train'): 797 | optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate) 798 | train_step = optimizer.minimize(cross_entropy_mean) 799 | 800 | return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input, 801 | final_tensor) 802 | 803 | 804 | def add_evaluation_step(result_tensor, ground_truth_tensor): 805 | """Inserts the operations we need to evaluate the accuracy of our results. 806 | 807 | Args: 808 | result_tensor: The new final node that produces results. 809 | ground_truth_tensor: The node we feed ground truth data 810 | into. 811 | 812 | Returns: 813 | Tuple of (evaluation step, prediction). 814 | """ 815 | with tf.name_scope('accuracy'): 816 | with tf.name_scope('correct_prediction'): 817 | prediction = tf.argmax(result_tensor, 1) 818 | correct_prediction = tf.equal( 819 | prediction, tf.argmax(ground_truth_tensor, 1)) 820 | with tf.name_scope('accuracy'): 821 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 822 | tf.summary.scalar('accuracy', evaluation_step) 823 | return evaluation_step, prediction 824 | 825 | 826 | def save_graph_to_file(sess, graph, graph_file_name): 827 | output_graph_def = graph_util.convert_variables_to_constants( 828 | sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) 829 | with gfile.FastGFile(graph_file_name, 'wb') as f: 830 | f.write(output_graph_def.SerializeToString()) 831 | return 832 | 833 | 834 | def prepare_file_system(): 835 | # Setup the directory we'll write summaries to for TensorBoard 836 | if tf.gfile.Exists(FLAGS.summaries_dir): 837 | tf.gfile.DeleteRecursively(FLAGS.summaries_dir) 838 | tf.gfile.MakeDirs(FLAGS.summaries_dir) 839 | if FLAGS.intermediate_store_frequency > 0: 840 | ensure_dir_exists(FLAGS.intermediate_output_graphs_dir) 841 | return 842 | 843 | 844 | def create_model_info(architecture): 845 | """Given the name of a model architecture, returns information about it. 846 | 847 | There are different base image recognition pretrained models that can be 848 | retrained using transfer learning, and this function translates from the name 849 | of a model to the attributes that are needed to download and train with it. 850 | 851 | Args: 852 | architecture: Name of a model architecture. 853 | 854 | Returns: 855 | Dictionary of information about the model, or None if the name isn't 856 | recognized 857 | 858 | Raises: 859 | ValueError: If architecture name is unknown. 860 | """ 861 | architecture = architecture.lower() 862 | if architecture == 'inception_v3': 863 | # pylint: disable=line-too-long 864 | data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 865 | # pylint: enable=line-too-long 866 | bottleneck_tensor_name = 'pool_3/_reshape:0' 867 | bottleneck_tensor_size = 2048 868 | input_width = 299 869 | input_height = 299 870 | input_depth = 3 871 | resized_input_tensor_name = 'Mul:0' 872 | model_file_name = 'classify_image_graph_def.pb' 873 | input_mean = 128 874 | input_std = 128 875 | elif architecture.startswith('mobilenet_'): 876 | parts = architecture.split('_') 877 | if len(parts) != 3 and len(parts) != 4: 878 | tf.logging.error("Couldn't understand architecture name '%s'", 879 | architecture) 880 | return None 881 | version_string = parts[1] 882 | if (version_string != '1.0' and version_string != '0.75' and 883 | version_string != '0.50' and version_string != '0.25'): 884 | tf.logging.error( 885 | """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25', 886 | but found '%s' for architecture '%s'""", 887 | version_string, architecture) 888 | return None 889 | size_string = parts[2] 890 | if (size_string != '224' and size_string != '192' and 891 | size_string != '160' and size_string != '128'): 892 | tf.logging.error( 893 | """The Mobilenet input size should be '224', '192', '160', or '128', 894 | but found '%s' for architecture '%s'""", 895 | size_string, architecture) 896 | return None 897 | if len(parts) == 3: 898 | is_quantized = False 899 | else: 900 | if parts[3] != 'quantized': 901 | tf.logging.error( 902 | "Couldn't understand architecture suffix '%s' for '%s'", parts[3], 903 | architecture) 904 | return None 905 | is_quantized = True 906 | data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' 907 | data_url += version_string + '_' + size_string + '_frozen.tgz' 908 | bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' 909 | bottleneck_tensor_size = 1001 910 | input_width = int(size_string) 911 | input_height = int(size_string) 912 | input_depth = 3 913 | resized_input_tensor_name = 'input:0' 914 | if is_quantized: 915 | model_base_name = 'quantized_graph.pb' 916 | else: 917 | model_base_name = 'frozen_graph.pb' 918 | model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string 919 | model_file_name = os.path.join(model_dir_name, model_base_name) 920 | input_mean = 127.5 921 | input_std = 127.5 922 | else: 923 | tf.logging.error("Couldn't understand architecture name '%s'", architecture) 924 | raise ValueError('Unknown architecture', architecture) 925 | 926 | return { 927 | 'data_url': data_url, 928 | 'bottleneck_tensor_name': bottleneck_tensor_name, 929 | 'bottleneck_tensor_size': bottleneck_tensor_size, 930 | 'input_width': input_width, 931 | 'input_height': input_height, 932 | 'input_depth': input_depth, 933 | 'resized_input_tensor_name': resized_input_tensor_name, 934 | 'model_file_name': model_file_name, 935 | 'input_mean': input_mean, 936 | 'input_std': input_std, 937 | } 938 | 939 | 940 | def add_jpeg_decoding(input_width, input_height, input_depth, input_mean, 941 | input_std): 942 | """Adds operations that perform JPEG decoding and resizing to the graph.. 943 | 944 | Args: 945 | input_width: Desired width of the image fed into the recognizer graph. 946 | input_height: Desired width of the image fed into the recognizer graph. 947 | input_depth: Desired channels of the image fed into the recognizer graph. 948 | input_mean: Pixel value that should be zero in the image for the graph. 949 | input_std: How much to divide the pixel values by before recognition. 950 | 951 | Returns: 952 | Tensors for the node to feed JPEG data into, and the output of the 953 | preprocessing steps. 954 | """ 955 | jpeg_data = tf.placeholder(tf.string, name='DecodeJPGInput') 956 | decoded_image = tf.image.decode_jpeg(jpeg_data, channels=input_depth) 957 | decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32) 958 | decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0) 959 | resize_shape = tf.stack([input_height, input_width]) 960 | resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32) 961 | resized_image = tf.image.resize_bilinear(decoded_image_4d, 962 | resize_shape_as_int) 963 | offset_image = tf.subtract(resized_image, input_mean) 964 | mul_image = tf.multiply(offset_image, 1.0 / input_std) 965 | return jpeg_data, mul_image 966 | 967 | 968 | def main(_): 969 | # Needed to make sure the logging output is visible. 970 | # See https://github.com/tensorflow/tensorflow/issues/3047 971 | tf.logging.set_verbosity(tf.logging.INFO) 972 | 973 | # Prepare necessary directories that can be used during training 974 | prepare_file_system() 975 | 976 | # Gather information about the model architecture we'll be using. 977 | model_info = create_model_info(FLAGS.architecture) 978 | if not model_info: 979 | tf.logging.error('Did not recognize architecture flag') 980 | return -1 981 | 982 | # Set up the pre-trained graph. 983 | maybe_download_and_extract(model_info['data_url']) 984 | graph, bottleneck_tensor, resized_image_tensor = ( 985 | create_model_graph(model_info)) 986 | 987 | # Look at the folder structure, and create lists of all the images. 988 | image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, 989 | FLAGS.validation_percentage) 990 | class_count = len(image_lists.keys()) 991 | if class_count == 0: 992 | tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir) 993 | return -1 994 | if class_count == 1: 995 | tf.logging.error('Only one valid folder of images found at ' + 996 | FLAGS.image_dir + 997 | ' - multiple classes are needed for classification.') 998 | return -1 999 | 1000 | # See if the command-line flags mean we're applying any distortions. 1001 | do_distort_images = should_distort_images( 1002 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, 1003 | FLAGS.random_brightness) 1004 | 1005 | with tf.Session(graph=graph) as sess: 1006 | # Set up the image decoding sub-graph. 1007 | jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding( 1008 | model_info['input_width'], model_info['input_height'], 1009 | model_info['input_depth'], model_info['input_mean'], 1010 | model_info['input_std']) 1011 | 1012 | if do_distort_images: 1013 | # We will be applying distortions, so setup the operations we'll need. 1014 | (distorted_jpeg_data_tensor, 1015 | distorted_image_tensor) = add_input_distortions( 1016 | FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, 1017 | FLAGS.random_brightness, model_info['input_width'], 1018 | model_info['input_height'], model_info['input_depth'], 1019 | model_info['input_mean'], model_info['input_std']) 1020 | else: 1021 | # We'll make sure we've calculated the 'bottleneck' image summaries and 1022 | # cached them on disk. 1023 | cache_bottlenecks(sess, image_lists, FLAGS.image_dir, 1024 | FLAGS.bottleneck_dir, jpeg_data_tensor, 1025 | decoded_image_tensor, resized_image_tensor, 1026 | bottleneck_tensor, FLAGS.architecture) 1027 | 1028 | # Add the new layer that we'll be training. 1029 | (train_step, cross_entropy, bottleneck_input, ground_truth_input, 1030 | final_tensor) = add_final_training_ops( 1031 | len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor, 1032 | model_info['bottleneck_tensor_size']) 1033 | 1034 | # Create the operations we need to evaluate the accuracy of our new layer. 1035 | evaluation_step, prediction = add_evaluation_step( 1036 | final_tensor, ground_truth_input) 1037 | 1038 | # Merge all the summaries and write them out to the summaries_dir 1039 | merged = tf.summary.merge_all() 1040 | train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', 1041 | sess.graph) 1042 | 1043 | validation_writer = tf.summary.FileWriter( 1044 | FLAGS.summaries_dir + '/validation') 1045 | 1046 | # Set up all our weights to their initial default values. 1047 | init = tf.global_variables_initializer() 1048 | sess.run(init) 1049 | 1050 | # Run the training for as many cycles as requested on the command line. 1051 | for i in range(FLAGS.how_many_training_steps): 1052 | # Get a batch of input bottleneck values, either calculated fresh every 1053 | # time with distortions applied, or from the cache stored on disk. 1054 | if do_distort_images: 1055 | (train_bottlenecks, 1056 | train_ground_truth) = get_random_distorted_bottlenecks( 1057 | sess, image_lists, FLAGS.train_batch_size, 'training', 1058 | FLAGS.image_dir, distorted_jpeg_data_tensor, 1059 | distorted_image_tensor, resized_image_tensor, bottleneck_tensor) 1060 | else: 1061 | (train_bottlenecks, 1062 | train_ground_truth, _) = get_random_cached_bottlenecks( 1063 | sess, image_lists, FLAGS.train_batch_size, 'training', 1064 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 1065 | decoded_image_tensor, resized_image_tensor, bottleneck_tensor, 1066 | FLAGS.architecture) 1067 | # Feed the bottlenecks and ground truth into the graph, and run a training 1068 | # step. Capture training summaries for TensorBoard with the `merged` op. 1069 | train_summary, _ = sess.run( 1070 | [merged, train_step], 1071 | feed_dict={bottleneck_input: train_bottlenecks, 1072 | ground_truth_input: train_ground_truth}) 1073 | train_writer.add_summary(train_summary, i) 1074 | 1075 | # Every so often, print out how well the graph is training. 1076 | is_last_step = (i + 1 == FLAGS.how_many_training_steps) 1077 | if (i % FLAGS.eval_step_interval) == 0 or is_last_step: 1078 | train_accuracy, cross_entropy_value = sess.run( 1079 | [evaluation_step, cross_entropy], 1080 | feed_dict={bottleneck_input: train_bottlenecks, 1081 | ground_truth_input: train_ground_truth}) 1082 | tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' % 1083 | (datetime.now(), i, train_accuracy * 100)) 1084 | tf.logging.info('%s: Step %d: Cross entropy = %f' % 1085 | (datetime.now(), i, cross_entropy_value)) 1086 | validation_bottlenecks, validation_ground_truth, _ = ( 1087 | get_random_cached_bottlenecks( 1088 | sess, image_lists, FLAGS.validation_batch_size, 'validation', 1089 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 1090 | decoded_image_tensor, resized_image_tensor, bottleneck_tensor, 1091 | FLAGS.architecture)) 1092 | # Run a validation step and capture training summaries for TensorBoard 1093 | # with the `merged` op. 1094 | validation_summary, validation_accuracy = sess.run( 1095 | [merged, evaluation_step], 1096 | feed_dict={bottleneck_input: validation_bottlenecks, 1097 | ground_truth_input: validation_ground_truth}) 1098 | validation_writer.add_summary(validation_summary, i) 1099 | tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' % 1100 | (datetime.now(), i, validation_accuracy * 100, 1101 | len(validation_bottlenecks))) 1102 | 1103 | # Store intermediate results 1104 | intermediate_frequency = FLAGS.intermediate_store_frequency 1105 | 1106 | if (intermediate_frequency > 0 and (i % intermediate_frequency == 0) 1107 | and i > 0): 1108 | intermediate_file_name = (FLAGS.intermediate_output_graphs_dir + 1109 | 'intermediate_' + str(i) + '.pb') 1110 | tf.logging.info('Save intermediate result to : ' + 1111 | intermediate_file_name) 1112 | save_graph_to_file(sess, graph, intermediate_file_name) 1113 | 1114 | # We've completed all our training, so run a final test evaluation on 1115 | # some new images we haven't used before. 1116 | test_bottlenecks, test_ground_truth, test_filenames = ( 1117 | get_random_cached_bottlenecks( 1118 | sess, image_lists, FLAGS.test_batch_size, 'testing', 1119 | FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, 1120 | decoded_image_tensor, resized_image_tensor, bottleneck_tensor, 1121 | FLAGS.architecture)) 1122 | test_accuracy, predictions = sess.run( 1123 | [evaluation_step, prediction], 1124 | feed_dict={bottleneck_input: test_bottlenecks, 1125 | ground_truth_input: test_ground_truth}) 1126 | tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % 1127 | (test_accuracy * 100, len(test_bottlenecks))) 1128 | 1129 | if FLAGS.print_misclassified_test_images: 1130 | tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===') 1131 | for i, test_filename in enumerate(test_filenames): 1132 | if predictions[i] != test_ground_truth[i].argmax(): 1133 | tf.logging.info('%70s %s' % 1134 | (test_filename, 1135 | list(image_lists.keys())[predictions[i]])) 1136 | 1137 | # Write out the trained graph and labels with the weights stored as 1138 | # constants. 1139 | save_graph_to_file(sess, graph, FLAGS.output_graph) 1140 | with gfile.FastGFile(FLAGS.output_labels, 'w') as f: 1141 | f.write('\n'.join(image_lists.keys()) + '\n') 1142 | 1143 | 1144 | if __name__ == '__main__': 1145 | parser = argparse.ArgumentParser() 1146 | parser.add_argument( 1147 | '--image_dir', 1148 | type=str, 1149 | default='', 1150 | help='Path to folders of labeled images.' 1151 | ) 1152 | parser.add_argument( 1153 | '--output_graph', 1154 | type=str, 1155 | default='/tmp/output_graph.pb', 1156 | help='Where to save the trained graph.' 1157 | ) 1158 | parser.add_argument( 1159 | '--intermediate_output_graphs_dir', 1160 | type=str, 1161 | default='/tmp/intermediate_graph/', 1162 | help='Where to save the intermediate graphs.' 1163 | ) 1164 | parser.add_argument( 1165 | '--intermediate_store_frequency', 1166 | type=int, 1167 | default=0, 1168 | help="""\ 1169 | How many steps to store intermediate graph. If "0" then will not 1170 | store.\ 1171 | """ 1172 | ) 1173 | parser.add_argument( 1174 | '--output_labels', 1175 | type=str, 1176 | default='/tmp/output_labels.txt', 1177 | help='Where to save the trained graph\'s labels.' 1178 | ) 1179 | parser.add_argument( 1180 | '--summaries_dir', 1181 | type=str, 1182 | default='/tmp/retrain_logs', 1183 | help='Where to save summary logs for TensorBoard.' 1184 | ) 1185 | parser.add_argument( 1186 | '--how_many_training_steps', 1187 | type=int, 1188 | default=6000, 1189 | help='How many training steps to run before ending.' 1190 | ) 1191 | parser.add_argument( 1192 | '--learning_rate', 1193 | type=float, 1194 | default=0.01, 1195 | help='How large a learning rate to use when training.' 1196 | ) 1197 | parser.add_argument( 1198 | '--testing_percentage', 1199 | type=int, 1200 | default=10, 1201 | help='What percentage of images to use as a test set.' 1202 | ) 1203 | parser.add_argument( 1204 | '--validation_percentage', 1205 | type=int, 1206 | default=10, 1207 | help='What percentage of images to use as a validation set.' 1208 | ) 1209 | parser.add_argument( 1210 | '--eval_step_interval', 1211 | type=int, 1212 | default=10, 1213 | help='How often to evaluate the training results.' 1214 | ) 1215 | parser.add_argument( 1216 | '--train_batch_size', 1217 | type=int, 1218 | default=100, 1219 | help='How many images to train on at a time.' 1220 | ) 1221 | parser.add_argument( 1222 | '--test_batch_size', 1223 | type=int, 1224 | default=-1, 1225 | help="""\ 1226 | How many images to test on. This test set is only used once, to evaluate 1227 | the final accuracy of the model after training completes. 1228 | A value of -1 causes the entire test set to be used, which leads to more 1229 | stable results across runs.\ 1230 | """ 1231 | ) 1232 | parser.add_argument( 1233 | '--validation_batch_size', 1234 | type=int, 1235 | default=100, 1236 | help="""\ 1237 | How many images to use in an evaluation batch. This validation set is 1238 | used much more often than the test set, and is an early indicator of how 1239 | accurate the model is during training. 1240 | A value of -1 causes the entire validation set to be used, which leads to 1241 | more stable results across training iterations, but may be slower on large 1242 | training sets.\ 1243 | """ 1244 | ) 1245 | parser.add_argument( 1246 | '--print_misclassified_test_images', 1247 | default=False, 1248 | help="""\ 1249 | Whether to print out a list of all misclassified test images.\ 1250 | """, 1251 | action='store_true' 1252 | ) 1253 | parser.add_argument( 1254 | '--model_dir', 1255 | type=str, 1256 | default='/tmp/imagenet', 1257 | help="""\ 1258 | Path to classify_image_graph_def.pb, 1259 | imagenet_synset_to_human_label_map.txt, and 1260 | imagenet_2012_challenge_label_map_proto.pbtxt.\ 1261 | """ 1262 | ) 1263 | parser.add_argument( 1264 | '--bottleneck_dir', 1265 | type=str, 1266 | default='/tmp/bottleneck', 1267 | help='Path to cache bottleneck layer values as files.' 1268 | ) 1269 | parser.add_argument( 1270 | '--final_tensor_name', 1271 | type=str, 1272 | default='final_result', 1273 | help="""\ 1274 | The name of the output classification layer in the retrained graph.\ 1275 | """ 1276 | ) 1277 | parser.add_argument( 1278 | '--flip_left_right', 1279 | default=False, 1280 | help="""\ 1281 | Whether to randomly flip half of the training images horizontally.\ 1282 | """, 1283 | action='store_true' 1284 | ) 1285 | parser.add_argument( 1286 | '--random_crop', 1287 | type=int, 1288 | default=0, 1289 | help="""\ 1290 | A percentage determining how much of a margin to randomly crop off the 1291 | training images.\ 1292 | """ 1293 | ) 1294 | parser.add_argument( 1295 | '--random_scale', 1296 | type=int, 1297 | default=0, 1298 | help="""\ 1299 | A percentage determining how much to randomly scale up the size of the 1300 | training images by.\ 1301 | """ 1302 | ) 1303 | parser.add_argument( 1304 | '--random_brightness', 1305 | type=int, 1306 | default=0, 1307 | help="""\ 1308 | A percentage determining how much to randomly multiply the training image 1309 | input pixels up or down by.\ 1310 | """ 1311 | ) 1312 | parser.add_argument( 1313 | '--architecture', 1314 | type=str, 1315 | default='inception_v3', 1316 | help="""\ 1317 | Which model architecture to use. 'inception_v3' is the most accurate, but 1318 | also the slowest. For faster or smaller models, chose a MobileNet with the 1319 | form 'mobilenet__[_quantized]'. For example, 1320 | 'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224 1321 | pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much 1322 | less accurate, but smaller and faster network that's 920 KB on disk and 1323 | takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html 1324 | for more information on Mobilenet.\ 1325 | """) 1326 | FLAGS, unparsed = parser.parse_known_args() 1327 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 1328 | -------------------------------------------------------------------------------- /training code.txt: -------------------------------------------------------------------------------- 1 | python retrain.py --output_graph=retrained_graph.pb --output_labels=retrained_labels.txt --architecture=MobileNet_1.0_224 --image_dir=images 2 | --------------------------------------------------------------------------------