└── main.cc /main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | #include "tensorflow/cc/ops/const_op.h" 12 | #include "tensorflow/cc/ops/image_ops.h" 13 | #include "tensorflow/cc/ops/standard_ops.h" 14 | #include "tensorflow/core/framework/graph.pb.h" 15 | #include "tensorflow/core/framework/tensor.h" 16 | #include "tensorflow/core/graph/default_device.h" 17 | #include "tensorflow/core/graph/graph_def_builder.h" 18 | #include "tensorflow/core/lib/core/errors.h" 19 | #include "tensorflow/core/lib/core/stringpiece.h" 20 | #include "tensorflow/core/lib/core/threadpool.h" 21 | #include "tensorflow/core/lib/io/path.h" 22 | #include "tensorflow/core/lib/strings/stringprintf.h" 23 | #include "tensorflow/core/platform/env.h" 24 | #include "tensorflow/core/platform/init_main.h" 25 | #include "tensorflow/core/platform/logging.h" 26 | #include "tensorflow/core/platform/types.h" 27 | #include "tensorflow/core/public/session.h" 28 | #include "tensorflow/core/util/command_line_flags.h" 29 | 30 | // These are all common classes it's handy to reference with no namespace. 31 | using tensorflow::Flag; 32 | using tensorflow::Tensor; 33 | using tensorflow::Status; 34 | using tensorflow::string; 35 | using tensorflow::int32; 36 | using tensorflow::ops::Softmax; 37 | 38 | #define printTensor(T, d) \ 39 | std::cout<< (T).tensor() << std::endl; 40 | 41 | #include 42 | #include 43 | #include 44 | 45 | 46 | #define YOLOV3_SIZE 416 47 | #define IMG_CHANNELS 3 48 | 49 | 50 | float bboxThreshold = 0.4; // BBox threshold 51 | float nmsThreshold = 0.4; // Non-maximum suppression threshold 52 | std::vector classes; 53 | 54 | // Reads a model graph definition from disk, and creates a session object you 55 | // can use to run it. 56 | Status LoadGraph(const string& graph_file_name, 57 | std::unique_ptr* session) { 58 | tensorflow::GraphDef graph_def; 59 | Status load_graph_status = 60 | ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); 61 | if (!load_graph_status.ok()) { 62 | return tensorflow::errors::NotFound("Failed to load compute graph at '", 63 | graph_file_name, "'"); 64 | } 65 | session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); 66 | Status session_create_status = (*session)->Create(graph_def); 67 | if (!session_create_status.ok()) { 68 | return session_create_status; 69 | } 70 | return Status::OK(); 71 | } 72 | 73 | cv::Mat resizeKeepAspectRatio(const cv::Mat &input, int width, int height) 74 | { 75 | cv::Mat output; 76 | 77 | double h1 = width * (input.rows/(double)input.cols); 78 | double w2 = height * (input.cols/(double)input.rows); 79 | if( h1 <= height) { 80 | cv::resize( input, output, cv::Size(width, h1)); 81 | } else { 82 | cv::resize( input, output, cv::Size(w2, height)); 83 | } 84 | 85 | int top = (height - output.rows) / 2; 86 | int down = (height - output.rows + 1) / 2; 87 | int left = (width - output.cols) / 2; 88 | int right = (width - output.cols + 1) / 2; 89 | 90 | cv::copyMakeBorder(output, output, top, down, left, right, cv::BORDER_CONSTANT, cv::Scalar(128,128,128) ); 91 | 92 | return output; 93 | } 94 | 95 | Status readTensorFromMat(const cv::Mat &mat, Tensor &outTensor) { 96 | 97 | auto root = tensorflow::Scope::NewRootScope(); 98 | using namespace ::tensorflow::ops; 99 | // Trick from https://github.com/tensorflow/tensorflow/issues/8033 100 | float *p = outTensor.flat().data(); 101 | cv::Mat fakeMat(mat.rows, mat.cols, CV_32FC3, p); 102 | mat.convertTo(fakeMat, CV_32FC3, 1.f); 103 | 104 | auto input_tensor = Placeholder(root.WithOpName("input"), tensorflow::DT_FLOAT); 105 | std::vector> inputs = {{"input", outTensor}}; 106 | auto noOp = Identity(root.WithOpName("noOp"), outTensor); 107 | 108 | 109 | // This runs the GraphDef network definition that we've just constructed, and 110 | // returns the results in the output outTensor. 111 | tensorflow::GraphDef graph; 112 | TF_RETURN_IF_ERROR(root.ToGraphDef(&graph)); 113 | 114 | std::vector outTensors; 115 | std::unique_ptr session(tensorflow::NewSession(tensorflow::SessionOptions())); 116 | 117 | TF_RETURN_IF_ERROR(session->Create(graph)); 118 | TF_RETURN_IF_ERROR(session->Run({inputs}, {"noOp"}, {}, &outTensors)); 119 | 120 | outTensor = outTensors.at(0); 121 | return Status::OK(); 122 | } 123 | 124 | // Draw the predicted bounding box 125 | void drawPred(int classId, float conf, int left, int top, int right, int bottom, cv::Mat& frame) 126 | { 127 | //Draw a rectangle displaying the bounding box 128 | cv::rectangle(frame, cv::Point(left, top), cv::Point(right, bottom), cv::Scalar(255, 178, 50), 2); 129 | 130 | //Get the label for the class name and its confidence 131 | string label = cv::format("%.2f", conf); 132 | if (!classes.empty()) 133 | { 134 | label = classes[classId] + ":" + label; 135 | } 136 | 137 | //Display the label at the top of the bounding box 138 | int baseLine; 139 | cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); 140 | top = cv::max(top, labelSize.height); 141 | cv::rectangle(frame, cv::Point(left, top - round(1.5*labelSize.height)), 142 | cv::Point(left + round(1.5*labelSize.width), top + baseLine), cv::Scalar(255, 255, 255), cv::FILLED); 143 | cv::putText(frame, label, cv::Point(left, top), cv::FONT_HERSHEY_SIMPLEX, 0.75, cv::Scalar(0,0,0),1); 144 | } 145 | 146 | 147 | // Remove the bounding boxes with low confidence using non-maxima suppression 148 | void postprocess(cv::Mat& frame, const std::vector& outs) 149 | { 150 | std::vector classIds; 151 | std::vector confidences; 152 | std::vector boxes; 153 | 154 | for (size_t i = 0; i < outs.size(); ++i) 155 | { 156 | // Scan through all the bounding boxes output from the network and keep only the 157 | // ones with high confidence scores. Assign the box's class label as the class 158 | // with the highest score for the box. 159 | float* data = (float*)outs[i].data; 160 | for (int j = 0; j < outs[i].rows; ++j, data += outs[i].cols) 161 | { 162 | cv::Mat scores = outs[i].row(j).colRange(5, outs[i].cols); 163 | cv::Point classIdPoint; 164 | double confidence; 165 | //// Get the value and location of the maximum score 166 | cv::minMaxLoc(scores, 0, &confidence, 0, &classIdPoint); 167 | if (data[4] > bboxThreshold) 168 | { 169 | int x0 = (int)(data[0]); 170 | int y0 = (int)(data[1]); 171 | int x1 = (int)(data[2]); 172 | int y1 = (int)(data[3]); 173 | 174 | //recover bbox according to input size 175 | int current_size = YOLOV3_SIZE; 176 | int rows = frame.rows; 177 | int cols = frame.cols; 178 | float final_ratio = std::min((float)current_size/cols, (float)current_size/rows); 179 | int padx = 0.5f * (current_size - final_ratio * cols); 180 | int pady = 0.5f * (current_size - final_ratio * rows); 181 | 182 | x0 = (x0 - padx) / final_ratio; 183 | y0 = (y0 - pady) / final_ratio; 184 | x1 = (x1 - padx) / final_ratio; 185 | y1 = (y1 - pady) / final_ratio; 186 | 187 | int left = x0; 188 | int top = y0; 189 | int width = x1 - x0; 190 | int height = y1 - y0; 191 | 192 | classIds.push_back(classIdPoint.x); 193 | confidences.push_back((float)confidence); 194 | boxes.push_back(cv::Rect(left, top, width, height)); 195 | } 196 | } 197 | } 198 | 199 | // Perform non maximum suppression to eliminate redundant overlapping boxes with 200 | // lower confidences 201 | std::vector indices; 202 | cv::dnn::NMSBoxes(boxes, confidences, bboxThreshold, nmsThreshold, indices); 203 | for (size_t i = 0; i < indices.size(); ++i) 204 | { 205 | int idx = indices[i]; 206 | cv::Rect box = boxes[idx]; 207 | drawPred(classIds[idx], confidences[idx], box.x, box.y, 208 | box.x + box.width, box.y + box.height, frame); 209 | } 210 | } 211 | 212 | int main(int argc, char* argv[]) { 213 | 214 | string dataset = "dataset/"; 215 | string graph = "model/frozen_model.pb"; 216 | std::vector files; 217 | string input_layer = "inputs"; //input ops 218 | string final_out = "output_boxes"; //output ops 219 | string root_dir = ""; 220 | 221 | string classesFile = "coco.names"; 222 | std::ifstream ifs(classesFile.c_str()); 223 | string line; 224 | while (getline(ifs, line)) classes.push_back(line); 225 | 226 | // We need to call this to set up global state for TensorFlow. 227 | tensorflow::port::InitMain(argv[0], &argc, &argv); 228 | if (argc > 1) { 229 | LOG(ERROR) << "Unknown argument " << argv[1] << "\n"; 230 | return -1; 231 | } 232 | 233 | // First we load and initialize the model. 234 | std::unique_ptr session; 235 | string graph_path = tensorflow::io::JoinPath(root_dir, graph); 236 | Status load_graph_status = LoadGraph(graph_path, &session); 237 | if (!load_graph_status.ok()) { 238 | LOG(ERROR) << load_graph_status; 239 | return -1; 240 | } 241 | cv::VideoCapture cap; 242 | if(!cap.open(0)) { 243 | return 0; 244 | } 245 | 246 | for(;;) { 247 | 248 | cv::Mat srcImage, rgbImage; 249 | cap >> srcImage; 250 | if(srcImage.empty()){ 251 | break; 252 | } 253 | cv::cvtColor(srcImage, rgbImage, CV_BGR2RGB); 254 | cv::Mat padImage = resizeKeepAspectRatio(rgbImage, YOLOV3_SIZE, YOLOV3_SIZE); 255 | 256 | Tensor resized_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, YOLOV3_SIZE, YOLOV3_SIZE, IMG_CHANNELS})); 257 | Status read_tensor_status = readTensorFromMat(padImage, resized_tensor); 258 | if (!read_tensor_status.ok()) { 259 | LOG(ERROR) << read_tensor_status; 260 | return -1; 261 | } 262 | 263 | // Actually run the image through the model. 264 | std::vector outputs; 265 | Status run_status = session->Run({{input_layer, resized_tensor}}, 266 | {final_out}, {}, &outputs); 267 | if (!run_status.ok()) { 268 | LOG(ERROR) << "Running model failed: " << run_status; 269 | return -1; 270 | } 271 | //std::cout << outputs[0].shape() << "\n"; 272 | float *p = outputs[0].flat().data(); 273 | cv::Mat result(outputs[0].dim_size(1), outputs[0].dim_size(2), CV_32FC(1), p); 274 | std::vector outs; 275 | outs.push_back (result); 276 | 277 | postprocess(rgbImage, outs); 278 | 279 | cv::cvtColor(rgbImage, srcImage , CV_RGB2BGR); 280 | cv::imshow( "Yolov3", srcImage ); 281 | if( cv::waitKey(10) == 27 ) break; 282 | 283 | } 284 | return 0; 285 | } 286 | --------------------------------------------------------------------------------