├── .gitignore ├── LICENSE ├── README.md ├── batchsize_clear.py ├── convert_script.txt ├── demo_video.py ├── make_post_process.py ├── make_pre_process.py └── test.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar.gz 2 | __pycache__ 3 | saved_model* 4 | *.pb 5 | *.xml 6 | *.bin 7 | *.onnx 8 | *.tflite 9 | *.trt 10 | *.mlmodel 11 | .vscode/ 12 | *.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020. Huawei Technologies Co., Ltd. 4 | All rights reserved. 5 | Copyright (c) 2021, Katsuya Hyodo 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HeadPoseEstimation-WHENet-yolov4-onnx-openvino 2 | ONNX, OpenVINO, TFLite, TensorRT, EdgeTPU, CoreML, TFJS, YOLOv4/YOLOv4-tiny-3L 3 | 4 | ![ezgif com-gif-maker (3)](https://user-images.githubusercontent.com/33194443/141761520-28038c2a-e89a-4887-a9de-0fdaa972005b.gif) 5 | 6 | - PINTO Special Custom Model 7 | https://github.com/PINTO0309/DMHead 8 | 9 | https://user-images.githubusercontent.com/33194443/216770749-be03a5ee-cdc9-4390-aeea-2aa2908cb6a4.mp4 10 | 11 | 12 | 13 | ## 1. Usage 14 | ```bash 15 | $ git clone https://github.com/PINTO0309/HeadPoseEstimation-WHENet-yolov4-onnx-openvino 16 | $ cd HeadPoseEstimation-WHENet-yolov4-onnx-openvino 17 | $ wget https://github.com/PINTO0309/HeadPoseEstimation-WHENet-yolov4-onnx-openvino/releases/download/v1.0.3/saved_model_224x224.tar.gz 18 | $ tar -zxvf saved_model_224x224.tar.gz && rm saved_model_224x224.tar.gz 19 | $ wget https://github.com/PINTO0309/HeadPoseEstimation-WHENet-yolov4-onnx-openvino/releases/download/v1.0.4/whenet_1x3x224x224_prepost.onnx 20 | $ mv whenet_1x3x224x224_prepost.onnx saved_model_224x224/ 21 | 22 | $ python3 demo_video.py 23 | ``` 24 | ```bash 25 | usage: demo_video.py \ 26 | [-h] \ 27 | [--whenet_mode {onnx,openvino}] \ 28 | [--device DEVICE] \ 29 | [--height_width HEIGHT_WIDTH] 30 | 31 | optional arguments: 32 | -h, --help 33 | show this help message and exit 34 | --whenet_mode {onnx,openvino} 35 | Choose whether to infer WHENet with ONNX or OpenVINO. Default: onnx 36 | --device DEVICE 37 | Path of the mp4 file or device number of the USB camera. Default: 0 38 | --height_width HEIGHT_WIDTH 39 | {H}x{W} Default: 480x640 40 | ``` 41 | 42 | ## 2. Reference 43 | 1. https://github.com/Ascend-Research/HeadPoseEstimation-WHENet 44 | 2. https://github.com/AlexeyAB/darknet 45 | 3. https://github.com/jkjung-avt/yolov4_crowdhuman 46 | 4. https://github.com/linghu8812/tensorrt_inference/tree/master/Yolov4 47 | 5. https://github.com/Tianxiaomo/pytorch-YOLOv4 48 | 6. https://github.com/PINTO0309/PINTO_model_zoo 49 | 7. https://github.com/PINTO0309/openvino2tensorflow 50 | 8. [Exporting WHENet](https://zenn.dev/pinto0309/scraps/1849b6909db13b) 51 | 9. [Darknet to ONNX to OpenVINO to TensorFlow to TFLite and Others](https://zenn.dev/pinto0309/scraps/b33883e3951605) 52 | 10. [Dual model head pose estimation. Fusion of SOTA models. 360° 6D HeadPose detection. All pre-processing and post-processing are fused together, allowing end-to-end processing in a single inference. 6DRepNet+WHENet](https://github.com/PINTO0309/DMHead) 53 | 54 | ## 3. Special Custom Model Structure 55 | ![whenet_1x3x224x224_prepost onnx](https://user-images.githubusercontent.com/33194443/174461110-32171aae-a11d-4329-99c5-3872aba70429.png) 56 | -------------------------------------------------------------------------------- /batchsize_clear.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import os 3 | import struct 4 | 5 | from argparse import ArgumentParser 6 | 7 | def rebatch(infile, outfile, batch_size): 8 | model = onnx.load(infile) 9 | graph = model.graph 10 | 11 | # Change batch size in input, output and value_info 12 | for tensor in list(graph.input) + list(graph.value_info) + list(graph.output): 13 | tensor.type.tensor_type.shape.dim[0].dim_param = batch_size 14 | 15 | # Set dynamic batch size in reshapes (-1) 16 | for node in graph.node: 17 | if node.op_type != 'Reshape': 18 | continue 19 | for init in graph.initializer: 20 | # node.input[1] is expected to be a reshape 21 | if init.name != node.input[1]: 22 | continue 23 | # Shape is stored as a list of ints 24 | if len(init.int64_data) > 0: 25 | # This overwrites bias nodes' reshape shape but should be fine 26 | init.int64_data[0] = -1 27 | # Shape is stored as bytes 28 | elif len(init.raw_data) > 0: 29 | shape = bytearray(init.raw_data) 30 | struct.pack_into('q', shape, 0, -1) 31 | init.raw_data = bytes(shape) 32 | 33 | onnx.save(model, outfile) 34 | 35 | if __name__ == '__main__': 36 | parser = ArgumentParser('Replace batch size with \'N\'') 37 | parser.add_argument('infile') 38 | parser.add_argument('outfile') 39 | args = parser.parse_args() 40 | 41 | rebatch(args.infile, args.outfile, 'N') -------------------------------------------------------------------------------- /convert_script.txt: -------------------------------------------------------------------------------- 1 | xhost +local: && \ 2 | docker run --gpus all -it --rm \ 3 | -v `pwd`:/home/user/workdir \ 4 | -v /tmp/.X11-unix/:/tmp/.X11-unix:rw \ 5 | --device /dev/video0:/dev/video0:mwr \ 6 | --net=host \ 7 | -e XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 8 | -e DISPLAY=$DISPLAY \ 9 | --privileged \ 10 | ghcr.io/pinto0309/openvino2tensorflow:latest 11 | 12 | H=224 13 | W=224 14 | MODEL=whenet 15 | $INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo_tf.py \ 16 | --input_model frozen_model.pb \ 17 | --input_shape [1,224,224,3] \ 18 | --output_dir openvino \ 19 | --data_type FP32 \ 20 | --output_dir openvino/FP32 \ 21 | --model_name ${MODEL}_${H}x${W} 22 | $INTEL_OPENVINO_DIR/deployment_tools/model_optimizer/mo_tf.py \ 23 | --input_model frozen_model.pb \ 24 | --input_shape [1,224,224,3] \ 25 | --output_dir openvino \ 26 | --data_type FP16 \ 27 | --output_dir openvino/FP16 \ 28 | --model_name ${MODEL}_${H}x${W} 29 | mkdir -p openvino/myriad 30 | ${INTEL_OPENVINO_DIR}/deployment_tools/inference_engine/lib/intel64/myriad_compile \ 31 | -m openvino/FP16/${MODEL}_${H}x${W}.xml \ 32 | -ip U8 \ 33 | -VPU_NUMBER_OF_SHAVES 4 \ 34 | -VPU_NUMBER_OF_CMX_SLICES 4 \ 35 | -o openvino/myriad/${MODEL}_${H}x${W}.blob 36 | 37 | openvino2tensorflow \ 38 | --model_path openvino/FP32/${MODEL}_${H}x${W}.xml \ 39 | --output_saved_model \ 40 | --output_pb \ 41 | --output_no_quant_float32_tflite \ 42 | --output_dynamic_range_quant_tflite \ 43 | --output_weight_quant_tflite \ 44 | --output_float16_quant_tflite \ 45 | --output_integer_quant_tflite \ 46 | --output_integer_quant_typ 'uint8' \ 47 | --string_formulas_for_normalization 'data / 255' \ 48 | --output_tfjs \ 49 | --output_coreml 50 | 51 | mv saved_model saved_model_${H}x${W} 52 | 53 | openvino2tensorflow \ 54 | --model_path openvino/FP32/${MODEL}_${H}x${W}.xml \ 55 | --output_saved_model \ 56 | --output_pb \ 57 | --output_edgetpu 58 | 59 | mv saved_model/model_full_integer_quant.tflite saved_model_${H}x${W} 60 | mv saved_model/model_full_integer_quant_edgetpu.tflite saved_model_${H}x${W} 61 | rm -rf saved_model 62 | 63 | openvino2tensorflow \ 64 | --model_path openvino/FP32/${MODEL}_${H}x${W}.xml \ 65 | --output_saved_model \ 66 | --output_pb \ 67 | --output_onnx \ 68 | --onnx_opset 11 \ 69 | --keep_input_tensor_in_nchw 70 | 71 | mv saved_model/model_float32.onnx saved_model_${H}x${W} 72 | 73 | onnx2trt saved_model_${H}x${W}/model_float32.onnx -o saved_model_${H}x${W}/whenet_rtx3070.trt -b 1 -d 16 -v 74 | 75 | 76 | ############################################################### 77 | 78 | python make_pre_process.py 79 | 80 | mv saved_model_preprocess/test.tflite pre_process_whenet.tflite 81 | 82 | python -m tf2onnx.convert \ 83 | --opset 11 \ 84 | --inputs-as-nchw input_1 \ 85 | --tflite pre_process_whenet.tflite \ 86 | --output pre_process_whenet.onnx 87 | 88 | onnxsim pre_process_whenet.onnx pre_process_whenet.onnx 89 | onnxsim pre_process_whenet.onnx pre_process_whenet.onnx 90 | 91 | snd4onnx \ 92 | --remove_node_names Transpose__11 \ 93 | --input_onnx_file_path pre_process_whenet.onnx \ 94 | --output_onnx_file_path pre_process_whenet.onnx 95 | 96 | sor4onnx \ 97 | --input_onnx_file_path pre_process_whenet.onnx \ 98 | --old_new "input_1" "input" \ 99 | --output_onnx_file_path pre_process_whenet.onnx \ 100 | --mode inputs 101 | 102 | sor4onnx \ 103 | --input_onnx_file_path pre_process_whenet.onnx \ 104 | --old_new "Identity_raw_output___3:0" "pre_output" \ 105 | --output_onnx_file_path pre_process_whenet.onnx \ 106 | --mode outputs 107 | 108 | 109 | 110 | python make_post_process.py 111 | 112 | mv saved_model_postprocess/test.tflite post_process_whenet.tflite 113 | 114 | docker run --gpus all -it --rm \ 115 | -v `pwd`:/home/user/workdir \ 116 | ghcr.io/pinto0309/tflite2tensorflow:latest 117 | 118 | 119 | tflite2tensorflow \ 120 | --model_path post_process_whenet.tflite \ 121 | --flatc_path ../flatc \ 122 | --schema_path ../schema.fbs \ 123 | --output_pb \ 124 | --optimizing_for_openvino_and_myriad 125 | 126 | tflite2tensorflow \ 127 | --model_path post_process_whenet.tflite \ 128 | --flatc_path ../flatc \ 129 | --schema_path ../schema.fbs \ 130 | --output_onnx \ 131 | --onnx_opset 11 132 | 133 | mv saved_model/model_float32.onnx post_process_whenet.onnx 134 | 135 | onnxsim post_process_whenet.onnx post_process_whenet.onnx 136 | onnxsim post_process_whenet.onnx post_process_whenet.onnx 137 | 138 | sor4onnx \ 139 | --input_onnx_file_path post_process_whenet.onnx \ 140 | --old_new "input_1" "post_yaw" \ 141 | --output_onnx_file_path post_process_whenet.onnx \ 142 | --mode inputs 143 | 144 | sor4onnx \ 145 | --input_onnx_file_path post_process_whenet.onnx \ 146 | --old_new "input_2" "post_pitch" \ 147 | --output_onnx_file_path post_process_whenet.onnx \ 148 | --mode inputs 149 | 150 | sor4onnx \ 151 | --input_onnx_file_path post_process_whenet.onnx \ 152 | --old_new "input_3" "post_roll" \ 153 | --output_onnx_file_path post_process_whenet.onnx \ 154 | --mode inputs 155 | 156 | sor4onnx \ 157 | --input_onnx_file_path post_process_whenet.onnx \ 158 | --old_new "Identity" "yaw_roll_pitch" \ 159 | --output_onnx_file_path post_process_whenet.onnx \ 160 | --mode outputs 161 | 162 | sor4onnx \ 163 | --input_onnx_file_path post_process_whenet.onnx \ 164 | --old_new "Identity" "post_sub" \ 165 | --output_onnx_file_path post_process_whenet.onnx 166 | 167 | exit 168 | 169 | 170 | snc4onnx \ 171 | --input_onnx_file_paths pre_process_whenet.onnx model_float32.onnx \ 172 | --srcop_destop pre_output input_1 \ 173 | --output_onnx_file_path whenet_1x3x224x224_prepost.onnx 174 | 175 | snc4onnx \ 176 | --input_onnx_file_paths whenet_1x3x224x224_prepost.onnx post_process_whenet.onnx \ 177 | --srcop_destop tf.identity post_yaw tf.identity_1 post_pitch tf.identity_2 post_roll \ 178 | --output_onnx_file_path whenet_1x3x224x224_prepost.onnx 179 | -------------------------------------------------------------------------------- /demo_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | xhost +local: && \ 3 | docker run --gpus all -it --rm \ 4 | -v `pwd`:/home/user/workdir \ 5 | -v /tmp/.X11-unix/:/tmp/.X11-unix:rw \ 6 | --device /dev/video0:/dev/video0:mwr \ 7 | --device /dev/video1:/dev/video1:mwr \ 8 | --device /dev/video2:/dev/video2:mwr \ 9 | --device /dev/video3:/dev/video3:mwr \ 10 | --device /dev/video4:/dev/video4:mwr \ 11 | --device /dev/video5:/dev/video5:mwr \ 12 | --net=host \ 13 | -e XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 14 | -e DISPLAY=$DISPLAY \ 15 | --privileged \ 16 | ghcr.io/pinto0309/openvino2tensorflow:latest 17 | 18 | sudo chmod 777 /dev/video4 && python3 demo_video.py 19 | """ 20 | 21 | import numpy as np 22 | import cv2 23 | import os 24 | import argparse 25 | from math import cos, sin 26 | import onnxruntime 27 | import numba as nb 28 | 29 | idx_tensor_yaw = [np.array(idx, dtype=np.float32) for idx in range(120)] 30 | idx_tensor = [np.array(idx, dtype=np.float32) for idx in range(66)] 31 | 32 | 33 | def softmax(x): 34 | x -= np.max(x,axis=1, keepdims=True) 35 | a = np.exp(x) 36 | b = np.sum(np.exp(x), axis=1, keepdims=True) 37 | return a/b 38 | 39 | 40 | def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=100): 41 | # Referenced from HopeNet https://github.com/natanielruiz/deep-head-pose 42 | pitch = pitch * np.pi / 180 43 | yaw = -(yaw * np.pi / 180) 44 | roll = roll * np.pi / 180 45 | if tdx != None and tdy != None: 46 | tdx = tdx 47 | tdy = tdy 48 | else: 49 | height, width = img.shape[:2] 50 | tdx = width / 2 51 | tdy = height / 2 52 | # X-Axis pointing to right. drawn in red 53 | x1 = size * (cos(yaw) * cos(roll)) + tdx 54 | y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy 55 | # Y-Axis | drawn in green 56 | # v 57 | x2 = size * (-cos(yaw) * sin(roll)) + tdx 58 | y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy 59 | # Z-Axis (out of the screen) drawn in blue 60 | x3 = size * (sin(yaw)) + tdx 61 | y3 = size * (-cos(yaw) * sin(pitch)) + tdy 62 | cv2.line(img, (int(tdx), int(tdy)), (int(x1),int(y1)),(0,0,255),2) 63 | cv2.line(img, (int(tdx), int(tdy)), (int(x2),int(y2)),(0,255,0),2) 64 | cv2.line(img, (int(tdx), int(tdy)), (int(x3),int(y3)),(255,0,0),2) 65 | return img 66 | 67 | 68 | def resize_and_pad(src, size, pad_color=0): 69 | img = src.copy() 70 | h, w = img.shape[:2] 71 | sh, sw = size 72 | if h > sh or w > sw: 73 | interp = cv2.INTER_AREA 74 | else: 75 | interp = cv2.INTER_CUBIC 76 | aspect = w/h 77 | if aspect > 1: 78 | new_w = sw 79 | new_h = np.round(new_w/aspect).astype(int) 80 | pad_vert = (sh-new_h)/2 81 | pad_top, pad_bot = \ 82 | np.floor(pad_vert).astype(int), np.ceil(pad_vert).astype(int) 83 | pad_left, pad_right = 0, 0 84 | elif aspect < 1: 85 | new_h = sh 86 | new_w = np.round(new_h*aspect).astype(int) 87 | pad_horz = (sw-new_w)/2 88 | pad_left, pad_right = \ 89 | np.floor(pad_horz).astype(int), np.ceil(pad_horz).astype(int) 90 | pad_top, pad_bot = 0, 0 91 | else: 92 | new_h, new_w = sh, sw 93 | pad_left, pad_right, pad_top, pad_bot = 0, 0, 0, 0 94 | if len(img.shape) == 3 and not isinstance(pad_color, (list, tuple, np.ndarray)): 95 | pad_color = [pad_color]*3 96 | scaled_img = cv2.resize( 97 | img, 98 | (new_w, new_h), 99 | interpolation=interp 100 | ) 101 | scaled_img = cv2.copyMakeBorder( 102 | scaled_img, 103 | pad_top, 104 | pad_bot, 105 | pad_left, 106 | pad_right, 107 | borderType=cv2.BORDER_CONSTANT, 108 | value=pad_color 109 | ) 110 | return scaled_img 111 | 112 | 113 | @nb.njit('i8[:](f4[:,:],f4[:], f4, b1)', fastmath=True, cache=True) 114 | def nms_cpu(boxes, confs, nms_thresh, min_mode): 115 | x1 = boxes[:, 0] 116 | y1 = boxes[:, 1] 117 | x2 = boxes[:, 2] 118 | y2 = boxes[:, 3] 119 | areas = (x2 - x1) * (y2 - y1) 120 | order = confs.argsort()[::-1] 121 | keep = [] 122 | while order.size > 0: 123 | idx_self = order[0] 124 | idx_other = order[1:] 125 | keep.append(idx_self) 126 | xx1 = np.maximum(x1[idx_self], x1[idx_other]) 127 | yy1 = np.maximum(y1[idx_self], y1[idx_other]) 128 | xx2 = np.minimum(x2[idx_self], x2[idx_other]) 129 | yy2 = np.minimum(y2[idx_self], y2[idx_other]) 130 | w = np.maximum(0.0, xx2 - xx1) 131 | h = np.maximum(0.0, yy2 - yy1) 132 | inter = w * h 133 | if min_mode: 134 | over = inter / np.minimum(areas[order[0]], areas[order[1:]]) 135 | else: 136 | over = inter / (areas[order[0]] + areas[order[1:]] - inter) 137 | inds = np.where(over <= nms_thresh)[0] 138 | order = order[inds + 1] 139 | return np.array(keep) 140 | 141 | 142 | def main(args): 143 | yolov4_head_H = 480 144 | yolov4_head_W = 640 145 | whenet_H = 224 146 | whenet_W = 224 147 | 148 | # YOLOv4-Head 149 | yolov4_model_name = 'yolov4_headdetection' 150 | yolov4_head = onnxruntime.InferenceSession( 151 | f'saved_model_{whenet_H}x{whenet_W}/{yolov4_model_name}_{yolov4_head_H}x{yolov4_head_W}.onnx', 152 | providers=[ 153 | 'CUDAExecutionProvider', 154 | 'CPUExecutionProvider', 155 | ] 156 | ) 157 | yolov4_head_input_name = yolov4_head.get_inputs()[0].name 158 | yolov4_head_output_names = [output.name for output in yolov4_head.get_outputs()] 159 | yolov4_head_output_shapes = [output.shape for output in yolov4_head.get_outputs()] 160 | assert yolov4_head_output_shapes[0] == [1, 18900, 1, 4] # boxes[N, num, classes, boxes] 161 | assert yolov4_head_output_shapes[1] == [1, 18900, 1] # confs[N, num, classes] 162 | 163 | # WHENet 164 | whenet_input_name = None 165 | whenet_output_names = None 166 | whenet_output_shapes = None 167 | mean = [0.485, 0.456, 0.406] 168 | std = [0.229, 0.224, 0.225] 169 | if args.whenet_mode == 'onnx': 170 | whenet = onnxruntime.InferenceSession( 171 | f'saved_model_{whenet_H}x{whenet_W}/whenet_1x3x224x224_prepost.onnx', 172 | providers=[ 173 | 'CUDAExecutionProvider', 174 | 'CPUExecutionProvider', 175 | ] 176 | ) 177 | whenet_input_name = whenet.get_inputs()[0].name 178 | whenet_output_names = [output.name for output in whenet.get_outputs()] 179 | 180 | exec_net = None 181 | input_name = None 182 | if args.whenet_mode == 'openvino': 183 | from openvino.inference_engine import IECore 184 | model_path = f'saved_model_{whenet_H}x{whenet_W}/openvino/FP16/whenet_{whenet_H}x{whenet_W}.xml' 185 | ie = IECore() 186 | net = ie.read_network(model_path, os.path.splitext(model_path)[0] + ".bin") 187 | exec_net = ie.load_network(network=net, device_name='CPU', num_requests=2) 188 | input_name = next(iter(net.input_info)) 189 | 190 | cap_width = int(args.height_width.split('x')[1]) 191 | cap_height = int(args.height_width.split('x')[0]) 192 | if args.device.isdecimal(): 193 | cap = cv2.VideoCapture(int(args.device)) 194 | else: 195 | cap = cv2.VideoCapture(args.device) 196 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, cap_width) 197 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, cap_height) 198 | WINDOWS_NAME = 'Demo' 199 | cv2.namedWindow(WINDOWS_NAME, cv2.WINDOW_NORMAL) 200 | cv2.resizeWindow(WINDOWS_NAME, cap_width, cap_height) 201 | 202 | while True: 203 | ret, frame = cap.read() 204 | if not ret: 205 | continue 206 | 207 | # ============================================================= YOLOv4 208 | conf_thresh = 0.60 209 | nms_thresh = 0.50 210 | 211 | # Resize 212 | resized_frame = resize_and_pad( 213 | frame, 214 | (yolov4_head_H, yolov4_head_W) 215 | ) 216 | width = resized_frame.shape[1] 217 | height = resized_frame.shape[0] 218 | # BGR to RGB 219 | rgb = resized_frame[..., ::-1] 220 | # HWC -> CHW 221 | chw = rgb.transpose(2, 0, 1) 222 | # normalize to [0, 1] interval 223 | chw = np.asarray(chw / 255., dtype=np.float32) 224 | # hwc --> nhwc 225 | nchw = chw[np.newaxis, ...] 226 | 227 | boxes, confs = yolov4_head.run( 228 | output_names = yolov4_head_output_names, 229 | input_feed = {yolov4_head_input_name: nchw} 230 | ) 231 | # [1, boxcount, 1, 4] --> [boxcount, 4] 232 | boxes = boxes[0][:, 0, :] 233 | # [1, boxcount, 1] --> [boxcount] 234 | confs = confs[0][:, 0] 235 | 236 | argwhere = confs > conf_thresh 237 | boxes = boxes[argwhere, :] 238 | confs = confs[argwhere] 239 | # nms 240 | heads = [] 241 | keep = nms_cpu( 242 | boxes=boxes, 243 | confs=confs, 244 | nms_thresh=nms_thresh, 245 | min_mode=False 246 | ) 247 | if (keep.size > 0): 248 | boxes = boxes[keep, :] 249 | confs = confs[keep] 250 | for k in range(boxes.shape[0]): 251 | heads.append( 252 | [ 253 | int(boxes[k, 0] * width), 254 | int(boxes[k, 1] * height), 255 | int(boxes[k, 2] * width), 256 | int(boxes[k, 3] * height), 257 | confs[k], 258 | ] 259 | ) 260 | 261 | canvas = resized_frame.copy() 262 | # ============================================================= WHENet 263 | croped_resized_frame = None 264 | if len(heads) > 0: 265 | for head in heads: 266 | x_min = head[0] 267 | y_min = head[1] 268 | x_max = head[2] 269 | y_max = head[3] 270 | 271 | # enlarge the bbox to include more background margin 272 | y_min = max(0, y_min - abs(y_min - y_max) / 10) 273 | y_max = min(resized_frame.shape[0], y_max + abs(y_min - y_max) / 10) 274 | x_min = max(0, x_min - abs(x_min - x_max) / 5) 275 | x_max = min(resized_frame.shape[1], x_max + abs(x_min - x_max) / 5) 276 | x_max = min(x_max, resized_frame.shape[1]) 277 | croped_frame = resized_frame[int(y_min):int(y_max), int(x_min):int(x_max)] 278 | # h,w -> 224,224 279 | croped_resized_frame = cv2.resize(croped_frame, (whenet_W, whenet_H)) 280 | # bgr --> rgb 281 | rgb = croped_resized_frame[..., ::-1] 282 | # hwc --> chw 283 | chw = rgb.transpose(2, 0, 1) 284 | # chw --> nchw 285 | nchw = np.asarray(chw[np.newaxis, :, :, :], dtype=np.float32) 286 | 287 | yaw = 0.0 288 | pitch = 0.0 289 | roll = 0.0 290 | if args.whenet_mode == 'onnx': 291 | outputs = whenet.run( 292 | output_names = whenet_output_names, 293 | input_feed = {whenet_input_name: nchw} 294 | ) 295 | yaw = outputs[0][0][0] 296 | roll = outputs[0][0][1] 297 | pitch = outputs[0][0][2] 298 | elif args.whenet_mode == 'openvino': 299 | # Normalization 300 | rgb = ((rgb / 255.0) - mean) / std 301 | output = exec_net.infer(inputs={input_name: nchw}) 302 | yaw = output['yaw_new/BiasAdd/Add'] 303 | roll = output['roll_new/BiasAdd/Add'] 304 | pitch = output['pitch_new/BiasAdd/Add'] 305 | 306 | yaw, pitch, roll = np.squeeze([yaw, pitch, roll]) 307 | 308 | print(f'yaw: {yaw}, pitch: {pitch}, roll: {roll}') 309 | 310 | # BBox draw 311 | deg_norm = 1.0 - abs(yaw / 180) 312 | blue = int(255 * deg_norm) 313 | cv2.rectangle( 314 | canvas, 315 | (int(x_min), int(y_min)), 316 | (int(x_max), int(y_max)), 317 | color=(blue, 0, 255-blue), 318 | thickness=2 319 | ) 320 | 321 | # Draw 322 | draw_axis( 323 | canvas, 324 | yaw, 325 | pitch, 326 | roll, 327 | tdx=(x_min+x_max)/2, 328 | tdy=(y_min+y_max)/2, 329 | size=abs(x_max-x_min)//2 330 | ) 331 | cv2.putText( 332 | canvas, 333 | f'yaw: {np.round(yaw)}', 334 | (int(x_min), int(y_min)), 335 | cv2.FONT_HERSHEY_SIMPLEX, 336 | 0.4, 337 | (100, 255, 0), 338 | 1 339 | ) 340 | cv2.putText( 341 | canvas, 342 | f'pitch: {np.round(pitch)}', 343 | (int(x_min), int(y_min) - 15), 344 | cv2.FONT_HERSHEY_SIMPLEX, 345 | 0.4, 346 | (100, 255, 0), 347 | 1 348 | ) 349 | cv2.putText( 350 | canvas, 351 | f'roll: {np.round(roll)}', 352 | (int(x_min), int(y_min)-30), 353 | cv2.FONT_HERSHEY_SIMPLEX, 354 | 0.4, 355 | (100, 255, 0), 356 | 1 357 | ) 358 | 359 | # cv2.imshow('Face', croped_resized_frame) 360 | 361 | key = cv2.waitKey(1) 362 | if key == 27: # ESC 363 | break 364 | 365 | cv2.imshow(WINDOWS_NAME, canvas) 366 | cv2.destroyAllWindows() 367 | 368 | if __name__ == "__main__": 369 | parser = argparse.ArgumentParser() 370 | parser.add_argument( 371 | "--whenet_mode", 372 | type=str, 373 | default='onnx', 374 | choices=['onnx', 'openvino'], 375 | help='Choose whether to infer WHENet with ONNX or OpenVINO. Default: onnx', 376 | ) 377 | parser.add_argument( 378 | "--device", 379 | type=str, 380 | default='0', 381 | help='Path of the mp4 file or device number of the USB camera. Default: 0', 382 | ) 383 | parser.add_argument( 384 | "--height_width", 385 | type=str, 386 | default='480x640', 387 | help='{H}x{W}. Default: 480x640', 388 | ) 389 | args = parser.parse_args() 390 | main(args) 391 | -------------------------------------------------------------------------------- /make_post_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 3 | import tensorflow as tf 4 | import numpy as np 5 | np.random.seed(0) 6 | 7 | # Create a model 8 | yaw = tf.keras.layers.Input( 9 | shape=[ 10 | 120, 11 | ], 12 | batch_size=1, 13 | dtype=tf.float32, 14 | ) 15 | 16 | pitch = tf.keras.layers.Input( 17 | shape=[ 18 | 66, 19 | ], 20 | batch_size=1, 21 | dtype=tf.float32, 22 | ) 23 | 24 | roll = tf.keras.layers.Input( 25 | shape=[ 26 | 66, 27 | ], 28 | batch_size=1, 29 | dtype=tf.float32, 30 | ) 31 | 32 | idx_tensor_yaw = [np.array(idx, dtype=np.float32) for idx in range(120)] 33 | idx_tensor = [np.array(idx, dtype=np.float32) for idx in range(66)] 34 | output_yaw = tf.math.reduce_sum(tf.nn.softmax(yaw, axis=1) * idx_tensor_yaw, axis=1, keepdims=True) * 3 - 180 35 | output_pitch = tf.math.reduce_sum(tf.nn.softmax(pitch, axis=1) * idx_tensor, axis=1, keepdims=True) * 3 - 99 36 | output_roll = tf.math.reduce_sum(tf.nn.softmax(roll, axis=1) * idx_tensor, axis=1, keepdims=True) * 3 - 99 37 | outputs = tf.concat([output_yaw,output_pitch,output_roll], axis=1) 38 | 39 | model = tf.keras.models.Model(inputs=[yaw,pitch,roll], outputs=[outputs]) 40 | model.summary() 41 | output_path = 'saved_model_postprocess' 42 | tf.saved_model.save(model, output_path) 43 | converter = tf.lite.TFLiteConverter.from_keras_model(model) 44 | converter.target_spec.supported_ops = [ 45 | tf.lite.OpsSet.TFLITE_BUILTINS, 46 | tf.lite.OpsSet.SELECT_TF_OPS 47 | ] 48 | tflite_model = converter.convert() 49 | open(f"{output_path}/test.tflite", "wb").write(tflite_model) -------------------------------------------------------------------------------- /make_pre_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 3 | import tensorflow as tf 4 | import numpy as np 5 | np.random.seed(0) 6 | 7 | # Create a model 8 | input = tf.keras.layers.Input( 9 | shape=[ 10 | 224, 11 | 224, 12 | 3, 13 | ], 14 | batch_size=1, 15 | dtype=tf.float32, 16 | ) 17 | 18 | mean = [0.485, 0.456, 0.406] 19 | std = [0.229, 0.224, 0.225] 20 | 21 | outputs = (input / 255.0 - mean) / std 22 | 23 | model = tf.keras.models.Model(inputs=[input], outputs=[outputs]) 24 | model.summary() 25 | output_path = 'saved_model_preprocess' 26 | tf.saved_model.save(model, output_path) 27 | converter = tf.lite.TFLiteConverter.from_keras_model(model) 28 | converter.target_spec.supported_ops = [ 29 | tf.lite.OpsSet.TFLITE_BUILTINS, 30 | tf.lite.OpsSet.SELECT_TF_OPS 31 | ] 32 | tflite_model = converter.convert() 33 | open(f"{output_path}/test.tflite", "wb").write(tflite_model) -------------------------------------------------------------------------------- /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PINTO0309/HeadPoseEstimation-WHENet-yolov4-onnx-openvino/e11fc99488cbcaaa06b5cf6cc730e058f03ab279/test.jpg --------------------------------------------------------------------------------