├── posenet ├── __init__.py ├── converter │ ├── config.py │ ├── wget.py │ ├── config.yaml │ └── tfjs2python.py ├── model.py ├── constants.py ├── decode.py ├── utils.py └── decode_multi.py ├── NOTICE.txt ├── get_test_images.py ├── benchmark.py ├── image_demo.py ├── webcam_demo.py ├── .gitignore ├── README.md └── LICENSE.txt /posenet/__init__.py: -------------------------------------------------------------------------------- 1 | from posenet.constants import * 2 | from posenet.decode_multi import decode_multiple_poses 3 | from posenet.model import load_model 4 | from posenet.utils import * 5 | -------------------------------------------------------------------------------- /posenet/converter/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | BASE_DIR = os.path.dirname(__file__) 5 | 6 | 7 | def load_config(config_name='config.yaml'): 8 | cfg_f = open(os.path.join(BASE_DIR, config_name), "r+") 9 | cfg = yaml.load(cfg_f) 10 | return cfg 11 | -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | PoseNet Python 2 | Copyright 2018 Ross Wightman 3 | 4 | Posenet tfjs converter (code in posenet/converter) 5 | Copyright (c) 2017 Infocom TPO (https://lab.infocom.co.jp/) 6 | Modified (c) 2018 Ross Wightman 7 | 8 | tfjs PoseNet weights and original JS code 9 | Copyright 2018 Google LLC. All Rights Reserved. 10 | 11 | 12 | -------------------------------------------------------------------------------- /posenet/converter/wget.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | import posixpath 3 | import json 4 | import os 5 | 6 | from posenet.converter.config import load_config 7 | 8 | CFG = load_config() 9 | GOOGLE_CLOUD_STORAGE_DIR = CFG['GOOGLE_CLOUD_STORAGE_DIR'] 10 | CHECKPOINTS = CFG['checkpoints'] 11 | CHK = CFG['chk'] 12 | 13 | 14 | def download_file(checkpoint, filename, base_dir): 15 | url = posixpath.join(GOOGLE_CLOUD_STORAGE_DIR, checkpoint, filename) 16 | urllib.request.urlretrieve(url, os.path.join(base_dir, checkpoint, filename)) 17 | 18 | 19 | def download(checkpoint, base_dir='./weights/'): 20 | save_dir = os.path.join(base_dir, checkpoint) 21 | if not os.path.exists(save_dir): 22 | os.makedirs(save_dir) 23 | 24 | download_file(checkpoint, 'manifest.json', base_dir) 25 | 26 | f = open(os.path.join(save_dir, 'manifest.json'), 'r') 27 | json_dict = json.load(f) 28 | 29 | for x in json_dict: 30 | filename = json_dict[x]['filename'] 31 | print('Downloading', filename) 32 | download_file(checkpoint, filename, base_dir) 33 | 34 | 35 | def main(): 36 | checkpoint = CHECKPOINTS[CHK] 37 | download(checkpoint) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /get_test_images.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | import os 3 | import argparse 4 | 5 | GOOGLE_CLOUD_IMAGE_BUCKET = 'https://storage.googleapis.com/tfjs-models/assets/posenet/' 6 | 7 | TEST_IMAGES = [ 8 | 'frisbee.jpg', 9 | 'frisbee_2.jpg', 10 | 'backpackman.jpg', 11 | 'boy_doughnut.jpg', 12 | 'soccer.png', 13 | 'with_computer.jpg', 14 | 'snowboard.jpg', 15 | 'person_bench.jpg', 16 | 'skiing.jpg', 17 | 'fire_hydrant.jpg', 18 | 'kyte.jpg', 19 | 'looking_at_computer.jpg', 20 | 'tennis.jpg', 21 | 'tennis_standing.jpg', 22 | 'truck.jpg', 23 | 'on_bus.jpg', 24 | 'tie_with_beer.jpg', 25 | 'baseball.jpg', 26 | 'multi_skiing.jpg', 27 | 'riding_elephant.jpg', 28 | 'skate_park_venice.jpg', 29 | 'skate_park.jpg', 30 | 'tennis_in_crowd.jpg', 31 | 'two_on_bench.jpg', 32 | ] 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--image_dir', type=str, default='./images') 36 | args = parser.parse_args() 37 | 38 | 39 | def main(): 40 | if not os.path.exists(args.image_dir): 41 | os.makedirs(args.image_dir) 42 | 43 | for f in TEST_IMAGES: 44 | url = os.path.join(GOOGLE_CLOUD_IMAGE_BUCKET, f) 45 | print('Downloading %s' % f) 46 | urllib.request.urlretrieve(url, os.path.join(args.image_dir, f)) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /posenet/converter/config.yaml: -------------------------------------------------------------------------------- 1 | chk: 3 # 3=mobilenet_v1_101 2 | imageSize: 513 3 | GOOGLE_CLOUD_STORAGE_DIR: 'https://storage.googleapis.com/tfjs-models/weights/posenet/' 4 | checkpoints: [ 'mobilenet_v1_050', 'mobilenet_v1_075', 'mobilenet_v1_100', 'mobilenet_v1_101'] 5 | outputStride: 16 6 | mobileNet100Architecture: [ 7 | ['conv2d', 2], 8 | ['separableConv', 1], 9 | ['separableConv', 2], 10 | ['separableConv', 1], 11 | ['separableConv', 2], 12 | ['separableConv', 1], 13 | ['separableConv', 2], 14 | ['separableConv', 1], 15 | ['separableConv', 1], 16 | ['separableConv', 1], 17 | ['separableConv', 1], 18 | ['separableConv', 1], 19 | ['separableConv', 2], 20 | ['separableConv', 1] 21 | ] 22 | mobileNet75Architecture: [ 23 | ['conv2d', 2], 24 | ['separableConv', 1], 25 | ['separableConv', 2], 26 | ['separableConv', 1], 27 | ['separableConv', 2], 28 | ['separableConv', 1], 29 | ['separableConv', 2], 30 | ['separableConv', 1], 31 | ['separableConv', 1], 32 | ['separableConv', 1], 33 | ['separableConv', 1], 34 | ['separableConv', 1], 35 | ['separableConv', 1], 36 | ['separableConv', 1] 37 | ] 38 | mobileNet50Architecture: [ 39 | ['conv2d', 2], 40 | ['separableConv', 1], 41 | ['separableConv', 2], 42 | ['separableConv', 1], 43 | ['separableConv', 2], 44 | ['separableConv', 1], 45 | ['separableConv', 2], 46 | ['separableConv', 1], 47 | ['separableConv', 1], 48 | ['separableConv', 1], 49 | ['separableConv', 1], 50 | ['separableConv', 1], 51 | ['separableConv', 1], 52 | ['separableConv', 1] 53 | ] 54 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import time 3 | import argparse 4 | import os 5 | 6 | import posenet 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--model', type=int, default=101) 11 | parser.add_argument('--image_dir', type=str, default='./images') 12 | parser.add_argument('--num_images', type=int, default=1000) 13 | args = parser.parse_args() 14 | 15 | 16 | def main(): 17 | 18 | with tf.Session() as sess: 19 | model_cfg, model_outputs = posenet.load_model(args.model, sess) 20 | output_stride = model_cfg['output_stride'] 21 | num_images = args.num_images 22 | 23 | filenames = [ 24 | f.path for f in os.scandir(args.image_dir) if f.is_file() and f.path.endswith(('.png', '.jpg'))] 25 | if len(filenames) > num_images: 26 | filenames = filenames[:num_images] 27 | 28 | images = {f: posenet.read_imgfile(f, 1.0, output_stride)[0] for f in filenames} 29 | 30 | start = time.time() 31 | for i in range(num_images): 32 | heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = sess.run( 33 | model_outputs, 34 | feed_dict={'image:0': images[filenames[i % len(filenames)]]} 35 | ) 36 | 37 | output = posenet.decode_multiple_poses( 38 | heatmaps_result.squeeze(axis=0), 39 | offsets_result.squeeze(axis=0), 40 | displacement_fwd_result.squeeze(axis=0), 41 | displacement_bwd_result.squeeze(axis=0), 42 | output_stride=output_stride, 43 | max_pose_detections=10, 44 | min_pose_score=0.25) 45 | 46 | print('Average FPS:', num_images / (time.time() - start)) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /posenet/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import posenet.converter.config 4 | 5 | MODEL_DIR = './_models' 6 | DEBUG_OUTPUT = False 7 | 8 | 9 | def model_id_to_ord(model_id): 10 | if 0 <= model_id < 4: 11 | return model_id # id is already ordinal 12 | elif model_id == 50: 13 | return 0 14 | elif model_id == 75: 15 | return 1 16 | elif model_id == 100: 17 | return 2 18 | else: # 101 19 | return 3 20 | 21 | 22 | def load_config(model_ord): 23 | converter_cfg = posenet.converter.config.load_config() 24 | checkpoints = converter_cfg['checkpoints'] 25 | output_stride = converter_cfg['outputStride'] 26 | checkpoint_name = checkpoints[model_ord] 27 | 28 | model_cfg = { 29 | 'output_stride': output_stride, 30 | 'checkpoint_name': checkpoint_name, 31 | } 32 | return model_cfg 33 | 34 | 35 | def load_model(model_id, sess, model_dir=MODEL_DIR): 36 | model_ord = model_id_to_ord(model_id) 37 | model_cfg = load_config(model_ord) 38 | model_path = os.path.join(model_dir, 'model-%s.pb' % model_cfg['checkpoint_name']) 39 | if not os.path.exists(model_path): 40 | print('Cannot find model file %s, converting from tfjs...' % model_path) 41 | from posenet.converter.tfjs2python import convert 42 | convert(model_ord, model_dir, check=False) 43 | assert os.path.exists(model_path) 44 | 45 | with tf.gfile.GFile(model_path, 'rb') as f: 46 | graph_def = tf.GraphDef() 47 | graph_def.ParseFromString(f.read()) 48 | sess.graph.as_default() 49 | tf.import_graph_def(graph_def, name='') 50 | 51 | if DEBUG_OUTPUT: 52 | graph_nodes = [n for n in graph_def.node] 53 | names = [] 54 | for t in graph_nodes: 55 | names.append(t.name) 56 | print('Loaded graph node:', t.name) 57 | 58 | offsets = sess.graph.get_tensor_by_name('offset_2:0') 59 | displacement_fwd = sess.graph.get_tensor_by_name('displacement_fwd_2:0') 60 | displacement_bwd = sess.graph.get_tensor_by_name('displacement_bwd_2:0') 61 | heatmaps = sess.graph.get_tensor_by_name('heatmap:0') 62 | 63 | return model_cfg, [heatmaps, offsets, displacement_fwd, displacement_bwd] 64 | -------------------------------------------------------------------------------- /posenet/constants.py: -------------------------------------------------------------------------------- 1 | 2 | PART_NAMES = [ 3 | "nose", "leftEye", "rightEye", "leftEar", "rightEar", "leftShoulder", 4 | "rightShoulder", "leftElbow", "rightElbow", "leftWrist", "rightWrist", 5 | "leftHip", "rightHip", "leftKnee", "rightKnee", "leftAnkle", "rightAnkle" 6 | ] 7 | 8 | NUM_KEYPOINTS = len(PART_NAMES) 9 | 10 | PART_IDS = {pn: pid for pid, pn in enumerate(PART_NAMES)} 11 | 12 | CONNECTED_PART_NAMES = [ 13 | ("leftHip", "leftShoulder"), ("leftElbow", "leftShoulder"), 14 | ("leftElbow", "leftWrist"), ("leftHip", "leftKnee"), 15 | ("leftKnee", "leftAnkle"), ("rightHip", "rightShoulder"), 16 | ("rightElbow", "rightShoulder"), ("rightElbow", "rightWrist"), 17 | ("rightHip", "rightKnee"), ("rightKnee", "rightAnkle"), 18 | ("leftShoulder", "rightShoulder"), ("leftHip", "rightHip") 19 | ] 20 | 21 | CONNECTED_PART_INDICES = [(PART_IDS[a], PART_IDS[b]) for a, b in CONNECTED_PART_NAMES] 22 | 23 | LOCAL_MAXIMUM_RADIUS = 1 24 | 25 | POSE_CHAIN = [ 26 | ("nose", "leftEye"), ("leftEye", "leftEar"), ("nose", "rightEye"), 27 | ("rightEye", "rightEar"), ("nose", "leftShoulder"), 28 | ("leftShoulder", "leftElbow"), ("leftElbow", "leftWrist"), 29 | ("leftShoulder", "leftHip"), ("leftHip", "leftKnee"), 30 | ("leftKnee", "leftAnkle"), ("nose", "rightShoulder"), 31 | ("rightShoulder", "rightElbow"), ("rightElbow", "rightWrist"), 32 | ("rightShoulder", "rightHip"), ("rightHip", "rightKnee"), 33 | ("rightKnee", "rightAnkle") 34 | ] 35 | 36 | PARENT_CHILD_TUPLES = [(PART_IDS[parent], PART_IDS[child]) for parent, child in POSE_CHAIN] 37 | 38 | PART_CHANNELS = [ 39 | 'left_face', 40 | 'right_face', 41 | 'right_upper_leg_front', 42 | 'right_lower_leg_back', 43 | 'right_upper_leg_back', 44 | 'left_lower_leg_front', 45 | 'left_upper_leg_front', 46 | 'left_upper_leg_back', 47 | 'left_lower_leg_back', 48 | 'right_feet', 49 | 'right_lower_leg_front', 50 | 'left_feet', 51 | 'torso_front', 52 | 'torso_back', 53 | 'right_upper_arm_front', 54 | 'right_upper_arm_back', 55 | 'right_lower_arm_back', 56 | 'left_lower_arm_front', 57 | 'left_upper_arm_front', 58 | 'left_upper_arm_back', 59 | 'left_lower_arm_back', 60 | 'right_hand', 61 | 'right_lower_arm_front', 62 | 'left_hand' 63 | ] -------------------------------------------------------------------------------- /posenet/decode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from posenet.constants import * 4 | 5 | 6 | def traverse_to_targ_keypoint( 7 | edge_id, source_keypoint, target_keypoint_id, scores, offsets, output_stride, displacements 8 | ): 9 | height = scores.shape[0] 10 | width = scores.shape[1] 11 | 12 | source_keypoint_indices = np.clip( 13 | np.round(source_keypoint / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32) 14 | 15 | displaced_point = source_keypoint + displacements[ 16 | source_keypoint_indices[0], source_keypoint_indices[1], edge_id] 17 | 18 | displaced_point_indices = np.clip( 19 | np.round(displaced_point / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32) 20 | 21 | score = scores[displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id] 22 | 23 | image_coord = displaced_point_indices * output_stride + offsets[ 24 | displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id] 25 | 26 | return score, image_coord 27 | 28 | 29 | def decode_pose( 30 | root_score, root_id, root_image_coord, 31 | scores, 32 | offsets, 33 | output_stride, 34 | displacements_fwd, 35 | displacements_bwd 36 | ): 37 | num_parts = scores.shape[2] 38 | num_edges = len(PARENT_CHILD_TUPLES) 39 | 40 | instance_keypoint_scores = np.zeros(num_parts) 41 | instance_keypoint_coords = np.zeros((num_parts, 2)) 42 | instance_keypoint_scores[root_id] = root_score 43 | instance_keypoint_coords[root_id] = root_image_coord 44 | 45 | for edge in reversed(range(num_edges)): 46 | target_keypoint_id, source_keypoint_id = PARENT_CHILD_TUPLES[edge] 47 | if (instance_keypoint_scores[source_keypoint_id] > 0.0 and 48 | instance_keypoint_scores[target_keypoint_id] == 0.0): 49 | score, coords = traverse_to_targ_keypoint( 50 | edge, 51 | instance_keypoint_coords[source_keypoint_id], 52 | target_keypoint_id, 53 | scores, offsets, output_stride, displacements_bwd) 54 | instance_keypoint_scores[target_keypoint_id] = score 55 | instance_keypoint_coords[target_keypoint_id] = coords 56 | 57 | for edge in range(num_edges): 58 | source_keypoint_id, target_keypoint_id = PARENT_CHILD_TUPLES[edge] 59 | if (instance_keypoint_scores[source_keypoint_id] > 0.0 and 60 | instance_keypoint_scores[target_keypoint_id] == 0.0): 61 | score, coords = traverse_to_targ_keypoint( 62 | edge, 63 | instance_keypoint_coords[source_keypoint_id], 64 | target_keypoint_id, 65 | scores, offsets, output_stride, displacements_fwd) 66 | instance_keypoint_scores[target_keypoint_id] = score 67 | instance_keypoint_coords[target_keypoint_id] = coords 68 | 69 | return instance_keypoint_scores, instance_keypoint_coords 70 | -------------------------------------------------------------------------------- /image_demo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import time 4 | import argparse 5 | import os 6 | 7 | import posenet 8 | 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', type=int, default=101) 12 | parser.add_argument('--scale_factor', type=float, default=1.0) 13 | parser.add_argument('--notxt', action='store_true') 14 | parser.add_argument('--image_dir', type=str, default='./images') 15 | parser.add_argument('--output_dir', type=str, default='./output') 16 | args = parser.parse_args() 17 | 18 | 19 | def main(): 20 | 21 | with tf.Session() as sess: 22 | model_cfg, model_outputs = posenet.load_model(args.model, sess) 23 | output_stride = model_cfg['output_stride'] 24 | 25 | if args.output_dir: 26 | if not os.path.exists(args.output_dir): 27 | os.makedirs(args.output_dir) 28 | 29 | filenames = [ 30 | f.path for f in os.scandir(args.image_dir) if f.is_file() and f.path.endswith(('.png', '.jpg'))] 31 | 32 | start = time.time() 33 | for f in filenames: 34 | input_image, draw_image, output_scale = posenet.read_imgfile( 35 | f, scale_factor=args.scale_factor, output_stride=output_stride) 36 | 37 | heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = sess.run( 38 | model_outputs, 39 | feed_dict={'image:0': input_image} 40 | ) 41 | 42 | pose_scores, keypoint_scores, keypoint_coords = posenet.decode_multiple_poses( 43 | heatmaps_result.squeeze(axis=0), 44 | offsets_result.squeeze(axis=0), 45 | displacement_fwd_result.squeeze(axis=0), 46 | displacement_bwd_result.squeeze(axis=0), 47 | output_stride=output_stride, 48 | max_pose_detections=10, 49 | min_pose_score=0.25) 50 | 51 | keypoint_coords *= output_scale 52 | 53 | if args.output_dir: 54 | draw_image = posenet.draw_skel_and_kp( 55 | draw_image, pose_scores, keypoint_scores, keypoint_coords, 56 | min_pose_score=0.25, min_part_score=0.25) 57 | 58 | cv2.imwrite(os.path.join(args.output_dir, os.path.relpath(f, args.image_dir)), draw_image) 59 | 60 | if not args.notxt: 61 | print() 62 | print("Results for image: %s" % f) 63 | for pi in range(len(pose_scores)): 64 | if pose_scores[pi] == 0.: 65 | break 66 | print('Pose #%d, score = %f' % (pi, pose_scores[pi])) 67 | for ki, (s, c) in enumerate(zip(keypoint_scores[pi, :], keypoint_coords[pi, :, :])): 68 | print('Keypoint %s, score = %f, coord = %s' % (posenet.PART_NAMES[ki], s, c)) 69 | 70 | print('Average FPS:', len(filenames) / (time.time() - start)) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /webcam_demo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import time 4 | import argparse 5 | 6 | import posenet 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--model', type=int, default=101) 10 | parser.add_argument('--cam_id', type=int, default=0) 11 | parser.add_argument('--cam_width', type=int, default=1280) 12 | parser.add_argument('--cam_height', type=int, default=720) 13 | parser.add_argument('--scale_factor', type=float, default=0.7125) 14 | parser.add_argument('--file', type=str, default=None, help="Optionally use a video file instead of a live camera") 15 | args = parser.parse_args() 16 | 17 | def gstreamer_pipeline (capture_width=1280, capture_height=720, display_width=1280, display_height=720, framerate=60, flip_method=2) : 18 | return ('nvarguscamerasrc ! ' 19 | 'video/x-raw(memory:NVMM), ' 20 | 'width=(int)%d, height=(int)%d, ' 21 | 'format=(string)NV12, framerate=(fraction)%d/1 ! ' 22 | 'nvvidconv flip-method=%d ! ' 23 | 'video/x-raw, width=(int)%d, height=(int)%d, format=(string)BGRx ! ' 24 | 'videoconvert ! ' 25 | 'video/x-raw, format=(string)BGR ! appsink' % (capture_width,capture_height,framerate,flip_method,display_width,display_height)) 26 | 27 | def main(): 28 | with tf.Session() as sess: 29 | model_cfg, model_outputs = posenet.load_model(args.model, sess) 30 | output_stride = model_cfg['output_stride'] 31 | 32 | if args.file is not None: 33 | cap = cv2.VideoCapture(args.file) 34 | else: 35 | cap = cv2.VideoCapture(gstreamer_pipeline(flip_method=2), cv2.CAP_GSTREAMER) 36 | #cap.set(3, args.cam_width) 37 | #cap.set(4, args.cam_height) 38 | 39 | start = time.time() 40 | frame_count = 0 41 | while True: 42 | input_image, display_image, output_scale = posenet.read_cap( 43 | cap, scale_factor=args.scale_factor, output_stride=output_stride) 44 | 45 | heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = sess.run( 46 | model_outputs, 47 | feed_dict={'image:0': input_image} 48 | ) 49 | 50 | pose_scores, keypoint_scores, keypoint_coords = posenet.decode_multi.decode_multiple_poses( 51 | heatmaps_result.squeeze(axis=0), 52 | offsets_result.squeeze(axis=0), 53 | displacement_fwd_result.squeeze(axis=0), 54 | displacement_bwd_result.squeeze(axis=0), 55 | output_stride=output_stride, 56 | max_pose_detections=10, 57 | min_pose_score=0.15) 58 | 59 | keypoint_coords *= output_scale 60 | 61 | # TODO this isn't particularly fast, use GL for drawing and display someday... 62 | overlay_image = posenet.draw_skel_and_kp( 63 | display_image, pose_scores, keypoint_scores, keypoint_coords, 64 | min_pose_score=0.15, min_part_score=0.1) 65 | 66 | 67 | cv2.namedWindow("posenet", cv2.WND_PROP_FULLSCREEN) 68 | cv2.setWindowProperty("posenet", cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN) 69 | cv2.imshow('posenet', overlay_image) 70 | frame_count += 1 71 | if cv2.waitKey(1) & 0xFF == ord('q'): 72 | break 73 | 74 | print('Average FPS: ', frame_count / (time.time() - start)) 75 | 76 | 77 | if __name__ == "__main__": 78 | main() 79 | 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | images/* 2 | output/* 3 | .idea/* 4 | .idea 5 | _models/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # celery beat schedule file 99 | celerybeat-schedule 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 133 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 134 | 135 | # User-specific stuff 136 | .idea/**/workspace.xml 137 | .idea/**/tasks.xml 138 | .idea/**/usage.statistics.xml 139 | .idea/**/dictionaries 140 | .idea/**/shelf 141 | 142 | # Generated files 143 | .idea/**/contentModel.xml 144 | 145 | # Sensitive or high-churn files 146 | .idea/**/dataSources/ 147 | .idea/**/dataSources.ids 148 | .idea/**/dataSources.local.xml 149 | .idea/**/sqlDataSources.xml 150 | .idea/**/dynamic.xml 151 | .idea/**/uiDesigner.xml 152 | .idea/**/dbnavigator.xml 153 | 154 | # Gradle 155 | .idea/**/gradle.xml 156 | .idea/**/libraries 157 | 158 | # Gradle and Maven with auto-import 159 | # When using Gradle or Maven with auto-import, you should exclude module files, 160 | # since they will be recreated, and may cause churn. Uncomment if using 161 | # auto-import. 162 | .idea/modules.xml 163 | .idea/*.iml 164 | .idea/modules 165 | 166 | # CMake 167 | cmake-build-*/ 168 | 169 | # Mongo Explorer plugin 170 | .idea/**/mongoSettings.xml 171 | 172 | # File-based project format 173 | *.iws 174 | 175 | # IntelliJ 176 | out/ 177 | 178 | # mpeltonen/sbt-idea plugin 179 | .idea_modules/ 180 | 181 | # JIRA plugin 182 | atlassian-ide-plugin.xml 183 | 184 | # Cursive Clojure plugin 185 | .idea/replstate.xml 186 | 187 | # Crashlytics plugin (for Android Studio and IntelliJ) 188 | com_crashlytics_export_strings.xml 189 | crashlytics.properties 190 | crashlytics-build.properties 191 | fabric.properties 192 | 193 | # Editor-based Rest Client 194 | .idea/httpRequests 195 | 196 | # Android studio 3.1+ serialized cache file 197 | .idea/caches/build_file_checksums.ser -------------------------------------------------------------------------------- /posenet/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import posenet.constants 5 | 6 | 7 | def valid_resolution(width, height, output_stride=16): 8 | target_width = (int(width) // output_stride) * output_stride + 1 9 | target_height = (int(height) // output_stride) * output_stride + 1 10 | return target_width, target_height 11 | 12 | 13 | def _process_input(source_img, scale_factor=1.0, output_stride=16): 14 | target_width, target_height = valid_resolution( 15 | source_img.shape[1] * scale_factor, source_img.shape[0] * scale_factor, output_stride=output_stride) 16 | scale = np.array([source_img.shape[0] / target_height, source_img.shape[1] / target_width]) 17 | 18 | input_img = cv2.resize(source_img, (target_width, target_height), interpolation=cv2.INTER_LINEAR) 19 | input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB).astype(np.float32) 20 | input_img = input_img * (2.0 / 255.0) - 1.0 21 | input_img = input_img.reshape(1, target_height, target_width, 3) 22 | return input_img, source_img, scale 23 | 24 | 25 | def read_cap(cap, scale_factor=1.0, output_stride=16): 26 | res, img = cap.read() 27 | if not res: 28 | raise IOError("webcam failure") 29 | return _process_input(img, scale_factor, output_stride) 30 | 31 | 32 | def read_imgfile(path, scale_factor=1.0, output_stride=16): 33 | img = cv2.imread(path) 34 | return _process_input(img, scale_factor, output_stride) 35 | 36 | 37 | def draw_keypoints( 38 | img, instance_scores, keypoint_scores, keypoint_coords, 39 | min_pose_confidence=0.5, min_part_confidence=0.5): 40 | cv_keypoints = [] 41 | for ii, score in enumerate(instance_scores): 42 | if score < min_pose_confidence: 43 | continue 44 | for ks, kc in zip(keypoint_scores[ii, :], keypoint_coords[ii, :, :]): 45 | if ks < min_part_confidence: 46 | continue 47 | cv_keypoints.append(cv2.KeyPoint(kc[1], kc[0], 10. * ks)) 48 | out_img = cv2.drawKeypoints(img, cv_keypoints, outImage=np.array([])) 49 | return out_img 50 | 51 | 52 | def get_adjacent_keypoints(keypoint_scores, keypoint_coords, min_confidence=0.1): 53 | results = [] 54 | for left, right in posenet.CONNECTED_PART_INDICES: 55 | if keypoint_scores[left] < min_confidence or keypoint_scores[right] < min_confidence: 56 | continue 57 | results.append( 58 | np.array([keypoint_coords[left][::-1], keypoint_coords[right][::-1]]).astype(np.int32), 59 | ) 60 | return results 61 | 62 | 63 | def draw_skeleton( 64 | img, instance_scores, keypoint_scores, keypoint_coords, 65 | min_pose_confidence=0.5, min_part_confidence=0.5): 66 | out_img = img 67 | adjacent_keypoints = [] 68 | for ii, score in enumerate(instance_scores): 69 | if score < min_pose_confidence: 70 | continue 71 | new_keypoints = get_adjacent_keypoints( 72 | keypoint_scores[ii, :], keypoint_coords[ii, :, :], min_part_confidence) 73 | adjacent_keypoints.extend(new_keypoints) 74 | out_img = cv2.polylines(out_img, adjacent_keypoints, isClosed=False, color=(255, 255, 0)) 75 | return out_img 76 | 77 | 78 | def draw_skel_and_kp( 79 | img, instance_scores, keypoint_scores, keypoint_coords, 80 | min_pose_score=0.5, min_part_score=0.5): 81 | out_img = img 82 | adjacent_keypoints = [] 83 | cv_keypoints = [] 84 | for ii, score in enumerate(instance_scores): 85 | if score < min_pose_score: 86 | continue 87 | 88 | new_keypoints = get_adjacent_keypoints( 89 | keypoint_scores[ii, :], keypoint_coords[ii, :, :], min_part_score) 90 | adjacent_keypoints.extend(new_keypoints) 91 | 92 | for ks, kc in zip(keypoint_scores[ii, :], keypoint_coords[ii, :, :]): 93 | if ks < min_part_score: 94 | continue 95 | cv_keypoints.append(cv2.KeyPoint(kc[1], kc[0], 10. * ks)) 96 | 97 | out_img = cv2.drawKeypoints( 98 | out_img, cv_keypoints, outImage=np.array([]), color=(255, 255, 0), 99 | flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) 100 | out_img = cv2.polylines(out_img, adjacent_keypoints, isClosed=False, color=(255, 255, 0)) 101 | return out_img 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## OBSOLETE ## 2 | 3 | If you're looking to do pose estimation on the Jetson Nano, check out the official NVIDIA project for the job. I'll guarantee it's better optimised than this one! 4 | 5 | https://github.com/NVIDIA-AI-IOT/trt_pose 6 | 7 | 8 | 9 | 10 | ## Posenet Jetson Nano ## 11 | 12 | A fork of rwightman's excellent work. Modified slightly to use the OpenCV/gstreaming capabilities preinstalled on the NVIDIA Jetson Nano. 13 | 14 | Requires: 15 | * Tensorflow (https://devtalk.nvidia.com/default/topic/1048776/official-tensorflow-for-jetson-nano-/) 16 | * ``` sudo apt-get install python3-pip libhdf5-serial-dev hdf5-tool ``` 17 | * ```pip3 install --extra-index-url https://developer.download.nvidia.com/compute/redist/jp/v42 tensorflow-gpu==1.13.1+nv19.5 --user --no-cache-dir``` 18 | * Scipy (May require lapack/blas update prior to install) 19 | * ```sudo apt-get update``` 20 | * ```sudo apt-get install -y build-essential libatlas-base-dev gfortran``` 21 | * ```pip3 install scipy``` 22 | 23 | Does not require any updates to OpenCV versions provided in the standard Jetson Nano build. 24 | 25 | Performance increases might be observed by setting up a memory swapfile. 26 | 27 | 28 | 29 | 30 | Original ReadMe follows: 31 | 32 | ## PoseNet Python 33 | 34 | This repository contains a pure Python implementation (multi-pose only) of the Google TensorFlow.js Posenet model. For a (slightly faster) PyTorch implementation that followed from this, see (https://github.com/rwightman/posenet-pytorch) 35 | 36 | I first adapted the JS code more or less verbatim and found the performance was low so made some vectorized numpy/scipy version of a few key functions (named `_fast`). 37 | 38 | Further optimization is possible 39 | * The base MobileNet models have a throughput of 200-300 fps on a GTX 1080 Ti (or better) 40 | * The multi-pose post processing code brings this rate down significantly. With a fast CPU and a GTX 1080+: 41 | * A literal translation of the JS post processing code dropped performance to approx 30fps 42 | * My 'fast' post processing results in 90-110fps 43 | * A Cython or pure C++ port would be even better... 44 | 45 | ### Install 46 | 47 | A suitable Python 3.x environment with a recent version of Tensorflow is required. 48 | 49 | Development and testing was done with Conda Python 3.6.8 and Tensorflow 1.12.0 on Linux. 50 | 51 | Windows 10 with the latest (as of 2019-01-19) 64-bit Python 3.7 Anaconda installer was also tested. 52 | 53 | If you want to use the webcam demo, a pip version of opencv (`pip install opencv-python`) is required instead of the conda version. Anaconda's default opencv does not include ffpmeg/VideoCapture support. Also, you may have to force install version 3.4.x as 4.x has a broken drawKeypoints binding. 54 | 55 | A conda environment setup as below should suffice: 56 | ``` 57 | conda install tensorflow-gpu scipy pyyaml python=3.6 58 | pip install opencv-python==3.4.5.20 59 | 60 | ``` 61 | 62 | ### Usage 63 | 64 | There are three demo apps in the root that utilize the PoseNet model. They are very basic and could definitely be improved. 65 | 66 | The first time these apps are run (or the library is used) model weights will be downloaded from the TensorFlow.js version and converted on the fly. 67 | 68 | For all demos, the model can be specified with the '--model` argument by using its ordinal id (0-3) or integer depth multiplier (50, 75, 100, 101). The default is the 101 model. 69 | 70 | #### image_demo.py 71 | 72 | Image demo runs inference on an input folder of images and outputs those images with the keypoints and skeleton overlayed. 73 | 74 | `python image_demo.py --model 101 --image_dir ./images --output_dir ./output` 75 | 76 | A folder of suitable test images can be downloaded by first running the `get_test_images.py` script. 77 | 78 | #### benchmark.py 79 | 80 | A minimal performance benchmark based on image_demo. Images in `--image_dir` are pre-loaded and inference is run `--num_images` times with no drawing and no text output. 81 | 82 | #### webcam_demo.py 83 | 84 | The webcam demo uses OpenCV to capture images from a connected webcam. The result is overlayed with the keypoints and skeletons and rendered to the screen. The default args for the webcam_demo assume device_id=0 for the camera and that 1280x720 resolution is possible. 85 | 86 | ### Credits 87 | 88 | The original model, weights, code, etc. was created by Google and can be found at https://github.com/tensorflow/tfjs-models/tree/master/posenet 89 | 90 | This port and my work is in no way related to Google. 91 | 92 | The Python conversion code that started me on my way was adapted from the CoreML port at https://github.com/infocom-tpo/PoseNet-CoreML 93 | 94 | ### TODO (someday, maybe) 95 | * More stringent verification of correctness against the original implementation 96 | * Performance improvements (especially edge loops in 'decode.py') 97 | * OpenGL rendering/drawing 98 | * Comment interfaces, tensor dimensions, etc 99 | * Implement batch inference for image_demo 100 | 101 | -------------------------------------------------------------------------------- /posenet/decode_multi.py: -------------------------------------------------------------------------------- 1 | from posenet.decode import * 2 | from posenet.constants import * 3 | import time 4 | import scipy.ndimage as ndi 5 | 6 | 7 | def within_nms_radius(poses, squared_nms_radius, point, keypoint_id): 8 | for _, _, pose_coord in poses: 9 | if np.sum((pose_coord[keypoint_id] - point) ** 2) <= squared_nms_radius: 10 | return True 11 | return False 12 | 13 | 14 | def within_nms_radius_fast(pose_coords, squared_nms_radius, point): 15 | if not pose_coords.shape[0]: 16 | return False 17 | return np.any(np.sum((pose_coords - point) ** 2, axis=1) <= squared_nms_radius) 18 | 19 | 20 | def get_instance_score( 21 | existing_poses, squared_nms_radius, 22 | keypoint_scores, keypoint_coords): 23 | not_overlapped_scores = 0. 24 | for keypoint_id in range(len(keypoint_scores)): 25 | if not within_nms_radius( 26 | existing_poses, squared_nms_radius, 27 | keypoint_coords[keypoint_id], keypoint_id): 28 | not_overlapped_scores += keypoint_scores[keypoint_id] 29 | return not_overlapped_scores / len(keypoint_scores) 30 | 31 | 32 | def get_instance_score_fast( 33 | exist_pose_coords, 34 | squared_nms_radius, 35 | keypoint_scores, keypoint_coords): 36 | 37 | if exist_pose_coords.shape[0]: 38 | s = np.sum((exist_pose_coords - keypoint_coords) ** 2, axis=2) > squared_nms_radius 39 | not_overlapped_scores = np.sum(keypoint_scores[np.all(s, axis=0)]) 40 | else: 41 | not_overlapped_scores = np.sum(keypoint_scores) 42 | return not_overlapped_scores / len(keypoint_scores) 43 | 44 | 45 | def score_is_max_in_local_window(keypoint_id, score, hmy, hmx, local_max_radius, scores): 46 | height = scores.shape[0] 47 | width = scores.shape[1] 48 | 49 | y_start = max(hmy - local_max_radius, 0) 50 | y_end = min(hmy + local_max_radius + 1, height) 51 | x_start = max(hmx - local_max_radius, 0) 52 | x_end = min(hmx + local_max_radius + 1, width) 53 | 54 | for y in range(y_start, y_end): 55 | for x in range(x_start, x_end): 56 | if scores[y, x, keypoint_id] > score: 57 | return False 58 | return True 59 | 60 | 61 | def build_part_with_score(score_threshold, local_max_radius, scores): 62 | parts = [] 63 | height = scores.shape[0] 64 | width = scores.shape[1] 65 | num_keypoints = scores.shape[2] 66 | 67 | for hmy in range(height): 68 | for hmx in range(width): 69 | for keypoint_id in range(num_keypoints): 70 | score = scores[hmy, hmx, keypoint_id] 71 | if score < score_threshold: 72 | continue 73 | if score_is_max_in_local_window(keypoint_id, score, hmy, hmx, 74 | local_max_radius, scores): 75 | parts.append(( 76 | score, keypoint_id, np.array((hmy, hmx)) 77 | )) 78 | return parts 79 | 80 | 81 | def build_part_with_score_fast(score_threshold, local_max_radius, scores): 82 | parts = [] 83 | num_keypoints = scores.shape[2] 84 | lmd = 2 * local_max_radius + 1 85 | 86 | # NOTE it seems faster to iterate over the keypoints and perform maximum_filter 87 | # on each subarray vs doing the op on the full score array with size=(lmd, lmd, 1) 88 | for keypoint_id in range(num_keypoints): 89 | kp_scores = scores[:, :, keypoint_id].copy() 90 | kp_scores[kp_scores < score_threshold] = 0. 91 | max_vals = ndi.maximum_filter(kp_scores, size=lmd, mode='constant') 92 | max_loc = np.logical_and(kp_scores == max_vals, kp_scores > 0) 93 | max_loc_idx = max_loc.nonzero() 94 | for y, x in zip(*max_loc_idx): 95 | parts.append(( 96 | scores[y, x, keypoint_id], 97 | keypoint_id, 98 | np.array((y, x)) 99 | )) 100 | 101 | return parts 102 | 103 | 104 | def decode_multiple_poses( 105 | scores, offsets, displacements_fwd, displacements_bwd, output_stride, 106 | max_pose_detections=10, score_threshold=0.5, nms_radius=20, min_pose_score=0.5): 107 | 108 | pose_count = 0 109 | pose_scores = np.zeros(max_pose_detections) 110 | pose_keypoint_scores = np.zeros((max_pose_detections, NUM_KEYPOINTS)) 111 | pose_keypoint_coords = np.zeros((max_pose_detections, NUM_KEYPOINTS, 2)) 112 | 113 | squared_nms_radius = nms_radius ** 2 114 | 115 | scored_parts = build_part_with_score_fast(score_threshold, LOCAL_MAXIMUM_RADIUS, scores) 116 | scored_parts = sorted(scored_parts, key=lambda x: x[0], reverse=True) 117 | 118 | # change dimensions from (h, w, x) to (h, w, x//2, 2) to allow return of complete coord array 119 | height = scores.shape[0] 120 | width = scores.shape[1] 121 | offsets = offsets.reshape(height, width, 2, -1).swapaxes(2, 3) 122 | displacements_fwd = displacements_fwd.reshape(height, width, 2, -1).swapaxes(2, 3) 123 | displacements_bwd = displacements_bwd.reshape(height, width, 2, -1).swapaxes(2, 3) 124 | 125 | for root_score, root_id, root_coord in scored_parts: 126 | root_image_coords = root_coord * output_stride + offsets[ 127 | root_coord[0], root_coord[1], root_id] 128 | 129 | if within_nms_radius_fast( 130 | pose_keypoint_coords[:pose_count, root_id, :], squared_nms_radius, root_image_coords): 131 | continue 132 | 133 | keypoint_scores, keypoint_coords = decode_pose( 134 | root_score, root_id, root_image_coords, 135 | scores, offsets, output_stride, 136 | displacements_fwd, displacements_bwd) 137 | 138 | pose_score = get_instance_score_fast( 139 | pose_keypoint_coords[:pose_count, :, :], squared_nms_radius, keypoint_scores, keypoint_coords) 140 | 141 | # NOTE this isn't in the original implementation, but it appears that by initially ordering by 142 | # part scores, and having a max # of detections, we can end up populating the returned poses with 143 | # lower scored poses than if we discard 'bad' ones and continue (higher pose scores can still come later). 144 | # Set min_pose_score to 0. to revert to original behaviour 145 | if min_pose_score == 0. or pose_score >= min_pose_score: 146 | pose_scores[pose_count] = pose_score 147 | pose_keypoint_scores[pose_count, :] = keypoint_scores 148 | pose_keypoint_coords[pose_count, :, :] = keypoint_coords 149 | pose_count += 1 150 | 151 | if pose_count >= max_pose_detections: 152 | break 153 | 154 | return pose_scores, pose_keypoint_scores, pose_keypoint_coords 155 | -------------------------------------------------------------------------------- /posenet/converter/tfjs2python.py: -------------------------------------------------------------------------------- 1 | import json 2 | import struct 3 | import tensorflow as tf 4 | from tensorflow.python.tools.freeze_graph import freeze_graph 5 | import cv2 6 | import numpy as np 7 | import os 8 | import tempfile 9 | 10 | from posenet.converter.config import load_config 11 | 12 | BASE_DIR = os.path.join(tempfile.gettempdir(), '_posenet_weights') 13 | 14 | 15 | def to_output_strided_layers(convolution_def, output_stride): 16 | current_stride = 1 17 | rate = 1 18 | block_id = 0 19 | buff = [] 20 | for _a in convolution_def: 21 | conv_type = _a[0] 22 | stride = _a[1] 23 | 24 | if current_stride == output_stride: 25 | layer_stride = 1 26 | layer_rate = rate 27 | rate *= stride 28 | else: 29 | layer_stride = stride 30 | layer_rate = 1 31 | current_stride *= stride 32 | 33 | buff.append({ 34 | 'blockId': block_id, 35 | 'convType': conv_type, 36 | 'stride': layer_stride, 37 | 'rate': layer_rate, 38 | 'outputStride': current_stride 39 | }) 40 | block_id += 1 41 | 42 | return buff 43 | 44 | 45 | def load_variables(chkpoint, base_dir=BASE_DIR): 46 | manifest_path = os.path.join(base_dir, chkpoint, "manifest.json") 47 | if not os.path.exists(manifest_path): 48 | print('Weights for checkpoint %s are not downloaded. Downloading to %s ...' % (chkpoint, base_dir)) 49 | from posenet.converter.wget import download 50 | download(chkpoint, base_dir) 51 | assert os.path.exists(manifest_path) 52 | 53 | f = open(manifest_path) 54 | variables = json.load(f) 55 | f.close() 56 | 57 | # with tf.variable_scope(None, 'MobilenetV1'): 58 | for x in variables: 59 | filename = variables[x]["filename"] 60 | byte = open(os.path.join(base_dir, chkpoint, filename), 'rb').read() 61 | fmt = str(int(len(byte) / struct.calcsize('f'))) + 'f' 62 | d = struct.unpack(fmt, byte) 63 | d = tf.cast(d, tf.float32) 64 | d = tf.reshape(d, variables[x]["shape"]) 65 | variables[x]["x"] = tf.Variable(d, name=x) 66 | 67 | return variables 68 | 69 | 70 | def _read_imgfile(path, width, height): 71 | img = cv2.imread(path) 72 | img = cv2.resize(img, (width, height)) 73 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 74 | img = img.astype(float) 75 | img = img * (2.0 / 255.0) - 1.0 76 | return img 77 | 78 | 79 | def build_network(image, layers, variables): 80 | 81 | def _weights(layer_name): 82 | return variables["MobilenetV1/" + layer_name + "/weights"]['x'] 83 | 84 | def _biases(layer_name): 85 | return variables["MobilenetV1/" + layer_name + "/biases"]['x'] 86 | 87 | def _depthwise_weights(layer_name): 88 | return variables["MobilenetV1/" + layer_name + "/depthwise_weights"]['x'] 89 | 90 | def _conv_to_output(mobile_net_output, output_layer_name): 91 | w = tf.nn.conv2d(mobile_net_output, _weights(output_layer_name), [1, 1, 1, 1], padding='SAME') 92 | w = tf.nn.bias_add(w, _biases(output_layer_name), name=output_layer_name) 93 | return w 94 | 95 | def _conv(inputs, stride, block_id): 96 | return tf.nn.relu6( 97 | tf.nn.conv2d(inputs, _weights("Conv2d_" + str(block_id)), stride, padding='SAME') 98 | + _biases("Conv2d_" + str(block_id))) 99 | 100 | def _separable_conv(inputs, stride, block_id, dilations): 101 | if dilations is None: 102 | dilations = [1, 1] 103 | 104 | dw_layer = "Conv2d_" + str(block_id) + "_depthwise" 105 | pw_layer = "Conv2d_" + str(block_id) + "_pointwise" 106 | 107 | w = tf.nn.depthwise_conv2d( 108 | inputs, _depthwise_weights(dw_layer), stride, 'SAME', rate=dilations, data_format='NHWC') 109 | w = tf.nn.bias_add(w, _biases(dw_layer)) 110 | w = tf.nn.relu6(w) 111 | 112 | w = tf.nn.conv2d(w, _weights(pw_layer), [1, 1, 1, 1], padding='SAME') 113 | w = tf.nn.bias_add(w, _biases(pw_layer)) 114 | w = tf.nn.relu6(w) 115 | 116 | return w 117 | 118 | x = image 119 | buff = [] 120 | with tf.variable_scope(None, 'MobilenetV1'): 121 | 122 | for m in layers: 123 | stride = [1, m['stride'], m['stride'], 1] 124 | rate = [m['rate'], m['rate']] 125 | if m['convType'] == "conv2d": 126 | x = _conv(x, stride, m['blockId']) 127 | buff.append(x) 128 | elif m['convType'] == "separableConv": 129 | x = _separable_conv(x, stride, m['blockId'], rate) 130 | buff.append(x) 131 | 132 | heatmaps = _conv_to_output(x, 'heatmap_2') 133 | offsets = _conv_to_output(x, 'offset_2') 134 | displacement_fwd = _conv_to_output(x, 'displacement_fwd_2') 135 | displacement_bwd = _conv_to_output(x, 'displacement_bwd_2') 136 | heatmaps = tf.sigmoid(heatmaps, 'heatmap') 137 | 138 | return heatmaps, offsets, displacement_fwd, displacement_bwd 139 | 140 | 141 | def convert(model_id, model_dir, check=False): 142 | cfg = load_config() 143 | checkpoints = cfg['checkpoints'] 144 | image_size = cfg['imageSize'] 145 | output_stride = cfg['outputStride'] 146 | chkpoint = checkpoints[model_id] 147 | 148 | if chkpoint == 'mobilenet_v1_050': 149 | mobile_net_arch = cfg['mobileNet50Architecture'] 150 | elif chkpoint == 'mobilenet_v1_075': 151 | mobile_net_arch = cfg['mobileNet75Architecture'] 152 | else: 153 | mobile_net_arch = cfg['mobileNet100Architecture'] 154 | 155 | width = image_size 156 | height = image_size 157 | 158 | if not os.path.exists(model_dir): 159 | os.makedirs(model_dir) 160 | 161 | cg = tf.Graph() 162 | with cg.as_default(): 163 | layers = to_output_strided_layers(mobile_net_arch, output_stride) 164 | variables = load_variables(chkpoint) 165 | 166 | init = tf.global_variables_initializer() 167 | with tf.Session() as sess: 168 | sess.run(init) 169 | saver = tf.train.Saver() 170 | 171 | image_ph = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='image') 172 | outputs = build_network(image_ph, layers, variables) 173 | 174 | sess.run( 175 | [outputs], 176 | feed_dict={ 177 | image_ph: [np.ndarray(shape=(height, width, 3), dtype=np.float32)] 178 | } 179 | ) 180 | 181 | save_path = os.path.join(model_dir, 'checkpoints', 'model-%s.ckpt' % chkpoint) 182 | if not os.path.exists(os.path.dirname(save_path)): 183 | os.makedirs(os.path.dirname(save_path)) 184 | checkpoint_path = saver.save(sess, save_path, write_state=False) 185 | 186 | tf.train.write_graph(cg, model_dir, "model-%s.pbtxt" % chkpoint) 187 | 188 | # Freeze graph and write our final model file 189 | freeze_graph( 190 | input_graph=os.path.join(model_dir, "model-%s.pbtxt" % chkpoint), 191 | input_saver="", 192 | input_binary=False, 193 | input_checkpoint=checkpoint_path, 194 | output_node_names='heatmap,offset_2,displacement_fwd_2,displacement_bwd_2', 195 | restore_op_name="save/restore_all", 196 | filename_tensor_name="save/Const:0", 197 | output_graph=os.path.join(model_dir, "model-%s.pb" % chkpoint), 198 | clear_devices=True, 199 | initializer_nodes="") 200 | 201 | if check and os.path.exists("./images/tennis_in_crowd.jpg"): 202 | # Result 203 | input_image = _read_imgfile("./images/tennis_in_crowd.jpg", width, height) 204 | input_image = np.array(input_image, dtype=np.float32) 205 | input_image = input_image.reshape(1, height, width, 3) 206 | 207 | heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = sess.run( 208 | outputs, 209 | feed_dict={image_ph: input_image} 210 | ) 211 | 212 | print("Test image stats") 213 | print(input_image) 214 | print(input_image.shape) 215 | print(np.mean(input_image)) 216 | 217 | heatmaps_result = heatmaps_result[0] 218 | 219 | print("Heatmaps") 220 | print(heatmaps_result[0:1, 0:1, :]) 221 | print(heatmaps_result.shape) 222 | print(np.mean(heatmaps_result)) 223 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 Ross Wightman 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. --------------------------------------------------------------------------------