├── .gitignore ├── 0_verify_models_and_protoc_install.py ├── 1_xml_to_csv.py ├── 2_generate_tfrecords.py ├── 3_train.py ├── 4_export_inference_graph.py ├── 5_test.py ├── README.md ├── TensorFlow_Tut_3_Object_Detection_Walk-through.docx ├── label_map.pbtxt └── ssd_inception_v2_coco.config /.gitignore: -------------------------------------------------------------------------------- 1 | # TensorFlow graph data files 2 | events.out.tfevents* 3 | 4 | /images 5 | /training_images 6 | /test_images 7 | 8 | /traffic_lights 9 | /traffic_lights_with_info 10 | 11 | /data 12 | /training_data 13 | /test_data 14 | 15 | /ssd_inception_v2_coco_2017_11_17 16 | 17 | ssd_inception_v2_coco_2017_11_17.tar 18 | ssd_inception_v2_coco_2017_11_17.tar.gz 19 | 20 | /ssd_inception_v2_coco_2018_01_28 21 | 22 | ssd_inception_v2_coco_2018_01_28.tar 23 | ssd_inception_v2_coco_2018_01_28.tar.gz 24 | 25 | /ssd_mobilenet_v1_coco_2017_11_17 26 | 27 | ssd_mobilenet_v1_coco_2017_11_17.tar.gz 28 | 29 | /training 30 | 31 | /inference_graph -------------------------------------------------------------------------------- /0_verify_models_and_protoc_install.py: -------------------------------------------------------------------------------- 1 | # 0_verify_models_and_protoc_install.py 2 | 3 | # This code is essentially this Python Jupyter Notebook by Google: 4 | # https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb 5 | # refactored to run as a regular Python script 6 | 7 | import numpy as np 8 | import os 9 | from six.moves import urllib 10 | import tarfile 11 | import tensorflow as tf 12 | import cv2 13 | from distutils.version import StrictVersion 14 | 15 | # module level variables ############################################################################################## 16 | PROTOS_DIR = "C:/TensorFlow/models/research/object_detection/protos" 17 | MIN_NUM_PY_FILES_IN_PROTOS_DIR = 5 18 | 19 | DOWNLOAD_MODEL_FROM_LOC = 'http://download.tensorflow.org/models/object_detection/' 20 | 21 | # choose either MobileNet or Inception 22 | # MobileNet is a smaller download and runs faster, but is less accurate 23 | MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17' 24 | # MODEL_NAME = 'ssd_inception_v2_coco_2017_11_17' 25 | 26 | MODEL_FILE_NAME = MODEL_NAME + '.tar.gz' 27 | 28 | MODEL_SAVE_DIR_LOC = "C:/TensorFlow/models/research/object_detection" 29 | FROZEN_INFERENCE_GRAPH_LOC = MODEL_SAVE_DIR_LOC + "/" + MODEL_NAME + "/" + "frozen_inference_graph.pb" 30 | LABEL_MAP_LOC = "C:/TensorFlow/models/research/object_detection/data/mscoco_label_map.pbtxt" 31 | TEST_IMAGES_DIR = "C:/TensorFlow/models/research/object_detection/test_images" 32 | 33 | NUM_CLASSES = 90 34 | 35 | ####################################################################################################################### 36 | def main(): 37 | print("starting program . . .") 38 | 39 | if not checkIfNecessaryPathsAndFilesExist(): 40 | return 41 | # end if 42 | 43 | # now that we've checked for the protoc compile, import the TensorFlow models repo utils content 44 | from utils import label_map_util 45 | from utils import visualization_utils as vis_util 46 | 47 | # if TensorFlow version is too old, show error message and bail 48 | # this next comment line is necessary to avoid a false warning if using the editor PyCharm 49 | # noinspection PyUnresolvedReferences 50 | if StrictVersion(tf.__version__) < StrictVersion('1.5.0'): 51 | print('error: Please upgrade your tensorflow installation to v1.5.* or later!') 52 | return 53 | # end if 54 | 55 | # if the frozen inference graph file does not already exist, download the model tar file and unzip it 56 | try: 57 | if not os.path.exists(FROZEN_INFERENCE_GRAPH_LOC): 58 | # if the model tar file has not already been downloaded, download it 59 | if not os.path.exists(os.path.join(MODEL_SAVE_DIR_LOC, MODEL_FILE_NAME)): 60 | # download the model 61 | print("downloading model . . .") 62 | # instantiate a URLopener object, then download the file 63 | opener = urllib.request.URLopener() 64 | opener.retrieve(DOWNLOAD_MODEL_FROM_LOC + MODEL_FILE_NAME, os.path.join(MODEL_SAVE_DIR_LOC, MODEL_FILE_NAME)) 65 | # end if 66 | 67 | # unzip the tar to get the frozen inference graph 68 | print("unzipping model . . .") 69 | tar_file = tarfile.open(os.path.join(MODEL_SAVE_DIR_LOC, MODEL_FILE_NAME)) 70 | for file in tar_file.getmembers(): 71 | file_name = os.path.basename(file.name) 72 | if 'frozen_inference_graph.pb' in file_name: 73 | tar_file.extract(file, MODEL_SAVE_DIR_LOC) 74 | # end if 75 | # end for 76 | # end if 77 | except Exception as e: 78 | print("error downloading or unzipping model: " + str(e)) 79 | return 80 | # end try 81 | 82 | # if the frozen inference graph does not exist after the above, show an error message and bail 83 | if not os.path.exists(FROZEN_INFERENCE_GRAPH_LOC): 84 | print("unable to get / create the frozen inference graph") 85 | return 86 | # end if 87 | 88 | # load the frozen model into memory 89 | print("loading frozen model into memory . . .") 90 | detection_graph = tf.Graph() 91 | try: 92 | with detection_graph.as_default(): 93 | od_graph_def = tf.GraphDef() 94 | with tf.gfile.GFile(FROZEN_INFERENCE_GRAPH_LOC, 'rb') as fid: 95 | serialized_graph = fid.read() 96 | od_graph_def.ParseFromString(serialized_graph) 97 | tf.import_graph_def(od_graph_def, name='') 98 | # end with 99 | # end with 100 | except Exception as e: 101 | print("error loading the frozen model into memory: " + str(e)) 102 | return 103 | # end try 104 | 105 | # load the label map 106 | print("loading label map . . .") 107 | label_map = label_map_util.load_labelmap(LABEL_MAP_LOC) 108 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 109 | category_index = label_map_util.create_category_index(categories) 110 | 111 | print("starting object detection . . .") 112 | with detection_graph.as_default(): 113 | with tf.Session(graph=detection_graph) as sess: 114 | for fileName in os.listdir(TEST_IMAGES_DIR): 115 | if fileName.endswith(".jpg"): 116 | image_np = cv2.imread(os.path.join(TEST_IMAGES_DIR, fileName)) 117 | if image_np is not None: 118 | # Definite input and output Tensors for detection_graph 119 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 120 | # Each box represents a part of the image where a particular object was detected. 121 | detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 122 | # Each score represent how level of confidence for each of the objects. 123 | # Score is shown on the result image, together with the class label. 124 | detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') 125 | detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') 126 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 127 | 128 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3] 129 | image_np_expanded = np.expand_dims(image_np, axis=0) 130 | # Actual detection. 131 | (boxes, scores, classes, num) = sess.run( 132 | [detection_boxes, detection_scores, detection_classes, num_detections], 133 | feed_dict={image_tensor: image_np_expanded}) 134 | # Visualization of the results of a detection. 135 | vis_util.visualize_boxes_and_labels_on_image_array(image_np, 136 | np.squeeze(boxes), 137 | np.squeeze(classes).astype(np.int32), 138 | np.squeeze(scores), 139 | category_index, 140 | use_normalized_coordinates=True, 141 | line_thickness=8) 142 | cv2.imshow("result", image_np) 143 | cv2.waitKey() 144 | # end if 145 | # end if 146 | # end for 147 | # end with 148 | # end with 149 | # end main 150 | 151 | ####################################################################################################################### 152 | def checkIfNecessaryPathsAndFilesExist(): 153 | if not os.path.exists(PROTOS_DIR): 154 | print('ERROR: PROTOS_DIR "' + PROTOS_DIR + '" does not seem to exist') 155 | print('Did you compile protoc into the TensorFlow models repository?') 156 | return False 157 | # end if 158 | 159 | # count the number of .py files in the protos directory, there should be many (20+) 160 | numPyFilesInProtosDir = 0 161 | for fileName in os.listdir(PROTOS_DIR): 162 | if fileName.endswith(".py"): 163 | numPyFilesInProtosDir += 1 164 | # end if 165 | # end for 166 | 167 | # if there are not enough .py files in the protos directory then protoc must not have been compiled, 168 | # so show an error and return False 169 | if numPyFilesInProtosDir < MIN_NUM_PY_FILES_IN_PROTOS_DIR: 170 | print('ERROR: less than ' + str(MIN_NUM_PY_FILES_IN_PROTOS_DIR) + ' .py files were found in PROTOS_DIR' + PROTOS_DIR) 171 | print('Did you compile protoc into the TensorFlow models repository?') 172 | return False 173 | # end if 174 | 175 | if not os.path.exists(MODEL_SAVE_DIR_LOC): 176 | print('ERROR: MODEL_SAVE_DIR_LOC "' + MODEL_SAVE_DIR_LOC + '" does not seem to exist') 177 | return False 178 | # end if 179 | 180 | if not os.path.exists(LABEL_MAP_LOC): 181 | print('ERROR: LABEL_MAP_LOC "' + LABEL_MAP_LOC + '" does not seem to exist') 182 | return False 183 | # end if 184 | 185 | if not os.path.exists(TEST_IMAGES_DIR): 186 | print('ERROR: TEST_IMAGES_DIR "' + TEST_IMAGES_DIR + '" does not seem to exist') 187 | return False 188 | # end if 189 | 190 | return True 191 | # end function 192 | 193 | ####################################################################################################################### 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /1_xml_to_csv.py: -------------------------------------------------------------------------------- 1 | # 1_xml_to_csv.py 2 | 3 | # Note: substantial portions of this code, expecially the actual XML to CSV conversion, are credit to Dat Tran 4 | # see his website here: https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9 5 | # and his GitHub here: https://github.com/datitran/raccoon_dataset/blob/master/xml_to_csv.py 6 | 7 | import os 8 | import glob 9 | import pandas as pd 10 | import xml.etree.ElementTree as ET 11 | 12 | # module level variables ############################################################################################## 13 | # train and test directories 14 | TRAINING_IMAGES_DIR = os.getcwd() + "/training_images/" 15 | TEST_IMAGES_DIR = os.getcwd() + "/test_images/" 16 | 17 | MIN_NUM_IMAGES_REQUIRED_FOR_TRAINING = 10 18 | MIN_NUM_IMAGES_SUGGESTED_FOR_TRAINING = 100 19 | 20 | MIN_NUM_IMAGES_REQUIRED_FOR_TESTING = 3 21 | 22 | # output .csv file names/locations 23 | TRAINING_DATA_DIR = os.getcwd() + "/" + "training_data" 24 | TRAIN_CSV_FILE_LOC = TRAINING_DATA_DIR + "/" + "train_labels.csv" 25 | EVAL_CSV_FILE_LOC = TRAINING_DATA_DIR + "/" + "eval_labels.csv" 26 | 27 | ####################################################################################################################### 28 | def main(): 29 | if not checkIfNecessaryPathsAndFilesExist(): 30 | return 31 | # end if 32 | 33 | # if the training data directory does not exist, create it 34 | try: 35 | if not os.path.exists(TRAINING_DATA_DIR): 36 | os.makedirs(TRAINING_DATA_DIR) 37 | # end if 38 | except Exception as e: 39 | print("unable to create directory " + TRAINING_DATA_DIR + "error: " + str(e)) 40 | # end try 41 | 42 | 43 | # convert training xml data to a single .csv file 44 | print("converting xml training data . . .") 45 | trainCsvResults = xml_to_csv(TRAINING_IMAGES_DIR) 46 | trainCsvResults.to_csv(TRAIN_CSV_FILE_LOC, index=None) 47 | print("training xml to .csv conversion successful, saved result to " + TRAIN_CSV_FILE_LOC) 48 | 49 | # convert test xml data to a single .csv file 50 | print("converting xml test data . . .") 51 | testCsvResults = xml_to_csv(TEST_IMAGES_DIR) 52 | testCsvResults.to_csv(EVAL_CSV_FILE_LOC, index=None) 53 | print("test xml to .csv conversion successful, saved result to " + EVAL_CSV_FILE_LOC) 54 | 55 | # end main 56 | 57 | ####################################################################################################################### 58 | def checkIfNecessaryPathsAndFilesExist(): 59 | if not os.path.exists(TRAINING_IMAGES_DIR): 60 | print('') 61 | print('ERROR: the training images directory "' + TRAINING_IMAGES_DIR + '" does not seem to exist') 62 | print('Did you set up the training images?') 63 | print('') 64 | return False 65 | # end if 66 | 67 | # get a list of all the .jpg / .xml file pairs in the training images directory 68 | trainingImagesWithAMatchingXmlFile = [] 69 | for fileName in os.listdir(TRAINING_IMAGES_DIR): 70 | if fileName.endswith(".jpg"): 71 | xmlFileName = os.path.splitext(fileName)[0] + ".xml" 72 | if os.path.exists(os.path.join(TRAINING_IMAGES_DIR, xmlFileName)): 73 | trainingImagesWithAMatchingXmlFile.append(fileName) 74 | # end if 75 | # end if 76 | # end for 77 | 78 | # show an error and return false if there are no images in the training directory 79 | if len(trainingImagesWithAMatchingXmlFile) <= 0: 80 | print("ERROR: there don't seem to be any images and matching XML files in " + TRAINING_IMAGES_DIR) 81 | print("Did you set up the training images?") 82 | return False 83 | # end if 84 | 85 | # show an error and return false if there are not at least 10 images and 10 matching XML files in TRAINING_IMAGES_DIR 86 | if len(trainingImagesWithAMatchingXmlFile) < MIN_NUM_IMAGES_REQUIRED_FOR_TRAINING: 87 | print("ERROR: there are not at least " + str(MIN_NUM_IMAGES_REQUIRED_FOR_TRAINING) + " images and matching XML files in " + TRAINING_IMAGES_DIR) 88 | print("Did you set up the training images?") 89 | return False 90 | # end if 91 | 92 | # show a warning if there are not at least 100 images and 100 matching XML files in TEST_IMAGES_DIR 93 | if len(trainingImagesWithAMatchingXmlFile) < MIN_NUM_IMAGES_SUGGESTED_FOR_TRAINING: 94 | print("WARNING: there are not at least " + str(MIN_NUM_IMAGES_SUGGESTED_FOR_TRAINING) + " images and matching XML files in " + TRAINING_IMAGES_DIR) 95 | print("At least " + str(MIN_NUM_IMAGES_SUGGESTED_FOR_TRAINING) + " image / xml pairs are recommended for bare minimum acceptable results") 96 | # note we do not return false here b/c this is a warning, not an error 97 | # end if 98 | 99 | if not os.path.exists(TEST_IMAGES_DIR): 100 | print('ERROR: TEST_IMAGES_DIR "' + TEST_IMAGES_DIR + '" does not seem to exist') 101 | return False 102 | # end if 103 | 104 | # get a list of all the .jpg / .xml file pairs in the test images directory 105 | testImagesWithAMatchingXmlFile = [] 106 | for fileName in os.listdir(TEST_IMAGES_DIR): 107 | if fileName.endswith(".jpg"): 108 | xmlFileName = os.path.splitext(fileName)[0] + ".xml" 109 | if os.path.exists(os.path.join(TEST_IMAGES_DIR, xmlFileName)): 110 | testImagesWithAMatchingXmlFile.append(fileName) 111 | # end if 112 | # end if 113 | # end for 114 | 115 | # show an error and return false if there are not at least 3 images and 3 matching XML files in TEST_IMAGES_DIR 116 | if len(testImagesWithAMatchingXmlFile) <= 3: 117 | print("ERROR: there are not at least " + str(MIN_NUM_IMAGES_REQUIRED_FOR_TESTING) + " image / xml pairs in " + TEST_IMAGES_DIR) 118 | print("Did you separate out the test image / xml pairs from the training image / xml pairs?") 119 | return False 120 | # end if 121 | 122 | return True 123 | # end function 124 | 125 | ####################################################################################################################### 126 | def xml_to_csv(path): 127 | xml_list = [] 128 | for xml_file in glob.glob(path + '/*.xml'): 129 | tree = ET.parse(xml_file) 130 | root = tree.getroot() 131 | for member in root.findall('object'): 132 | value = (root.find('filename').text, int(root.find('size')[0].text), int(root.find('size')[1].text), member[0].text, 133 | int(member[4][0].text), int(member[4][1].text), int(member[4][2].text), int(member[4][3].text)) 134 | xml_list.append(value) 135 | # end for 136 | # end for 137 | 138 | column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'] 139 | xml_df = pd.DataFrame(xml_list, columns=column_name) 140 | return xml_df 141 | # end function 142 | 143 | ####################################################################################################################### 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /2_generate_tfrecords.py: -------------------------------------------------------------------------------- 1 | # generate_tfrecords.py 2 | 3 | # Note: substantial portions of this code, expecially the create_tf_example() function, are credit to Dat Tran 4 | # see his website here: https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9 5 | # and his GitHub here: https://github.com/CDahmsTemp/TensorFlow_Tut_3_Object_Detection_Walk-through/blob/master/1_xml_to_csv.py 6 | 7 | import os 8 | import io 9 | import pandas as pd 10 | import tensorflow as tf 11 | from PIL import Image 12 | from object_detection.utils import dataset_util 13 | from collections import namedtuple 14 | 15 | # module-level variables ############################################################################################## 16 | 17 | # input training CSV file and training images directory 18 | TRAIN_CSV_FILE_LOC = os.getcwd() + "/training_data/" + "train_labels.csv" 19 | TRAIN_IMAGES_DIR = os.getcwd() + "/training_images" 20 | 21 | # input test CSV file and test images directory 22 | EVAL_CSV_FILE_LOC = os.getcwd() + "/training_data/" + "eval_labels.csv" 23 | TEST_IMAGES_DIR = os.getcwd() + "/test_images" 24 | 25 | # training and testing output .tfrecord files 26 | TRAIN_TFRECORD_FILE_LOC = os.getcwd() + "/training_data/" + "train.tfrecord" 27 | EVAL_TFRECORD_FILE_LOC = os.getcwd() + "/training_data/" + "eval.tfrecord" 28 | 29 | ####################################################################################################################### 30 | def main(): 31 | if not checkIfNecessaryPathsAndFilesExist(): 32 | return 33 | # end if 34 | 35 | # write the train data .tfrecord file 36 | trainTfRecordFileWriteSuccessful = writeTfRecordFile(TRAIN_CSV_FILE_LOC, TRAIN_TFRECORD_FILE_LOC, TRAIN_IMAGES_DIR) 37 | if trainTfRecordFileWriteSuccessful: 38 | print("successfully created the training TFRectrds, saved to: " + TRAIN_TFRECORD_FILE_LOC) 39 | # end if 40 | 41 | # write the eval data .tfrecord file 42 | evalTfRecordFileWriteSuccessful = writeTfRecordFile(EVAL_CSV_FILE_LOC, EVAL_TFRECORD_FILE_LOC, TEST_IMAGES_DIR) 43 | if evalTfRecordFileWriteSuccessful: 44 | print("successfully created the eval TFRecords, saved to: " + EVAL_TFRECORD_FILE_LOC) 45 | # end if 46 | 47 | # end main 48 | 49 | ####################################################################################################################### 50 | def writeTfRecordFile(csvFileName, tfRecordFileName, imagesDir): 51 | # use pandas to read in the .csv file data, pandas.read_csv() returns a type DataFrame with the given param 52 | csvFileDataFrame = pd.read_csv(csvFileName) 53 | 54 | # reformat the CSV data into a format TensorFlow can work with 55 | csvFileDataList = reformatCsvFileData(csvFileDataFrame) 56 | 57 | # instantiate a TFRecordWriter for the file data 58 | tfRecordWriter = tf.python_io.TFRecordWriter(tfRecordFileName) 59 | 60 | # for each file (not each line) in the CSV file data . . . 61 | # (each image/.xml file pair can have more than one box, and therefore more than one line for that file in the CSV file) 62 | for singleFileData in csvFileDataList: 63 | tfExample = createTfExample(singleFileData, imagesDir) 64 | tfRecordWriter.write(tfExample.SerializeToString()) 65 | # end for 66 | tfRecordWriter.close() 67 | return True # return True to indicate success 68 | # end function 69 | 70 | ####################################################################################################################### 71 | def checkIfNecessaryPathsAndFilesExist(): 72 | if not os.path.exists(TRAIN_CSV_FILE_LOC): 73 | print('ERROR: TRAIN_CSV_FILE "' + TRAIN_CSV_FILE_LOC + '" does not seem to exist') 74 | return False 75 | # end if 76 | 77 | if not os.path.exists(TRAIN_IMAGES_DIR): 78 | print('ERROR: TRAIN_IMAGES_DIR "' + TRAIN_IMAGES_DIR + '" does not seem to exist') 79 | return False 80 | # end if 81 | 82 | if not os.path.exists(EVAL_CSV_FILE_LOC): 83 | print('ERROR: TEST_CSV_FILE "' + EVAL_CSV_FILE_LOC + '" does not seem to exist') 84 | return False 85 | # end if 86 | 87 | if not os.path.exists(TEST_IMAGES_DIR): 88 | print('ERROR: TEST_IMAGES_DIR "' + TEST_IMAGES_DIR + '" does not seem to exist') 89 | return False 90 | # end if 91 | 92 | return True 93 | # end function 94 | 95 | ####################################################################################################################### 96 | def reformatCsvFileData(csvFileDataFrame): 97 | # the purpose of this function is to translate the data from one CSV file in pandas.DataFrame format 98 | # into a list of the named tuple below, which then can be fed into TensorFlow 99 | 100 | # establish the named tuple data format 101 | dataFormat = namedtuple('data', ['filename', 'object']) 102 | 103 | # pandas.DataFrame.groupby() returns type pandas.core.groupby.DataFrameGroupBy 104 | csvFileDataFrameGroupBy = csvFileDataFrame.groupby('filename') 105 | 106 | # declare, populate, and return the list of named tuples of CSV data 107 | csvFileDataList = [] 108 | for filename, x in zip(csvFileDataFrameGroupBy.groups.keys(), csvFileDataFrameGroupBy.groups): 109 | csvFileDataList.append(dataFormat(filename, csvFileDataFrameGroupBy.get_group(x))) 110 | # end for 111 | return csvFileDataList 112 | # end function 113 | 114 | ####################################################################################################################### 115 | def createTfExample(singleFileData, path): 116 | # use TensorFlow's GFile function to open the .jpg image matching the current box data 117 | with tf.gfile.GFile(os.path.join(path, '{}'.format(singleFileData.filename)), 'rb') as tensorFlowImageFile: 118 | tensorFlowImage = tensorFlowImageFile.read() 119 | # end with 120 | 121 | # get the image width and height via converting from a TensorFlow image to an io library BytesIO image, 122 | # then to a PIL Image, then breaking out the width and height 123 | bytesIoImage = io.BytesIO(tensorFlowImage) 124 | pilImage = Image.open(bytesIoImage) 125 | width, height = pilImage.size 126 | 127 | # get the file name from the file data passed in, and set the image format to .jpg 128 | fileName = singleFileData.filename.encode('utf8') 129 | imageFormat = b'jpg' 130 | 131 | # declare empty lists for the box x, y, mins and maxes, and the class as text and as an integer 132 | xMins = [] 133 | xMaxs = [] 134 | yMins = [] 135 | yMaxs = [] 136 | classesAsText = [] 137 | classesAsInts = [] 138 | 139 | # for each row in the current .xml file's data . . . (each row in the .xml file corresponds to one box) 140 | for index, row in singleFileData.object.iterrows(): 141 | xMins.append(row['xmin'] / width) 142 | xMaxs.append(row['xmax'] / width) 143 | yMins.append(row['ymin'] / height) 144 | yMaxs.append(row['ymax'] / height) 145 | classesAsText.append(row['class'].encode('utf8')) 146 | classesAsInts.append(classAsTextToClassAsInt(row['class'])) 147 | # end for 148 | 149 | # finally we can calculate and return the TensorFlow Example 150 | tfExample = tf.train.Example(features=tf.train.Features(feature={ 151 | 'image/height': dataset_util.int64_feature(height), 152 | 'image/width': dataset_util.int64_feature(width), 153 | 'image/filename': dataset_util.bytes_feature(fileName), 154 | 'image/source_id': dataset_util.bytes_feature(fileName), 155 | 'image/encoded': dataset_util.bytes_feature(tensorFlowImage), 156 | 'image/format': dataset_util.bytes_feature(imageFormat), 157 | 'image/object/bbox/xmin': dataset_util.float_list_feature(xMins), 158 | 'image/object/bbox/xmax': dataset_util.float_list_feature(xMaxs), 159 | 'image/object/bbox/ymin': dataset_util.float_list_feature(yMins), 160 | 'image/object/bbox/ymax': dataset_util.float_list_feature(yMaxs), 161 | 'image/object/class/text': dataset_util.bytes_list_feature(classesAsText), 162 | 'image/object/class/label': dataset_util.int64_list_feature(classesAsInts)})) 163 | 164 | return tfExample 165 | # end function 166 | 167 | ####################################################################################################################### 168 | def classAsTextToClassAsInt(classAsText): 169 | 170 | # ToDo: If you have more than one classification, add an if statement for each 171 | # ToDo: i.e. if you have 3 classes, you would have 3 if statements and then the else 172 | 173 | if classAsText == 'traffic_light': 174 | return 1 175 | else: 176 | print("error in class_text_to_int(), row_label could not be identified") 177 | return -1 178 | # end if 179 | # end function 180 | 181 | ####################################################################################################################### 182 | if __name__ == '__main__': 183 | main() -------------------------------------------------------------------------------- /3_train.py: -------------------------------------------------------------------------------- 1 | # 3_train.py 2 | # 3 | # original source from Google: 4 | # https://github.com/tensorflow/models/blob/master/research/object_detection/train.py 5 | 6 | import functools 7 | import json 8 | import os 9 | import tensorflow as tf 10 | 11 | from object_detection.legacy import trainer 12 | from object_detection.builders import dataset_builder 13 | from object_detection.builders import model_builder 14 | from object_detection.utils import config_util 15 | from object_detection.utils import dataset_util 16 | 17 | # module-level variables ############################################################################################## 18 | 19 | # this is the big (pipeline).config file that contains various directory locations and many tunable parameters 20 | PIPELINE_CONFIG_PATH = os.getcwd() + "/" + "ssd_inception_v2_coco.config" 21 | 22 | # verify this extracted directory exists, 23 | # also verify it's the directory referred to by the 'fine_tune_checkpoint' parameter in your (pipeline).config file 24 | MODEL_DIR = os.getcwd() + "/" + "ssd_inception_v2_coco_2018_01_28" 25 | 26 | # verify that your MODEL_DIR contains these files 27 | FILES_MODEL_DIR_MUST_CONTAIN = [ "checkpoint" , 28 | "frozen_inference_graph.pb", 29 | "model.ckpt.data-00000-of-00001", 30 | "model.ckpt.index", 31 | "model.ckpt.meta"] 32 | 33 | # directory to save the checkpoints and training summaries 34 | TRAINING_DATA_DIR = os.getcwd() + "/training_data/" 35 | 36 | # number of clones to deploy per worker 37 | NUM_CLONES = 1 38 | 39 | # Force clones to be deployed on CPU. Note that even if set to False (allowing ops to run on gpu), 40 | # some ops may still be run on the CPU if they have no GPU kernel 41 | CLONE_ON_CPU = False 42 | 43 | ####################################################################################################################### 44 | # this next comment line is necessary to suppress a false PyCharm warning 45 | # noinspection PyUnresolvedReferences 46 | def main(_): 47 | print("starting program . . .") 48 | 49 | # show info to std out during the training process 50 | tf.logging.set_verbosity(tf.logging.INFO) 51 | 52 | if not checkIfNecessaryPathsAndFilesExist(): 53 | return 54 | # end if 55 | 56 | configs = config_util.get_configs_from_pipeline_file(PIPELINE_CONFIG_PATH) 57 | tf.gfile.Copy(PIPELINE_CONFIG_PATH, os.path.join(TRAINING_DATA_DIR, 'pipeline.config'), overwrite=True) 58 | 59 | model_config = configs['model'] 60 | train_config = configs['train_config'] 61 | input_config = configs['train_input_config'] 62 | 63 | model_fn = functools.partial(model_builder.build, model_config=model_config, is_training=True) 64 | 65 | # ToDo: this nested function seems odd, factor this out eventually ?? 66 | # nested function 67 | def get_next(config): 68 | return dataset_builder.make_initializable_iterator(dataset_builder.build(config)).get_next() 69 | # end nested function 70 | 71 | create_input_dict_fn = functools.partial(get_next, input_config) 72 | 73 | env = json.loads(os.environ.get('TF_CONFIG', '{}')) 74 | cluster_data = env.get('cluster', None) 75 | cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None 76 | task_data = env.get('task', None) or {'type': 'master', 'index': 0} 77 | task_info = type('TaskSpec', (object,), task_data) 78 | 79 | # parameters for a single worker 80 | ps_tasks = 0 81 | worker_replicas = 1 82 | worker_job_name = 'lonely_worker' 83 | task = 0 84 | is_chief = True 85 | master = '' 86 | 87 | if cluster_data and 'worker' in cluster_data: 88 | # number of total worker replicas include "worker"s and the "master". 89 | worker_replicas = len(cluster_data['worker']) + 1 90 | # end if 91 | 92 | if cluster_data and 'ps' in cluster_data: 93 | ps_tasks = len(cluster_data['ps']) 94 | # end if 95 | 96 | if worker_replicas > 1 and ps_tasks < 1: 97 | raise ValueError('At least 1 ps task is needed for distributed training.') 98 | # end if 99 | 100 | if worker_replicas >= 1 and ps_tasks > 0: 101 | # set up distributed training 102 | server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc', job_name=task_info.type, task_index=task_info.index) 103 | if task_info.type == 'ps': 104 | server.join() 105 | return 106 | # end if 107 | 108 | worker_job_name = '%s/task:%d' % (task_info.type, task_info.index) 109 | task = task_info.index 110 | is_chief = (task_info.type == 'master') 111 | master = server.target 112 | # end if 113 | 114 | trainer.train(create_input_dict_fn, model_fn, train_config, master, task, NUM_CLONES, worker_replicas, 115 | CLONE_ON_CPU, ps_tasks, worker_job_name, is_chief, TRAINING_DATA_DIR) 116 | 117 | ####################################################################################################################### 118 | def checkIfNecessaryPathsAndFilesExist(): 119 | if not os.path.exists(PIPELINE_CONFIG_PATH): 120 | print('ERROR: the big (pipeline).config file "' + PIPELINE_CONFIG_PATH + '" does not seem to exist') 121 | return False 122 | # end if 123 | 124 | missingModelMessage = "Did you download and extract the model from the TensorFlow GitHub models repository detection model zoo?" + "\n" + \ 125 | "https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md" + "\n" + \ 126 | "ssd_inception_v2_coco is recommended" 127 | 128 | # check if the model directory exists 129 | if not os.path.exists(MODEL_DIR): 130 | print('ERROR: the model directory "' + MODEL_DIR + '" does not seem to exist') 131 | print(missingModelMessage) 132 | return False 133 | # end if 134 | 135 | # check if each of the files that should be in the model directory are there 136 | for necessaryModelFileName in FILES_MODEL_DIR_MUST_CONTAIN: 137 | if not os.path.exists(os.path.join(MODEL_DIR, necessaryModelFileName)): 138 | print('ERROR: the model file "' + MODEL_DIR + "/" + necessaryModelFileName + '" does not seem to exist') 139 | print(missingModelMessage) 140 | return False 141 | # end if 142 | # end for 143 | 144 | if not os.path.exists(TRAINING_DATA_DIR): 145 | print('ERROR: TRAINING_DATA_DIR "' + TRAINING_DATA_DIR + '" does not seem to exist') 146 | return False 147 | # end if 148 | 149 | return True 150 | # end function 151 | 152 | ####################################################################################################################### 153 | if __name__ == '__main__': 154 | tf.app.run() 155 | -------------------------------------------------------------------------------- /4_export_inference_graph.py: -------------------------------------------------------------------------------- 1 | # export_inference_graph.py 2 | # 3 | # original file by Google: 4 | # https://github.com/tensorflow/models/blob/master/research/object_detection/export_inference_graph.py 5 | 6 | import os 7 | import tensorflow as tf 8 | from google.protobuf import text_format 9 | from object_detection import exporter 10 | from object_detection.protos import pipeline_pb2 11 | 12 | # module-level variables ############################################################################################## 13 | 14 | # INPUT_TYPE can be "image_tensor", "encoded_image_string_tensor", or "tf_example" 15 | INPUT_TYPE = "image_tensor" 16 | 17 | # If INPUT_TYPE is "image_tensor", INPUT_SHAPE can explicitly set. The shape of this input tensor to a fixed size. 18 | # The dimensions are to be provided as a comma-separated list of integers. A value of -1 can be used for unknown dimensions. 19 | # If not specified, for an image_tensor, the default shape will be partially specified as [None, None, None, 3] 20 | INPUT_SHAPE = None 21 | 22 | # the location of the big config file 23 | PIPELINE_CONFIG_LOC = os.getcwd() + "/" + "ssd_inception_v2_coco.config" 24 | 25 | # the final checkpoint result of the training process 26 | TRAINED_CHECKPOINT_PREFIX_LOC = os.getcwd() + "/training_data/model.ckpt-500" 27 | 28 | # the output directory to place the inference graph data, note that it's ok if this directory does not already exist 29 | # because the call to export_inference_graph() below will create this directory if it does not exist already 30 | OUTPUT_DIR = os.getcwd() + "/" + "inference_graph" 31 | 32 | ####################################################################################################################### 33 | def main(_): 34 | print("starting script . . .") 35 | 36 | if not checkIfNecessaryPathsAndFilesExist(): 37 | return 38 | # end if 39 | 40 | print("calling TrainEvalPipelineConfig() . . .") 41 | trainEvalPipelineConfig = pipeline_pb2.TrainEvalPipelineConfig() 42 | 43 | print("checking and merging " + os.path.basename(PIPELINE_CONFIG_LOC) + " into trainEvalPipelineConfig . . .") 44 | with tf.gfile.GFile(PIPELINE_CONFIG_LOC, 'r') as f: 45 | text_format.Merge(f.read(), trainEvalPipelineConfig) 46 | # end with 47 | 48 | print("calculating input shape . . .") 49 | if INPUT_SHAPE: 50 | input_shape = [ int(dim) if dim != '-1' else None for dim in INPUT_SHAPE.split(',') ] 51 | else: 52 | input_shape = None 53 | # end if 54 | 55 | print("calling export_inference_graph() . . .") 56 | exporter.export_inference_graph(INPUT_TYPE, trainEvalPipelineConfig, TRAINED_CHECKPOINT_PREFIX_LOC, OUTPUT_DIR, input_shape) 57 | 58 | print("done !!") 59 | # end main 60 | 61 | ####################################################################################################################### 62 | def checkIfNecessaryPathsAndFilesExist(): 63 | if not os.path.exists(PIPELINE_CONFIG_LOC): 64 | print('ERROR: PIPELINE_CONFIG_LOC "' + PIPELINE_CONFIG_LOC + '" does not seem to exist') 65 | return False 66 | # end if 67 | 68 | # TRAINED_CHECKPOINT_PREFIX_LOC is a special case because there is no actual file with this name. 69 | # i.e. if TRAINED_CHECKPOINT_PREFIX_LOC is: 70 | # "C:\Users\cdahms\Documents\TensorFlow_Tut_3_Object_Detection_Walk-through\training_data\training_data\model.ckpt-500" 71 | # this exact file does not exist, but there should be 3 files including this name, which would be: 72 | # "model.ckpt-500.data-00000-of-00001" 73 | # "model.ckpt-500.index" 74 | # "model.ckpt-500.meta" 75 | # therefore it's necessary to verify that the stated directory exists and then check if there are at least three files 76 | # in the stated directory that START with the stated name 77 | 78 | # break out the directory location and the file prefix 79 | trainedCkptPrefixPath, filePrefix = os.path.split(TRAINED_CHECKPOINT_PREFIX_LOC) 80 | 81 | # return false if the directory does not exist 82 | if not os.path.exists(trainedCkptPrefixPath): 83 | print('ERROR: directory "' + trainedCkptPrefixPath + '" does not seem to exist') 84 | print('was the training completed successfully?') 85 | return False 86 | # end if 87 | 88 | # count how many files in the stated directory start with the stated prefix 89 | numFilesThatStartWithPrefix = 0 90 | for fileName in os.listdir(trainedCkptPrefixPath): 91 | if fileName.startswith(filePrefix): 92 | numFilesThatStartWithPrefix += 1 93 | # end if 94 | # end if 95 | 96 | # if less than 3 files start with the stated prefix, return false 97 | if numFilesThatStartWithPrefix < 3: 98 | print('ERROR: 3 files statring with "' + filePrefix + '" do not seem to be present in the directory "' + trainedCkptPrefixPath + '"') 99 | print('was the training completed successfully?') 100 | # end if 101 | 102 | # if we get here the necessary directories and files are present, so return True 103 | return True 104 | # end function 105 | 106 | ####################################################################################################################### 107 | if __name__ == '__main__': 108 | tf.app.run() 109 | -------------------------------------------------------------------------------- /5_test.py: -------------------------------------------------------------------------------- 1 | # test.py 2 | 3 | import numpy as np 4 | import os 5 | import tensorflow as tf 6 | import cv2 7 | 8 | from utils import label_map_util 9 | from utils import visualization_utils as vis_util 10 | from distutils.version import StrictVersion 11 | 12 | # module level variables ############################################################################################## 13 | TEST_IMAGE_DIR = os.getcwd() + "/test_images" 14 | FROZEN_INFERENCE_GRAPH_LOC = os.getcwd() + "/inference_graph/frozen_inference_graph.pb" 15 | LABELS_LOC = os.getcwd() + "/" + "label_map.pbtxt" 16 | NUM_CLASSES = 1 17 | 18 | ####################################################################################################################### 19 | def main(): 20 | print("starting program . . .") 21 | 22 | if not checkIfNecessaryPathsAndFilesExist(): 23 | return 24 | # end if 25 | 26 | # this next comment line is necessary to avoid a false PyCharm warning 27 | # noinspection PyUnresolvedReferences 28 | if StrictVersion(tf.__version__) < StrictVersion('1.5.0'): 29 | raise ImportError('Please upgrade your tensorflow installation to v1.5.* or later!') 30 | # end if 31 | 32 | # load a (frozen) TensorFlow model into memory 33 | detection_graph = tf.Graph() 34 | with detection_graph.as_default(): 35 | od_graph_def = tf.GraphDef() 36 | with tf.gfile.GFile(FROZEN_INFERENCE_GRAPH_LOC, 'rb') as fid: 37 | serialized_graph = fid.read() 38 | od_graph_def.ParseFromString(serialized_graph) 39 | tf.import_graph_def(od_graph_def, name='') 40 | # end with 41 | # end with 42 | 43 | # Loading label map 44 | # Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine 45 | label_map = label_map_util.load_labelmap(LABELS_LOC) 46 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, 47 | use_display_name=True) 48 | category_index = label_map_util.create_category_index(categories) 49 | 50 | imageFilePaths = [] 51 | for imageFileName in os.listdir(TEST_IMAGE_DIR): 52 | if imageFileName.endswith(".jpg"): 53 | imageFilePaths.append(TEST_IMAGE_DIR + "/" + imageFileName) 54 | # end if 55 | # end for 56 | 57 | with detection_graph.as_default(): 58 | with tf.Session(graph=detection_graph) as sess: 59 | for image_path in imageFilePaths: 60 | 61 | print(image_path) 62 | 63 | image_np = cv2.imread(image_path) 64 | 65 | if image_np is None: 66 | print("error reading file " + image_path) 67 | continue 68 | # end if 69 | 70 | # Definite input and output Tensors for detection_graph 71 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 72 | # Each box represents a part of the image where a particular object was detected. 73 | detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 74 | # Each score represent how level of confidence for each of the objects. 75 | # Score is shown on the result image, together with the class label. 76 | detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') 77 | detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') 78 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 79 | 80 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3] 81 | image_np_expanded = np.expand_dims(image_np, axis=0) 82 | # Actual detection. 83 | (boxes, scores, classes, num) = sess.run( 84 | [detection_boxes, detection_scores, detection_classes, num_detections], 85 | feed_dict={image_tensor: image_np_expanded}) 86 | # Visualization of the results of a detection. 87 | vis_util.visualize_boxes_and_labels_on_image_array(image_np, 88 | np.squeeze(boxes), 89 | np.squeeze(classes).astype(np.int32), 90 | np.squeeze(scores), 91 | category_index, 92 | use_normalized_coordinates=True, 93 | line_thickness=8) 94 | cv2.imshow("image_np", image_np) 95 | cv2.waitKey() 96 | # end for 97 | # end with 98 | # end with 99 | # end main 100 | 101 | ####################################################################################################################### 102 | def checkIfNecessaryPathsAndFilesExist(): 103 | if not os.path.exists(TEST_IMAGE_DIR): 104 | print('ERROR: TEST_IMAGE_DIR "' + TEST_IMAGE_DIR + '" does not seem to exist') 105 | return False 106 | # end if 107 | 108 | # ToDo: check here that the test image directory contains at least one image 109 | 110 | if not os.path.exists(FROZEN_INFERENCE_GRAPH_LOC): 111 | print('ERROR: FROZEN_INFERENCE_GRAPH_LOC "' + FROZEN_INFERENCE_GRAPH_LOC + '" does not seem to exist') 112 | print('was the inference graph exported successfully?') 113 | return False 114 | # end if 115 | 116 | if not os.path.exists(LABELS_LOC): 117 | print('ERROR: the label map file "' + LABELS_LOC + '" does not seem to exist') 118 | return False 119 | # end if 120 | 121 | return True 122 | # end function 123 | 124 | ####################################################################################################################### 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Tutorial 3: Object Detection Walk-through 2 | 3 | See the documents "TensorFlow_Tut_3_Object_Detection_Walk-through" above (MS Word or PDF version) 4 | 5 | Walk-through video: 6 | https://www.youtube.com/watch?v=rWFg6R5ccOc 7 | -------------------------------------------------------------------------------- /TensorFlow_Tut_3_Object_Detection_Walk-through.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MicrocontrollersAndMore/TensorFlow_Tut_3_Object_Detection_Walk-through/09c54356716f0ffe9dad2c4652e33b72f585ca4f/TensorFlow_Tut_3_Object_Detection_Walk-through.docx -------------------------------------------------------------------------------- /label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'traffic_light' 4 | } 5 | -------------------------------------------------------------------------------- /ssd_inception_v2_coco.config: -------------------------------------------------------------------------------- 1 | # SSD with Inception v2 configuration for MSCOCO Dataset. 2 | # Users should configure the fine_tune_checkpoint field in the train config as 3 | # well as the label_map_path and input_path fields in the train_input_reader and 4 | # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that 5 | # should be configured. 6 | 7 | model { 8 | ssd { 9 | num_classes: 1 10 | box_coder { 11 | faster_rcnn_box_coder { 12 | y_scale: 10.0 13 | x_scale: 10.0 14 | height_scale: 5.0 15 | width_scale: 5.0 16 | } 17 | } 18 | matcher { 19 | argmax_matcher { 20 | matched_threshold: 0.5 21 | unmatched_threshold: 0.5 22 | ignore_thresholds: false 23 | negatives_lower_than_unmatched: true 24 | force_match_for_each_row: true 25 | } 26 | } 27 | similarity_calculator { 28 | iou_similarity { 29 | } 30 | } 31 | anchor_generator { 32 | ssd_anchor_generator { 33 | num_layers: 6 34 | min_scale: 0.2 35 | max_scale: 0.95 36 | aspect_ratios: 1.0 37 | aspect_ratios: 2.0 38 | aspect_ratios: 0.5 39 | aspect_ratios: 3.0 40 | aspect_ratios: 0.3333 41 | reduce_boxes_in_lowest_layer: true 42 | } 43 | } 44 | image_resizer { 45 | fixed_shape_resizer { 46 | height: 300 47 | width: 300 48 | } 49 | } 50 | box_predictor { 51 | convolutional_box_predictor { 52 | min_depth: 0 53 | max_depth: 0 54 | num_layers_before_predictor: 0 55 | use_dropout: false 56 | dropout_keep_probability: 0.8 57 | kernel_size: 3 58 | box_code_size: 4 59 | apply_sigmoid_to_scores: false 60 | conv_hyperparams { 61 | activation: RELU_6, 62 | regularizer { 63 | l2_regularizer { 64 | weight: 0.00004 65 | } 66 | } 67 | initializer { 68 | truncated_normal_initializer { 69 | stddev: 0.03 70 | mean: 0.0 71 | } 72 | } 73 | } 74 | } 75 | } 76 | feature_extractor { 77 | type: 'ssd_inception_v2' 78 | min_depth: 16 79 | depth_multiplier: 1.0 80 | conv_hyperparams { 81 | activation: RELU_6, 82 | regularizer { 83 | l2_regularizer { 84 | weight: 0.00004 85 | } 86 | } 87 | initializer { 88 | truncated_normal_initializer { 89 | stddev: 0.03 90 | mean: 0.0 91 | } 92 | } 93 | batch_norm { 94 | train: true, 95 | scale: true, 96 | center: true, 97 | decay: 0.9997, 98 | epsilon: 0.001, 99 | } 100 | } 101 | override_base_feature_extractor_hyperparams: true 102 | } 103 | loss { 104 | classification_loss { 105 | weighted_sigmoid { 106 | anchorwise_output: true 107 | } 108 | } 109 | localization_loss { 110 | weighted_smooth_l1 { 111 | anchorwise_output: true 112 | } 113 | } 114 | hard_example_miner { 115 | num_hard_examples: 3000 116 | iou_threshold: 0.99 117 | loss_type: CLASSIFICATION 118 | max_negatives_per_positive: 3 119 | min_negatives_per_image: 0 120 | } 121 | classification_weight: 1.0 122 | localization_weight: 1.0 123 | } 124 | normalize_loss_by_num_matches: true 125 | post_processing { 126 | batch_non_max_suppression { 127 | score_threshold: 1e-8 128 | iou_threshold: 0.6 129 | max_detections_per_class: 100 130 | max_total_detections: 100 131 | } 132 | score_converter: SIGMOID 133 | } 134 | } 135 | } 136 | 137 | train_config: { 138 | batch_size: 24 139 | optimizer { 140 | rms_prop_optimizer: { 141 | learning_rate: { 142 | exponential_decay_learning_rate { 143 | initial_learning_rate: 0.004 144 | decay_steps: 800720 145 | decay_factor: 0.95 146 | } 147 | } 148 | momentum_optimizer_value: 0.9 149 | decay: 0.9 150 | epsilon: 1.0 151 | } 152 | } 153 | fine_tune_checkpoint: "ssd_inception_v2_coco_2018_01_28/model.ckpt" 154 | from_detection_checkpoint: true 155 | # Note: The below line limits the training process to 200K steps, which we 156 | # empirically found to be sufficient enough to train the pets dataset. This 157 | # effectively bypasses the learning rate schedule (the learning rate will 158 | # never decay). Remove the below line to train indefinitely. 159 | num_steps: 500 160 | data_augmentation_options { 161 | random_horizontal_flip { 162 | } 163 | } 164 | data_augmentation_options { 165 | ssd_random_crop { 166 | } 167 | } 168 | } 169 | 170 | train_input_reader: { 171 | tf_record_input_reader { 172 | input_path: "training_data/train.tfrecord" 173 | } 174 | label_map_path: "label_map.pbtxt" 175 | } 176 | 177 | eval_config: { 178 | num_examples: 8000 179 | # Note: The below line limits the evaluation process to 10 evaluations. 180 | # Remove the below line to evaluate indefinitely. 181 | max_evals: 10 182 | } 183 | 184 | eval_input_reader: { 185 | tf_record_input_reader { 186 | input_path: "training_data/eval.tfrecord" 187 | } 188 | label_map_path: "label_map.pbtxt" 189 | shuffle: false 190 | num_readers: 1 191 | num_epochs: 1 192 | } --------------------------------------------------------------------------------