├── s_l.jpg ├── selfie.jpg ├── README.md ├── main_export_onnx.py ├── main.cpp └── main.py /s_l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/yolov5-face-landmarks-opencv-v2/HEAD/s_l.jpg -------------------------------------------------------------------------------- /selfie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/yolov5-face-landmarks-opencv-v2/HEAD/selfie.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # yolov5-face-landmarks-opencv-v2 2 | 更新的yolov5检测人脸和关键点,只依赖opencv库就可以运行,程序包含C++和Python两个版本的。 3 | 对比上一个版本的,现在这个版本的不同之处在于: 4 | 5 | (1).分辨率调整为640x640 6 | 7 | (2).提供了yolov5s, yolov5m, yolov5l三种检测人脸+关键点的模型 8 | 9 | (3). 后处理方式稍有不同 10 | 11 | (4). 在yolov5网络结构里的第一个模块是StemBlock,不再是FCOUS 12 | 13 | onnx文件在百度云盘,下载链接:https://pan.baidu.com/s/1KXtKykQ0qroA5WyO3hko0g 14 | 提取码:376l 15 | -------------------------------------------------------------------------------- /main_export_onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import cv2 5 | import copy 6 | import yaml 7 | from models.experimental import attempt_load 8 | from models.yolo import parse_model 9 | from utils.datasets import letterbox 10 | from utils.general import check_img_size 11 | from models.common import Conv, Contract 12 | from utils.activations import Hardswish, SiLU 13 | 14 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | stride = [8, 16, 32] 16 | 17 | def test_export(opt): 18 | ch = 3 19 | with open(opt.cfg) as f: 20 | yaml_info = yaml.load(f, Loader=yaml.FullLoader) 21 | 22 | anchors = yaml_info['anchors'] 23 | nc = yaml_info['nc'] 24 | na = len(anchors[0]) // 2 25 | no = nc + 5 + 10 26 | nl = len(anchors) 27 | 28 | _, save = parse_model(yaml_info, ch=[ch]) 29 | model = attempt_load(opt.weights, map_location=device) # load FP32 model 30 | # Load model 31 | img_size = opt.imgsize 32 | conf_thres = 0.3 33 | iou_thres = 0.5 34 | 35 | orgimg = cv2.imread(opt.image) # BGR 36 | img0 = copy.deepcopy(orgimg) 37 | assert orgimg is not None, 'Image Not Found ' + opt.image 38 | h0, w0 = orgimg.shape[:2] # orig hw 39 | r = img_size / max(h0, w0) # resize image to img_size 40 | if r != 1: # always resize down, only resize up if training with augmentation 41 | interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR 42 | img0 = cv2.resize(img0, (int(w0 * r), int(h0 * r)), interpolation=interp) 43 | 44 | imgsz = check_img_size(img_size, s=model.stride.max()) # check img_size 45 | 46 | img = letterbox(img0, new_shape=imgsz)[0] 47 | # Convert 48 | img = img[:, :, ::-1].transpose(2, 0, 1).copy() # BGR to RGB, to 3x416x416 49 | 50 | # Run inference 51 | img = torch.from_numpy(img).to(device) 52 | img = img.float() # uint8 to fp16/32 53 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 54 | if img.ndimension() == 3: 55 | img = img.unsqueeze(0) 56 | 57 | # Inference 58 | pred = model(img)[0] 59 | 60 | x = copy.deepcopy(img) 61 | onnxmodel = model.model 62 | y = [] 63 | for m in onnxmodel: 64 | if m.f != -1: # if not from previous layer 65 | x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] 66 | x = m(x) # run 67 | y.append(x if m.i in save else None) # save output 68 | print(torch.equal(x[0], pred)) 69 | return onnxmodel, img, save, na, no 70 | 71 | class my_yolov5_model(nn.Module): 72 | def __init__(self, model, save, na, no): 73 | super().__init__() 74 | self.model = model 75 | self.contract = Contract(gain=2) 76 | self.len_model = len(model) 77 | self.save = save 78 | self.na = na 79 | self.no = no 80 | def forward(self, x): 81 | y = [] 82 | for m in self.model: 83 | if m.f != -1: # if not from previous layer 84 | x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers 85 | 86 | x = m(x) # run 87 | y.append(x if m.i in self.save else None) # save output 88 | 89 | x[0] = x[0].view(-1, self.no) 90 | x[1] = x[1].view(-1, self.no) 91 | x[2] = x[2].view(-1, self.no) 92 | return torch.cat(x, 0) 93 | 94 | if __name__=='__main__': 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='yaml file path') 97 | parser.add_argument('--weights', type=str, default='weights/yolov5s-face.pt', help='model.pt path') 98 | parser.add_argument('--image', type=str, default='data/images/test.jpg', help='source') # file/folder, 0 for webcam 99 | parser.add_argument('--imgsize', type=int, default=640, help='inference size (pixels)') 100 | opt = parser.parse_args() 101 | 102 | onnxmodel, img, save, na, no = test_export(opt) 103 | 104 | onnxmodel[-1].export = True 105 | net = my_yolov5_model(onnxmodel, save, na, no).to(device) 106 | net.eval() 107 | # with torch.no_grad(): 108 | # out = net(img) 109 | # print(out) 110 | 111 | f = opt.weights.replace('.pt', '.onnx') # filename 112 | input = torch.zeros(1, 3, opt.imgsize, opt.imgsize).to(device) 113 | # Update model 114 | for k, m in net.named_modules(): 115 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 116 | if isinstance(m, Conv): # assign export-friendly activations 117 | if isinstance(m.act, nn.Hardswish): 118 | m.act = Hardswish() 119 | elif isinstance(m.act, nn.SiLU): 120 | m.act = SiLU() 121 | torch.onnx.export(net, input, f, verbose=False, opset_version=12, input_names=['data'], output_names=['out']) 122 | 123 | cvnet = cv2.dnn.readNet(f) 124 | input = cv2.imread(opt.image) 125 | input = cv2.resize(input, (opt.imgsize,opt.imgsize)) 126 | blob = cv2.dnn.blobFromImage(input) 127 | cvnet.setInput(blob) 128 | outs = cvnet.forward(cvnet.getUnconnectedOutLayersNames()) -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace cv; 9 | using namespace dnn; 10 | using namespace std; 11 | 12 | struct Net_config 13 | { 14 | float confThreshold; // Confidence threshold 15 | float nmsThreshold; // Non-maximum suppression threshold 16 | float objThreshold; //Object Confidence threshold 17 | string netname; 18 | }; 19 | 20 | class YOLO 21 | { 22 | public: 23 | YOLO(Net_config config); 24 | void detect(Mat& frame); 25 | private: 26 | const float anchors[3][6] = { {4,5, 8,10, 13,16}, {23,29, 43,55, 73,105},{146,217, 231,300, 335,433} }; 27 | const float stride[3] = { 8.0, 16.0, 32.0 }; 28 | const int inpWidth = 640; 29 | const int inpHeight = 640; 30 | float confThreshold; 31 | float nmsThreshold; 32 | float objThreshold; 33 | 34 | char netname[20]; 35 | Net net; 36 | void drawPred(float conf, int left, int top, int right, int bottom, Mat& frame, vector landmark); 37 | void sigmoid(Mat* out, int length); 38 | }; 39 | 40 | static inline float sigmoid_x(float x) 41 | { 42 | return static_cast(1.f / (1.f + exp(-x))); 43 | } 44 | 45 | YOLO::YOLO(Net_config config) 46 | { 47 | cout << "Net use " << config.netname << endl; 48 | this->confThreshold = config.confThreshold; 49 | this->nmsThreshold = config.nmsThreshold; 50 | this->objThreshold = config.objThreshold; 51 | strcpy_s(this->netname, config.netname.c_str()); 52 | 53 | string modelFile = this->netname; 54 | modelFile += "-face.onnx"; 55 | this->net = readNet(modelFile); 56 | } 57 | 58 | void YOLO::drawPred(float conf, int left, int top, int right, int bottom, Mat& frame, vector landmark) // Draw the predicted bounding box 59 | { 60 | //Draw a rectangle displaying the bounding box 61 | rectangle(frame, Point(left, top), Point(right, bottom), Scalar(0, 0, 255), 2); 62 | 63 | //Get the label for the class name and its confidence 64 | string label = format("%.2f", conf); 65 | 66 | //Display the label at the top of the bounding box 67 | int baseLine; 68 | Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); 69 | top = max(top, labelSize.height); 70 | //rectangle(frame, Point(left, top - int(1.5 * labelSize.height)), Point(left + int(1.5 * labelSize.width), top + baseLine), Scalar(0, 255, 0), FILLED); 71 | putText(frame, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 0.75, Scalar(0, 255, 0), 1); 72 | for (int i = 0; i < 5; i++) 73 | { 74 | circle(frame, Point(landmark[2 * i], landmark[2 * i + 1]), 1, Scalar(0, 255, 0), -1); 75 | } 76 | } 77 | 78 | void YOLO::sigmoid(Mat* out, int length) 79 | { 80 | float* pdata = (float*)(out->data); 81 | int i = 0; 82 | for (i = 0; i < length; i++) 83 | { 84 | pdata[i] = 1.0 / (1 + expf(-pdata[i])); 85 | } 86 | } 87 | 88 | void YOLO::detect(Mat& frame) 89 | { 90 | Mat blob; 91 | blobFromImage(frame, blob, 1 / 255.0, Size(this->inpWidth, this->inpHeight), Scalar(0, 0, 0), true, false); 92 | this->net.setInput(blob); 93 | vector outs; 94 | this->net.forward(outs, this->net.getUnconnectedOutLayersNames()); 95 | 96 | /////generate proposals 97 | vector confidences; 98 | vector boxes; 99 | vector< vector> landmarks; 100 | float ratioh = (float)frame.rows / this->inpHeight, ratiow = (float)frame.cols / this->inpWidth; 101 | int n = 0, q = 0, i = 0, j = 0, nout = 16, row_ind = 0, k = 0; ///xmin,ymin,xamx,ymax,box_score,x1,y1, ... ,x5,y5,face_score 102 | for (n = 0; n < 3; n++) ///特征图尺度 103 | { 104 | int num_grid_x = (int)(this->inpWidth / this->stride[n]); 105 | int num_grid_y = (int)(this->inpHeight / this->stride[n]); 106 | for (q = 0; q < 3; q++) ///anchor 107 | { 108 | const float anchor_w = this->anchors[n][q * 2]; 109 | const float anchor_h = this->anchors[n][q * 2 + 1]; 110 | for (i = 0; i < num_grid_y; i++) 111 | { 112 | for (j = 0; j < num_grid_x; j++) 113 | { 114 | float* pdata = (float*)outs[0].data + row_ind * nout; 115 | float box_score = sigmoid_x(pdata[4]); 116 | if (box_score > this->objThreshold) 117 | { 118 | float face_score = sigmoid_x(pdata[15]); 119 | //if (face_score > this->confThreshold) 120 | //{ 121 | float cx = (sigmoid_x(pdata[0]) * 2.f - 0.5f + j) * this->stride[n]; ///cx 122 | float cy = (sigmoid_x(pdata[1]) * 2.f - 0.5f + i) * this->stride[n]; ///cy 123 | float w = powf(sigmoid_x(pdata[2]) * 2.f, 2.f) * anchor_w; ///w 124 | float h = powf(sigmoid_x(pdata[3]) * 2.f, 2.f) * anchor_h; ///h 125 | 126 | int left = (cx - 0.5*w)*ratiow; 127 | int top = (cy - 0.5*h)*ratioh; 128 | 129 | confidences.push_back(face_score); 130 | boxes.push_back(Rect(left, top, (int)(w*ratiow), (int)(h*ratioh))); 131 | vector landmark(10); 132 | for (k = 5; k < 15; k+=2) 133 | { 134 | const int ind = k - 5; 135 | landmark[ind] = (int)(pdata[k] * anchor_w + j * this->stride[n])*ratiow; 136 | landmark[ind + 1] = (int)(pdata[k + 1] * anchor_h + i * this->stride[n])*ratioh; 137 | } 138 | landmarks.push_back(landmark); 139 | //} 140 | } 141 | row_ind++; 142 | } 143 | } 144 | } 145 | } 146 | 147 | // Perform non maximum suppression to eliminate redundant overlapping boxes with 148 | // lower confidences 149 | vector indices; 150 | NMSBoxes(boxes, confidences, this->confThreshold, this->nmsThreshold, indices); 151 | for (size_t i = 0; i < indices.size(); ++i) 152 | { 153 | int idx = indices[i]; 154 | Rect box = boxes[idx]; 155 | this->drawPred(confidences[idx], box.x, box.y, 156 | box.x + box.width, box.y + box.height, frame, landmarks[idx]); 157 | } 158 | } 159 | 160 | int main() 161 | { 162 | Net_config yolo_nets = {0.3, 0.5, 0.3, "yolov5s"}; ///choice = [yolov5s, yolov5m, yolov5l] 163 | YOLO yolo_model(yolo_nets); 164 | string imgpath = "selfie.jpg"; 165 | Mat srcimg = imread(imgpath); 166 | yolo_model.detect(srcimg); 167 | 168 | static const string kWinName = "Deep learning object detection in OpenCV"; 169 | namedWindow(kWinName, WINDOW_NORMAL); 170 | imshow(kWinName, srcimg); 171 | waitKey(0); 172 | destroyAllWindows(); 173 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import numpy as np 4 | 5 | class yolov5(): 6 | def __init__(self, yolo_type, confThreshold=0.5, nmsThreshold=0.5, objThreshold=0.5): 7 | anchors = [[4,5, 8,10, 13,16], [23,29, 43,55, 73,105], [146,217, 231,300, 335,433]] 8 | num_classes = 1 9 | self.nl = len(anchors) 10 | self.na = len(anchors[0]) // 2 11 | self.no = num_classes + 5 + 10 12 | self.grid = [np.zeros(1)] * self.nl 13 | self.stride = np.array([8., 16., 32.]) 14 | self.anchor_grid = np.asarray(anchors, dtype=np.float32).reshape(self.nl, -1, 2) 15 | self.inpWidth = 640 16 | self.inpHeight = 640 17 | self.net = cv2.dnn.readNet(yolo_type+'-face.onnx') 18 | self.confThreshold = confThreshold 19 | self.nmsThreshold = nmsThreshold 20 | self.objThreshold = objThreshold 21 | 22 | def _make_grid(self, nx=20, ny=20): 23 | xv, yv = np.meshgrid(np.arange(ny), np.arange(nx)) 24 | return np.stack((xv, yv), 2).reshape((-1, 2)).astype(np.float32) 25 | 26 | def postprocess(self, frame, outs): 27 | frameHeight = frame.shape[0] 28 | frameWidth = frame.shape[1] 29 | ratioh, ratiow = frameHeight / self.inpHeight, frameWidth / self.inpWidth 30 | # Scan through all the bounding boxes output from the network and keep only the 31 | # ones with high confidence scores. Assign the box's class label as the class with the highest score. 32 | 33 | confidences = [] 34 | boxes = [] 35 | landmarks = [] 36 | for detection in outs: 37 | confidence = detection[15] 38 | # if confidence > self.confThreshold and detection[4] > self.objThreshold: 39 | if detection[4] > self.objThreshold: 40 | center_x = int(detection[0] * ratiow) 41 | center_y = int(detection[1] * ratioh) 42 | width = int(detection[2] * ratiow) 43 | height = int(detection[3] * ratioh) 44 | left = int(center_x - width / 2) 45 | top = int(center_y - height / 2) 46 | 47 | confidences.append(float(confidence)) 48 | boxes.append([left, top, width, height]) 49 | landmark = detection[5:15] * np.tile(np.float32([ratiow,ratioh]), 5) 50 | landmarks.append(landmark.astype(np.int32)) 51 | # Perform non maximum suppression to eliminate redundant overlapping boxes with 52 | # lower confidences. 53 | indices = cv2.dnn.NMSBoxes(boxes, confidences, self.confThreshold, self.nmsThreshold) 54 | for i in indices: 55 | i = i[0] 56 | box = boxes[i] 57 | left = box[0] 58 | top = box[1] 59 | width = box[2] 60 | height = box[3] 61 | landmark = landmarks[i] 62 | frame = self.drawPred(frame, confidences[i], left, top, left + width, top + height, landmark) 63 | return frame 64 | def drawPred(self, frame, conf, left, top, right, bottom, landmark): 65 | # Draw a bounding box. 66 | cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255), thickness=2) 67 | # label = '%.2f' % conf 68 | # Display the label at the top of the bounding box 69 | # labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) 70 | # top = max(top, labelSize[1]) 71 | # cv2.putText(frame, label, (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) 72 | for i in range(5): 73 | cv2.circle(frame, (landmark[i*2], landmark[i*2+1]), 1, (0,255,0), thickness=-1) 74 | return frame 75 | def detect(self, srcimg): 76 | blob = cv2.dnn.blobFromImage(srcimg, 1 / 255.0, (self.inpWidth, self.inpHeight), [0, 0, 0], swapRB=True, crop=False) 77 | # Sets the input to the network 78 | self.net.setInput(blob) 79 | 80 | # Runs the forward pass to get output of the output layers 81 | outs = self.net.forward(self.net.getUnconnectedOutLayersNames())[0] 82 | 83 | # inference output 84 | outs[..., [0,1,2,3,4,15]] = 1 / (1 + np.exp(-outs[..., [0,1,2,3,4,15]])) ###sigmoid 85 | row_ind = 0 86 | for i in range(self.nl): 87 | h, w = int(self.inpHeight/self.stride[i]), int(self.inpWidth/self.stride[i]) 88 | length = int(self.na * h * w) 89 | if self.grid[i].shape[2:4] != (h,w): 90 | self.grid[i] = self._make_grid(w, h) 91 | 92 | g_i = np.tile(self.grid[i], (self.na, 1)) 93 | a_g_i = np.repeat(self.anchor_grid[i], h * w, axis=0) 94 | outs[row_ind:row_ind + length, 0:2] = (outs[row_ind:row_ind + length, 0:2] * 2. - 0.5 + g_i) * int(self.stride[i]) 95 | outs[row_ind:row_ind + length, 2:4] = (outs[row_ind:row_ind + length, 2:4] * 2) ** 2 * a_g_i 96 | 97 | outs[row_ind:row_ind + length, 5:7] = outs[row_ind:row_ind + length, 5:7] * a_g_i + g_i * int(self.stride[i]) # landmark x1 y1 98 | outs[row_ind:row_ind + length, 7:9] = outs[row_ind:row_ind + length, 7:9] * a_g_i + g_i * int(self.stride[i]) # landmark x2 y2 99 | outs[row_ind:row_ind + length, 9:11] = outs[row_ind:row_ind + length, 9:11] * a_g_i + g_i * int(self.stride[i]) # landmark x3 y3 100 | outs[row_ind:row_ind + length, 11:13] = outs[row_ind:row_ind + length, 11:13] * a_g_i + g_i * int(self.stride[i]) # landmark x4 y4 101 | outs[row_ind:row_ind + length, 13:15] = outs[row_ind:row_ind + length, 13:15] * a_g_i + g_i * int(self.stride[i]) # landmark x5 y5 102 | row_ind += length 103 | return outs 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('--yolo_type', type=str, default='yolov5m', choices=['yolov5s', 'yolov5m', 'yolov5l'], help="yolo type") 108 | parser.add_argument("--imgpath", type=str, default='selfie.jpg', help="image path") 109 | parser.add_argument('--confThreshold', default=0.3, type=float, help='class confidence') 110 | parser.add_argument('--nmsThreshold', default=0.5, type=float, help='nms iou thresh') 111 | parser.add_argument('--objThreshold', default=0.3, type=float, help='object confidence') 112 | args = parser.parse_args() 113 | 114 | yolonet = yolov5(args.yolo_type, confThreshold=args.confThreshold, nmsThreshold=args.nmsThreshold, objThreshold=args.objThreshold) 115 | srcimg = cv2.imread(args.imgpath) 116 | dets = yolonet.detect(srcimg) 117 | srcimg = yolonet.postprocess(srcimg, dets) 118 | 119 | winName = 'Deep learning object detection in OpenCV' 120 | cv2.namedWindow(winName, 0) 121 | cv2.imshow(winName, srcimg) 122 | cv2.waitKey(0) 123 | cv2.destroyAllWindows() --------------------------------------------------------------------------------