├── README.md ├── create_tf_record.py ├── data ├── 1.jpg ├── 2.jpg └── 3.jpg ├── mask_rcnn.cpp ├── mask_rcnn.py ├── mask_rcnn_inception_v2_coco.config ├── maskrcnn.ipynb ├── read_pbtxt_file.py ├── road.pbtxt ├── string_int_label_map_pb2.py └── tf_text_graph_mask_rcnn.py /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow_maksrcnn_opencv 2 | A Train Test demo of maskrcnn_opencv 3 | 4 | #### depandend 5 | - python3 6 | - opencv (>3.4.3) 7 | - tensorflow_gpu (12.0) 8 | - cuda9.0 cudnn7.0 9 | 10 | #### 数据生成 11 | 12 | - install labelme 13 | ``` 14 | sudo pip3 install labelme 15 | ``` 16 | 17 | - generate tf.record 18 | 19 | ``` 20 | python3 create_tf_record.py \ 21 | --images_dir=path_to_images_dir \ 22 | --annotations_json_dir=path_to_train_annotations_json_dir \ 23 | --label_map_path=path_to_label_map.pbtxt \ 24 | --output_path=path_to_train.record 25 | ``` 26 | #### download pretrain model 27 | 28 | [下载inception预训练模型](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md/) 29 | 30 | #### download tensorflow models 31 | ``` 32 | git clone https://github.com/tensorflow/models.git 33 | ``` 34 | 35 | #### traing 36 | 在models/research/object_detection目录下执行 37 | ``` 38 | python model_main.py \ 39 | --model_dir=path/to/save/directory \ 40 | --pipeline_config_path=path/to/mask_rcnn_inception_v2_xxx.config 41 | ``` 42 | 43 | #### pb2graph 44 | 45 | get model graph txt 46 | ``` 47 | python tf_text_graph_rcnn.py 48 | ``` 49 | 50 | ![Image text](https://github.com/mahxn0/tensorflow_maksrcnn_opencv/blob/master/data/1.jpg) 51 | 52 | ![Image text](https://github.com/mahxn0/tensorflow_maksrcnn_opencv/blob/master/data/2.jpg) 53 | 54 | 55 | ![Image text](https://github.com/mahxn0/tensorflow_maksrcnn_opencv/blob/master/data/3.jpg) 56 | -------------------------------------------------------------------------------- /create_tf_record.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Aug 26 10:57:09 2018 5 | 6 | @author: shirhe-lyh 7 | """ 8 | 9 | """Convert raw dataset to TFRecord for object_detection. 10 | 11 | Please note that this tool only applies to labelme's annotations(json file). 12 | 13 | Example usage: 14 | python3 create_tf_record.py \ 15 | --images_dir=your absolute path to read images. 16 | --annotations_json_dir=your path to annotaion json files. 17 | --label_map_path=your path to label_map.pbtxt 18 | --output_path=your path to write .record. 19 | """ 20 | 21 | import cv2 22 | import glob 23 | import hashlib 24 | import io 25 | import json 26 | import numpy as np 27 | import os 28 | import PIL.Image 29 | import tensorflow as tf 30 | 31 | import read_pbtxt_file 32 | 33 | 34 | flags = tf.app.flags 35 | 36 | flags.DEFINE_string('images_dir', None, 'Path to images directory.') 37 | flags.DEFINE_string('annotations_json_dir', 'datasets/annotations', 38 | 'Path to annotations directory.') 39 | flags.DEFINE_string('label_map_path', None, 'Path to label map proto.') 40 | flags.DEFINE_string('output_path', None, 'Path to the output tfrecord.') 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | 45 | def int64_feature(value): 46 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 47 | 48 | 49 | def int64_list_feature(value): 50 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 51 | 52 | 53 | def bytes_feature(value): 54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 55 | 56 | 57 | def bytes_list_feature(value): 58 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 59 | 60 | 61 | def float_list_feature(value): 62 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 63 | 64 | 65 | def create_tf_example(annotation_dict, label_map_dict=None): 66 | """Converts image and annotations to a tf.Example proto. 67 | 68 | Args: 69 | annotation_dict: A dictionary containing the following keys: 70 | ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg', 71 | 'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks', 72 | 'class_names']. 73 | label_map_dict: A dictionary maping class_names to indices. 74 | 75 | Returns: 76 | example: The converted tf.Example. 77 | 78 | Raises: 79 | ValueError: If label_map_dict is None or is not containing a class_name. 80 | """ 81 | if annotation_dict is None: 82 | return None 83 | if label_map_dict is None: 84 | raise ValueError('`label_map_dict` is None') 85 | 86 | height = annotation_dict.get('height', None) 87 | width = annotation_dict.get('width', None) 88 | filename = annotation_dict.get('filename', None) 89 | sha256_key = annotation_dict.get('sha256_key', None) 90 | encoded_jpg = annotation_dict.get('encoded_jpg', None) 91 | image_format = annotation_dict.get('format', None) 92 | xmins = annotation_dict.get('xmins', None) 93 | xmaxs = annotation_dict.get('xmaxs', None) 94 | ymins = annotation_dict.get('ymins', None) 95 | ymaxs = annotation_dict.get('ymaxs', None) 96 | masks = annotation_dict.get('masks', None) 97 | class_names = annotation_dict.get('class_names', None) 98 | print("class_names:",class_names) 99 | labels = [] 100 | for class_name in class_names: 101 | label = label_map_dict.get(class_name, 'None') 102 | print("label:",label) 103 | if label is None: 104 | raise ValueError('`label_map_dict` is not containing {}.'.format( 105 | class_name)) 106 | labels.append(label) 107 | 108 | encoded_masks = [] 109 | for mask in masks: 110 | pil_image = PIL.Image.fromarray(mask.astype(np.uint8)) 111 | output_io = io.BytesIO() 112 | pil_image.save(output_io, format='PNG') 113 | encoded_masks.append(output_io.getvalue()) 114 | 115 | feature_dict = { 116 | 'image/height': int64_feature(height), 117 | 'image/width': int64_feature(width), 118 | 'image/filename': bytes_feature(filename.encode('utf8')), 119 | 'image/source_id': bytes_feature(filename.encode('utf8')), 120 | 'image/key/sha256': bytes_feature(sha256_key.encode('utf8')), 121 | 'image/encoded': bytes_feature(encoded_jpg), 122 | 'image/format': bytes_feature(image_format.encode('utf8')), 123 | 'image/object/bbox/xmin': float_list_feature(xmins), 124 | 'image/object/bbox/xmax': float_list_feature(xmaxs), 125 | 'image/object/bbox/ymin': float_list_feature(ymins), 126 | 'image/object/bbox/ymax': float_list_feature(ymaxs), 127 | 'image/object/mask': bytes_list_feature(encoded_masks), 128 | 'image/object/class/label': int64_list_feature(labels)} 129 | example = tf.train.Example(features=tf.train.Features( 130 | feature=feature_dict)) 131 | return example 132 | 133 | 134 | def _get_annotation_dict(images_dir, annotation_json_path): 135 | """Get boundingboxes and masks. 136 | 137 | Args: 138 | images_dir: Path to images directory. 139 | annotation_json_path: Path to annotated json file corresponding to 140 | the image. The json file annotated by labelme with keys: 141 | ['lineColor', 'imageData', 'fillColor', 'imagePath', 'shapes', 142 | 'flags']. 143 | 144 | Returns: 145 | annotation_dict: A dictionary containing the following keys: 146 | ['height', 'width', 'filename', 'sha256_key', 'encoded_jpg', 147 | 'format', 'xmins', 'xmaxs', 'ymins', 'ymaxs', 'masks', 148 | 'class_names']. 149 | # 150 | # Raises: 151 | # ValueError: If images_dir or annotation_json_path is not exist. 152 | """ 153 | # if not os.path.exists(images_dir): 154 | # raise ValueError('`images_dir` is not exist.') 155 | # 156 | # if not os.path.exists(annotation_json_path): 157 | # raise ValueError('`annotation_json_path` is not exist.') 158 | 159 | if (not os.path.exists(images_dir) or 160 | not os.path.exists(annotation_json_path)): 161 | return None 162 | 163 | with open(annotation_json_path, 'r') as f: 164 | json_text = json.load(f) 165 | shapes = json_text.get('shapes', None) 166 | if shapes is None: 167 | return None 168 | image_relative_path = json_text.get('imagePath', None) 169 | print("imagePath",image_relative_path) 170 | if image_relative_path is None: 171 | return None 172 | image_name = image_relative_path.split('/')[-1] 173 | image_path = os.path.join(images_dir, image_name) 174 | image_format = image_name.split('.')[-1].replace('jpg', 'jpeg') 175 | if not os.path.exists(image_path): 176 | return None 177 | 178 | with tf.gfile.GFile(image_path, 'rb') as fid: 179 | encoded_jpg = fid.read() 180 | image = cv2.imread(image_path) 181 | height = image.shape[0] 182 | width = image.shape[1] 183 | key = hashlib.sha256(encoded_jpg).hexdigest() 184 | 185 | xmins = [] 186 | xmaxs = [] 187 | ymins = [] 188 | ymaxs = [] 189 | masks = [] 190 | class_names = [] 191 | hole_polygons = [] 192 | for mark in shapes: 193 | class_name = mark.get('label') 194 | class_names.append(class_name) 195 | polygon = mark.get('points') 196 | polygon = np.array(polygon) 197 | if class_name == 'hole': 198 | hole_polygons.append(polygon) 199 | else: 200 | mask = np.zeros(image.shape[:2]) 201 | cv2.fillPoly(mask, [polygon], 1) 202 | masks.append(mask) 203 | 204 | # Boundingbox 205 | x = polygon[:, 0] 206 | y = polygon[:, 1] 207 | xmin = np.min(x) 208 | xmax = np.max(x) 209 | ymin = np.min(y) 210 | ymax = np.max(y) 211 | xmins.append(float(xmin) / width) 212 | xmaxs.append(float(xmax) / width) 213 | ymins.append(float(ymin) / height) 214 | ymaxs.append(float(ymax) / height) 215 | # Remove holes in mask 216 | for mask in masks: 217 | mask = cv2.fillPoly(mask, hole_polygons, 0) 218 | 219 | annotation_dict = {'height': height, 220 | 'width': width, 221 | 'filename': image_name, 222 | 'sha256_key': key, 223 | 'encoded_jpg': encoded_jpg, 224 | 'format': image_format, 225 | 'xmins': xmins, 226 | 'xmaxs': xmaxs, 227 | 'ymins': ymins, 228 | 'ymaxs': ymaxs, 229 | 'masks': masks, 230 | 'class_names': class_names} 231 | return annotation_dict 232 | 233 | 234 | def main(_): 235 | if not os.path.exists(FLAGS.images_dir): 236 | raise ValueError('`images_dir` is not exist.') 237 | if not os.path.exists(FLAGS.annotations_json_dir): 238 | raise ValueError('`annotations_json_dir` is not exist.') 239 | if not os.path.exists(FLAGS.label_map_path): 240 | raise ValueError('`label_map_path` is not exist.') 241 | 242 | label_map = read_pbtxt_file.get_label_map_dict(FLAGS.label_map_path) 243 | 244 | writer = tf.python_io.TFRecordWriter(FLAGS.output_path) 245 | 246 | num_annotations_skiped = 0 247 | annotations_json_path = os.path.join(FLAGS.annotations_json_dir, '*.json') 248 | for i, annotation_file in enumerate(glob.glob(annotations_json_path)): 249 | if i % 100 == 0: 250 | print('On image %d', i) 251 | 252 | annotation_dict = _get_annotation_dict(FLAGS.images_dir, annotation_file) 253 | if annotation_dict is None: 254 | num_annotations_skiped += 1 255 | continue 256 | tf_example = create_tf_example(annotation_dict, label_map) 257 | writer.write(tf_example.SerializeToString()) 258 | 259 | print('Successfully created TFRecord to {}.'.format(FLAGS.output_path)) 260 | 261 | 262 | if __name__ == '__main__': 263 | tf.app.run() 264 | 265 | -------------------------------------------------------------------------------- /data/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahxn0/tensorflow_maksrcnn_opencv/fffab00d57689f2ef556e7fb141eefc1b708f76c/data/1.jpg -------------------------------------------------------------------------------- /data/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahxn0/tensorflow_maksrcnn_opencv/fffab00d57689f2ef556e7fb141eefc1b708f76c/data/2.jpg -------------------------------------------------------------------------------- /data/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahxn0/tensorflow_maksrcnn_opencv/fffab00d57689f2ef556e7fb141eefc1b708f76c/data/3.jpg -------------------------------------------------------------------------------- /mask_rcnn.cpp: -------------------------------------------------------------------------------- 1 | // This code is written at BigVision LLC. It is based on the OpenCV project. It is subject to the license terms in the LICENSE file found in this distribution and at http://opencv.org/license.html 2 | 3 | // Usage example: ./mask_rcnn.out --video=run.mp4 4 | // ./mask_rcnn.out --image=bird.jpg 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | const char* keys = 15 | "{help h usage ? | | Usage examples: \n\t\t./mask-rcnn.out --image=traffic.jpg \n\t\t./mask-rcnn.out --video=sample.mp4}" 16 | "{image i || input image }" 17 | "{video v || input video }" 18 | ; 19 | using namespace cv; 20 | using namespace dnn; 21 | using namespace std; 22 | 23 | // Initialize the parameters 24 | float confThreshold = 0.5; // Confidence threshold 25 | float maskThreshold = 0.3; // Mask threshold 26 | 27 | vector classes; 28 | vector colors; 29 | 30 | // Draw the predicted bounding box 31 | void drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask); 32 | 33 | // Postprocess the neural network's output for each frame 34 | void postprocess(Mat& frame, const vector& outs); 35 | 36 | int main(int argc, char** argv) 37 | { 38 | CommandLineParser parser(argc, argv, keys); 39 | parser.about("Use this script to run object detection using YOLO3 in OpenCV."); 40 | if (parser.has("help")) 41 | { 42 | parser.printMessage(); 43 | return 0; 44 | } 45 | // Load names of classes 46 | string classesFile = "labels.names"; 47 | ifstream ifs(classesFile.c_str()); 48 | string line; 49 | while (getline(ifs, line)) classes.push_back(line); 50 | 51 | // Load the colors 52 | string colorsFile = "colors.txt"; 53 | ifstream colorFptr(colorsFile.c_str()); 54 | while (getline(colorFptr, line)) { 55 | char* pEnd; 56 | double r, g, b; 57 | r = strtod (line.c_str(), &pEnd); 58 | g = strtod (pEnd, NULL); 59 | b = strtod (pEnd, NULL); 60 | Scalar color = Scalar(r, g, b, 255.0); 61 | colors.push_back(Scalar(r, g, b, 255.0)); 62 | } 63 | 64 | // Give the configuration and weight files for the model 65 | String textGraph = "./inference_512/graph.pbtxt"; 66 | String modelWeights = "./inference_512/frozen_inference_graph.pb"; 67 | // Load the network 68 | Net net = readNetFromTensorflow(modelWeights, textGraph); 69 | net.setPreferableBackend(DNN_BACKEND_OPENCV); 70 | net.setPreferableTarget(DNN_TARGET_CPU); 71 | 72 | // Open a video file or an image file or a camera stream. 73 | string str, outputFile; 74 | VideoCapture cap; 75 | VideoWriter video; 76 | Mat frame, blob; 77 | 78 | try { 79 | 80 | outputFile = "mask_rcnn_out_cpp.avi"; 81 | if (parser.has("image")) 82 | { 83 | // Open the image file 84 | str = parser.get("image"); 85 | //cout << "Image file input : " << str << endl; 86 | ifstream ifile(str); 87 | if (!ifile) throw("error"); 88 | cap.open(str); 89 | str.replace(str.end()-4, str.end(), "_mask_rcnn_out.jpg"); 90 | outputFile = str; 91 | } 92 | else if (parser.has("video")) 93 | { 94 | // Open the video file 95 | str = parser.get("video"); 96 | ifstream ifile(str); 97 | if (!ifile) throw("error"); 98 | cap.open(str); 99 | str.replace(str.end()-4, str.end(), "_mask_rcnn_out.avi"); 100 | outputFile = str; 101 | } 102 | // Open the webcam 103 | else cap.open(parser.get("device")); 104 | 105 | } 106 | catch(...) { 107 | cout << "Could not open the input image/video stream" << endl; 108 | return 0; 109 | } 110 | 111 | // Get the video writer initialized to save the output video 112 | if (!parser.has("image")) { 113 | video.open(outputFile, VideoWriter::fourcc('M','J','P','G'), 28, Size(cap.get(CAP_PROP_FRAME_WIDTH), cap.get(CAP_PROP_FRAME_HEIGHT))); 114 | } 115 | 116 | // Create a window 117 | static const string kWinName = "Deep learning object detection in OpenCV"; 118 | namedWindow(kWinName, WINDOW_NORMAL); 119 | 120 | // Process frames. 121 | while (waitKey(1) < 0) 122 | { 123 | // get frame from the video 124 | cap >> frame; 125 | 126 | // Stop the program if reached end of video 127 | if (frame.empty()) { 128 | cout << "Done processing !!!" << endl; 129 | cout << "Output file is stored as " << outputFile << endl; 130 | waitKey(3000); 131 | break; 132 | } 133 | // Create a 4D blob from a frame. 134 | blobFromImage(frame, blob, 1.0, Size(frame.cols, frame.rows), Scalar(), true, false); 135 | //blobFromImage(frame, blob); 136 | 137 | //Sets the input to the network 138 | net.setInput(blob); 139 | 140 | // Runs the forward pass to get output from the output layers 141 | std::vector outNames(2); 142 | outNames[0] = "detection_out_final"; 143 | outNames[1] = "detection_masks"; 144 | vector outs; 145 | net.forward(outs, outNames); 146 | 147 | // Extract the bounding box and mask for each of the detected objects 148 | postprocess(frame, outs); 149 | 150 | // Put efficiency information. The function getPerfProfile returns the overall time for inference(t) and the timings for each of the layers(in layersTimes) 151 | vector layersTimes; 152 | double freq = getTickFrequency() / 1000; 153 | double t = net.getPerfProfile(layersTimes) / freq; 154 | string label = format("Mask-RCNN on 2.5 GHz Intel Core i7 CPU, Inference time for a frame : %0.0f ms", t); 155 | putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0)); 156 | 157 | // Write the frame with the detection boxes 158 | Mat detectedFrame; 159 | frame.convertTo(detectedFrame, CV_8U); 160 | if (parser.has("image")) imwrite(outputFile, detectedFrame); 161 | else video.write(detectedFrame); 162 | 163 | imshow(kWinName, frame); 164 | 165 | } 166 | 167 | cap.release(); 168 | if (!parser.has("image")) video.release(); 169 | 170 | return 0; 171 | } 172 | 173 | // For each frame, extract the bounding box and mask for each detected object 174 | void postprocess(Mat& frame, const vector& outs) 175 | { 176 | Mat outDetections = outs[0]; 177 | Mat outMasks = outs[1]; 178 | 179 | // Output size of masks is NxCxHxW where 180 | // N - number of detected boxes 181 | // C - number of classes (excluding background) 182 | // HxW - segmentation shape 183 | const int numDetections = outDetections.size[2]; 184 | const int numClasses = outMasks.size[1]; 185 | 186 | outDetections = outDetections.reshape(1, outDetections.total() / 7); 187 | for (int i = 0; i < numDetections; ++i) 188 | { 189 | float score = outDetections.at(i, 2); 190 | if (score > confThreshold) 191 | { 192 | // Extract the bounding box 193 | int classId = static_cast(outDetections.at(i, 1)); 194 | int left = static_cast(frame.cols * outDetections.at(i, 3)); 195 | int top = static_cast(frame.rows * outDetections.at(i, 4)); 196 | int right = static_cast(frame.cols * outDetections.at(i, 5)); 197 | int bottom = static_cast(frame.rows * outDetections.at(i, 6)); 198 | 199 | left = max(0, min(left, frame.cols - 1)); 200 | top = max(0, min(top, frame.rows - 1)); 201 | right = max(0, min(right, frame.cols - 1)); 202 | bottom = max(0, min(bottom, frame.rows - 1)); 203 | Rect box = Rect(left, top, right - left + 1, bottom - top + 1); 204 | 205 | // Extract the mask for the object 206 | Mat objectMask(outMasks.size[2], outMasks.size[3],CV_32F, outMasks.ptr(i,classId)); 207 | 208 | // Draw bounding box, colorize and show the mask on the image 209 | drawBox(frame, classId, score, box, objectMask); 210 | 211 | } 212 | } 213 | } 214 | 215 | // Draw the predicted bounding box, colorize and show the mask on the image 216 | void drawBox(Mat& frame, int classId, float conf, Rect box, Mat& objectMask) 217 | { 218 | //Draw a rectangle displaying the bounding box 219 | rectangle(frame, Point(box.x, box.y), Point(box.x+box.width, box.y+box.height), Scalar(255, 178, 50), 3); 220 | 221 | //Get the label for the class name and its confidence 222 | string label = format("%.2f", conf); 223 | if (!classes.empty()) 224 | { 225 | CV_Assert(classId < (int)classes.size()); 226 | label = classes[classId] + ":" + label; 227 | } 228 | 229 | //Display the label at the top of the bounding box 230 | int baseLine; 231 | Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); 232 | box.y = max(box.y, labelSize.height); 233 | rectangle(frame, Point(box.x, box.y - round(1.5*labelSize.height)), Point(box.x + round(1.5*labelSize.width), box.y + baseLine), Scalar(255, 255, 255), FILLED); 234 | putText(frame, label, Point(box.x, box.y), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0,0,0),1); 235 | 236 | Scalar color = colors[classId%colors.size()]; 237 | 238 | // Resize the mask, threshold, color and apply it on the image 239 | resize(objectMask, objectMask, Size(box.width, box.height)); 240 | Mat mask = (objectMask > maskThreshold); 241 | Mat coloredRoi = (0.3 * color + 0.7 * frame(box)); 242 | coloredRoi.convertTo(coloredRoi, CV_8UC3); 243 | 244 | // Draw the contours on the image 245 | vector contours; 246 | Mat hierarchy; 247 | mask.convertTo(mask, CV_8U); 248 | findContours(mask, contours, hierarchy, RETR_CCOMP, CHAIN_APPROX_SIMPLE); 249 | drawContours(coloredRoi, contours, -1, color, 5, LINE_8, hierarchy, 100); 250 | coloredRoi.copyTo(frame(box), mask); 251 | 252 | } 253 | -------------------------------------------------------------------------------- /mask_rcnn.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import argparse 3 | import numpy as np 4 | import os.path 5 | import sys 6 | import random 7 | 8 | # Initialize the parameters 9 | confThreshold = 0.5 # Confidence threshold 10 | maskThreshold = 0.3 # Mask threshold 11 | 12 | parser = argparse.ArgumentParser(description='Use this script to run Mask-RCNN object detection and segmentation') 13 | parser.add_argument('--image', help='Path to image file') 14 | parser.add_argument('--video', help='Path to video file.') 15 | args = parser.parse_args() 16 | 17 | # Draw the predicted bounding box, colorize and show the mask on the image 18 | def drawBox(frame, classId, conf, left, top, right, bottom, classMask): 19 | # Draw a bounding box. 20 | needle_box=frame[top:bottom,left:right] 21 | cv.rectangle(frame, (left, top), (right, bottom), (255, 178, 50), 3) 22 | print(left,top,left,right) 23 | # Print a label of class. 24 | label = '%.2f' % conf 25 | if classes: 26 | assert(classId < len(classes)) 27 | label = '%s:%s' % (classes[classId], label) 28 | print("label",label) 29 | # Display the label at the top of the bounding box 30 | labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1) 31 | top = max(top, labelSize[1]) 32 | cv.rectangle(frame, (left, top - round(1.5*labelSize[1])), (left + round(1.5*labelSize[0]), top + baseLine), (255, 255, 255), cv.FILLED) 33 | cv.putText(frame, label, (left, top), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0,0,0), 1) 34 | # Resize the mask, threshold, color and apply it on the image 35 | classMask = cv.resize(classMask, (right - left + 1, bottom - top + 1)) 36 | mask = (classMask > maskThreshold) 37 | roi = frame[top:bottom+1, left:right+1][mask] 38 | 39 | # color = colors[classId%len(colors)] 40 | # Comment the above line and uncomment the two lines below to generate different instance colors 41 | colorIndex = random.randint(0, len(colors)-1) 42 | color = colors[colorIndex] 43 | 44 | frame[top:bottom+1, left:right+1][mask] = ([0.3*color[0], 0.3*color[1], 0.3*color[2]] + 0.7 * roi).astype(np.uint8) 45 | 46 | # Draw the contours on the image 47 | mask = mask.astype(np.uint8) 48 | #im2, contours, hierarchy = cv.findContours(mask,cv.RETR_TREE,cv.CHAIN_APPROX_SIMPLE) 49 | #cv.drawContours(frame[top:bottom+1, left:right+1], contours, -1, color, 3, cv.LINE_8, hierarchy, 100) 50 | 51 | # For each frame, extract the bounding box and mask for each detected object 52 | def postprocess(boxes, masks): 53 | # Output size of masks is NxCxHxW where 54 | # N - number of detected boxes 55 | # C - number of classes (excluding background) 56 | # HxW - segmentation shape 57 | numClasses = masks.shape[1] 58 | numDetections = boxes.shape[2] 59 | 60 | frameH = frame.shape[0] 61 | frameW = frame.shape[1] 62 | 63 | print("numClasses %d" % numClasses) 64 | print("numDetections %d" % numDetections ) 65 | for i in range(numDetections): 66 | box = boxes[0, 0, i] 67 | 68 | mask = masks[i] 69 | score = box[2] 70 | if score > confThreshold: 71 | classId = int(box[1]) 72 | print("classId %d" % classId) 73 | print("score %d" % score) 74 | # Extract the bounding box 75 | left = int(frameW * box[3]) 76 | top = int(frameH * box[4]) 77 | right = int(frameW * box[5]) 78 | bottom = int(frameH * box[6]) 79 | 80 | left = max(0, min(left, frameW - 1)) 81 | top = max(0, min(top, frameH - 1)) 82 | right = max(0, min(right, frameW - 1)) 83 | bottom = max(0, min(bottom, frameH - 1)) 84 | 85 | # Extract the mask for the object 86 | classMask = mask[classId] 87 | 88 | # Draw bounding box, colorize and show the mask on the image 89 | drawBox(frame, classId, score, left, top, right, bottom, classMask) 90 | 91 | 92 | # Load names of classes 93 | classesFile = "labels.names"; 94 | classes = None 95 | with open(classesFile, 'rt') as f: 96 | classes = f.read().rstrip('\n').split('\n') 97 | 98 | # Give the textGraph and weight files for the model 99 | #textGraph = "./mask_rcnn_inception_v2_coco_2018_01_28.pbtxt"; 100 | #modelWeights = "./mask_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb"; 101 | textGraph = "./inference_512/graph.pbtxt"; 102 | #textGraph="./mask_rcnn_inception_v2_coco_2018_11_28/graph.pbtxt" 103 | #modelWeights = "./mask_rcnn_inception_v2_coco_2018_11_28/saved_model/saved_model.pb"; 104 | #modelWeights = "./mask_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb"; 105 | modelWeights ="./inference_512/frozen_inference_graph.pb" 106 | # Load the network 107 | net = cv.dnn.readNetFromTensorflow(modelWeights, textGraph); 108 | net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV) 109 | net.setPreferableTarget(cv.dnn.DNN_TARGET_CPU) 110 | 111 | # Load the classes 112 | colorsFile = "colors.txt"; 113 | with open(colorsFile, 'rt') as f: 114 | colorsStr = f.read().rstrip('\n').split('\n') 115 | colors = [] #[0,0,0] 116 | for i in range(len(colorsStr)): 117 | rgb = colorsStr[i].split(' ') 118 | color = np.array([float(rgb[0]), float(rgb[1]), float(rgb[2])]) 119 | colors.append(color) 120 | 121 | winName = 'Mask-RCNN Object detection and Segmentation in OpenCV' 122 | cv.namedWindow(winName, cv.WINDOW_NORMAL) 123 | 124 | outputFile = "mask_rcnn_out_py.avi" 125 | if (args.image): 126 | # Open the image file 127 | if not os.path.isfile(args.image): 128 | print("Input image file ", args.image, " doesn't exist") 129 | sys.exit(1) 130 | cap = cv.VideoCapture(args.image) 131 | outputFile = args.image[:-4]+'_mask_rcnn_out_py.jpg' 132 | elif (args.video): 133 | # Open the video file 134 | if not os.path.isfile(args.video): 135 | print("Input video file ", args.video, " doesn't exist") 136 | sys.exit(1) 137 | cap = cv.VideoCapture(args.video) 138 | outputFile = args.video[:-4]+'_mask_rcnn_out_py.avi' 139 | else: 140 | # Webcam input 141 | cap = cv.VideoCapture(0) 142 | 143 | # Get the video writer initialized to save the output video 144 | if (not args.image): 145 | vid_writer = cv.VideoWriter(outputFile, cv.VideoWriter_fourcc('M','J','P','G'), 28, (round(cap.get(cv.CAP_PROP_FRAME_WIDTH)),round(cap.get(cv.CAP_PROP_FRAME_HEIGHT)))) 146 | 147 | while cv.waitKey(1) < 0: 148 | 149 | # Get frame from the video 150 | hasFrame, frame = cap.read() 151 | 152 | # Stop the program if reached end of video 153 | if not hasFrame: 154 | print("Done processing !!!") 155 | print("Output file is stored as ", outputFile) 156 | cv.waitKey(3000) 157 | break 158 | 159 | # Create a 4D blob from a frame. 160 | blob = cv.dnn.blobFromImage(frame, swapRB=True, crop=False) 161 | 162 | # Set the input to the network 163 | net.setInput(blob) 164 | 165 | # Run the forward pass to get output from the output layers 166 | boxes, masks = net.forward(['detection_out_final', 'detection_masks']) 167 | print(type(boxes),type(boxes)) 168 | print (boxes,masks) 169 | # Extract the bounding box and mask for each of the detected objects 170 | postprocess(boxes, masks) 171 | 172 | # Put efficiency information. 173 | t, _ = net.getPerfProfile() 174 | #cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0)) 175 | # Write the frame with the detection boxes 176 | if (args.image): 177 | cv.imwrite(outputFile, frame.astype(np.uint8)); 178 | else: 179 | vid_writer.write(frame.astype(np.uint8)) 180 | 181 | cv.imshow(winName, frame) 182 | 183 | -------------------------------------------------------------------------------- /mask_rcnn_inception_v2_coco.config: -------------------------------------------------------------------------------- 1 | # Mask R-CNN with Inception V2 2 | # Configured for MSCOCO Dataset. 3 | # Users should configure the fine_tune_checkpoint field in the train config as 4 | # well as the label_map_path and input_path fields in the train_input_reader and 5 | # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that 6 | # should be configured. 7 | 8 | model { 9 | faster_rcnn { 10 | num_classes: 12 11 | image_resizer { 12 | keep_aspect_ratio_resizer { 13 | min_dimension: 800 14 | max_dimension: 1365 15 | } 16 | } 17 | number_of_stages: 3 18 | feature_extractor { 19 | type: 'faster_rcnn_inception_v2' 20 | first_stage_features_stride: 16 21 | } 22 | first_stage_anchor_generator { 23 | grid_anchor_generator { 24 | scales: [0.25, 0.5, 1.0, 2.0] 25 | aspect_ratios: [0.5, 1.0, 2.0] 26 | height_stride: 16 27 | width_stride: 16 28 | } 29 | } 30 | first_stage_box_predictor_conv_hyperparams { 31 | op: CONV 32 | regularizer { 33 | l2_regularizer { 34 | weight: 0.0 35 | } 36 | } 37 | initializer { 38 | truncated_normal_initializer { 39 | stddev: 0.01 40 | } 41 | } 42 | } 43 | first_stage_nms_score_threshold: 0.0 44 | first_stage_nms_iou_threshold: 0.7 45 | first_stage_max_proposals: 300 46 | first_stage_localization_loss_weight: 2.0 47 | first_stage_objectness_loss_weight: 1.0 48 | initial_crop_size: 14 49 | maxpool_kernel_size: 2 50 | maxpool_stride: 2 51 | second_stage_box_predictor { 52 | mask_rcnn_box_predictor { 53 | use_dropout: false 54 | dropout_keep_probability: 1.0 55 | predict_instance_masks: true 56 | mask_height: 15 57 | mask_width: 15 58 | mask_prediction_conv_depth: 0 59 | mask_prediction_num_conv_layers: 2 60 | fc_hyperparams { 61 | op: FC 62 | regularizer { 63 | l2_regularizer { 64 | weight: 0.0 65 | } 66 | } 67 | initializer { 68 | variance_scaling_initializer { 69 | factor: 1.0 70 | uniform: true 71 | mode: FAN_AVG 72 | } 73 | } 74 | } 75 | conv_hyperparams { 76 | op: CONV 77 | regularizer { 78 | l2_regularizer { 79 | weight: 0.0 80 | } 81 | } 82 | initializer { 83 | truncated_normal_initializer { 84 | stddev: 0.01 85 | } 86 | } 87 | } 88 | } 89 | } 90 | second_stage_post_processing { 91 | batch_non_max_suppression { 92 | score_threshold: 0.0 93 | iou_threshold: 0.6 94 | max_detections_per_class: 100 95 | max_total_detections: 300 96 | } 97 | score_converter: SOFTMAX 98 | } 99 | second_stage_localization_loss_weight: 2.0 100 | second_stage_classification_loss_weight: 1.0 101 | second_stage_mask_prediction_loss_weight: 4.0 102 | } 103 | } 104 | 105 | train_config: { 106 | batch_size: 1 107 | optimizer { 108 | momentum_optimizer: { 109 | learning_rate: { 110 | manual_step_learning_rate { 111 | initial_learning_rate: 0.0002 112 | schedule { 113 | step: 2000 114 | learning_rate: .00002 115 | } 116 | schedule { 117 | step: 5000 118 | learning_rate: .000002 119 | } 120 | } 121 | } 122 | momentum_optimizer_value: 0.9 123 | } 124 | use_moving_average: false 125 | } 126 | gradient_clipping_by_norm: 10.0 127 | fine_tune_checkpoint: "/home/mahxn0/M_DeepLearning/models/research/object_detection/mask_rcnn_test/RoadMarking/mask_rcnn_inception_v2_coco_2018_01_28/model.ckpt" 128 | from_detection_checkpoint: true 129 | # Note: The below line limits the RoadMarking process to 200K steps, which we 130 | # empirically found to be sufficient enough to train the pets dataset. This 131 | # effectively bypasses the learning rate schedule (the learning rate will 132 | # never decay). Remove the below line to train indefinitely. 133 | num_steps: 4000 134 | data_augmentation_options { 135 | random_horizontal_flip { 136 | } 137 | } 138 | } 139 | 140 | train_input_reader: { 141 | tf_record_input_reader { 142 | input_path: "/home/mahxn0/M_DeepLearning/models/research/object_detection/mask_rcnn_test/RoadMarking/tf_train.record" 143 | } 144 | label_map_path: "/home/mahxn0/M_DeepLearning/models/research/object_detection/mask_rcnn_test/RoadMarking/road.pbtxt" 145 | load_instance_masks: true 146 | mask_type: PNG_MASKS 147 | } 148 | 149 | eval_config: { 150 | num_examples: 8000 151 | # Note: The below line limits the evaluation process to 10 evaluations. 152 | # Remove the below line to evaluate indefinitely. 153 | max_evals: 10 154 | } 155 | 156 | eval_input_reader: { 157 | tf_record_input_reader { 158 | input_path: "/home/mahxn0/M_DeepLearning/models/research/object_detection/mask_rcnn_test/RoadMarking/tf_val.record" 159 | } 160 | label_map_path: "/home/mahxn0/M_DeepLearning/models/research/object_detection/mask_rcnn_test/RoadMarking/road.pbtxt" 161 | load_instance_masks: true 162 | mask_type: PNG_MASKS 163 | shuffle: false 164 | num_readers: 1 165 | } 166 | -------------------------------------------------------------------------------- /read_pbtxt_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Aug 26 13:42:50 2018 5 | 6 | @author: shirhe-lyh 7 | """ 8 | 9 | """A tool to read .pbtxt file. 10 | 11 | See Details at: 12 | TensorFlow models/research/object_detetion/protos/string_int_label_pb2.py 13 | TensorFlow models/research/object_detection/utils/label_map_util.py 14 | """ 15 | import sys 16 | import tensorflow as tf 17 | 18 | from google.protobuf import text_format 19 | 20 | sys.path.append('/home/mahxn0/M_DeepLearning/models/research/object_detection') 21 | #import string_int_label_map_pb2 22 | from protos import string_int_label_map_pb2 23 | 24 | def load_pbtxt_file(path): 25 | """Read .pbtxt file. 26 | 27 | Args: 28 | path: Path to StringIntLabelMap proto text file (.pbtxt file). 29 | 30 | Returns: 31 | A StringIntLabelMapProto. 32 | 33 | Raises: 34 | ValueError: If path is not exist. 35 | """ 36 | if not tf.gfile.Exists(path): 37 | raise ValueError('`path` is not exist.') 38 | 39 | with tf.gfile.GFile(path, 'r') as fid: 40 | pbtxt_string = fid.read() 41 | pbtxt = string_int_label_map_pb2.StringIntLabelMap() 42 | try: 43 | text_format.Merge(pbtxt_string, pbtxt) 44 | except text_format.ParseError: 45 | pbtxt.ParseFromString(pbtxt_string) 46 | return pbtxt 47 | 48 | 49 | def get_label_map_dict(path): 50 | """Reads a .pbtxt file and returns a dictionary. 51 | 52 | Args: 53 | path: Path to StringIntLabelMap proto text file. 54 | 55 | Returns: 56 | A dictionary mapping class names to indices. 57 | """ 58 | pbtxt = load_pbtxt_file(path) 59 | 60 | result_dict = {} 61 | for item in pbtxt.item: 62 | result_dict[item.name] = item.id 63 | return result_dict 64 | 65 | -------------------------------------------------------------------------------- /road.pbtxt: -------------------------------------------------------------------------------- 1 | item{ 2 | id:1 3 | name:'RoadMarking_LongSolidLine' 4 | } 5 | 6 | item{ 7 | id:2 8 | name:'RoadMarking_DottedLine' 9 | } 10 | 11 | item{ 12 | id:3 13 | name:'RoadMarking_ArrowLine' 14 | } 15 | 16 | item{ 17 | id:4 18 | name:'RoadMarking_EntranceLine' 19 | } 20 | 21 | item{ 22 | id:5 23 | name:'RoadMarking_TransverseSolidLine' 24 | } 25 | 26 | item{ 27 | id:6 28 | name:'RoadMarking_Sidewalk' 29 | } 30 | 31 | item{ 32 | id:7 33 | name:'RoadMarking_DottedLineChangXi' 34 | } 35 | 36 | item{ 37 | id:8 38 | name:'mark' 39 | } 40 | 41 | item{ 42 | id:9 43 | name:'RoadMarking_MeshLine' 44 | } 45 | 46 | item{ 47 | id:10 48 | name:'RoadMarking_DecelerationHeng' 49 | } 50 | 51 | item{ 52 | id:11 53 | name:'RoadMarking_DecelerationZong' 54 | } 55 | 56 | item{ 57 | id:12 58 | name:'RoadMarking_DottedLineDuanXi' 59 | } 60 | -------------------------------------------------------------------------------- /string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='object_detection/protos/string_int_label_map.proto', 19 | package='object_detection.protos', 20 | syntax='proto2', 21 | serialized_options=None, 22 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 23 | ) 24 | 25 | 26 | 27 | 28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 29 | name='StringIntLabelMapItem', 30 | full_name='object_detection.protos.StringIntLabelMapItem', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | serialized_options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | serialized_options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | serialized_options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | serialized_options=None, 63 | is_extendable=False, 64 | syntax='proto2', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=79, 69 | serialized_end=150, 70 | ) 71 | 72 | 73 | _STRINGINTLABELMAP = _descriptor.Descriptor( 74 | name='StringIntLabelMap', 75 | full_name='object_detection.protos.StringIntLabelMap', 76 | filename=None, 77 | file=DESCRIPTOR, 78 | containing_type=None, 79 | fields=[ 80 | _descriptor.FieldDescriptor( 81 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 82 | number=1, type=11, cpp_type=10, label=3, 83 | has_default_value=False, default_value=[], 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | serialized_options=None, file=DESCRIPTOR), 87 | ], 88 | extensions=[ 89 | ], 90 | nested_types=[], 91 | enum_types=[ 92 | ], 93 | serialized_options=None, 94 | is_extendable=False, 95 | syntax='proto2', 96 | extension_ranges=[], 97 | oneofs=[ 98 | ], 99 | serialized_start=152, 100 | serialized_end=233, 101 | ) 102 | 103 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 104 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 105 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 106 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 107 | 108 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 109 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 110 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 111 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 112 | )) 113 | _sym_db.RegisterMessage(StringIntLabelMapItem) 114 | 115 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 116 | DESCRIPTOR = _STRINGINTLABELMAP, 117 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 118 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 119 | )) 120 | _sym_db.RegisterMessage(StringIntLabelMap) 121 | 122 | 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /tf_text_graph_mask_rcnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from tf_text_graph_common import * 4 | 5 | parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' 6 | 'Mask-RCNN model from TensorFlow Object Detection API. ' 7 | 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') 8 | parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') 9 | parser.add_argument('--output', required=True, help='Path to output text graph.') 10 | parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.') 11 | args = parser.parse_args() 12 | 13 | scopesToKeep = ('FirstStageFeatureExtractor', 'Conv', 14 | 'FirstStageBoxPredictor/BoxEncodingPredictor', 15 | 'FirstStageBoxPredictor/ClassPredictor', 16 | 'CropAndResize', 17 | 'MaxPool2D', 18 | 'SecondStageFeatureExtractor', 19 | 'SecondStageBoxPredictor', 20 | 'Preprocessor/sub', 21 | 'Preprocessor/mul', 22 | 'image_tensor') 23 | 24 | scopesToIgnore = ('FirstStageFeatureExtractor/Assert', 25 | 'FirstStageFeatureExtractor/Shape', 26 | 'FirstStageFeatureExtractor/strided_slice', 27 | 'FirstStageFeatureExtractor/GreaterEqual', 28 | 'FirstStageFeatureExtractor/LogicalAnd') 29 | 30 | # Load a config file. 31 | config = readTextMessage(args.config) 32 | config = config['model'][0]['faster_rcnn'][0] 33 | num_classes = int(config['num_classes'][0]) 34 | 35 | grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0] 36 | scales = [float(s) for s in grid_anchor_generator['scales']] 37 | aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']] 38 | width_stride = float(grid_anchor_generator['width_stride'][0]) 39 | height_stride = float(grid_anchor_generator['height_stride'][0]) 40 | features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0]) 41 | first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0]) 42 | first_stage_max_proposals = int(config['first_stage_max_proposals'][0]) 43 | 44 | print('Number of classes: %d' % num_classes) 45 | print('Scales: %s' % str(scales)) 46 | print('Aspect ratios: %s' % str(aspect_ratios)) 47 | print('Width stride: %f' % width_stride) 48 | print('Height stride: %f' % height_stride) 49 | print('Features stride: %f' % features_stride) 50 | 51 | # Read the graph. 52 | writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks']) 53 | graph_def = parseTextGraph(args.output) 54 | 55 | removeIdentity(graph_def) 56 | 57 | def to_remove(name, op): 58 | return name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ 59 | (name.startswith('CropAndResize') and op != 'CropAndResize') 60 | 61 | removeUnusedNodesAndAttrs(to_remove, graph_def) 62 | 63 | 64 | # Connect input node to the first layer 65 | assert(graph_def.node[0].op == 'Placeholder') 66 | graph_def.node[1].input.insert(0, graph_def.node[0].name) 67 | 68 | # Temporarily remove top nodes. 69 | topNodes = [] 70 | numCropAndResize = 0 71 | while True: 72 | node = graph_def.node.pop() 73 | topNodes.append(node) 74 | if node.op == 'CropAndResize': 75 | numCropAndResize += 1 76 | if numCropAndResize == 2: 77 | break 78 | 79 | addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', 80 | 'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def) 81 | 82 | addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', 83 | 'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4 84 | 85 | addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax', 86 | 'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def) 87 | 88 | # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd 89 | addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd', 90 | 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def) 91 | 92 | proposals = NodeDef() 93 | proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) 94 | proposals.op = 'PriorBox' 95 | proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') 96 | proposals.input.append(graph_def.node[0].name) # image_tensor 97 | 98 | proposals.addAttr('flip', False) 99 | proposals.addAttr('clip', True) 100 | proposals.addAttr('step', features_stride) 101 | proposals.addAttr('offset', 0.0) 102 | proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2]) 103 | 104 | widths = [] 105 | heights = [] 106 | for a in aspect_ratios: 107 | for s in scales: 108 | ar = np.sqrt(a) 109 | heights.append((features_stride**2) * s / ar) 110 | widths.append((features_stride**2) * s * ar) 111 | 112 | proposals.addAttr('width', widths) 113 | proposals.addAttr('height', heights) 114 | 115 | graph_def.node.extend([proposals]) 116 | 117 | # Compare with Reshape_5 118 | detectionOut = NodeDef() 119 | detectionOut.name = 'detection_out' 120 | detectionOut.op = 'DetectionOutput' 121 | 122 | detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') 123 | detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten') 124 | detectionOut.input.append('proposals') 125 | 126 | detectionOut.addAttr('num_classes', 2) 127 | detectionOut.addAttr('share_location', True) 128 | detectionOut.addAttr('background_label_id', 0) 129 | detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold) 130 | detectionOut.addAttr('top_k', 6000) 131 | detectionOut.addAttr('code_type', "CENTER_SIZE") 132 | detectionOut.addAttr('keep_top_k', first_stage_max_proposals) 133 | detectionOut.addAttr('clip', True) 134 | 135 | graph_def.node.extend([detectionOut]) 136 | 137 | # Save as text. 138 | cropAndResizeNodesNames = [] 139 | for node in reversed(topNodes): 140 | if node.op != 'CropAndResize': 141 | graph_def.node.extend([node]) 142 | topNodes.pop() 143 | else: 144 | cropAndResizeNodesNames.append(node.name) 145 | if numCropAndResize == 1: 146 | break 147 | else: 148 | graph_def.node.extend([node]) 149 | topNodes.pop() 150 | numCropAndResize -= 1 151 | 152 | addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def) 153 | 154 | addSlice('SecondStageBoxPredictor/Reshape_1/softmax', 155 | 'SecondStageBoxPredictor/Reshape_1/slice', 156 | [0, 0, 1], [-1, -1, -1], graph_def) 157 | 158 | addReshape('SecondStageBoxPredictor/Reshape_1/slice', 159 | 'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def) 160 | 161 | # Replace Flatten subgraph onto a single node. 162 | for i in reversed(range(len(graph_def.node))): 163 | if graph_def.node[i].op == 'CropAndResize': 164 | graph_def.node[i].input.insert(1, 'detection_out') 165 | 166 | if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': 167 | addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def) 168 | 169 | graph_def.node[i].input.pop() 170 | graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2') 171 | 172 | if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', 173 | 'SecondStageBoxPredictor/Flatten/flatten/strided_slice', 174 | 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape', 175 | 'SecondStageBoxPredictor/Flatten_1/flatten/Shape', 176 | 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice', 177 | 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']: 178 | del graph_def.node[i] 179 | 180 | for node in graph_def.node: 181 | if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \ 182 | node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape': 183 | node.op = 'Flatten' 184 | node.input.pop() 185 | 186 | if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D', 187 | 'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']: 188 | node.addAttr('loc_pred_transposed', True) 189 | 190 | if node.name.startswith('MaxPool2D'): 191 | assert(node.op == 'MaxPool') 192 | assert(len(cropAndResizeNodesNames) == 2) 193 | node.input = [cropAndResizeNodesNames[0]] 194 | del cropAndResizeNodesNames[0] 195 | 196 | ################################################################################ 197 | ### Postprocessing 198 | ################################################################################ 199 | addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def) 200 | 201 | variance = NodeDef() 202 | variance.name = 'proposals/variance' 203 | variance.op = 'Const' 204 | variance.addAttr('value', [0.1, 0.1, 0.2, 0.2]) 205 | graph_def.node.extend([variance]) 206 | 207 | varianceEncoder = NodeDef() 208 | varianceEncoder.name = 'variance_encoded' 209 | varianceEncoder.op = 'Mul' 210 | varianceEncoder.input.append('SecondStageBoxPredictor/Reshape') 211 | varianceEncoder.input.append(variance.name) 212 | varianceEncoder.addAttr('axis', 2) 213 | graph_def.node.extend([varianceEncoder]) 214 | 215 | addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def) 216 | addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def) 217 | 218 | detectionOut = NodeDef() 219 | detectionOut.name = 'detection_out_final' 220 | detectionOut.op = 'DetectionOutput' 221 | 222 | detectionOut.input.append('variance_encoded/flatten') 223 | detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') 224 | detectionOut.input.append('detection_out/slice/reshape') 225 | 226 | detectionOut.addAttr('num_classes', num_classes) 227 | detectionOut.addAttr('share_location', False) 228 | detectionOut.addAttr('background_label_id', num_classes + 1) 229 | detectionOut.addAttr('nms_threshold', 0.6) 230 | detectionOut.addAttr('code_type', "CENTER_SIZE") 231 | detectionOut.addAttr('keep_top_k',100) 232 | detectionOut.addAttr('clip', True) 233 | detectionOut.addAttr('variance_encoded_in_target', True) 234 | detectionOut.addAttr('confidence_threshold', 0.3) 235 | detectionOut.addAttr('group_by_classes', False) 236 | graph_def.node.extend([detectionOut]) 237 | 238 | for node in reversed(topNodes): 239 | graph_def.node.extend([node]) 240 | 241 | if node.name.startswith('MaxPool2D'): 242 | assert(node.op == 'MaxPool') 243 | assert(len(cropAndResizeNodesNames) == 1) 244 | node.input = [cropAndResizeNodesNames[0]] 245 | 246 | for i in reversed(range(len(graph_def.node))): 247 | if graph_def.node[i].op == 'CropAndResize': 248 | graph_def.node[i].input.insert(1, 'detection_out_final') 249 | break 250 | 251 | graph_def.node[-1].name = 'detection_masks' 252 | graph_def.node[-1].op = 'Sigmoid' 253 | graph_def.node[-1].input.pop() 254 | 255 | # Save as text. 256 | graph_def.save(args.output) 257 | --------------------------------------------------------------------------------