├── tests ├── __init__.py ├── utils │ ├── __init__.py │ └── dist_test.py ├── conftest.py ├── model_test.py └── cli_test.py ├── tf_bodypix ├── __init__.py ├── utils │ ├── __init__.py │ ├── typing.py │ ├── s3.py │ ├── dist.py │ ├── io.py │ ├── v4l2.py │ ├── timer.py │ ├── image.py │ └── opencv.py ├── bodypix_js_utils │ ├── __init__.py │ ├── multi_person │ │ ├── __init__.py │ │ ├── util.py │ │ ├── decode_multiple_poses.py │ │ └── decode_pose.py │ ├── output_rendering_util.py │ ├── decode_part_map.py │ ├── types.py │ ├── part_channels.py │ ├── keypoints.py │ ├── build_part_with_score_queue.py │ └── util.py ├── __main__.py ├── api.py ├── tflite.py ├── sink.py ├── source.py ├── draw.py ├── download.py ├── model.py └── cli.py ├── .python-version ├── .flake8 ├── requirements.tflite.txt ├── requirements.build.txt ├── MANIFEST.in ├── constraints.txt ├── .dockerignore ├── .gitignore ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── requirements.dev.txt ├── requirements.txt ├── docker └── entrypoint.sh ├── .pylintrc ├── Dockerfile ├── LICENSE ├── setup.py ├── Makefile └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tf_bodypix/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.8.7 2 | -------------------------------------------------------------------------------- /tf_bodypix/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | -------------------------------------------------------------------------------- /requirements.tflite.txt: -------------------------------------------------------------------------------- 1 | tflite-runtime==2.11.0 2 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/multi_person/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.build.txt: -------------------------------------------------------------------------------- 1 | pip==22.3.1 2 | wheel==0.38.4 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include tf_bodypix 2 | prune tests 3 | prune .* 4 | -------------------------------------------------------------------------------- /constraints.txt: -------------------------------------------------------------------------------- 1 | urllib3>=1.26.14 2 | watchdog>=2.2.1 3 | wrapt>=1.14.1 4 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | venv_temp 3 | data 4 | dist 5 | build 6 | *.egg-info 7 | 8 | *.pyc 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | venv_temp 3 | data 4 | dist 5 | build 6 | *.egg-info 7 | 8 | *.pyc 9 | *.tflite 10 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: pip 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | open-pull-requests-limit: 10 8 | -------------------------------------------------------------------------------- /tf_bodypix/__main__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from tf_bodypix.cli import main 3 | 4 | 5 | if __name__ == '__main__': 6 | logging.basicConfig(level='INFO') 7 | main() 8 | -------------------------------------------------------------------------------- /tf_bodypix/api.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-import 2 | from tf_bodypix.model import load_model # noqa 3 | from tf_bodypix.download import download_model, BodyPixModelPaths # noqa 4 | -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | flake8==5.0.4 2 | mypy==0.991 3 | pylint==2.15.10 4 | pytest==7.2.1 5 | pytest-watch==4.2.0 6 | setuptools-scm==7.1.0 7 | types-requests==2.28.11.8 8 | types-urllib3==1.26.25.4 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<1.24 2 | opencv-python==4.7.0.68 3 | Pillow==9.4.0 4 | pyfakewebcam==0.1.0 5 | tensorflow==2.11.0 6 | tensorflow-estimator==2.11.0 7 | tfjs-graph-converter==1.6.1 8 | requests>=2.26.0 9 | -------------------------------------------------------------------------------- /tf_bodypix/utils/typing.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-import 2 | # flake8: noqa: F401 3 | 4 | try: 5 | # Python 3.8+ 6 | from typing import Protocol 7 | except ImportError: 8 | Protocol = object # type: ignore 9 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(scope='session', autouse=True) 7 | def setup_logging(): 8 | for name in ['tests', 'tf_bodypix']: 9 | logging.getLogger(name).setLevel('DEBUG') 10 | -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SUB_COMMAND="${1}" 4 | 5 | if [[ ${SUB_COMMAND} == "bash" ]]; then 6 | shift 7 | exec "/bin/bash" "${@}" 8 | elif [[ ${SUB_COMMAND} == "python" ]]; then 9 | shift 10 | exec "python" "${@}" 11 | fi 12 | 13 | exec python -m tf_bodypix "${@}" 14 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | extension-pkg-whitelist= 3 | cv2 4 | 5 | [TYPECHECK] 6 | generated-members=cv2 7 | 8 | [MESSAGES CONTROL] 9 | disable= 10 | missing-docstring, 11 | too-few-public-methods, 12 | too-many-arguments, 13 | too-many-instance-attributes, 14 | duplicate-code, 15 | consider-using-f-string, 16 | invalid-name 17 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/output_rendering_util.py: -------------------------------------------------------------------------------- 1 | # based on: 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/output_rendering_util.ts 3 | 4 | 5 | RAINBOW_PART_COLORS = [ 6 | (110, 64, 170), (143, 61, 178), (178, 60, 178), (210, 62, 167), 7 | (238, 67, 149), (255, 78, 125), (255, 94, 99), (255, 115, 75), 8 | (255, 140, 56), (239, 167, 47), (217, 194, 49), (194, 219, 64), 9 | (175, 240, 91), (135, 245, 87), (96, 247, 96), (64, 243, 115), 10 | (40, 234, 141), (28, 219, 169), (26, 199, 194), (33, 176, 213), 11 | (47, 150, 224), (65, 125, 224), (84, 101, 214), (99, 81, 195) 12 | ] 13 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/decode_part_map.py: -------------------------------------------------------------------------------- 1 | # based on: 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/decode_part_map.ts 3 | 4 | try: 5 | import tensorflow as tf 6 | except ImportError: 7 | tf = None 8 | 9 | import numpy as np 10 | 11 | 12 | DEFAULT_DTYPE = ( 13 | tf.int32 if tf is not None else np.int32 14 | ) 15 | 16 | 17 | def to_mask_tensor( 18 | segment_scores: np.ndarray, 19 | threshold: float, 20 | dtype: type = DEFAULT_DTYPE 21 | ) -> np.ndarray: 22 | if tf is None: 23 | return (segment_scores > threshold).astype(dtype) 24 | return tf.cast( 25 | tf.greater(segment_scores, threshold), 26 | dtype 27 | ) 28 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/types.py: -------------------------------------------------------------------------------- 1 | # based on; 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.5/body-pix/src/types.ts 3 | 4 | 5 | from typing import Dict, NamedTuple 6 | 7 | import numpy as np 8 | 9 | 10 | class Part(NamedTuple): 11 | heatmap_x: int 12 | heatmap_y: int 13 | keypoint_id: int 14 | 15 | 16 | class Vector2D(NamedTuple): 17 | y: float 18 | x: float 19 | 20 | 21 | TensorBuffer3D = np.ndarray 22 | T_ArrayLike_3D = TensorBuffer3D 23 | 24 | 25 | class PartWithScore(NamedTuple): 26 | score: float 27 | part: Part 28 | 29 | 30 | class Keypoint(NamedTuple): 31 | score: float 32 | position: Vector2D 33 | part: str 34 | 35 | 36 | class Pose(NamedTuple): 37 | keypoints: Dict[int, Keypoint] 38 | score: float 39 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/part_channels.py: -------------------------------------------------------------------------------- 1 | # based on: 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/part_channels.ts 3 | 4 | 5 | PART_CHANNELS = [ 6 | 'left_face', 7 | 'right_face', 8 | 'left_upper_arm_front', 9 | 'left_upper_arm_back', 10 | 'right_upper_arm_front', 11 | 'right_upper_arm_back', 12 | 'left_lower_arm_front', 13 | 'left_lower_arm_back', 14 | 'right_lower_arm_front', 15 | 'right_lower_arm_back', 16 | 'left_hand', 17 | 'right_hand', 18 | 'torso_front', 19 | 'torso_back', 20 | 'left_upper_leg_front', 21 | 'left_upper_leg_back', 22 | 'right_upper_leg_front', 23 | 'right_upper_leg_back', 24 | 'left_lower_leg_front', 25 | 'left_lower_leg_back', 26 | 'right_lower_leg_front', 27 | 'right_lower_leg_back', 28 | 'left_feet', 29 | 'right_feet' 30 | ] 31 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8.8-slim-buster as base 2 | 3 | 4 | # shared between builder and runtime image 5 | RUN apt-get update \ 6 | && apt-get install -y --no-install-recommends \ 7 | dumb-init \ 8 | libgl1 \ 9 | libglib2.0-0 \ 10 | libsm6 \ 11 | && rm -rf /var/lib/apt/lists/* 12 | 13 | WORKDIR /opt/tf-bodypix 14 | 15 | 16 | # builder 17 | FROM base as builder 18 | 19 | COPY requirements.build.txt ./ 20 | RUN pip install --disable-pip-version-check --no-warn-script-location --user -r requirements.build.txt 21 | 22 | COPY requirements.txt ./ 23 | RUN pip install --disable-pip-version-check --no-warn-script-location --user -r requirements.txt 24 | 25 | 26 | # runtime image 27 | FROM base 28 | 29 | COPY --from=builder /root/.local /root/.local 30 | 31 | COPY tf_bodypix ./tf_bodypix 32 | 33 | COPY docker/entrypoint.sh ./docker/entrypoint.sh 34 | 35 | ENTRYPOINT ["/usr/bin/dumb-init", "--", "/opt/tf-bodypix/docker/entrypoint.sh"] 36 | -------------------------------------------------------------------------------- /tf_bodypix/tflite.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | try: 4 | import tensorflow as tf 5 | except ImportError: 6 | tf = None 7 | 8 | try: 9 | import tfjs_graph_converter 10 | except ImportError: 11 | tfjs_graph_converter = None 12 | 13 | 14 | LOGGER = logging.getLogger(__name__) 15 | 16 | 17 | def get_tflite_converter_for_tfjs_model_path(model_path: str) -> 'tf.lite.TFLiteConverter': 18 | if tfjs_graph_converter is None: 19 | raise ImportError('tfjs_graph_converter required') 20 | graph = tfjs_graph_converter.api.load_graph_model(model_path) 21 | tf_fn = tfjs_graph_converter.api.graph_to_function_v2(graph) 22 | return tf.lite.TFLiteConverter.from_concrete_functions([tf_fn]) 23 | 24 | 25 | def get_tflite_converter_for_model_path(model_path: str) -> 'tf.lite.TFLiteConverter': 26 | LOGGER.debug('converting model_path: %s', model_path) 27 | # if model_path.endswith('.json'): 28 | return get_tflite_converter_for_tfjs_model_path(model_path) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Daniel Ecer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /tests/utils/dist_test.py: -------------------------------------------------------------------------------- 1 | from tf_bodypix.utils.dist import get_required_and_extras 2 | 3 | 4 | TENSORFLOW_REQUIREMENT = 'tensorflow==2.1.3' 5 | 6 | 7 | class TestGetRequiredAndExtras: 8 | def test_should_group_single_requirement(self): 9 | assert get_required_and_extras( 10 | [('req1==1.2.3', ['group1'])] 11 | ) == ( 12 | [], 13 | {'group1': ['req1==1.2.3'], 'all': ['req1==1.2.3']} 14 | ) 15 | 16 | def test_should_fallback_to_default(self): 17 | assert get_required_and_extras( 18 | [('req1==1.2.3', [None])] 19 | ) == ( 20 | ['req1==1.2.3'], 21 | {'all': ['req1==1.2.3']} 22 | ) 23 | 24 | def test_should_group_multiple_requirement(self): 25 | assert get_required_and_extras( 26 | [('req1==1.2.3', ['group1']), ('req2==1.2.3', ['group2']), ('req3==1.2.3', [None])] 27 | ) == ( 28 | ['req3==1.2.3'], 29 | { 30 | 'group1': ['req1==1.2.3'], 31 | 'group2': ['req2==1.2.3'], 32 | 'all': ['req1==1.2.3', 'req2==1.2.3', 'req3==1.2.3'] 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /tf_bodypix/sink.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import contextmanager 3 | from functools import partial 4 | from typing import Callable, ContextManager, Iterator 5 | 6 | import numpy as np 7 | 8 | from tf_bodypix.utils.image import write_image_to 9 | 10 | # pylint: disable=import-outside-toplevel 11 | 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | T_OutputSink = Callable[[np.ndarray], None] 17 | 18 | 19 | def get_v4l2_output_sink(device_name: str) -> ContextManager[T_OutputSink]: 20 | from tf_bodypix.utils.v4l2 import VideoLoopbackImageSink 21 | return VideoLoopbackImageSink(device_name) 22 | 23 | 24 | @contextmanager 25 | def get_image_file_output_sink(path: str) -> Iterator[T_OutputSink]: 26 | yield partial(write_image_to, path=path) 27 | 28 | 29 | def get_image_output_sink_for_path(path: str) -> ContextManager[T_OutputSink]: 30 | if path.startswith('/dev/video'): 31 | return get_v4l2_output_sink(path) 32 | return get_image_file_output_sink(path) 33 | 34 | 35 | def get_show_image_output_sink() -> ContextManager[T_OutputSink]: 36 | from tf_bodypix.utils.opencv import ShowImageSink 37 | return ShowImageSink( 38 | window_name='image', 39 | window_title='tf-bodypix' 40 | ) 41 | -------------------------------------------------------------------------------- /tf_bodypix/utils/s3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import urllib.request 4 | from xml.etree import ElementTree 5 | from typing import Iterable 6 | 7 | 8 | LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | S3_NS = 'http://doc.s3.amazonaws.com/2006-03-01' 12 | S3_PREFIX = '{%s}' % S3_NS 13 | S3_CONTENTS = S3_PREFIX + 'Contents' 14 | S3_KEY = S3_PREFIX + 'Key' 15 | S3_NEXT_MARKER = S3_PREFIX + 'NextMarker' 16 | 17 | 18 | def iter_s3_file_urls(base_url: str) -> Iterable[str]: 19 | if not base_url.endswith('/'): 20 | base_url += '/' 21 | marker = None 22 | while True: 23 | current_url = base_url 24 | if marker: 25 | current_url += '?marker=' + marker 26 | with urllib.request.urlopen(current_url) as url_fp: 27 | response_data = url_fp.read() 28 | LOGGER.debug('response_data: %r', response_data) 29 | root = ElementTree.fromstring(response_data) 30 | for item in root.findall(S3_CONTENTS): 31 | key = item.findtext(S3_KEY) 32 | LOGGER.debug('key: %s', key) 33 | if key: 34 | yield base_url + key 35 | next_marker = root.findtext(S3_NEXT_MARKER) 36 | if not next_marker or next_marker == marker: 37 | break 38 | marker = next_marker 39 | -------------------------------------------------------------------------------- /tf_bodypix/utils/dist.py: -------------------------------------------------------------------------------- 1 | def get_requirement_groups(requirement): # pylint: disable=too-many-return-statements 2 | requirement_lower = requirement.lower() 3 | if 'tensorflow' in requirement_lower: 4 | return ['tf'] 5 | if 'tfjs' in requirement_lower: 6 | return ['tfjs'] 7 | if 'numpy' in requirement_lower: 8 | return ['tfjs'] 9 | if 'pillow' in requirement_lower: 10 | return ['image'] 11 | if 'opencv' in requirement_lower: 12 | return ['webcam', 'video'] 13 | if 'pyfakewebcam' in requirement_lower: 14 | return ['webcam'] 15 | return [None] 16 | 17 | 18 | def get_requirements_with_groups(all_required_packages): 19 | return [ 20 | (requirement, get_requirement_groups(requirement)) 21 | for requirement in all_required_packages 22 | ] 23 | 24 | 25 | def get_required_and_extras(required_packages_with_groups, include_all=True): 26 | grouped_extras = {} 27 | all_groups = ['all'] if include_all else [] 28 | for requirement, groups in required_packages_with_groups: 29 | for group in groups + all_groups: 30 | grouped_extras.setdefault(group, []).append(requirement) 31 | return ( 32 | grouped_extras.get(None, []), 33 | {key: value for key, value in grouped_extras.items() if key} 34 | ) 35 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/keypoints.py: -------------------------------------------------------------------------------- 1 | # based on: 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.5/body-pix/src/keypoints.ts 3 | 4 | 5 | PART_NAMES = [ 6 | 'nose', 'leftEye', 'rightEye', 'leftEar', 'rightEar', 'leftShoulder', 7 | 'rightShoulder', 'leftElbow', 'rightElbow', 'leftWrist', 'rightWrist', 8 | 'leftHip', 'rightHip', 'leftKnee', 'rightKnee', 'leftAnkle', 'rightAnkle' 9 | ] 10 | 11 | NUM_KEYPOINTS = len(PART_NAMES) 12 | 13 | 14 | PART_IDS = { 15 | part_name: part_id 16 | for part_id, part_name in enumerate(PART_NAMES) 17 | } 18 | 19 | 20 | CONNECTED_PART_NAMES = [ 21 | ['leftHip', 'leftShoulder'], ['leftElbow', 'leftShoulder'], 22 | ['leftElbow', 'leftWrist'], ['leftHip', 'leftKnee'], 23 | ['leftKnee', 'leftAnkle'], ['rightHip', 'rightShoulder'], 24 | ['rightElbow', 'rightShoulder'], ['rightElbow', 'rightWrist'], 25 | ['rightHip', 'rightKnee'], ['rightKnee', 'rightAnkle'], 26 | ['leftShoulder', 'rightShoulder'], ['leftHip', 'rightHip'] 27 | ] 28 | 29 | 30 | POSE_CHAIN = [ 31 | ['nose', 'leftEye'], ['leftEye', 'leftEar'], ['nose', 'rightEye'], 32 | ['rightEye', 'rightEar'], ['nose', 'leftShoulder'], 33 | ['leftShoulder', 'leftElbow'], ['leftElbow', 'leftWrist'], 34 | ['leftShoulder', 'leftHip'], ['leftHip', 'leftKnee'], 35 | ['leftKnee', 'leftAnkle'], ['nose', 'rightShoulder'], 36 | ['rightShoulder', 'rightElbow'], ['rightElbow', 'rightWrist'], 37 | ['rightShoulder', 'rightHip'], ['rightHip', 'rightKnee'], 38 | ['rightKnee', 'rightAnkle'] 39 | ] 40 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/multi_person/util.py: -------------------------------------------------------------------------------- 1 | # based on; 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.5/body-pix/src/multi_person/util.ts 3 | 4 | import logging 5 | 6 | from ..types import Part, TensorBuffer3D, Vector2D 7 | from ..keypoints import NUM_KEYPOINTS 8 | 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | def getOffsetPoint( 14 | y: float, x: float, keypoint_id: int, offsets: TensorBuffer3D 15 | ) -> Vector2D: 16 | return Vector2D( 17 | y=offsets[int(y), int(x), keypoint_id], 18 | x=offsets[int(y), int(x), keypoint_id + NUM_KEYPOINTS] 19 | ) 20 | 21 | 22 | def getImageCoords( 23 | part: Part, outputStride: int, offsets: TensorBuffer3D 24 | ) -> Vector2D: 25 | LOGGER.debug('part: %s', part) 26 | offset_point = getOffsetPoint( 27 | part.heatmap_y, part.heatmap_x, part.keypoint_id, offsets 28 | ) 29 | LOGGER.debug('offset_point: %s', offset_point) 30 | LOGGER.debug('offsets.shape: %s', offsets.shape) 31 | return Vector2D( 32 | x=part.heatmap_x * outputStride + offset_point.x, 33 | y=part.heatmap_y * outputStride + offset_point.y 34 | ) 35 | 36 | 37 | def clamp(a: int, min_value: int, max_value: int) -> int: 38 | return min(max_value, max(min_value, a)) 39 | 40 | 41 | def squaredDistance( 42 | y1: float, x1: float, y2: float, x2: float 43 | ) -> float: 44 | dy = y2 - y1 45 | dx = x2 - x1 46 | return dy * dy + dx * dx 47 | 48 | 49 | def squared_distance_vector(a: Vector2D, b: Vector2D) -> float: 50 | return squaredDistance(a.y, a.x, b.y, b.x) 51 | 52 | 53 | def addVectors(a: Vector2D, b: Vector2D) -> Vector2D: 54 | return Vector2D(x=a.x + b.x, y=a.y + b.y) 55 | -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import MagicMock 3 | 4 | import numpy as np 5 | 6 | from tf_bodypix.model import BodyPixModelWrapper 7 | 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | ANY_INT_FACTOR_1 = 5 13 | 14 | 15 | class TestBodyPixModelWrapper: 16 | def test_should_be_able_to_padded_and_resized_image_matching_output_stride_plus_one(self): 17 | predict_fn = MagicMock(name='predict_fn') 18 | output_stride = 16 19 | internal_resolution = 0.5 20 | model = BodyPixModelWrapper( 21 | predict_fn=predict_fn, 22 | output_stride=output_stride, 23 | internal_resolution=internal_resolution 24 | ) 25 | default_tensor_names = { 26 | 'float_segments', 27 | 'float_part_heatmaps', 28 | 'float_heatmaps', 29 | 'float_short_offsets', 30 | 'float_long_offsets', 31 | 'float_part_offsets', 32 | 'displacement_fwd', 33 | 'displacement_bwd' 34 | } 35 | predict_fn.return_value = { 36 | key: np.array([]) 37 | for key in default_tensor_names 38 | } 39 | resolution_matching_output_stride_plus_1 = int( 40 | (output_stride * ANY_INT_FACTOR_1 + 1) / internal_resolution 41 | ) 42 | LOGGER.debug( 43 | 'resolution_matching_output_stride_plus_1: %s', 44 | resolution_matching_output_stride_plus_1 45 | ) 46 | image = np.ones( 47 | shape=( 48 | resolution_matching_output_stride_plus_1, 49 | resolution_matching_output_stride_plus_1, 50 | 3 51 | ) 52 | ) 53 | model.predict_single(image) 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | from setuptools import find_packages, setup 4 | 5 | from tf_bodypix.utils.dist import ( 6 | get_requirements_with_groups, 7 | get_required_and_extras 8 | ) 9 | 10 | 11 | with open('requirements.txt', 'r', encoding='utf-8') as f: 12 | REQUIRED_PACKAGES = f.readlines() 13 | 14 | 15 | with open('requirements.tflite.txt', 'r', encoding='utf-8') as f: 16 | TFLITE_REQUIRED_PACKAGES = f.readlines() 17 | 18 | 19 | with open('README.md', 'r', encoding='utf-8') as f: 20 | LONG_DESCRIPTION = '\n'.join([ 21 | line.rstrip() 22 | for line in f 23 | if not line.startswith('[![') 24 | ]) 25 | 26 | 27 | def local_scheme(version): 28 | if not version.distance and not version.dirty: 29 | return "" 30 | return str(int(time())) 31 | 32 | 33 | DEFAULT_REQUIRED_PACKAGES, EXTRAS = get_required_and_extras( 34 | get_requirements_with_groups(REQUIRED_PACKAGES) 35 | ) 36 | 37 | ALL_EXTRAS = { 38 | **EXTRAS, 39 | 'tflite': TFLITE_REQUIRED_PACKAGES 40 | } 41 | 42 | packages = find_packages(exclude=["tests", "tests.*"]) 43 | 44 | setup( 45 | name="tf-bodypix", 46 | use_scm_version={ 47 | "local_scheme": local_scheme 48 | }, 49 | setup_requires=['setuptools_scm'], 50 | author="Daniel Ecer", 51 | url="https://github.com/de-code/python-tf-bodypix", 52 | install_requires=DEFAULT_REQUIRED_PACKAGES, 53 | extras_require=ALL_EXTRAS, 54 | packages=packages, 55 | include_package_data=True, 56 | description='Python implemention of the TensorFlow BodyPix model.', 57 | long_description=LONG_DESCRIPTION, 58 | long_description_content_type='text/markdown', 59 | classifiers=[ 60 | "Programming Language :: Python :: 3", 61 | "License :: OSI Approved :: MIT License", 62 | "Operating System :: OS Independent", 63 | ] 64 | ) 65 | -------------------------------------------------------------------------------- /tf_bodypix/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from hashlib import md5 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import requests 7 | 8 | 9 | DEFAULT_KERAS_CACHE_DIR = '~/.keras' 10 | DEFAULT_USER_AGENT = 'tf-bodypix' 11 | 12 | 13 | def strip_url_suffix(path: str) -> str: 14 | qs_index = path.find('?') 15 | if qs_index > 0: 16 | return path[:qs_index] 17 | return path 18 | 19 | 20 | def get_default_cache_dir( 21 | cache_dir: Optional[str] = None, 22 | cache_subdir: Optional[str] = None 23 | ): 24 | result = os.path.expanduser(cache_dir or DEFAULT_KERAS_CACHE_DIR) 25 | if cache_subdir: 26 | result = os.path.join(result, cache_subdir) 27 | return result 28 | 29 | 30 | def download_file_to( 31 | source_url: str, 32 | local_path: str, 33 | user_agent: str = DEFAULT_USER_AGENT, 34 | skip_if_exists: bool = True, 35 | timeout: float = 60 * 60 # default to 1h 36 | ): 37 | if skip_if_exists and os.path.exists(local_path): 38 | return local_path 39 | response = requests.get(source_url, timeout=timeout, headers={ 40 | 'User-Agent': user_agent 41 | }) 42 | response.raise_for_status() 43 | local_path_path = Path(local_path) 44 | local_path_path.parent.mkdir(parents=True, exist_ok=True) 45 | local_path_path.write_bytes(response.content) 46 | return local_path 47 | 48 | 49 | def get_file(file_path: str, download: bool = True) -> str: 50 | if not download: 51 | return file_path 52 | if os.path.exists(file_path): 53 | return file_path 54 | cache_dir = get_default_cache_dir() 55 | local_path = os.path.join( 56 | cache_dir, 57 | ( 58 | md5(file_path.encode('utf-8')).hexdigest() 59 | + '-' 60 | + os.path.basename(strip_url_suffix(file_path)) 61 | ) 62 | ) 63 | return download_file_to( 64 | source_url=file_path, 65 | local_path=local_path 66 | ) 67 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/build_part_with_score_queue.py: -------------------------------------------------------------------------------- 1 | # based on; 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.5/body-pix/src/multi_person/build_part_with_score_queue.ts 3 | 4 | import logging 5 | from collections import deque 6 | from typing import Deque 7 | 8 | from tf_bodypix.bodypix_js_utils.types import PartWithScore, Part, T_ArrayLike_3D 9 | 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | def score_is_maximum_in_local_window( 15 | keypoint_id: int, 16 | score: float, 17 | heatmap_y: int, 18 | heatmap_x: int, 19 | local_maximum_radius: float, 20 | scores: T_ArrayLike_3D 21 | ) -> bool: 22 | height, width = scores.shape[:2] 23 | y_start = int(max(heatmap_y - local_maximum_radius, 0)) 24 | y_end = int(min(heatmap_y + local_maximum_radius + 1, height)) 25 | for y_current in range(y_start, y_end): 26 | x_start = int(max(heatmap_x - local_maximum_radius, 0)) 27 | x_end = int(min(heatmap_x + local_maximum_radius + 1, width)) 28 | for x_current in range(x_start, x_end): 29 | if scores[y_current, x_current, keypoint_id] > score: 30 | return False 31 | return True 32 | 33 | 34 | def build_part_with_score_queue( 35 | score_threshold: float, 36 | local_maximum_radius: float, 37 | scores: T_ArrayLike_3D 38 | ) -> Deque[PartWithScore]: 39 | height, width, num_keypoints = scores.shape[:3] 40 | part_with_scores = [] 41 | 42 | LOGGER.debug('num_keypoints=%s', num_keypoints) 43 | 44 | for heatmap_y in range(height): 45 | for heatmap_x in range(width): 46 | for keypoint_id in range(num_keypoints): 47 | score = scores[heatmap_y, heatmap_x, keypoint_id] 48 | 49 | # Only consider parts with score greater or equal to threshold as 50 | # root candidates. 51 | if score < score_threshold: 52 | continue 53 | 54 | # Only consider keypoints whose score is maximum in a local window. 55 | if not score_is_maximum_in_local_window( 56 | keypoint_id, score, heatmap_y, heatmap_x, local_maximum_radius, 57 | scores 58 | ): 59 | continue 60 | 61 | part_with_scores.append(PartWithScore( 62 | score=score, 63 | part=Part(heatmap_y=heatmap_y, heatmap_x=heatmap_x, keypoint_id=keypoint_id) 64 | )) 65 | 66 | return deque( 67 | sorted( 68 | part_with_scores, 69 | key=lambda part_with_score: part_with_score.score, 70 | reverse=True 71 | ) 72 | ) 73 | -------------------------------------------------------------------------------- /tf_bodypix/utils/v4l2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from pyfakewebcam import FakeWebcam 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | # pylint: disable=protected-access 11 | 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | def create_fakewebcam( 17 | device_name: str, 18 | preferred_width: int, 19 | preferred_height: int 20 | ) -> FakeWebcam: 21 | fakewebcam_instance = FakeWebcam( 22 | device_name, 23 | width=preferred_width, 24 | height=preferred_height 25 | ) 26 | fakewebcam_settings = fakewebcam_instance._settings 27 | actual_width = fakewebcam_settings.fmt.pix.width 28 | actual_height = fakewebcam_settings.fmt.pix.height 29 | if actual_height != preferred_height or actual_width != preferred_width: 30 | LOGGER.warning( 31 | 'unable to set virtual webcam resolution, using: width=%d, height=%d', 32 | actual_width, actual_height 33 | ) 34 | fakewebcam_instance._buffer = np.zeros( 35 | (actual_height, 2 * actual_width), 36 | dtype=np.uint8 37 | ) 38 | fakewebcam_instance._yuv = np.zeros( 39 | (actual_height, actual_width, 3), 40 | dtype=np.uint8 41 | ) 42 | fakewebcam_instance._ones = np.ones( 43 | (actual_height, actual_width, 1), 44 | dtype=np.uint8 45 | ) 46 | return fakewebcam_instance 47 | 48 | 49 | def close_fakewebcam(fakewebcam_instance: FakeWebcam): 50 | os.close(fakewebcam_instance._video_device) 51 | 52 | 53 | class VideoLoopbackImageSink: 54 | def __init__(self, device_name: str): 55 | self.device_name = device_name 56 | self.fakewebcam_instance = None 57 | self.width = None 58 | self.height = None 59 | 60 | def __enter__(self): 61 | return self 62 | 63 | def __exit__(self, *_, **__): 64 | if self.fakewebcam_instance is not None: 65 | close_fakewebcam(self.fakewebcam_instance) 66 | 67 | def initialize_fakewebcam(self, preferred_width: int, preferred_height: int): 68 | fakewebcam_instance = create_fakewebcam( 69 | self.device_name, 70 | preferred_width=preferred_width, 71 | preferred_height=preferred_height 72 | ) 73 | self.fakewebcam_instance = fakewebcam_instance 74 | self.width = fakewebcam_instance._settings.fmt.pix.width 75 | self.height = fakewebcam_instance._settings.fmt.pix.height 76 | 77 | def __call__(self, image_array: np.ndarray): 78 | image_array = np.asarray(image_array).astype(np.uint8) 79 | height, width, *_ = image_array.shape 80 | if self.fakewebcam_instance is None: 81 | LOGGER.info('initializing, width=%d, height=%d', width, height) 82 | self.initialize_fakewebcam( 83 | preferred_width=width, 84 | preferred_height=height 85 | ) 86 | if height != self.height or width != self.width: 87 | LOGGER.info('resizing to: width=%d, height=%d', self.width, self.height) 88 | image_array = cv2.resize( 89 | image_array, 90 | (self.width, self.height), 91 | interpolation=cv2.INTER_AREA 92 | ) 93 | LOGGER.info('resized image_array.shape=%s', image_array.shape) 94 | assert self.fakewebcam_instance is not None 95 | self.fakewebcam_instance.schedule_frame(image_array) 96 | -------------------------------------------------------------------------------- /tf_bodypix/utils/timer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from time import time 3 | from typing import Dict, List, Optional 4 | 5 | 6 | LOGGER = logging.getLogger(__name__) 7 | 8 | 9 | def _mean(a: List[float]) -> float: 10 | if not a: 11 | return 0 12 | return sum(a) / len(a) 13 | 14 | 15 | class LoggingTimer: 16 | def __init__(self, min_interval: float = 1): 17 | self.min_interval = min_interval 18 | self.interval_start_time: Optional[float] = None 19 | self.frame_start_time: Optional[float] = None 20 | self.frame_durations: List[float] = [] 21 | self.step_durations_map: Dict[Optional[str], List[float]] = {} 22 | self.current_step_name: Optional[str] = None 23 | self.current_step_start_time: Optional[float] = None 24 | self.ordered_step_names: List[Optional[str]] = [] 25 | 26 | def start(self): 27 | current_time = time() 28 | self.interval_start_time = current_time 29 | 30 | def _set_current_step_name(self, step_name: str, current_time: Optional[float] = None): 31 | if step_name == self.current_step_name: 32 | return 33 | if current_time is None: 34 | current_time = time() 35 | assert self.current_step_start_time is not None 36 | duration = current_time - self.current_step_start_time 37 | if duration > 0 or self.current_step_name: 38 | self.step_durations_map.setdefault(self.current_step_name, []).append( 39 | duration 40 | ) 41 | if self.current_step_name not in self.ordered_step_names: 42 | self.ordered_step_names.append(self.current_step_name) 43 | self.current_step_name = step_name 44 | self.current_step_start_time = current_time 45 | 46 | def on_frame_start(self, initial_step_name: Optional[str] = None): 47 | self.frame_start_time = time() 48 | self.current_step_name = initial_step_name 49 | self.current_step_start_time = self.frame_start_time 50 | 51 | def on_step_start(self, step_name: str): 52 | if step_name == self.current_step_name: 53 | return 54 | self._set_current_step_name(step_name) 55 | 56 | def on_step_end(self): 57 | self._set_current_step_name(None) 58 | 59 | def on_frame_end(self): 60 | frame_end_time = time() 61 | self._set_current_step_name(None, current_time=frame_end_time) 62 | self.frame_durations.append(frame_end_time - self.frame_start_time) 63 | self.check_log(frame_end_time) 64 | 65 | def check_log(self, current_time: float): 66 | assert self.interval_start_time is not None 67 | interval_duration = current_time - self.interval_start_time 68 | if self.frame_durations and interval_duration >= self.min_interval: 69 | step_info = ', '.join([ 70 | '%s=%0.3f' % (step_name, _mean( 71 | self.step_durations_map.get(step_name, []) 72 | )) 73 | for step_name in self.ordered_step_names 74 | ]) 75 | LOGGER.info( 76 | '%0.3fs per frame (%0.1ffps%s)', 77 | _mean(self.frame_durations), 78 | len(self.frame_durations) / interval_duration, 79 | ', ' + step_info if step_info else '' 80 | ) 81 | self.frame_durations.clear() 82 | self.step_durations_map.clear() 83 | self.ordered_step_names.clear() 84 | self.interval_start_time = current_time 85 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/multi_person/decode_multiple_poses.py: -------------------------------------------------------------------------------- 1 | # based on; 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.5/body-pix/src/multi_person/decode_multiple_poses.ts 3 | 4 | import logging 5 | 6 | from typing import Dict, List 7 | 8 | from tf_bodypix.bodypix_js_utils.types import ( 9 | Pose, TensorBuffer3D, Vector2D, 10 | Keypoint 11 | ) 12 | from tf_bodypix.bodypix_js_utils.build_part_with_score_queue import ( 13 | build_part_with_score_queue 14 | ) 15 | 16 | from .util import getImageCoords, squared_distance_vector 17 | from .decode_pose import decodePose 18 | 19 | 20 | LOGGER = logging.getLogger(__name__) 21 | 22 | 23 | kLocalMaximumRadius = 1 24 | 25 | 26 | def withinNmsRadiusOfCorrespondingPoint( 27 | poses: List[Pose], 28 | squaredNmsRadius: float, 29 | vector: Vector2D, 30 | keypointId: int 31 | ) -> bool: 32 | return any( 33 | squared_distance_vector( 34 | vector, pose.keypoints[keypointId].position 35 | ) <= squaredNmsRadius 36 | for pose in poses 37 | ) 38 | 39 | 40 | def getInstanceScore( 41 | existingPoses: List[Pose], 42 | squaredNmsRadius: float, 43 | instanceKeypoints: Dict[int, Keypoint] 44 | ) -> float: 45 | LOGGER.debug('instanceKeypoints: %s', instanceKeypoints) 46 | notOverlappedKeypointScores = sum(( 47 | keypoint.score 48 | for keypointId, keypoint in instanceKeypoints.items() 49 | if not withinNmsRadiusOfCorrespondingPoint( 50 | existingPoses, squaredNmsRadius, 51 | keypoint.position, keypointId 52 | ) 53 | )) 54 | 55 | return notOverlappedKeypointScores / len(instanceKeypoints) 56 | 57 | 58 | def decodeMultiplePoses( 59 | scoresBuffer: TensorBuffer3D, offsetsBuffer: TensorBuffer3D, 60 | displacementsFwdBuffer: TensorBuffer3D, 61 | displacementsBwdBuffer: TensorBuffer3D, outputStride: int, 62 | maxPoseDetections: int, scoreThreshold: float = 0.5, nmsRadius: float = 20 63 | ) -> List[Pose]: 64 | poses: List[Pose] = [] 65 | 66 | queue = build_part_with_score_queue( 67 | scoreThreshold, kLocalMaximumRadius, scoresBuffer 68 | ) 69 | # LOGGER.debug('queue: %s', queue) 70 | 71 | squaredNmsRadius = nmsRadius * nmsRadius 72 | 73 | # Generate at most maxDetections object instances per image in 74 | # decreasing root part score order. 75 | while len(poses) < maxPoseDetections and queue: 76 | # The top element in the queue is the next root candidate. 77 | root = queue.popleft() 78 | 79 | # Part-based non-maximum suppression: We reject a root candidate if it 80 | # is within a disk of `nmsRadius` pixels from the corresponding part of 81 | # a previously detected instance. 82 | rootImageCoords = getImageCoords( 83 | root.part, outputStride, offsetsBuffer 84 | ) 85 | if withinNmsRadiusOfCorrespondingPoint( 86 | poses, squaredNmsRadius, rootImageCoords, root.part.keypoint_id 87 | ): 88 | continue 89 | 90 | # Start a new detection instance at the position of the root. 91 | keypoints = decodePose( 92 | root, scoresBuffer, offsetsBuffer, outputStride, displacementsFwdBuffer, 93 | displacementsBwdBuffer 94 | ) 95 | 96 | # LOGGER.debug('keypoints: %s', keypoints) 97 | 98 | score = getInstanceScore(poses, squaredNmsRadius, keypoints) 99 | 100 | poses.append(Pose(keypoints=keypoints, score=score)) 101 | 102 | return poses 103 | -------------------------------------------------------------------------------- /tf_bodypix/source.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | import os 4 | from contextlib import contextmanager 5 | from queue import Queue 6 | from threading import Thread 7 | from typing import ContextManager, Iterable, Iterator, Optional 8 | 9 | from tf_bodypix.utils.image import load_image, ImageSize, ImageArray 10 | from tf_bodypix.utils.io import get_file, strip_url_suffix 11 | 12 | 13 | # pylint: disable=import-outside-toplevel 14 | 15 | 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | 19 | T_ImageSource = ContextManager[Iterable[ImageArray]] 20 | 21 | 22 | def is_video_path(path: str) -> bool: 23 | ext = os.path.splitext(os.path.basename(strip_url_suffix(path)))[-1] 24 | LOGGER.debug('ext: %s', ext) 25 | return ext.lower() in {'.webm', '.mkv', '.mp4'} 26 | 27 | 28 | def get_webcam_number(path: str) -> Optional[int]: 29 | match = re.match(r'(?:/dev/video|webcam:)(\d+)', path) 30 | if not match: 31 | return None 32 | return int(match.group(1)) 33 | 34 | 35 | def get_video_image_source(path: str, **kwargs) -> T_ImageSource: 36 | from tf_bodypix.utils.opencv import get_video_image_source as _get_video_image_source 37 | return _get_video_image_source(path, **kwargs) 38 | 39 | 40 | def get_webcam_image_source(webcam_number: int, **kwargs) -> T_ImageSource: 41 | from tf_bodypix.utils.opencv import get_webcam_image_source as _get_webcam_image_source 42 | return _get_webcam_image_source(webcam_number, **kwargs) 43 | 44 | 45 | @contextmanager 46 | def get_simple_image_source( 47 | path: str, 48 | image_size: Optional[ImageSize] = None, 49 | **_ 50 | ) -> Iterator[Iterable[ImageArray]]: 51 | local_image_path = get_file(path) 52 | LOGGER.debug('local_image_path: %r', local_image_path) 53 | image_array = load_image(local_image_path, image_size=image_size) 54 | yield [image_array] 55 | 56 | 57 | def get_image_source(path: str, **kwargs) -> T_ImageSource: 58 | webcam_number = get_webcam_number(path) 59 | if webcam_number is not None: 60 | return get_webcam_image_source(webcam_number, **kwargs) 61 | if is_video_path(path): 62 | return get_video_image_source(path, **kwargs) 63 | return get_simple_image_source(path, **kwargs) 64 | 65 | 66 | class ThreadedImageSource: 67 | def __init__(self, image_source: T_ImageSource, queue_size: int = 1): 68 | self.image_source = image_source 69 | self.image_source_iterator = None 70 | self.queue: 'Queue[ImageArray]' = Queue(queue_size) 71 | self.thread = None 72 | self.stopped = False 73 | 74 | def __enter__(self): 75 | self.stopped = False 76 | self.thread = Thread(target=self.run) 77 | self.image_source_iterator = iter(self.image_source.__enter__()) 78 | self.thread.start() 79 | LOGGER.info('using threaded image source') 80 | return self 81 | 82 | def __exit__(self, *args, **kwargs): 83 | self.stop() 84 | self.image_source.__exit__(*args, **kwargs) 85 | 86 | def __iter__(self): 87 | return self 88 | 89 | def __next__(self): 90 | LOGGER.debug('reading from queue, qsize: %d', self.queue.qsize()) 91 | return self.queue.get() 92 | 93 | def stop(self): 94 | self.stopped = True 95 | self.thread.join() 96 | 97 | def run(self): 98 | while not self.stopped: 99 | try: 100 | data = next(self.image_source_iterator) 101 | except StopIteration: 102 | self.stopped = True 103 | return 104 | self.queue.put(data) 105 | 106 | 107 | def get_threaded_image_source(image_source: T_ImageSource) -> T_ImageSource: 108 | return ThreadedImageSource(image_source) 109 | -------------------------------------------------------------------------------- /tf_bodypix/draw.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Iterable, Tuple, Optional 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from tf_bodypix.utils.image import ImageArray 8 | from tf_bodypix.bodypix_js_utils.types import Pose, Keypoint 9 | from tf_bodypix.bodypix_js_utils.keypoints import CONNECTED_PART_NAMES 10 | 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | T_Color = Tuple[int, int, int] 16 | 17 | 18 | def get_filtered_keypoints_by_score( 19 | keypoints: Iterable[Keypoint], 20 | min_score: float 21 | ) -> List[Keypoint]: 22 | return [ 23 | keypoint 24 | for keypoint in keypoints 25 | if keypoint.score >= min_score 26 | ] 27 | 28 | 29 | def get_adjacent_keypoints( 30 | keypoints: Iterable[Keypoint] 31 | ) -> List[Tuple[Keypoint, Keypoint]]: 32 | keypoint_by_name = { 33 | keypoint.part: keypoint 34 | for keypoint in keypoints 35 | } 36 | return [ 37 | (keypoint_by_name[part_name_1], keypoint_by_name[part_name_2]) 38 | for part_name_1, part_name_2 in CONNECTED_PART_NAMES 39 | if keypoint_by_name.get(part_name_1) and keypoint_by_name.get(part_name_2) 40 | ] 41 | 42 | 43 | def get_cv_keypoints(keypoints: Iterable[Keypoint]) -> List[cv2.KeyPoint]: 44 | try: 45 | return [ 46 | cv2.KeyPoint( 47 | x=keypoint.position.x, 48 | y=keypoint.position.y, 49 | size=3 50 | ) 51 | for keypoint in keypoints 52 | ] 53 | except TypeError: 54 | # backwards compatibility with opencv 4.5.2 and below 55 | return [ 56 | cv2.KeyPoint( 57 | x=keypoint.position.x, 58 | y=keypoint.position.y, 59 | _size=3 60 | ) 61 | for keypoint in keypoints 62 | ] 63 | 64 | 65 | def draw_skeleton( 66 | image: ImageArray, 67 | keypoints: Iterable[Keypoint], 68 | color: T_Color, 69 | thickness: int = 3 70 | ): 71 | adjacent_keypoints = get_adjacent_keypoints(keypoints) 72 | for keypoint_1, keypoint_2 in adjacent_keypoints: 73 | cv2.line( 74 | image, 75 | (round(keypoint_1.position.x), round(keypoint_1.position.y)), 76 | (round(keypoint_2.position.x), round(keypoint_2.position.y)), 77 | color=color, 78 | thickness=thickness 79 | ) 80 | return image 81 | 82 | 83 | def draw_keypoints( 84 | image: ImageArray, 85 | keypoints: Iterable[Keypoint], 86 | color: T_Color 87 | ): 88 | image = cv2.drawKeypoints( 89 | image, 90 | get_cv_keypoints(keypoints), 91 | outImage=image, 92 | color=color 93 | ) 94 | return image 95 | 96 | 97 | def draw_pose( 98 | image: ImageArray, 99 | pose: Pose, 100 | min_score: float = 0.1, 101 | keypoints_color: Optional[T_Color] = None, 102 | skeleton_color: Optional[T_Color] = None): 103 | keypoints_to_draw = get_filtered_keypoints_by_score( 104 | pose.keypoints.values(), 105 | min_score=min_score 106 | ) 107 | LOGGER.debug('keypoints_to_draw: %s', keypoints_to_draw) 108 | if keypoints_color: 109 | image = draw_keypoints(image, keypoints_to_draw, color=keypoints_color) 110 | if skeleton_color: 111 | image = draw_skeleton(image, keypoints_to_draw, color=skeleton_color) 112 | return image 113 | 114 | 115 | def draw_poses(image: ImageArray, poses: List[Pose], **kwargs): 116 | if not poses: 117 | return image 118 | output_image = image.astype(np.uint8) 119 | for pose in poses: 120 | output_image = draw_pose(output_image, pose, **kwargs) 121 | return output_image 122 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: [ develop ] 6 | tags: 7 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 8 | pull_request: 9 | branches: [ develop ] 10 | 11 | jobs: 12 | # https://github.com/actions/runner/issues/1138 13 | check_secrets: 14 | runs-on: ubuntu-latest 15 | outputs: 16 | HAS_TEST_PYPI_PASSWORD: ${{ steps.check.outputs.HAS_TEST_PYPI_PASSWORD }} 17 | steps: 18 | - run: > 19 | echo "::set-output name=HAS_TEST_PYPI_PASSWORD::${{ env.TEST_PYPI_PASSWORD != '' }}"; 20 | id: check 21 | env: 22 | TEST_PYPI_PASSWORD: ${{ secrets.test_pypi_password }} 23 | 24 | build_tflite: 25 | needs: [] 26 | runs-on: ${{ matrix.os }} 27 | strategy: 28 | matrix: 29 | os: [ubuntu-latest] 30 | python-version: ['3.8'] 31 | include: 32 | - python-version: '3.8' 33 | 34 | steps: 35 | - uses: actions/checkout@v2 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v2 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | - name: Install dependencies 41 | run: | 42 | make venv-create SYSTEM_PYTHON=python 43 | make dev-install-tflite 44 | - name: Test with pytest 45 | run: | 46 | make dev-pytest-tflite 47 | 48 | build: 49 | needs: ["check_secrets"] 50 | runs-on: ${{ matrix.os }} 51 | strategy: 52 | matrix: 53 | os: [ubuntu-latest] 54 | python-version: ['3.7', '3.8', '3.9', '3.10'] 55 | include: 56 | - python-version: '3.8' 57 | push-package: true 58 | - os: windows-2019 59 | python-version: '3.8' 60 | - os: macos-latest 61 | python-version: '3.8' 62 | 63 | steps: 64 | - uses: actions/checkout@v2 65 | - name: Set up Python ${{ matrix.python-version }} 66 | uses: actions/setup-python@v2 67 | with: 68 | python-version: ${{ matrix.python-version }} 69 | - name: Install dependencies 70 | run: | 71 | make dev-venv SYSTEM_PYTHON=python 72 | - name: Lint 73 | run: | 74 | make dev-lint 75 | - name: Test with pytest 76 | run: | 77 | make dev-pytest 78 | - name: Build dist 79 | if: matrix.push-package == true 80 | run: | 81 | make dev-remove-dist dev-build-dist dev-list-dist-contents dev-test-install-dist 82 | - name: Publish distribution to Test PyPI 83 | if: > 84 | matrix.push-package == true 85 | && needs.check_secrets.outputs.HAS_TEST_PYPI_PASSWORD == 'true' 86 | uses: pypa/gh-action-pypi-publish@master 87 | with: 88 | password: ${{ secrets.test_pypi_password }} 89 | repository_url: https://test.pypi.org/legacy/ 90 | - name: Publish distribution to PyPI 91 | if: matrix.push-package == true && startsWith(github.ref, 'refs/tags') 92 | uses: pypa/gh-action-pypi-publish@master 93 | with: 94 | password: ${{ secrets.pypi_password }} 95 | 96 | docker-build: 97 | 98 | runs-on: ubuntu-latest 99 | 100 | steps: 101 | - name: Set tags 102 | id: set_tags 103 | run: | 104 | DOCKER_IMAGE=de4code/tf-bodypix 105 | VERSION="" 106 | if [[ $GITHUB_REF == refs/tags/v* ]]; then 107 | VERSION=${GITHUB_REF#refs/tags/v} 108 | fi 109 | if [[ $VERSION =~ ^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then 110 | TAGS="${DOCKER_IMAGE}:${VERSION},${DOCKER_IMAGE}:latest" 111 | else 112 | TAGS="${DOCKER_IMAGE}_unstable:${GITHUB_SHA},${DOCKER_IMAGE}_unstable:latest" 113 | fi 114 | echo "TAGS=${TAGS}" 115 | echo ::set-output name=tags::${TAGS} 116 | - name: Set up QEMU 117 | uses: docker/setup-qemu-action@v1 118 | - name: Set up Docker Buildx 119 | uses: docker/setup-buildx-action@v1 120 | - name: Login to DockerHub 121 | env: 122 | DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} 123 | if: ${{ env.DOCKERHUB_USERNAME != '' }} 124 | uses: docker/login-action@v1 125 | with: 126 | username: ${{ secrets.DOCKERHUB_USERNAME }} 127 | password: ${{ secrets.DOCKERHUB_TOKEN }} 128 | - name: Build and push 129 | id: docker_build 130 | uses: docker/build-push-action@v2 131 | with: 132 | push: ${{ github.event_name != 'pull_request' }} 133 | tags: ${{ steps.set_tags.outputs.tags }} 134 | - name: Image digest 135 | run: echo ${{ steps.docker_build.outputs.digest }} 136 | -------------------------------------------------------------------------------- /tests/cli_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | from tf_bodypix.download import ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS, BodyPixModelPaths 5 | from tf_bodypix.model import ModelArchitectureNames 6 | from tf_bodypix.cli import DEFAULT_MODEL_TFLITE_PATH, main 7 | 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | EXAMPLE_IMAGE_URL = ( 13 | r'https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/' 14 | r'Person_Of_Interest_-_Panel_%289353656298%29.jpg/' 15 | r'640px-Person_Of_Interest_-_Panel_%289353656298%29.jpg' 16 | ) 17 | 18 | 19 | EXAMPLE_BACKGROUND_IMAGE_URL = ( 20 | r'https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa' 21 | r'/Gold_Coast_skyline.jpg/640px-Gold_Coast_skyline.jpg' 22 | ) 23 | 24 | 25 | class TestMain: 26 | def test_should_not_fail_to_draw_mask(self, tmp_path: Path): 27 | output_image_path = tmp_path / 'mask.jpg' 28 | main([ 29 | 'draw-mask', 30 | '--source=%s' % EXAMPLE_IMAGE_URL, 31 | '--output=%s' % output_image_path 32 | ]) 33 | 34 | def test_should_not_fail_to_draw_selected_mask(self, tmp_path: Path): 35 | output_image_path = tmp_path / 'mask.jpg' 36 | main([ 37 | 'draw-mask', 38 | '--source=%s' % EXAMPLE_IMAGE_URL, 39 | '--output=%s' % output_image_path, 40 | '--parts', 'left_face', 'right_face' 41 | ]) 42 | 43 | def test_should_not_fail_to_draw_colored_mask(self, tmp_path: Path): 44 | output_image_path = tmp_path / 'mask.jpg' 45 | main([ 46 | 'draw-mask', 47 | '--source=%s' % EXAMPLE_IMAGE_URL, 48 | '--output=%s' % output_image_path, 49 | '--colored' 50 | ]) 51 | 52 | def test_should_not_fail_to_draw_selected_colored_mask(self, tmp_path: Path): 53 | output_image_path = tmp_path / 'mask.jpg' 54 | main([ 55 | 'draw-mask', 56 | '--source=%s' % EXAMPLE_IMAGE_URL, 57 | '--output=%s' % output_image_path, 58 | '--parts', 'left_face', 'right_face', 59 | '--colored' 60 | ]) 61 | 62 | def test_should_not_fail_to_draw_single_person_pose(self, tmp_path: Path): 63 | output_image_path = tmp_path / 'output.jpg' 64 | main([ 65 | 'draw-pose', 66 | '--source=%s' % EXAMPLE_IMAGE_URL, 67 | '--output=%s' % output_image_path 68 | ]) 69 | 70 | def test_should_not_fail_to_blur_background(self, tmp_path: Path): 71 | output_image_path = tmp_path / 'output.jpg' 72 | main([ 73 | 'blur-background', 74 | '--source=%s' % EXAMPLE_IMAGE_URL, 75 | '--output=%s' % output_image_path 76 | ]) 77 | 78 | def test_should_not_fail_to_replace_background(self, tmp_path: Path): 79 | output_image_path = tmp_path / 'output.jpg' 80 | main([ 81 | 'replace-background', 82 | '--source=%s' % EXAMPLE_IMAGE_URL, 83 | '--background=%s' % EXAMPLE_IMAGE_URL, 84 | '--output=%s' % output_image_path 85 | ]) 86 | 87 | def test_should_list_all_default_model_urls(self, capsys): 88 | expected_urls = [ 89 | value 90 | for key, value in BodyPixModelPaths.__dict__.items() 91 | if not key.startswith('_') 92 | ] 93 | main(['list-models']) 94 | captured = capsys.readouterr() 95 | output_urls = captured.out.splitlines() 96 | LOGGER.debug('output_urls: %s', output_urls) 97 | missing_urls = set(expected_urls) - set(output_urls) 98 | assert not missing_urls 99 | 100 | def test_should_list_all_default_tflite_models(self, capsys): 101 | expected_urls = ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS 102 | main(['list-tflite-models']) 103 | captured = capsys.readouterr() 104 | output_urls = captured.out.splitlines() 105 | LOGGER.debug('output_urls: %s', output_urls) 106 | missing_urls = set(expected_urls) - set(output_urls) 107 | assert not missing_urls 108 | 109 | def test_should_be_able_to_convert_to_tflite_and_use_model(self, tmp_path: Path): 110 | output_model_file = tmp_path / 'model.tflite' 111 | main([ 112 | 'convert-to-tflite', 113 | '--model-path=%s' % BodyPixModelPaths.MOBILENET_FLOAT_75_STRIDE_16, 114 | '--optimize', 115 | '--quantization-type=int8', 116 | '--output-model-file=%s' % output_model_file 117 | ]) 118 | output_image_path = tmp_path / 'mask.jpg' 119 | main([ 120 | 'draw-mask', 121 | '--model-path=%s' % output_model_file, 122 | '--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1, 123 | '--output-stride=16', 124 | '--source=%s' % EXAMPLE_IMAGE_URL, 125 | '--output=%s' % output_image_path 126 | ]) 127 | 128 | def test_should_be_able_to_use_existing_tflite_model(self, tmp_path: Path): 129 | output_image_path = tmp_path / 'mask.jpg' 130 | main([ 131 | 'draw-mask', 132 | '--model-path=%s' % DEFAULT_MODEL_TFLITE_PATH, 133 | '--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1, 134 | '--output-stride=16', 135 | '--source=%s' % EXAMPLE_IMAGE_URL, 136 | '--output=%s' % output_image_path 137 | ]) 138 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/multi_person/decode_pose.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from typing import Dict 4 | 5 | from ..types import PartWithScore, TensorBuffer3D, Keypoint, Vector2D 6 | from ..keypoints import PART_NAMES, PART_IDS, POSE_CHAIN 7 | from .util import getImageCoords, clamp, addVectors, getOffsetPoint 8 | 9 | 10 | LOGGER = logging.getLogger(__name__) 11 | 12 | 13 | parentChildrenTuples = [ 14 | (PART_IDS[parentJoinName], PART_IDS[childJoinName]) 15 | for parentJoinName, childJoinName in POSE_CHAIN 16 | ] 17 | 18 | parentToChildEdges = [ 19 | childJointId 20 | for _, childJointId in parentChildrenTuples 21 | ] 22 | 23 | childToParentEdges = [ 24 | parentJointId 25 | for parentJointId, _ in parentChildrenTuples 26 | ] 27 | 28 | 29 | def getDisplacement( 30 | edgeId: int, point: Vector2D, displacements: TensorBuffer3D 31 | ) -> Vector2D: 32 | numEdges = displacements.shape[2] // 2 33 | # LOGGER.debug('point=%s, edgeId=%s, numEdges=%s', point, edgeId, numEdges) 34 | x_int = int(point.x) 35 | y_int = int(point.y) 36 | return Vector2D( 37 | y=displacements[y_int, x_int, edgeId], 38 | x=displacements[y_int, x_int, numEdges + edgeId] 39 | ) 40 | 41 | 42 | def getStridedIndexNearPoint( 43 | point: Vector2D, outputStride: int, height: int, 44 | width: int 45 | ) -> Vector2D: 46 | # LOGGER.debug('point: %s', point) 47 | return Vector2D( 48 | y=clamp(round(point.y / outputStride), 0, height - 1), 49 | x=clamp(round(point.x / outputStride), 0, width - 1) 50 | ) 51 | 52 | 53 | def traverseToTargetKeypoint( # pylint: disable=too-many-locals 54 | edgeId: int, 55 | sourceKeypoint: Keypoint, 56 | targetKeypointId: int, 57 | scoresBuffer: TensorBuffer3D, 58 | offsets: TensorBuffer3D, outputStride: int, 59 | displacements: TensorBuffer3D, 60 | offsetRefineStep: int = 2 61 | ) -> Keypoint: 62 | height, width = scoresBuffer.shape[:2] 63 | 64 | # Nearest neighbor interpolation for the source->target displacements. 65 | sourceKeypointIndices = getStridedIndexNearPoint( 66 | sourceKeypoint.position, outputStride, height, width 67 | ) 68 | 69 | displacement = getDisplacement( 70 | edgeId, sourceKeypointIndices, displacements 71 | ) 72 | 73 | displacedPoint = addVectors(sourceKeypoint.position, displacement) 74 | targetKeypoint = displacedPoint 75 | for _ in range(offsetRefineStep): 76 | targetKeypointIndices = getStridedIndexNearPoint( 77 | targetKeypoint, outputStride, height, width 78 | ) 79 | 80 | offsetPoint = getOffsetPoint( 81 | targetKeypointIndices.y, targetKeypointIndices.x, targetKeypointId, 82 | offsets 83 | ) 84 | 85 | targetKeypoint = addVectors( 86 | Vector2D( 87 | x=targetKeypointIndices.x * outputStride, 88 | y=targetKeypointIndices.y * outputStride 89 | ), 90 | Vector2D( 91 | x=offsetPoint.x, y=offsetPoint.y 92 | ) 93 | ) 94 | 95 | targetKeyPointIndices = getStridedIndexNearPoint( 96 | targetKeypoint, outputStride, height, width 97 | ) 98 | score = scoresBuffer[ 99 | int(targetKeyPointIndices.y), int(targetKeyPointIndices.x), targetKeypointId 100 | ] 101 | 102 | return Keypoint( 103 | position=targetKeypoint, 104 | part=PART_NAMES[targetKeypointId], 105 | score=score 106 | ) 107 | 108 | 109 | def decodePose( 110 | root: PartWithScore, scores: TensorBuffer3D, offsets: TensorBuffer3D, 111 | outputStride: int, displacementsFwd: TensorBuffer3D, 112 | displacementsBwd: TensorBuffer3D 113 | ) -> Dict[int, Keypoint]: 114 | # numParts = scores.shape[2] 115 | numEdges = len(parentToChildEdges) 116 | 117 | instanceKeypoints: Dict[int, Keypoint] = {} 118 | # Start a new detection instance at the position of the root. 119 | # const {part: rootPart, score: rootScore} = root; 120 | rootPoint = getImageCoords(root.part, outputStride, offsets) 121 | 122 | instanceKeypoints[root.part.keypoint_id] = Keypoint( 123 | score=root.score, 124 | part=PART_NAMES[root.part.keypoint_id], 125 | position=rootPoint 126 | ) 127 | 128 | # Decode the part positions upwards in the tree, following the backward 129 | # displacements. 130 | for edge in reversed(range(numEdges)): 131 | sourceKeypointId = parentToChildEdges[edge] 132 | targetKeypointId = childToParentEdges[edge] 133 | if ( 134 | instanceKeypoints.get(sourceKeypointId) 135 | and not instanceKeypoints.get(targetKeypointId) 136 | ): 137 | instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint( 138 | edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, 139 | offsets, outputStride, displacementsBwd 140 | ) 141 | 142 | # Decode the part positions downwards in the tree, following the forward 143 | # displacements. 144 | for edge in range(numEdges): 145 | sourceKeypointId = childToParentEdges[edge] 146 | targetKeypointId = parentToChildEdges[edge] 147 | if ( 148 | instanceKeypoints.get(sourceKeypointId) 149 | and not instanceKeypoints.get(targetKeypointId) 150 | ): 151 | instanceKeypoints[targetKeypointId] = traverseToTargetKeypoint( 152 | edge, instanceKeypoints[sourceKeypointId], targetKeypointId, scores, 153 | offsets, outputStride, displacementsFwd 154 | ) 155 | 156 | return instanceKeypoints 157 | -------------------------------------------------------------------------------- /tf_bodypix/download.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import os 4 | import re 5 | from urllib.parse import urlparse 6 | 7 | from hashlib import md5 8 | 9 | from tf_bodypix.utils.io import download_file_to, get_default_cache_dir 10 | 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | _DOWNLOAD_URL_PREFIX = r'https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/' 16 | 17 | 18 | class BodyPixModelPaths: 19 | MOBILENET_FLOAT_50_STRIDE_8 = ( 20 | _DOWNLOAD_URL_PREFIX + 'mobilenet/float/050/model-stride8.json' 21 | ) 22 | MOBILENET_FLOAT_50_STRIDE_16 = ( 23 | _DOWNLOAD_URL_PREFIX + 'mobilenet/float/050/model-stride16.json' 24 | ) 25 | MOBILENET_FLOAT_75_STRIDE_8 = ( 26 | _DOWNLOAD_URL_PREFIX + 'mobilenet/float/075/model-stride8.json' 27 | ) 28 | MOBILENET_FLOAT_75_STRIDE_16 = ( 29 | _DOWNLOAD_URL_PREFIX + 'mobilenet/float/075/model-stride16.json' 30 | ) 31 | MOBILENET_FLOAT_100_STRIDE_8 = ( 32 | _DOWNLOAD_URL_PREFIX + 'mobilenet/float/100/model-stride8.json' 33 | ) 34 | MOBILENET_FLOAT_100_STRIDE_16 = ( 35 | _DOWNLOAD_URL_PREFIX + 'mobilenet/float/100/model-stride16.json' 36 | ) 37 | 38 | RESNET50_FLOAT_STRIDE_16 = ( 39 | _DOWNLOAD_URL_PREFIX + 'resnet50/float/model-stride16.json' 40 | ) 41 | RESNET50_FLOAT_STRIDE_32 = ( 42 | _DOWNLOAD_URL_PREFIX + 'resnet50/float/model-stride32.json' 43 | ) 44 | 45 | # deprecated (shouldn't have mobilenet in the name) 46 | MOBILENET_RESNET50_FLOAT_STRIDE_16 = ( 47 | _DOWNLOAD_URL_PREFIX + 'resnet50/float/model-stride16.json' 48 | ) 49 | MOBILENET_RESNET50_FLOAT_STRIDE_32 = ( 50 | _DOWNLOAD_URL_PREFIX + 'resnet50/float/model-stride32.json' 51 | ) 52 | 53 | 54 | _TFLITE_DOWNLOAD_URL_PREFIX = r'https://www.dropbox.com/sh/d6tqb3gfrugs7ne/' 55 | 56 | 57 | class TensorFlowLiteBodyPixModelPaths: 58 | MOBILENET_FLOAT_50_STRIDE_8_FLOAT16 = ( 59 | _TFLITE_DOWNLOAD_URL_PREFIX 60 | + 'AADUtMGoDO6vzOfRLP0Dg7ira/mobilenet-float-multiplier-050-stride8-float16.tflite?dl=1' 61 | ) 62 | MOBILENET_FLOAT_50_STRIDE_16_FLOAT16 = ( 63 | _TFLITE_DOWNLOAD_URL_PREFIX 64 | + 'AAAhnozSEO07xzgL495dW3h8a/mobilenet-float-multiplier-050-stride16-float16.tflite?dl=1' 65 | ) 66 | 67 | MOBILENET_FLOAT_75_STRIDE_8_FLOAT16 = ( 68 | _TFLITE_DOWNLOAD_URL_PREFIX 69 | + 'AADBYGO2xj2v9Few4qBq62wZa/mobilenet-float-multiplier-075-stride8-float16.tflite?dl=1' 70 | ) 71 | MOBILENET_FLOAT_75_STRIDE_16_FLOAT16 = ( 72 | _TFLITE_DOWNLOAD_URL_PREFIX 73 | + 'AAAGYNAOTTWBl9ZDhALv7rEOa/mobilenet-float-multiplier-075-stride16-float16.tflite?dl=1' 74 | ) 75 | 76 | MOBILENET_FLOAT_100_STRIDE_8_FLOAT16 = ( 77 | _TFLITE_DOWNLOAD_URL_PREFIX 78 | + 'AADr8zOtPZz2cWlQEvKgIbdTa/mobilenet-float-multiplier-100-stride8-float16.tflite?dl=1' 79 | ) 80 | MOBILENET_FLOAT_100_STRIDE_16_FLOAT16 = ( 81 | _TFLITE_DOWNLOAD_URL_PREFIX 82 | + 'AAAo-hkaCqx2pN99cCvDPcosa/mobilenet-float-multiplier-100-stride16-float16.tflite?dl=1' 83 | ) 84 | 85 | RESNET50_FLOAT_STRIDE_16 = ( 86 | _TFLITE_DOWNLOAD_URL_PREFIX 87 | + 'AADvvgLyPXMPOeRyRY9WQ9Mva/resnet50-float-stride16-float16.tflite?dl=1' 88 | ) 89 | MOBILENET_RESNET50_FLOAT_STRIDE_32 = ( 90 | _TFLITE_DOWNLOAD_URL_PREFIX 91 | + 'AADGlTuMQQeL8vm6BuOwObKTa/resnet50-float-stride32-float16.tflite?dl=1' 92 | ) 93 | 94 | 95 | ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS = [ 96 | value 97 | for key, value in TensorFlowLiteBodyPixModelPaths.__dict__.items() 98 | if key.isupper() and isinstance(value, str) 99 | ] 100 | 101 | 102 | class DownloadError(RuntimeError): 103 | pass 104 | 105 | 106 | def download_model(model_path: str) -> str: 107 | if os.path.exists(model_path): 108 | return model_path 109 | parsed_model_path = urlparse(model_path) 110 | local_name_part = re.sub( 111 | r'[^a-zA-Z0-9]+', 112 | r'-', 113 | os.path.splitext(parsed_model_path.path)[0] 114 | ) 115 | local_name = ( 116 | md5(model_path.encode('utf-8')).hexdigest() + '-' 117 | + os.path.basename(local_name_part) 118 | ) 119 | LOGGER.debug('local_name: %r', local_name) 120 | cache_dir = get_default_cache_dir( 121 | cache_subdir=os.path.join('tf-bodypix', local_name) 122 | ) 123 | if parsed_model_path.path.endswith('.tflite'): 124 | return download_file_to( 125 | source_url=model_path, 126 | local_path=os.path.join( 127 | cache_dir, 128 | os.path.basename(parsed_model_path.path) 129 | ), 130 | skip_if_exists=True 131 | ) 132 | if not parsed_model_path.path.endswith('.json'): 133 | raise ValueError('remote model path needs to end with .json') 134 | model_base_path = os.path.dirname(model_path) 135 | local_model_json_path = download_file_to( 136 | source_url=model_path, 137 | local_path=os.path.join(cache_dir, 'model.json'), 138 | skip_if_exists=True 139 | ) 140 | local_model_path = os.path.dirname(local_model_json_path) 141 | LOGGER.debug('local_model_json_path: %r', local_model_json_path) 142 | try: 143 | with open(local_model_json_path, 'r', encoding='utf-8') as model_json_fp: 144 | model_json = json.load(model_json_fp) 145 | except UnicodeDecodeError as exc: 146 | LOGGER.error( 147 | 'failed to process %r due to %r', 148 | local_model_json_path, exc, exc_info=True 149 | ) 150 | raise DownloadError( 151 | 'failed to process %r due to %r' % ( 152 | local_model_json_path, exc 153 | ) 154 | ) from exc 155 | LOGGER.debug('model_json.keys: %s', model_json.keys()) 156 | weights_manifest = model_json['weightsManifest'] 157 | weights_manifest_paths = sorted({ 158 | path 159 | for item in weights_manifest 160 | for path in item.get('paths', []) 161 | }) 162 | LOGGER.debug('weights_manifest_paths: %s', weights_manifest_paths) 163 | for weights_manifest_path in weights_manifest_paths: 164 | local_model_json_path = download_file_to( 165 | source_url=model_base_path + '/' + weights_manifest_path, 166 | local_path=os.path.join(cache_dir, os.path.basename(weights_manifest_path)), 167 | skip_if_exists=True 168 | ) 169 | return local_model_path 170 | -------------------------------------------------------------------------------- /tf_bodypix/utils/image.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import namedtuple 4 | from typing import Optional, Sequence 5 | 6 | import numpy as np 7 | 8 | try: 9 | import tensorflow as tf 10 | except ImportError: 11 | tf = None 12 | 13 | try: 14 | import cv2 15 | except ImportError: 16 | cv2 = None 17 | 18 | try: 19 | import PIL.Image 20 | except ImportError: 21 | PIL = None 22 | 23 | 24 | LOGGER = logging.getLogger(__name__) 25 | 26 | 27 | ImageSize = namedtuple('ImageSize', ('height', 'width')) 28 | 29 | 30 | ImageArray = np.ndarray 31 | 32 | 33 | class ResizeMethod: 34 | BILINEAR = 'bilinear' 35 | 36 | 37 | def require_opencv(): 38 | if cv2 is None: 39 | raise ImportError('OpenCV is required') 40 | 41 | 42 | def box_blur_image(image: np.ndarray, blur_size: int) -> np.ndarray: 43 | if not blur_size: 44 | return image 45 | require_opencv() 46 | if len(image.shape) == 4: 47 | image = image[0] 48 | result = cv2.blur(np.asarray(image), (blur_size, blur_size)) 49 | if len(result.shape) == 2: 50 | result = np.expand_dims(result, axis=-1) 51 | result = result.astype(np.float32) 52 | return result 53 | 54 | 55 | def get_image_size(image: np.ndarray): 56 | height, width, *_ = image.shape 57 | return ImageSize(height=height, width=width) 58 | 59 | 60 | def _resize_image_to_using_tf( 61 | image_array: np.ndarray, 62 | image_size: ImageSize, 63 | resize_method: Optional[str] = None 64 | ) -> np.ndarray: 65 | if not resize_method: 66 | resize_method = tf.image.ResizeMethod.BILINEAR 67 | LOGGER.debug('resizing image: %r -> %r', image_array.shape, image_size) 68 | return tf.image.resize( 69 | image_array, 70 | (image_size.height, image_size.width), 71 | method=resize_method 72 | ) 73 | 74 | 75 | def _get_pil_image(image_array: np.ndarray) -> 'PIL.Image': 76 | if image_array.shape[-1] == 1: 77 | pil_mode = 'L' 78 | image_array = np.reshape(image_array, image_array.shape[:2]) 79 | else: 80 | pil_mode = 'RGB' 81 | image_array = image_array.astype(np.uint8) 82 | pil_image = PIL.Image.fromarray(image_array, mode=pil_mode) 83 | return pil_image 84 | 85 | 86 | # copied from: 87 | # https://chao-ji.github.io/jekyll/update/2018/07/19/BilinearResize.html 88 | def _numpy_bilinear_resize_2d( # pylint: disable=too-many-locals 89 | image: np.ndarray, 90 | height: int, 91 | width: int 92 | ) -> np.ndarray: 93 | """ 94 | `image` is a 2-D numpy array 95 | `height` and `width` are the desired spatial dimension of the new 2-D array. 96 | """ 97 | img_height, img_width = image.shape 98 | 99 | image = image.ravel() 100 | 101 | x_ratio = float(img_width - 1) / (width - 1) if width > 1 else 0 102 | y_ratio = float(img_height - 1) / (height - 1) if height > 1 else 0 103 | 104 | y, x = np.divmod(np.arange(height * width), width) 105 | 106 | x_l = np.floor(x_ratio * x).astype('int32') 107 | y_l = np.floor(y_ratio * y).astype('int32') 108 | 109 | x_h = np.ceil(x_ratio * x).astype('int32') 110 | y_h = np.ceil(y_ratio * y).astype('int32') 111 | 112 | x_weight = (x_ratio * x) - x_l 113 | y_weight = (y_ratio * y) - y_l 114 | 115 | a = image[y_l * img_width + x_l] 116 | b = image[y_l * img_width + x_h] 117 | c = image[y_h * img_width + x_l] 118 | d = image[y_h * img_width + x_h] 119 | 120 | resized = ( 121 | a * (1 - x_weight) * (1 - y_weight) + 122 | b * x_weight * (1 - y_weight) + 123 | c * y_weight * (1 - x_weight) + 124 | d * x_weight * y_weight 125 | ) 126 | 127 | return resized.reshape(height, width) 128 | 129 | 130 | def _numpy_bilinear_resize_3d(image: np.ndarray, height: int, width: int) -> np.ndarray: 131 | _, _, dimensions = image.shape 132 | return np.stack( 133 | [ 134 | _numpy_bilinear_resize_2d( 135 | image[:, :, dimension], height, width 136 | ) 137 | for dimension in range(dimensions) 138 | ], 139 | axis=-1 140 | ) 141 | 142 | 143 | def _resize_image_to_using_numpy( 144 | image_array: np.ndarray, 145 | image_size: ImageSize, 146 | resize_method: Optional[str] = None 147 | ) -> np.ndarray: 148 | assert not resize_method or resize_method == 'bilinear' 149 | if len(image_array.shape) == 4: 150 | assert image_array.shape[0] == 1 151 | image_array = image_array[0] 152 | LOGGER.debug( 153 | 'resizing image: %r (%r) -> %r', image_array.shape, image_array.dtype, image_size 154 | ) 155 | resize_image_array = ( 156 | _numpy_bilinear_resize_3d( 157 | np.asarray(image_array), image_size.height, image_size.width 158 | ).astype(image_array.dtype) 159 | ) 160 | LOGGER.debug( 161 | 'resize_image_array image: %r (%r)', image_array.shape, resize_image_array.dtype 162 | ) 163 | return resize_image_array 164 | 165 | 166 | def resize_image_to( 167 | image_array: np.ndarray, 168 | image_size: ImageSize, 169 | resize_method: Optional[str] = None 170 | ) -> np.ndarray: 171 | if get_image_size(image_array) == image_size: 172 | LOGGER.debug('image has already desired size: %s', image_size) 173 | return image_array 174 | 175 | if tf is not None: 176 | return _resize_image_to_using_tf(image_array, image_size, resize_method) 177 | return _resize_image_to_using_numpy(image_array, image_size, resize_method) 178 | 179 | 180 | def crop_and_resize_batch( # pylint: disable=too-many-locals 181 | image_array_batch: np.ndarray, 182 | boxes: Sequence[Sequence[float]], 183 | box_indices: Sequence[int], 184 | crop_size: Sequence[int], 185 | method='bilinear', 186 | ) -> np.ndarray: 187 | if tf is not None: 188 | return tf.image.crop_and_resize( 189 | image_array_batch, 190 | boxes=boxes, 191 | box_indices=box_indices, 192 | crop_size=crop_size, 193 | method=method 194 | ) 195 | assert list(box_indices) == [0] 196 | assert len(boxes) == 1 197 | assert len(crop_size) == 2 198 | box = np.array(boxes[0]) 199 | assert np.min(box) >= 0 200 | assert np.max(box) <= 1 201 | y1, x1, y2, x2 = list(box) 202 | assert y1 <= y2 203 | assert x1 <= x2 204 | assert len(image_array_batch) == 1 205 | image_size = get_image_size(image_array_batch[0]) 206 | image_y1 = int(y1 * (image_size.height - 1)) 207 | image_y2 = int(y2 * (image_size.height - 1)) 208 | image_x1 = int(x1 * (image_size.width - 1)) 209 | image_x2 = int(x2 * (image_size.width - 1)) 210 | LOGGER.debug('image y1, x1, y2, x2: %r', (image_y1, image_x1, image_y2, image_x2)) 211 | cropped_image_array = image_array_batch[0][ 212 | image_y1:(1 + image_y2), image_x1: (1 + image_x2), : 213 | ] 214 | LOGGER.debug('cropped_image_array: %r', cropped_image_array.shape) 215 | resized_cropped_image_array = resize_image_to( 216 | cropped_image_array, ImageSize(height=crop_size[0], width=crop_size[1]) 217 | ) 218 | return np.expand_dims(resized_cropped_image_array, 0) 219 | 220 | 221 | def bgr_to_rgb(image: np.ndarray) -> np.ndarray: 222 | # see https://www.scivision.dev/numpy-image-bgr-to-rgb/ 223 | return image[..., ::-1] 224 | 225 | 226 | def rgb_to_bgr(image: np.ndarray) -> np.ndarray: 227 | return bgr_to_rgb(image) 228 | 229 | 230 | def _load_image_using_tf( 231 | local_image_path: str, 232 | image_size: Optional[ImageSize] = None 233 | ) -> np.ndarray: 234 | image = tf.keras.preprocessing.image.load_img( 235 | local_image_path 236 | ) 237 | image_array = tf.keras.preprocessing.image.img_to_array(image) 238 | if image_size is not None: 239 | image_array = resize_image_to(image_array, image_size) 240 | return image_array 241 | 242 | 243 | def _load_image_using_pillow( 244 | local_image_path: str, 245 | image_size: Optional[ImageSize] = None 246 | ) -> np.ndarray: 247 | with PIL.Image.open(local_image_path) as image: 248 | image_array = np.asarray(image) 249 | if image_size is not None: 250 | image_array = resize_image_to(image_array, image_size) 251 | return image_array 252 | 253 | 254 | def load_image( 255 | local_image_path: str, 256 | image_size: Optional[ImageSize] = None 257 | ) -> np.ndarray: 258 | if tf is not None: 259 | return _load_image_using_tf(local_image_path, image_size=image_size) 260 | return _load_image_using_pillow(local_image_path, image_size=image_size) 261 | 262 | 263 | def save_image_using_tf(image_array: np.ndarray, path: str): 264 | tf.keras.preprocessing.image.save_img(path, image_array) 265 | 266 | 267 | def save_image_using_pillow(image_array: np.ndarray, path: str): 268 | pil_image = _get_pil_image(image_array) 269 | pil_image.save(path) 270 | 271 | 272 | def write_image_to(image_array: np.ndarray, path: str): 273 | LOGGER.info('writing image to: %r', path) 274 | os.makedirs(os.path.dirname(path), exist_ok=True) 275 | if tf is not None: 276 | save_image_using_tf(image_array, path) 277 | else: 278 | save_image_using_pillow(image_array, path) 279 | -------------------------------------------------------------------------------- /tf_bodypix/utils/opencv.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import deque 3 | from contextlib import contextmanager 4 | from time import monotonic, sleep 5 | from typing import Callable, ContextManager, Deque, Iterable, Iterator, Optional, Union 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | from tf_bodypix.utils.io import get_file 11 | from tf_bodypix.utils.image import ( 12 | ImageSize, ImageArray, bgr_to_rgb, rgb_to_bgr, 13 | get_image_size 14 | ) 15 | 16 | 17 | LOGGER = logging.getLogger(__name__) 18 | 19 | 20 | DEFAULT_WEBCAM_FOURCC = 'MJPG' 21 | 22 | 23 | def iter_read_raw_video_images( 24 | video_capture: cv2.VideoCapture, 25 | repeat: bool = False, 26 | is_stopped: Optional[Callable[[], bool]] = None 27 | ) -> Iterable[ImageArray]: 28 | while is_stopped is None or not is_stopped(): 29 | grabbed, image_array = video_capture.read() 30 | if not grabbed: 31 | LOGGER.info('video end reached') 32 | if not repeat: 33 | return 34 | video_capture.set(cv2.CAP_PROP_POS_FRAMES, 0) 35 | grabbed, image_array = video_capture.read() 36 | if not grabbed: 37 | LOGGER.info('unable to rewind video') 38 | return 39 | yield image_array 40 | 41 | 42 | def iter_resize_video_images( 43 | video_images: Iterable[ImageArray], 44 | image_size: Optional[ImageSize] = None, 45 | interpolation: int = cv2.INTER_LINEAR 46 | ) -> Iterable[ImageArray]: 47 | is_first = True 48 | for image_array in video_images: 49 | LOGGER.debug('video image_array.shape: %s', image_array.shape) 50 | if is_first: 51 | LOGGER.info( 52 | 'received video image shape: %s (requested: %s)', 53 | image_array.shape, image_size 54 | ) 55 | is_first = False 56 | if image_size and get_image_size(image_array) != image_size: 57 | image_array = cv2.resize( 58 | image_array, 59 | (image_size.width, image_size.height), 60 | interpolation=interpolation 61 | ) 62 | yield image_array 63 | 64 | 65 | def iter_convert_video_images_to_rgb( 66 | video_images: Iterable[ImageArray] 67 | ) -> Iterable[ImageArray]: 68 | return (bgr_to_rgb(image_array) for image_array in video_images) 69 | 70 | 71 | def iter_delay_video_images_to_fps( 72 | video_images: Iterable[ImageArray], 73 | fps: Optional[float] = None 74 | ) -> Iterable[np.ndarray]: 75 | if not fps or fps <= 0: 76 | LOGGER.info('no fps requested, providing images from source (without delay)') 77 | yield from video_images 78 | return 79 | desired_frame_time = 1 / fps 80 | LOGGER.info( 81 | 'limiting frame rate to %.3f fsp (%.1f ms per frame)', 82 | fps, desired_frame_time * 1000 83 | ) 84 | last_frame_time = None 85 | frame_times: Deque[float] = deque(maxlen=10) 86 | current_fps = 0.0 87 | additional_frame_adjustment = 0.0 88 | end_frame_time = monotonic() 89 | video_images_iterator = iter(video_images) 90 | while True: 91 | start_frame_time = end_frame_time 92 | # attempt to retrieve the next frame (that may vary in time) 93 | try: 94 | image_array = next(video_images_iterator) 95 | except StopIteration: 96 | return 97 | # wait time until delivery in order to achieve a similar fps 98 | current_time = monotonic() 99 | if last_frame_time: 100 | desired_wait_time = ( 101 | desired_frame_time 102 | - (current_time - last_frame_time) 103 | + additional_frame_adjustment 104 | ) 105 | if desired_wait_time > 0: 106 | LOGGER.debug( 107 | 'sleeping for desired fps: %s (desired_frame_time: %s, fps: %.3f)', 108 | desired_wait_time, desired_frame_time, current_fps 109 | ) 110 | sleep(desired_wait_time) 111 | last_frame_time = monotonic() 112 | # emit the frame (post processing may add to the overall) 113 | yield image_array 114 | end_frame_time = monotonic() 115 | frame_time = end_frame_time - start_frame_time 116 | additional_frame_adjustment = desired_frame_time - frame_time 117 | frame_times.append(frame_time) 118 | current_fps = 1 / (sum(frame_times) / len(frame_times)) 119 | 120 | 121 | def iter_read_video_images( 122 | video_capture: cv2.VideoCapture, 123 | image_size: Optional[ImageSize] = None, 124 | interpolation: int = cv2.INTER_LINEAR, 125 | repeat: bool = True, 126 | fps: Optional[float] = None 127 | ) -> Iterable[np.ndarray]: 128 | video_images: Iterable[np.ndarray] 129 | video_images = iter_read_raw_video_images(video_capture, repeat=repeat) 130 | video_images = iter_delay_video_images_to_fps(video_images, fps) 131 | video_images = iter_resize_video_images( 132 | video_images, image_size=image_size, interpolation=interpolation 133 | ) 134 | video_images = iter_convert_video_images_to_rgb(video_images) 135 | return video_images 136 | 137 | 138 | @contextmanager 139 | def get_video_image_source( # pylint: disable=too-many-locals 140 | path: Union[str, int], 141 | image_size: Optional[ImageSize] = None, 142 | download: bool = True, 143 | fps: Optional[float] = None, 144 | fourcc: Optional[str] = None, 145 | buffer_size: Optional[int] = None, 146 | **_ 147 | ) -> Iterator[Iterable[ImageArray]]: 148 | local_path: Union[str, int] 149 | if isinstance(path, str): 150 | local_path = get_file(path, download=download) 151 | else: 152 | local_path = path 153 | if local_path != path: 154 | LOGGER.info('loading video: %r (downloaded from %r)', local_path, path) 155 | else: 156 | LOGGER.info('loading video: %r', path) 157 | video_capture = cv2.VideoCapture(local_path) 158 | if fourcc: 159 | LOGGER.info('setting video fourcc to %r', fourcc) 160 | video_capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*fourcc)) 161 | if buffer_size: 162 | video_capture.set(cv2.CAP_PROP_BUFFERSIZE, buffer_size) 163 | if image_size: 164 | LOGGER.info('attempting to set video image size to: %s', image_size) 165 | video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, image_size.width) 166 | video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, image_size.height) 167 | if fps: 168 | LOGGER.info('attempting to set video fps to %r', fps) 169 | video_capture.set(cv2.CAP_PROP_FPS, fps) 170 | actual_image_size = ImageSize( 171 | width=video_capture.get(cv2.CAP_PROP_FRAME_WIDTH), 172 | height=video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) 173 | ) 174 | actual_fps = video_capture.get(cv2.CAP_PROP_FPS) 175 | frame_count = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) 176 | LOGGER.info( 177 | 'video reported image size: %s (%s fps, %s frames)', 178 | actual_image_size, actual_fps, frame_count 179 | ) 180 | try: 181 | yield iter_read_video_images( 182 | video_capture, 183 | image_size=image_size, 184 | fps=fps if fps is not None else actual_fps 185 | ) 186 | finally: 187 | LOGGER.debug('releasing video capture: %s', path) 188 | video_capture.release() 189 | 190 | 191 | def get_webcam_image_source( 192 | path: Union[str, int], 193 | fourcc: Optional[str] = None, 194 | buffer_size: int = 1, 195 | **kwargs 196 | ) -> ContextManager[Iterable[ImageArray]]: 197 | if fourcc is None: 198 | fourcc = DEFAULT_WEBCAM_FOURCC 199 | return get_video_image_source(path, fourcc=fourcc, buffer_size=buffer_size, **kwargs) 200 | 201 | 202 | class ShowImageSink: 203 | def __init__( 204 | self, 205 | window_name: str, 206 | window_title: str = '' 207 | ): 208 | self.window_name = window_name 209 | self.window_title = window_title 210 | self.was_opened = False 211 | 212 | def __enter__(self): 213 | return self 214 | 215 | def __exit__(self, *_, **__): 216 | if self.was_opened: 217 | cv2.destroyWindow(self.window_name) 218 | 219 | @property 220 | def is_closed(self): 221 | if not self.was_opened: 222 | return False 223 | cv2.waitKey(1) 224 | return cv2.getWindowProperty(self.window_name, cv2.WND_PROP_VISIBLE) <= 0 225 | 226 | def create_window(self, image_size: ImageSize): 227 | cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL) 228 | cv2.resizeWindow(self.window_name, image_size.width, image_size.height) 229 | if self.window_title: 230 | cv2.setWindowTitle(self.window_name, self.window_title) 231 | self.was_opened = True 232 | 233 | def __call__(self, image_array: np.ndarray): 234 | if self.is_closed: 235 | LOGGER.info('window closed') 236 | raise KeyboardInterrupt('window closed') 237 | image_array = np.asarray(image_array).astype(np.uint8) 238 | if not self.was_opened: 239 | self.create_window(get_image_size(image_array)) 240 | cv2.imshow(self.window_name, rgb_to_bgr(image_array)) 241 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VENV = venv 2 | 3 | ifeq ($(OS),Windows_NT) 4 | VENV_BIN = $(VENV)/Scripts 5 | else 6 | VENV_BIN = $(VENV)/bin 7 | endif 8 | 9 | PYTHON = $(VENV_BIN)/python 10 | PIP = $(VENV_BIN)/python -m pip 11 | 12 | SYSTEM_PYTHON = python3 13 | 14 | VENV_TEMP = venv_temp 15 | 16 | ARGS = 17 | 18 | 19 | IMAGE_URL = https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/Person_Of_Interest_-_Panel_%289353656298%29.jpg/640px-Person_Of_Interest_-_Panel_%289353656298%29.jpg 20 | BACKGROUND_IMAGE_URL = https://upload.wikimedia.org/wikipedia/commons/thumb/a/aa/Gold_Coast_skyline.jpg/640px-Gold_Coast_skyline.jpg 21 | OUTPUT_MASK_PATH = data/example-mask.jpg 22 | OUTPUT_SELECTED_MASK_PATH = data/example-selected-mask.jpg 23 | OUTPUT_COLORED_MASK_PATH = data/example-colored-mask.jpg 24 | OUTPUT_SELECTED_COLORED_MASK_PATH = data/example-selected-colored-mask.jpg 25 | OUTPUT_WEBCAM_MASK_PATH = data/webcam-mask.jpg 26 | MASK_THRESHOLD = 0.75 27 | ADD_OVERLAY_ALPHA = 0.5 28 | 29 | SELECTED_PARTS = left_face right_face 30 | 31 | WEBCAM_PATH = webcam:0 32 | VIRTUAL_VIDEO_DEVICE = /dev/video2 33 | 34 | IMAGE_NAME = de4code/tf-bodypix_unstable 35 | IMAGE_TAG = develop 36 | 37 | 38 | venv-clean: 39 | @if [ -d "$(VENV)" ]; then \ 40 | rm -rf "$(VENV)"; \ 41 | fi 42 | 43 | 44 | venv-create: 45 | $(SYSTEM_PYTHON) -m venv $(VENV) 46 | 47 | 48 | dev-install-build-dependencies: 49 | $(PIP) install --requirement=requirements.build.txt 50 | 51 | 52 | dev-install: dev-install-build-dependencies 53 | $(PIP) install \ 54 | --constraint=constraints.txt \ 55 | --requirement=requirements.dev.txt \ 56 | --requirement=requirements.txt 57 | 58 | 59 | dev-install-tflite: dev-install-build-dependencies 60 | $(PIP) install \ 61 | --constraint=constraints.txt \ 62 | --requirement=requirements.dev.txt 63 | $(PIP) install .[tflite,image] 64 | 65 | 66 | dev-run-pip: 67 | $(PIP) $(ARGS) 68 | 69 | 70 | dev-venv: venv-create dev-install 71 | 72 | 73 | dev-flake8: 74 | $(PYTHON) -m flake8 tf_bodypix tests setup.py 75 | 76 | 77 | dev-pylint: 78 | $(PYTHON) -m pylint tf_bodypix tests setup.py 79 | 80 | 81 | dev-mypy: 82 | $(PYTHON) -m mypy --ignore-missing-imports --show-error-codes \ 83 | tf_bodypix tests setup.py 84 | 85 | 86 | dev-lint: dev-flake8 dev-pylint dev-mypy 87 | 88 | 89 | dev-pytest: 90 | $(PYTHON) -m pytest -p no:cacheprovider $(ARGS) 91 | 92 | 93 | dev-pytest-tflite: 94 | $(MAKE) dev-pytest \ 95 | ARGS='tests/cli_test.py -k test_should_be_able_to_use_existing_tflite_model' 96 | 97 | 98 | dev-watch: 99 | $(PYTHON) -m pytest_watch -- -p no:cacheprovider -p no:warnings $(ARGS) 100 | 101 | 102 | dev-watch-tflite: 103 | $(MAKE) dev-watch \ 104 | ARGS='tests/cli_test.py -k test_should_be_able_to_use_existing_tflite_model' 105 | 106 | 107 | dev-test: dev-lint dev-pytest 108 | 109 | 110 | dev-remove-dist: 111 | rm -rf ./dist 112 | 113 | 114 | dev-build-dist: 115 | $(PYTHON) setup.py sdist bdist_wheel 116 | 117 | 118 | dev-list-dist-contents: 119 | tar -ztvf dist/tf-bodypix-*.tar.gz 120 | 121 | 122 | dev-get-version: 123 | $(PYTHON) setup.py --version 124 | 125 | 126 | dev-test-install-dist: 127 | $(MAKE) VENV=$(VENV_TEMP) venv-create 128 | $(VENV_TEMP)/bin/pip install -r requirements.build.txt 129 | $(VENV_TEMP)/bin/pip install --force-reinstall ./dist/*.tar.gz 130 | $(VENV_TEMP)/bin/pip install --force-reinstall ./dist/*.whl 131 | 132 | 133 | run: 134 | $(PYTHON) -m tf_bodypix $(ARGS) 135 | 136 | 137 | list-models: 138 | $(PYTHON) -m tf_bodypix \ 139 | list-models 140 | 141 | 142 | list-tflite-models: 143 | $(PYTHON) -m tf_bodypix \ 144 | list-tflite-models 145 | 146 | 147 | convert-example-draw-mask: 148 | $(PYTHON) -m tf_bodypix \ 149 | draw-mask \ 150 | --source \ 151 | "$(IMAGE_URL)" \ 152 | --output \ 153 | "$(OUTPUT_MASK_PATH)" \ 154 | --threshold=$(MASK_THRESHOLD) \ 155 | $(ARGS) 156 | 157 | 158 | convert-example-draw-selected-mask: 159 | $(PYTHON) -m tf_bodypix \ 160 | draw-mask \ 161 | --source \ 162 | "$(IMAGE_URL)" \ 163 | --output \ 164 | "$(OUTPUT_SELECTED_MASK_PATH)" \ 165 | --threshold=$(MASK_THRESHOLD) \ 166 | --parts $(SELECTED_PARTS) \ 167 | $(ARGS) 168 | 169 | 170 | convert-example-draw-colored-mask: 171 | $(PYTHON) -m tf_bodypix \ 172 | draw-mask \ 173 | --source \ 174 | "$(IMAGE_URL)" \ 175 | --output \ 176 | "$(OUTPUT_COLORED_MASK_PATH)" \ 177 | --threshold=$(MASK_THRESHOLD) \ 178 | --colored \ 179 | $(ARGS) 180 | 181 | 182 | convert-example-draw-selected-colored-mask: 183 | $(PYTHON) -m tf_bodypix \ 184 | draw-mask \ 185 | --source \ 186 | "$(IMAGE_URL)" \ 187 | --output \ 188 | "$(OUTPUT_SELECTED_COLORED_MASK_PATH)" \ 189 | --threshold=$(MASK_THRESHOLD) \ 190 | --colored \ 191 | --parts $(SELECTED_PARTS) \ 192 | $(ARGS) 193 | 194 | 195 | webcam-draw-mask: 196 | $(PYTHON) -m tf_bodypix \ 197 | draw-mask \ 198 | --source \ 199 | "$(WEBCAM_PATH)" \ 200 | --show-output \ 201 | --threshold=$(MASK_THRESHOLD) \ 202 | --add-overlay-alpha=$(ADD_OVERLAY_ALPHA) \ 203 | $(ARGS) 204 | 205 | 206 | webcam-blur-background: 207 | $(PYTHON) -m tf_bodypix \ 208 | blur-background \ 209 | --source \ 210 | "$(WEBCAM_PATH)" \ 211 | --show-output \ 212 | --threshold=$(MASK_THRESHOLD) \ 213 | $(ARGS) 214 | 215 | 216 | webcam-replace-background: 217 | $(PYTHON) -m tf_bodypix \ 218 | replace-background \ 219 | --source \ 220 | "$(WEBCAM_PATH)" \ 221 | --background \ 222 | "$(BACKGROUND_IMAGE_URL)" \ 223 | --show-output \ 224 | --threshold=$(MASK_THRESHOLD) \ 225 | $(ARGS) 226 | 227 | 228 | webcam-v4l2-draw-mask: 229 | $(PYTHON) -m tf_bodypix \ 230 | draw-mask \ 231 | --source \ 232 | "$(WEBCAM_PATH)" \ 233 | --output=$(VIRTUAL_VIDEO_DEVICE) \ 234 | --threshold=$(MASK_THRESHOLD) \ 235 | --add-overlay-alpha=$(ADD_OVERLAY_ALPHA) \ 236 | $(ARGS) 237 | 238 | 239 | webcam-v4l2-draw-mask-colored: 240 | $(PYTHON) -m tf_bodypix \ 241 | draw-mask \ 242 | --source \ 243 | "$(WEBCAM_PATH)" \ 244 | --output=$(VIRTUAL_VIDEO_DEVICE) \ 245 | --threshold=$(MASK_THRESHOLD) \ 246 | --add-overlay-alpha=$(ADD_OVERLAY_ALPHA) \ 247 | --colored \ 248 | $(ARGS) 249 | 250 | 251 | webcam-v4l2-blur-background: 252 | $(PYTHON) -m tf_bodypix \ 253 | blur-background \ 254 | --source \ 255 | "$(WEBCAM_PATH)" \ 256 | --output=$(VIRTUAL_VIDEO_DEVICE) \ 257 | --threshold=$(MASK_THRESHOLD) \ 258 | $(ARGS) 259 | 260 | 261 | webcam-v4l2-replace-background: 262 | $(PYTHON) -m tf_bodypix \ 263 | replace-background \ 264 | --source \ 265 | "$(WEBCAM_PATH)" \ 266 | --background \ 267 | "$(BACKGROUND_IMAGE_URL)" \ 268 | --output=$(VIRTUAL_VIDEO_DEVICE) \ 269 | --threshold=$(MASK_THRESHOLD) \ 270 | $(ARGS) 271 | 272 | 273 | convert-tfjs-models-to-tflite: 274 | mkdir -p "./data/tflite-models" 275 | $(PYTHON) -m tf_bodypix \ 276 | convert-to-tflite \ 277 | --model-path \ 278 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/050/model-stride8.json" \ 279 | --optimize \ 280 | --quantization-type=float16 \ 281 | --output-model-file "./data/tflite-models/mobilenet-float-multiplier-050-stride8-float16.tflite" 282 | $(PYTHON) -m tf_bodypix \ 283 | convert-to-tflite \ 284 | --model-path \ 285 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/050/model-stride16.json" \ 286 | --optimize \ 287 | --quantization-type=float16 \ 288 | --output-model-file "./data/tflite-models/mobilenet-float-multiplier-050-stride16-float16.tflite" 289 | $(PYTHON) -m tf_bodypix \ 290 | convert-to-tflite \ 291 | --model-path \ 292 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride8.json" \ 293 | --optimize \ 294 | --quantization-type=float16 \ 295 | --output-model-file "./data/tflite-models/mobilenet-float-multiplier-075-stride8-float16.tflite" 296 | $(PYTHON) -m tf_bodypix \ 297 | convert-to-tflite \ 298 | --model-path \ 299 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \ 300 | --optimize \ 301 | --quantization-type=float16 \ 302 | --output-model-file "./data/tflite-models/mobilenet-float-multiplier-075-stride16-float16.tflite" 303 | $(PYTHON) -m tf_bodypix \ 304 | convert-to-tflite \ 305 | --model-path \ 306 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/100/model-stride8.json" \ 307 | --optimize \ 308 | --quantization-type=float16 \ 309 | --output-model-file "./data/tflite-models/mobilenet-float-multiplier-100-stride8-float16.tflite" 310 | $(PYTHON) -m tf_bodypix \ 311 | convert-to-tflite \ 312 | --model-path \ 313 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/100/model-stride16.json" \ 314 | --optimize \ 315 | --quantization-type=float16 \ 316 | --output-model-file "./data/tflite-models/mobilenet-float-multiplier-100-stride16-float16.tflite" 317 | $(PYTHON) -m tf_bodypix \ 318 | convert-to-tflite \ 319 | --model-path \ 320 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/resnet50/float/model-stride16.json" \ 321 | --optimize \ 322 | --quantization-type=float16 \ 323 | --output-model-file "./data/tflite-models/resnet50-float-stride16-float16.tflite" 324 | $(PYTHON) -m tf_bodypix \ 325 | convert-to-tflite \ 326 | --model-path \ 327 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/resnet50/float/model-stride32.json" \ 328 | --optimize \ 329 | --quantization-type=float16 \ 330 | --output-model-file "./data/tflite-models/resnet50-float-stride32-float16.tflite" 331 | 332 | 333 | docker-build: 334 | docker build . -t $(IMAGE_NAME):$(IMAGE_TAG) 335 | 336 | 337 | docker-run: 338 | docker run \ 339 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 340 | -e DISPLAY=unix$$DISPLAY \ 341 | -v /dev/shm:/dev/shm \ 342 | --rm $(IMAGE_NAME):$(IMAGE_TAG) $(ARGS) 343 | -------------------------------------------------------------------------------- /tf_bodypix/bodypix_js_utils/util.py: -------------------------------------------------------------------------------- 1 | # based on: 2 | # https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/util.ts 3 | 4 | import logging 5 | import math 6 | from collections import namedtuple 7 | from typing import List, Optional, Tuple, Union 8 | 9 | try: 10 | import tensorflow as tf 11 | except ImportError: 12 | tf = None 13 | 14 | import numpy as np 15 | 16 | from tf_bodypix.utils.image import ( 17 | ResizeMethod, 18 | crop_and_resize_batch, 19 | resize_image_to, 20 | ImageSize 21 | ) 22 | 23 | from .types import Keypoint, Pose, Vector2D 24 | 25 | 26 | LOGGER = logging.getLogger(__name__) 27 | 28 | 29 | Padding = namedtuple('Padding', ('top', 'bottom', 'left', 'right')) 30 | 31 | 32 | # see isValidInputResolution 33 | def is_valid_input_resolution( 34 | resolution: Union[int, float], output_stride: int 35 | ) -> bool: 36 | return (resolution - 1) % output_stride == 0 37 | 38 | 39 | # see toValidInputResolution 40 | def to_valid_input_resolution( 41 | input_resolution: Union[int, float], output_stride: int 42 | ) -> int: 43 | if is_valid_input_resolution(input_resolution, output_stride): 44 | return int(input_resolution) 45 | 46 | return int(math.floor(input_resolution / output_stride) * output_stride + 1) 47 | 48 | 49 | # see toInputResolutionHeightAndWidth 50 | def get_bodypix_input_resolution_height_and_width( 51 | internal_resolution_percentage: float, 52 | output_stride: int, 53 | input_height: int, 54 | input_width: int 55 | ) -> Tuple[int, int]: 56 | return ( 57 | to_valid_input_resolution( 58 | input_height * internal_resolution_percentage, output_stride), 59 | to_valid_input_resolution( 60 | input_width * internal_resolution_percentage, output_stride) 61 | ) 62 | 63 | 64 | def _pad_image_like_tensorflow( 65 | image: np.ndarray, 66 | padding: Padding 67 | ) -> np.ndarray: 68 | """ 69 | This is my padding function to replace with tf.image.pad_to_bounding_box 70 | :param image: 71 | :param padding: 72 | :return: 73 | """ 74 | 75 | padded = np.copy(image) 76 | dims = padded.shape 77 | dtype = image.dtype 78 | 79 | if padding.top != 0: 80 | top_zero_row = np.zeros(shape=(padding.top, dims[1], dims[2]), dtype=dtype) 81 | padded = np.vstack([top_zero_row, padded]) 82 | 83 | if padding.bottom != 0: 84 | bottom_zero_row = np.zeros(shape=(padding.top, dims[1], dims[2]), dtype=dtype) 85 | padded = np.vstack([padded, bottom_zero_row]) 86 | 87 | dims = padded.shape 88 | if padding.left != 0: 89 | left_zero_column = np.zeros(shape=(dims[0], padding.left, dims[2]), dtype=dtype) 90 | padded = np.hstack([left_zero_column, padded]) 91 | 92 | if padding.right != 0: 93 | right_zero_column = np.zeros(shape=(dims[0], padding.right, dims[2]), dtype=dtype) 94 | padded = np.hstack([padded, right_zero_column]) 95 | 96 | return padded 97 | 98 | 99 | # see padAndResizeTo 100 | def pad_and_resize_to( 101 | image: np.ndarray, 102 | target_height, target_width: int 103 | ) -> Tuple[np.ndarray, Padding]: 104 | input_height, input_width = image.shape[:2] 105 | target_aspect = target_width / target_height 106 | aspect = input_width / input_height 107 | if aspect < target_aspect: 108 | # pads the width 109 | padding = Padding( 110 | top=0, 111 | bottom=0, 112 | left=round(0.5 * (target_aspect * input_height - input_width)), 113 | right=round(0.5 * (target_aspect * input_height - input_width)) 114 | ) 115 | else: 116 | # pads the height 117 | padding = Padding( 118 | top=round(0.5 * ((1.0 / target_aspect) * input_width - input_height)), 119 | bottom=round(0.5 * ((1.0 / target_aspect) * input_width - input_height)), 120 | left=0, 121 | right=0 122 | ) 123 | 124 | if tf is not None: 125 | padded = tf.image.pad_to_bounding_box( 126 | image, 127 | offset_height=padding.top, 128 | offset_width=padding.left, 129 | target_height=padding.top + input_height + padding.bottom, 130 | target_width=padding.left + input_width + padding.right 131 | ) 132 | resized = tf.image.resize([padded], [target_height, target_width])[0] 133 | else: 134 | padded = _pad_image_like_tensorflow(image, padding) 135 | LOGGER.debug( 136 | 'padded: %r (%r) -> %r (%r)', 137 | image.shape, image.dtype, padded.shape, padded.dtype 138 | ) 139 | resized = resize_image_to( 140 | padded, ImageSize(width=target_width, height=target_height) 141 | ) 142 | LOGGER.debug( 143 | 'resized: %r (%r) -> %r (%r)', 144 | padded.shape, padded.dtype, resized.shape, resized.dtype 145 | ) 146 | return resized, padding 147 | 148 | 149 | def get_images_batch(image: np.ndarray) -> np.ndarray: 150 | if len(image.shape) == 4: 151 | return image 152 | if len(image.shape) == 3: 153 | if tf is not None: 154 | return image[tf.newaxis, ...] 155 | return np.expand_dims(image, axis=0) 156 | raise ValueError('invalid dimension, shape=%s' % str(image.shape)) 157 | 158 | 159 | # reverse of pad_and_resize_to 160 | def remove_padding_and_resize_back( 161 | resized_and_padded: np.ndarray, 162 | original_height: int, 163 | original_width: int, 164 | padding: Padding, 165 | resize_method: Optional[str] = None 166 | ) -> np.ndarray: 167 | if not resize_method: 168 | resize_method = ResizeMethod.BILINEAR 169 | boxes = [[ 170 | padding.top / (original_height + padding.top + padding.bottom - 1.0), 171 | padding.left / (original_width + padding.left + padding.right - 1.0), 172 | ( 173 | (padding.top + original_height - 1.0) 174 | / (original_height + padding.top + padding.bottom - 1.0) 175 | ), 176 | ( 177 | (padding.left + original_width - 1.0) 178 | / (original_width + padding.left + padding.right - 1.0) 179 | ) 180 | ]] 181 | return crop_and_resize_batch( 182 | get_images_batch(resized_and_padded), 183 | boxes=boxes, 184 | box_indices=[0], 185 | crop_size=[original_height, original_width], 186 | method=resize_method 187 | )[0] 188 | 189 | 190 | def remove_padding_and_resize_back_simple( 191 | resized_and_padded: np.ndarray, 192 | original_height: int, 193 | original_width: int, 194 | padding: Padding, 195 | resize_method: Optional[str] = None 196 | ) -> np.ndarray: 197 | padded_height = padding.top + original_height + padding.bottom 198 | padded_width = padding.left + original_width + padding.right 199 | padded = resize_image_to( 200 | resized_and_padded, 201 | ImageSize(height=padded_height, width=padded_width), 202 | resize_method=resize_method 203 | ) 204 | cropped = tf.image.crop_to_bounding_box( 205 | padded, 206 | offset_height=padding.top, 207 | offset_width=padding.left, 208 | target_height=original_height, 209 | target_width=original_width 210 | ) 211 | return cropped[0] 212 | 213 | 214 | def _get_sigmoid_using_tf(x: np.ndarray): 215 | return tf.math.sigmoid(x) 216 | 217 | 218 | def _get_sigmoid_using_numpy(x: np.ndarray): 219 | return 1/(1 + np.exp(-x)) 220 | 221 | 222 | def get_sigmoid(x: np.ndarray): 223 | if tf is not None: 224 | return _get_sigmoid_using_tf(x) 225 | return _get_sigmoid_using_numpy(x) 226 | 227 | 228 | # see scaleAndCropToInputTensorShape 229 | def scale_and_crop_to_input_tensor_shape( 230 | image: np.ndarray, 231 | input_height: int, 232 | input_width: int, 233 | resized_height: int, 234 | resized_width: int, 235 | padding: Padding, 236 | apply_sigmoid_activation: bool = False, 237 | resize_method: Optional[str] = None 238 | ) -> np.ndarray: 239 | resized_and_padded = resize_image_to( 240 | image, 241 | ImageSize(height=resized_height, width=resized_width), 242 | resize_method=resize_method 243 | ) 244 | if apply_sigmoid_activation: 245 | resized_and_padded = get_sigmoid(resized_and_padded) 246 | LOGGER.debug('after sigmoid: %r', resized_and_padded.shape) 247 | return remove_padding_and_resize_back( 248 | resized_and_padded, 249 | input_height, input_width, 250 | padding, 251 | resize_method=resize_method 252 | ) 253 | 254 | 255 | ZERO_VECTOR_2D = Vector2D(x=0, y=0) 256 | 257 | 258 | def _scale_and_offset_vector( 259 | vector: Vector2D, scale_vector: Vector2D, offset_vector: Vector2D 260 | ) -> Vector2D: 261 | return Vector2D( 262 | x=vector.x * scale_vector.x + offset_vector.x, 263 | y=vector.y * scale_vector.y + offset_vector.y 264 | ) 265 | 266 | 267 | def scalePose( 268 | pose: Pose, scale_vector: Vector2D, offset_vector: Vector2D 269 | ) -> Pose: 270 | return Pose( 271 | score=pose.score, 272 | keypoints={ 273 | keypoint_id: Keypoint( 274 | score=keypoint.score, 275 | part=keypoint.part, 276 | position=_scale_and_offset_vector( 277 | keypoint.position, 278 | scale_vector, 279 | offset_vector 280 | ) 281 | ) 282 | for keypoint_id, keypoint in pose.keypoints.items() 283 | } 284 | ) 285 | 286 | 287 | def scalePoses( 288 | poses: List[Pose], scale_vector: Vector2D, offset_vector: Vector2D 289 | ) -> List[Pose]: 290 | if ( 291 | scale_vector.x == 1 292 | and scale_vector.y == 1 293 | and offset_vector.x == 0 294 | and offset_vector.y == 0 295 | ): 296 | return poses 297 | return [ 298 | scalePose(pose, scale_vector, offset_vector) 299 | for pose in poses 300 | ] 301 | 302 | 303 | def flipPosesHorizontal(poses: List[Pose], imageWidth: int) -> List[Pose]: 304 | if imageWidth <= 0: 305 | return poses 306 | scale_vector = Vector2D(x=-1, y=1) 307 | offset_vector = Vector2D(x=imageWidth - 1, y=0) 308 | return scalePoses( 309 | poses, 310 | scale_vector, 311 | offset_vector 312 | ) 313 | 314 | 315 | def scaleAndFlipPoses( 316 | poses: List[Pose], 317 | height: int, 318 | width: int, 319 | inputResolutionHeight: int, 320 | inputResolutionWidth: int, 321 | padding: Padding, 322 | flipHorizontal: bool 323 | ) -> List[Pose]: 324 | scale_vector = Vector2D( 325 | y=(height + padding.top + padding.bottom) / (inputResolutionHeight), 326 | x=(width + padding.left + padding.right) / (inputResolutionWidth) 327 | ) 328 | offset_vector = Vector2D( 329 | x=-padding.left, 330 | y=-padding.top 331 | ) 332 | 333 | LOGGER.debug('height: %s', height) 334 | LOGGER.debug('width: %s', width) 335 | LOGGER.debug('inputResolutionHeight: %s', inputResolutionHeight) 336 | LOGGER.debug('inputResolutionWidth: %s', inputResolutionWidth) 337 | LOGGER.debug('scale_vector: %s', scale_vector) 338 | LOGGER.debug('offset_vector: %s', offset_vector) 339 | 340 | scaledPoses = scalePoses( 341 | poses, scale_vector, offset_vector 342 | ) 343 | 344 | if flipHorizontal: 345 | return flipPosesHorizontal(scaledPoses, width) 346 | return scaledPoses 347 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow BodyPix (TF BodyPix) 2 | 3 | [![PyPi version](https://img.shields.io/pypi/v/tf-bodypix)](https://pypi.org/project/tf-bodypix/) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 5 | 6 | A Python implementation of [body-pix](https://github.com/tensorflow/tfjs-models/tree/body-pix-v2.0.4/body-pix). 7 | 8 | Goals of this project is: 9 | 10 | * Python library, making it easy to integrate the BodyPix model 11 | * CLI with limited functionality, mostly for demonstration purpose 12 | 13 | ## Prerequisits 14 | 15 | * Python 3.7+ 16 | 17 | ## Install 18 | 19 | Install with all dependencies: 20 | 21 | ```bash 22 | pip install tf-bodypix[all] 23 | ``` 24 | 25 | Install with minimal or no dependencies: 26 | 27 | ```bash 28 | pip install tf-bodypix 29 | ``` 30 | 31 | Extras are provided to make it easier to provide or exclude dependencies 32 | when using this project as a library: 33 | 34 | | extra name | description 35 | | ---------- | ----------- 36 | | tf | [TensorFlow](https://pypi.org/project/tensorflow/) (required). But you may use your own build. 37 | | tfjs | TensorFlow JS Model support, using [tfjs-graph-converter](https://pypi.org/project/tfjs-graph-converter/) 38 | | tflite | [tflite-runtime](https://pypi.org/project/tflite-runtime/) 39 | | image | Image loading via [Pillow](https://pypi.org/project/Pillow/), required by the CLI. 40 | | video | Video support via [OpenCV](https://pypi.org/project/opencv-python/) 41 | | webcam | Webcam support via [OpenCV](https://pypi.org/project/opencv-python/) and [pyfakewebcam](https://pypi.org/project/pyfakewebcam/) 42 | | all | All of the libraries (except `tflite-runtime`) 43 | 44 | ## Python API 45 | 46 | ```python 47 | from pathlib import Path 48 | import tensorflow as tf 49 | from tf_bodypix.api import download_model, load_model, BodyPixModelPaths 50 | 51 | # setup input and output paths 52 | output_path = Path('./data/example-output') 53 | output_path.mkdir(parents=True, exist_ok=True) 54 | input_url = ( 55 | 'https://www.dropbox.com/s/7tsaqgdp149d8aj/serious-black-businesswoman-sitting-at-desk-in-office-5669603.jpg?dl=1' 56 | ) 57 | local_input_path = tf.keras.utils.get_file(origin=input_url) 58 | 59 | # load model (once) 60 | bodypix_model = load_model(download_model( 61 | BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16 62 | )) 63 | 64 | # get prediction result 65 | image = tf.keras.preprocessing.image.load_img(local_input_path) 66 | image_array = tf.keras.preprocessing.image.img_to_array(image) 67 | result = bodypix_model.predict_single(image_array) 68 | 69 | # simple mask 70 | mask = result.get_mask(threshold=0.75) 71 | tf.keras.preprocessing.image.save_img( 72 | f'{output_path}/output-mask.jpg', 73 | mask 74 | ) 75 | 76 | # colored mask (separate colour for each body part) 77 | colored_mask = result.get_colored_part_mask(mask) 78 | tf.keras.preprocessing.image.save_img( 79 | f'{output_path}/output-colored-mask.jpg', 80 | colored_mask 81 | ) 82 | 83 | # poses 84 | from tf_bodypix.draw import draw_poses # utility function using OpenCV 85 | 86 | poses = result.get_poses() 87 | image_with_poses = draw_poses( 88 | image_array.copy(), # create a copy to ensure we are not modifing the source image 89 | poses, 90 | keypoints_color=(255, 100, 100), 91 | skeleton_color=(100, 100, 255) 92 | ) 93 | tf.keras.preprocessing.image.save_img( 94 | f'{output_path}/output-poses.jpg', 95 | image_with_poses 96 | ) 97 | ``` 98 | 99 | ## CLI 100 | 101 | ### CLI Help 102 | 103 | ```bash 104 | python -m tf_bodypix --help 105 | ``` 106 | 107 | or 108 | 109 | ```bash 110 | python -m tf_bodypix --help 111 | ``` 112 | 113 | ### List Available Models 114 | 115 | ```bash 116 | python -m tf_bodypix list-models 117 | ``` 118 | 119 | The result will be a list of all of the `bodypix` TensorFlow JS models available in the [tfjs-models bucket](https://storage.googleapis.com/tfjs-models/). 120 | 121 | Those URLs can be passed as the `--model-path` arguments below, or to the `download_model` method of the Python API. 122 | 123 | The CLI will download and cache the model from the provided path. If no `--model-path` is provided, it will use a default model (mobilenet). 124 | 125 | To list TensorFlow Lite models instead: 126 | 127 | ```bash 128 | python -m tf_bodypix list-tflite-models 129 | ``` 130 | 131 | ### Inputs and Outputs 132 | 133 | Most commands will work with inputs (source) and outputs. 134 | 135 | The source path can be specified via the `--source` parameter. 136 | 137 | The following inputs are supported: 138 | 139 | | type | description | 140 | | -----| ----------- | 141 | | image | Static image (e.g. `.png`) | 142 | | video | Video (e.g. `.mp4`) | 143 | | webcam | Linux Webcam (`/dev/videoN` or `webcam:0`) | 144 | 145 | If the source path points to an external file (e.g. `https://`), then it will be downloaded and locally cached. 146 | 147 | The output path can be specified via `--output`, unless `--show-output` is used. 148 | 149 | The following outpus are supported: 150 | 151 | | type | description | 152 | | -----| ----------- | 153 | | image_writer | Write to a static image (e.g. `.png`) | 154 | | v4l2 | Linux Virtual Webcam (`/dev/videoN`) | 155 | | window | Display a window (by using `--show-output`) | 156 | 157 | ### Example commands 158 | 159 | #### Creating a simple body mask 160 | 161 | ```bash 162 | python -m tf_bodypix \ 163 | draw-mask \ 164 | --source \ 165 | "https://www.dropbox.com/s/7tsaqgdp149d8aj/serious-black-businesswoman-sitting-at-desk-in-office-5669603.jpg?dl=1" \ 166 | --show-output \ 167 | --threshold=0.75 168 | ``` 169 | 170 | Image Source: [Serious black businesswoman sitting at desk in office](https://www.pexels.com/photo/serious-black-businesswoman-sitting-at-desk-in-office-5669603/) 171 | 172 | #### Add the mask over the original image using `--mask-alpha` 173 | 174 | ```bash 175 | python -m tf_bodypix \ 176 | draw-mask \ 177 | --source \ 178 | "https://www.dropbox.com/s/7tsaqgdp149d8aj/serious-black-businesswoman-sitting-at-desk-in-office-5669603.jpg?dl=1" \ 179 | --show-output \ 180 | --threshold=0.75 \ 181 | --mask-alpha=0.5 182 | ``` 183 | 184 | Image Source: [Serious black businesswoman sitting at desk in office](https://www.pexels.com/photo/serious-black-businesswoman-sitting-at-desk-in-office-5669603/) 185 | 186 | #### Colorize the body mask depending on the body part 187 | 188 | ```bash 189 | python -m tf_bodypix \ 190 | draw-mask \ 191 | --source \ 192 | "https://www.dropbox.com/s/7tsaqgdp149d8aj/serious-black-businesswoman-sitting-at-desk-in-office-5669603.jpg?dl=1" \ 193 | --show-output \ 194 | --threshold=0.75 \ 195 | --mask-alpha=0.5 \ 196 | --colored 197 | ``` 198 | 199 | Image Source: [Serious black businesswoman sitting at desk in office](https://www.pexels.com/photo/serious-black-businesswoman-sitting-at-desk-in-office-5669603/) 200 | 201 | #### Additionally select the body parts 202 | 203 | ```bash 204 | python -m tf_bodypix \ 205 | draw-mask \ 206 | --source \ 207 | "https://www.dropbox.com/s/7tsaqgdp149d8aj/serious-black-businesswoman-sitting-at-desk-in-office-5669603.jpg?dl=1" \ 208 | --show-output \ 209 | --threshold=0.75 \ 210 | --mask-alpha=0.5 \ 211 | --parts left_face right_face \ 212 | --colored 213 | ``` 214 | 215 | Image Source: [Serious black businesswoman sitting at desk in office](https://www.pexels.com/photo/serious-black-businesswoman-sitting-at-desk-in-office-5669603/) 216 | 217 | #### Add mask overlay to a video 218 | 219 | ```bash 220 | python -m tf_bodypix \ 221 | draw-mask \ 222 | --source \ 223 | "https://www.dropbox.com/s/s7jga3f0dreavlb/video-of-a-man-laughing-and-happy-1608393-360p.mp4?dl=1" \ 224 | --show-output \ 225 | --threshold=0.75 \ 226 | --mask-alpha=0.5 \ 227 | --colored 228 | ``` 229 | 230 | Video Source: [Video Of A Man Laughing And Happy](https://www.pexels.com/video/video-of-a-man-laughing-and-happy-1608393/) 231 | 232 | #### Add pose overlay to a video 233 | 234 | ```bash 235 | python -m tf_bodypix \ 236 | draw-pose \ 237 | --source \ 238 | "https://www.dropbox.com/s/pv5v8dkpj5wung7/an-old-man-doing-a-tai-chi-exercise-2882799-360p.mp4?dl=1" \ 239 | --show-output \ 240 | --threshold=0.75 241 | ``` 242 | 243 | #### Blur background of a video 244 | 245 | ```bash 246 | python -m tf_bodypix \ 247 | blur-background \ 248 | --source \ 249 | "https://www.dropbox.com/s/s7jga3f0dreavlb/video-of-a-man-laughing-and-happy-1608393-360p.mp4?dl=1" \ 250 | --show-output \ 251 | --threshold=0.75 \ 252 | --mask-blur=5 \ 253 | --background-blur=20 254 | ``` 255 | 256 | Video Source: [Video Of A Man Laughing And Happy](https://www.pexels.com/video/video-of-a-man-laughing-and-happy-1608393/) 257 | 258 | #### Replace the background of a video 259 | 260 | ```bash 261 | python -m tf_bodypix \ 262 | replace-background \ 263 | --source \ 264 | "https://www.dropbox.com/s/s7jga3f0dreavlb/video-of-a-man-laughing-and-happy-1608393-360p.mp4?dl=1" \ 265 | --background \ 266 | "https://www.dropbox.com/s/b22ss59j6pp83zy/brown-landscape-under-grey-sky-3244513.jpg?dl=1" \ 267 | --show-output \ 268 | --threshold=0.75 \ 269 | --mask-blur=5 270 | ``` 271 | 272 | Video Source: [Video Of A Man Laughing And Happy](https://www.pexels.com/video/video-of-a-man-laughing-and-happy-1608393/) 273 | 274 | Background: [Brown Landscape Under Grey Sky](https://www.pexels.com/photo/brown-landscape-under-grey-sky-3244513/) 275 | 276 | #### Capture Webcam and adding mask overlay 277 | 278 | ```bash 279 | python -m tf_bodypix \ 280 | draw-mask \ 281 | --source webcam:0 \ 282 | --show-output \ 283 | --threshold=0.75 \ 284 | --mask-alpha=0.5 \ 285 | --colored 286 | ``` 287 | 288 | #### Capture Webcam and adding mask overlay, writing to v4l2loopback device 289 | 290 | (replace `/dev/videoN` with the actual virtual video device) 291 | 292 | ```bash 293 | python -m tf_bodypix \ 294 | draw-mask \ 295 | --source webcam:0 \ 296 | --output /dev/videoN \ 297 | --threshold=0.75 \ 298 | --mask-alpha=0.5 \ 299 | --colored 300 | ``` 301 | 302 | #### Capture Webcam and blur background, writing to v4l2loopback device 303 | 304 | (replace `/dev/videoN` with the actual virtual video device) 305 | 306 | ```bash 307 | python -m tf_bodypix \ 308 | blur-background \ 309 | --source webcam:0 \ 310 | --background-blur 20 \ 311 | --output /dev/videoN \ 312 | --threshold=0.75 313 | ``` 314 | 315 | #### Capture Webcam and replace background, writing to v4l2loopback device 316 | 317 | (replace `/dev/videoN` with the actual virtual video device) 318 | 319 | ```bash 320 | python -m tf_bodypix \ 321 | replace-background \ 322 | --source webcam:0 \ 323 | --background \ 324 | "https://www.dropbox.com/s/b22ss59j6pp83zy/brown-landscape-under-grey-sky-3244513.jpg?dl=1" \ 325 | --threshold=0.75 \ 326 | --output /dev/videoN 327 | ``` 328 | 329 | Background: [Brown Landscape Under Grey Sky](https://www.pexels.com/photo/brown-landscape-under-grey-sky-3244513/) 330 | 331 | ## TensorFlow Lite Model support (experimental) 332 | 333 | The model path may also point to a TensorFlow Lite model (`.tflite` extension). Whether that actually improves performance may depend on the platform and available hardware. 334 | 335 | You could convert one of the available TensorFlow JS models to TensorFlow Lite using the following command: 336 | 337 | ```bash 338 | python -m tf_bodypix \ 339 | convert-to-tflite \ 340 | --model-path \ 341 | "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \ 342 | --optimize \ 343 | --quantization-type=float16 \ 344 | --output-model-file "./mobilenet-float-multiplier-075-stride16-float16.tflite" 345 | ``` 346 | 347 | The above command is provided for convenience. 348 | You may use alternative methods depending on your preference and requirements. 349 | 350 | Relevant links: 351 | 352 | * [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/) 353 | * [TF Lite post_training_quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) 354 | * [TF GitHub #40183](https://github.com/tensorflow/tensorflow/issues/40183). 355 | 356 | ## TensorFlow Lite Runtime support (experimental) 357 | 358 | This project can also be used with [tflite-runtime](https://pypi.org/project/tflite-runtime/) instead of full TensorFlow (e.g. by using the `tflite` extra). 359 | However, [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/) would require full TensorFlow. 360 | In order to avoid it, one needs to use a TensorFlow Lite model (see previous section). 361 | 362 | ## Docker Usage 363 | 364 | You could also use the Docker image if you prefer. 365 | The entrypoint will by default delegate to the CLI, except for `python` or `bash` commands. 366 | 367 | ```bash 368 | # pull latest image (you may also use tags) 369 | docker pull de4code/tf-bodypix 370 | ``` 371 | 372 | ```bash 373 | # mount real and virtual webcam devices on linux 374 | docker run --rm \ 375 | --device /dev/video0 \ 376 | --device /dev/video2 \ 377 | de4code/tf-bodypix \ 378 | blur-background \ 379 | --source /dev/video0 \ 380 | --output /dev/video2 \ 381 | --background-blur 20 \ 382 | --threshold=0.75 383 | ``` 384 | 385 | ```bash 386 | # mount x11 display on linux 387 | docker run --rm \ 388 | --net=host \ 389 | --volume /tmp/.X11-unix:/tmp/.X11-unix \ 390 | --volume ${HOME}/.Xauthority:/root/.Xauthority \ 391 | --env DISPLAY \ 392 | de4code/tf-bodypix \ 393 | replace-background \ 394 | --source \ 395 | "https://www.dropbox.com/s/s7jga3f0dreavlb/video-of-a-man-laughing-and-happy-1608393-360p.mp4?dl=1" \ 396 | --background \ 397 | "https://www.dropbox.com/s/b22ss59j6pp83zy/brown-landscape-under-grey-sky-3244513.jpg?dl=1" \ 398 | --show-output \ 399 | --threshold=0.75 \ 400 | --mask-blur=5 401 | ``` 402 | 403 | ## Example Media 404 | 405 | Here are a few example media files you could try. 406 | 407 | Images: 408 | 409 | * [Serious black businesswoman sitting at desk in office](https://www.dropbox.com/s/7tsaqgdp149d8aj/serious-black-businesswoman-sitting-at-desk-in-office-5669603.jpg?dl=1) ([Source](https://www.pexels.com/photo/serious-black-businesswoman-sitting-at-desk-in-office-5669603/)) 410 | * [Woman Wearing Gray Notch Lapel Suit Jacket](https://www.dropbox.com/s/ygfudebvbm1pksk/woman-wearing-gray-notch-lapel-suit-jacket-2381069-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/woman-wearing-gray-notch-lapel-suit-jacket-2381069/)) 411 | * [Smiling Woman Standing In Front Of A Colorful Flag](https://www.dropbox.com/s/ddyj89vkz7cmzmg/smiling-woman-standing-in-front-of-a-colorful-flag-5255422-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/smiling-woman-standing-in-front-of-a-colorful-flag-5255422/)) 412 | * [Man and Woman Smiling Inside Building](https://www.dropbox.com/s/5z7v5wtwx3dmrdu/man-and-woman-smiling-inside-building-1367269-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/man-and-woman-smiling-inside-building-1367269/)) 413 | * [Two Woman in Black Sits on Chair Near Table](https://www.dropbox.com/s/dq9e2dv86qd9ror/two-woman-in-black-sits-on-chair-near-table-1181605-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/two-woman-in-black-sits-on-chair-near-table-1181605/)) 414 | * [Female barista in beanie and apron resting chin on had](https://www.dropbox.com/s/88qb3yldsb4l2id/female-barista-in-beanie-and-apron-resting-chin-on-had-4350057-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/female-barista-in-beanie-and-apron-resting-chin-on-had-4350057/)) 415 | * [Smiling Woman Holding White Android Smartphone While Sitting Front of Table](https://www.dropbox.com/s/43awel6e1mxja5v/smiling-woman-holding-white-android-smartphone-while-sitting-front-of-table-1462631-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/smiling-woman-holding-white-android-smartphone-while-sitting-front-of-table-1462631/)) 416 | * [Woman Having Coffee and Rice Bowl](https://www.dropbox.com/s/zndltp65n93poy2/woman-having-coffee-and-rice-bowl-4058316-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/woman-having-coffee-and-rice-bowl-4058316/)) 417 | * [Woman Smiling While Holding a Coffee Cup](https://www.dropbox.com/s/0txws4j79o9hewr/woman-smiling-while-holding-a-coffee-cup-6787913-small.jpg?dl=1) ([Source](https://www.pexels.com/photo/woman-smiling-while-holding-a-coffee-cup-6787913/)) 418 | 419 | Videos: 420 | 421 | * [Video Of A Man Laughing And Happy](https://www.dropbox.com/s/s7jga3f0dreavlb/video-of-a-man-laughing-and-happy-1608393-360p.mp4?dl=1) ([Source](https://www.pexels.com/video/video-of-a-man-laughing-and-happy-1608393/)) 422 | * [A Group Of People In A Business Meeting](https://www.dropbox.com/s/6pc6m9b0zd2mpsv/a-group-of-people-in-a-business-meeting-6774216-360p.mp4?dl=1) ([Source](https://www.pexels.com/video/a-group-of-people-in-a-business-meeting-6774216/)) 423 | * [An Old Man Doing A Tai Chi Exercise](https://www.dropbox.com/s/pv5v8dkpj5wung7/an-old-man-doing-a-tai-chi-exercise-2882799-360p.mp4?dl=1) ([Source](https://www.pexels.com/video/an-old-man-doing-a-tai-chi-exercise-2882799/)) 424 | 425 | Background: 426 | 427 | * [Brown Landscape Under Grey Sky](https://www.dropbox.com/s/b22ss59j6pp83zy/brown-landscape-under-grey-sky-3244513.jpg?dl=1) ([Source](https://www.pexels.com/photo/brown-landscape-under-grey-sky-3244513/)) 428 | 429 | ## Experimental Downstream Projects 430 | 431 | * [Layered Vision](https://github.com/de-code/layered-vision) is an experimental project using the `tf-bodypix` Python API. 432 | 433 | ## Acknowledgements 434 | 435 | * [Original TensorFlow JS Implementation of BodyPix](https://github.com/tensorflow/tfjs-models/tree/body-pix-v2.0.4/body-pix) 436 | * [Linux-Fake-Background-Webcam](https://github.com/fangfufu/Linux-Fake-Background-Webcam), an implementation of the [blog post](https://elder.dev/posts/open-source-virtual-background/) describing using the TensorFlow JS implementation with Python via a Socket API. 437 | * [tfjs-to-tf](https://github.com/patlevin/tfjs-to-tf) for providing an easy way to convert TensorFlow JS models 438 | * [virtual_webcam_background](https://github.com/allo-/virtual_webcam_background) for a great pure Python implementation 439 | -------------------------------------------------------------------------------- /tf_bodypix/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from abc import ABC, abstractmethod 4 | from collections import namedtuple 5 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | 9 | try: 10 | import tensorflow as tf 11 | tflite = tf.lite 12 | except ImportError: 13 | tf = None 14 | import tflite_runtime.interpreter as tflite # type: ignore 15 | 16 | try: 17 | import tfjs_graph_converter 18 | except ImportError: 19 | tfjs_graph_converter = None 20 | 21 | 22 | from tf_bodypix.bodypix_js_utils.decode_part_map import ( 23 | to_mask_tensor 24 | ) 25 | 26 | from tf_bodypix.bodypix_js_utils.part_channels import PART_CHANNELS 27 | 28 | from tf_bodypix.bodypix_js_utils.output_rendering_util import ( 29 | RAINBOW_PART_COLORS 30 | ) 31 | 32 | from tf_bodypix.bodypix_js_utils.util import ( 33 | get_bodypix_input_resolution_height_and_width, 34 | pad_and_resize_to, 35 | scale_and_crop_to_input_tensor_shape, 36 | scaleAndFlipPoses, 37 | Padding 38 | ) 39 | 40 | from tf_bodypix.bodypix_js_utils.types import Pose 41 | 42 | from tf_bodypix.bodypix_js_utils.multi_person.decode_multiple_poses import ( 43 | decodeMultiplePoses 44 | ) 45 | 46 | 47 | LOGGER = logging.getLogger(__name__) 48 | 49 | 50 | PART_CHANNEL_INDEX_BY_NAME = { 51 | name: index 52 | for index, name in enumerate(PART_CHANNELS) 53 | } 54 | 55 | 56 | ImageSize = namedtuple('ImageSize', ('height', 'width')) 57 | 58 | 59 | T_Color = Union[Tuple[int, int, int], Tuple[int, int, int, int]] 60 | 61 | 62 | class ModelArchitectureNames: 63 | MOBILENET_V1 = 'mobilenet_v1' 64 | RESNET_50 = 'resnet50' 65 | 66 | 67 | VALID_MODEL_ARCHITECTURE_NAMES = { 68 | ModelArchitectureNames.MOBILENET_V1, 69 | ModelArchitectureNames.RESNET_50 70 | } 71 | 72 | 73 | # see https://github.com/tensorflow/tfjs-models/blob/body-pix-v2.0.4/body-pix/src/resnet.ts 74 | IMAGE_NET_MEAN = [-123.15, -115.90, -103.06] 75 | 76 | 77 | class DictPredictWrapper: 78 | def __init__( 79 | self, 80 | wrapped: Callable[[np.ndarray], Union[dict, list]], 81 | output_names: List[str] 82 | ): 83 | self.wrapped = wrapped 84 | self.output_names = output_names 85 | 86 | def __call__(self, *args, **kwargs): 87 | result = self.wrapped(*args, **kwargs) 88 | if isinstance(result, list): 89 | return dict(zip(self.output_names, result)) 90 | return result 91 | 92 | 93 | class BodyPixArchitecture(ABC): 94 | def __init__(self, architecture_name: str): 95 | self.architecture_name = architecture_name 96 | 97 | @abstractmethod 98 | def __call__(self, image: np.ndarray) -> dict: 99 | pass 100 | 101 | 102 | def _get_imagenet_preprocessed_image_using_numpy( 103 | image_array: np.ndarray 104 | ) -> np.ndarray: 105 | result = np.divide(image_array, 127.5, dtype=np.float32) 106 | result = np.subtract(result, 1, out=result) 107 | LOGGER.debug( 108 | 'imagenet preprocessed: %r (%r) -> %r (%r)', 109 | image_array.shape, image_array.dtype, 110 | result.shape, result.dtype 111 | ) 112 | return result 113 | 114 | 115 | def _get_mobilenet_preprocessed_image( 116 | image_array: np.ndarray 117 | ) -> np.ndarray: 118 | if tf is not None: 119 | return tf.keras.applications.mobilenet.preprocess_input(image_array) 120 | return _get_imagenet_preprocessed_image_using_numpy(image_array) 121 | 122 | 123 | class MobileNetBodyPixPredictWrapper(BodyPixArchitecture): 124 | def __init__(self, predict_fn: Callable[[np.ndarray], dict]): 125 | super().__init__(ModelArchitectureNames.MOBILENET_V1) 126 | self.predict_fn = predict_fn 127 | 128 | def __call__(self, image: np.ndarray) -> dict: 129 | if len(image.shape) == 3: 130 | if tf is not None: 131 | image = image[tf.newaxis, ...] 132 | else: 133 | image = np.expand_dims(image, axis=0) 134 | return self.predict_fn( 135 | _get_mobilenet_preprocessed_image(image) 136 | ) 137 | 138 | 139 | class ResNet50BodyPixPredictWrapper(BodyPixArchitecture): 140 | def __init__(self, predict_fn: Callable[[np.ndarray], dict]): 141 | super().__init__(ModelArchitectureNames.RESNET_50) 142 | self.predict_fn = predict_fn 143 | 144 | def __call__(self, image: np.ndarray) -> dict: 145 | image = np.add(image, np.array(IMAGE_NET_MEAN)) 146 | # Note: tf.keras.applications.resnet50.preprocess_input is rotating the image as well? 147 | if len(image.shape) == 3: 148 | if tf is not None: 149 | image = image[tf.newaxis, ...] 150 | else: 151 | image = np.expand_dims(image, axis=0) 152 | if tf is not None: 153 | image = tf.constant(tf.cast(image, tf.float32)) 154 | else: 155 | image = np.asarray(image).astype(np.float32) 156 | LOGGER.debug('image.shape: %s (%s)', image.shape, image.dtype) 157 | predictions = self.predict_fn(image) 158 | return predictions 159 | 160 | 161 | def get_colored_part_mask_for_segmentation( 162 | part_segmentation: np.ndarray, 163 | part_colors: Optional[List[T_Color]] = None, 164 | default_color: Optional[T_Color] = None 165 | ): 166 | _part_colors = ( 167 | part_colors if part_colors is not None 168 | else RAINBOW_PART_COLORS 169 | ) 170 | part_colors_array = np.asarray(_part_colors) 171 | if default_color is None: 172 | default_color = (0, 0, 0) 173 | # np.take will take the last value if the index is -1 174 | part_colors_with_default_array = np.append( 175 | part_colors_array, 176 | np.asarray([default_color]), 177 | axis=-2 178 | ) 179 | LOGGER.debug('part_colors_with_default_array.shape: %s', part_colors_with_default_array.shape) 180 | LOGGER.debug('part_segmentation.shape: %s', part_segmentation.shape) 181 | part_segmentation_colored = np.take( 182 | part_colors_with_default_array, 183 | part_segmentation, 184 | axis=-2 185 | ) 186 | LOGGER.debug('part_segmentation_colored.shape: %s', part_segmentation_colored.shape) 187 | return part_segmentation_colored 188 | 189 | 190 | def is_all_part_names(part_names: Optional[List[str]]) -> bool: 191 | if not part_names: 192 | return True 193 | part_names_set = set(part_names) 194 | if len(part_names_set) == len(PART_CHANNELS): 195 | return True 196 | return False 197 | 198 | 199 | def get_filtered_part_segmentation( 200 | part_segmentation: np.ndarray, 201 | part_names: Optional[List[str]] = None 202 | ): 203 | if is_all_part_names(part_names): 204 | return part_segmentation 205 | assert part_names 206 | part_names_set = set(part_names) 207 | part_filter_mask = np.asarray([ 208 | ( 209 | part_index 210 | if part_name in part_names_set 211 | else -1 212 | ) 213 | for part_index, part_name in enumerate(PART_CHANNELS) 214 | ]) 215 | LOGGER.debug('part_filter_mask: %s', part_filter_mask) 216 | return part_filter_mask[part_segmentation] 217 | 218 | 219 | class BodyPixResultWrapper: 220 | def __init__( 221 | self, 222 | segments_logits: np.ndarray, 223 | part_heatmap_logits: np.ndarray, 224 | heatmap_logits: Optional[np.ndarray], 225 | short_offsets: Optional[np.ndarray], 226 | long_offsets: Optional[np.ndarray], 227 | part_offsets: Optional[np.ndarray], 228 | displacement_fwd: Optional[np.ndarray], 229 | displacement_bwd: Optional[np.ndarray], 230 | output_stride: int, 231 | original_size: ImageSize, 232 | model_input_size: ImageSize, 233 | padding: Padding): 234 | self.segments_logits = segments_logits 235 | self.part_heatmap_logits = part_heatmap_logits 236 | self.heatmap_logits = heatmap_logits 237 | self.short_offsets = short_offsets 238 | self.long_offsets = long_offsets 239 | self.part_offsets = part_offsets 240 | self.displacement_fwd = displacement_fwd 241 | self.displacement_bwd = displacement_bwd 242 | self.output_stride = output_stride 243 | self.original_size = original_size 244 | self.model_input_size = model_input_size 245 | self.padding = padding 246 | 247 | def _get_scaled_scores( 248 | self, 249 | logits: np.ndarray, 250 | resize_method: Optional[str] = None 251 | ) -> np.ndarray: 252 | LOGGER.debug('logits: %r', logits.shape) 253 | return scale_and_crop_to_input_tensor_shape( 254 | logits, 255 | self.original_size.height, 256 | self.original_size.width, 257 | self.model_input_size.height, 258 | self.model_input_size.width, 259 | padding=self.padding, 260 | apply_sigmoid_activation=True, 261 | resize_method=resize_method 262 | ) 263 | 264 | def get_scaled_segment_scores(self, **kwargs) -> np.ndarray: 265 | return self._get_scaled_scores(self.segments_logits, **kwargs) 266 | 267 | def get_scaled_part_heatmap_scores(self, **kwargs) -> np.ndarray: 268 | return self._get_scaled_scores(self.part_heatmap_logits, **kwargs) 269 | 270 | def get_scaled_part_segmentation( 271 | self, 272 | mask: Optional[np.ndarray] = None, 273 | part_names: Optional[List[str]] = None, 274 | outside_mask_value: int = -1, 275 | resize_method: Optional[str] = None 276 | ) -> np.ndarray: 277 | scaled_part_heatmap_argmax = np.argmax( 278 | self.get_scaled_part_heatmap_scores(resize_method=resize_method), 279 | -1 280 | ) 281 | LOGGER.debug('scaled_part_heatmap_argmax.shape: %s', scaled_part_heatmap_argmax.shape) 282 | if part_names: 283 | scaled_part_heatmap_argmax = get_filtered_part_segmentation( 284 | scaled_part_heatmap_argmax, 285 | part_names 286 | ) 287 | LOGGER.debug( 288 | 'scaled_part_heatmap_argmax.shape (filtered): %s', 289 | scaled_part_heatmap_argmax.shape 290 | ) 291 | if mask is not None: 292 | LOGGER.debug('mask.shape: %s', mask.shape) 293 | return np.where( 294 | np.squeeze(mask, axis=-1), 295 | scaled_part_heatmap_argmax, 296 | np.asarray([outside_mask_value]) 297 | ) 298 | return scaled_part_heatmap_argmax 299 | 300 | def get_mask( 301 | self, 302 | threshold: float, 303 | resize_method: Optional[str] = None, 304 | **kwargs 305 | ) -> np.ndarray: 306 | return to_mask_tensor( 307 | self.get_scaled_segment_scores(resize_method=resize_method), 308 | threshold, 309 | **kwargs 310 | ) 311 | 312 | def get_part_mask( 313 | self, 314 | mask: np.ndarray, 315 | part_names: Optional[List[str]] = None, 316 | resize_method: Optional[str] = None 317 | ) -> np.ndarray: 318 | if is_all_part_names(part_names): 319 | return mask 320 | part_segmentation = self.get_scaled_part_segmentation( 321 | mask, part_names=part_names, resize_method=resize_method 322 | ) 323 | part_mask = np.where( 324 | np.expand_dims(part_segmentation, -1) >= 0, 325 | mask, 326 | 0 327 | ) 328 | LOGGER.debug('part_mask.shape: %s', part_mask.shape) 329 | return part_mask 330 | 331 | def get_colored_part_mask( 332 | self, 333 | mask: np.ndarray, 334 | part_colors: Optional[List[T_Color]] = None, 335 | part_names: Optional[List[str]] = None, 336 | resize_method: Optional[str] = None 337 | ) -> np.ndarray: 338 | part_segmentation = self.get_scaled_part_segmentation( 339 | mask, part_names=part_names, resize_method=resize_method 340 | ) 341 | return get_colored_part_mask_for_segmentation( 342 | part_segmentation, 343 | part_colors=part_colors 344 | ) 345 | 346 | def get_poses(self) -> List[Pose]: 347 | assert self.heatmap_logits is not None 348 | assert self.short_offsets is not None 349 | assert self.displacement_fwd is not None 350 | assert self.displacement_bwd is not None 351 | poses = decodeMultiplePoses( 352 | scoresBuffer=np.asarray(self.heatmap_logits[0]), 353 | offsetsBuffer=np.asarray(self.short_offsets[0]), 354 | displacementsFwdBuffer=np.asarray(self.displacement_fwd[0]), 355 | displacementsBwdBuffer=np.asarray(self.displacement_bwd[0]), 356 | outputStride=self.output_stride, 357 | maxPoseDetections=2 358 | ) 359 | scaled_poses = scaleAndFlipPoses( 360 | poses, 361 | height=self.original_size.height, 362 | width=self.original_size.width, 363 | inputResolutionHeight=self.model_input_size.height, 364 | inputResolutionWidth=self.model_input_size.width, 365 | padding=self.padding, 366 | flipHorizontal=False 367 | ) 368 | return scaled_poses 369 | 370 | 371 | class BodyPixModelWrapper: 372 | def __init__( 373 | self, 374 | predict_fn: Callable[[np.ndarray], Dict[str, Any]], 375 | output_stride: int, 376 | internal_resolution: float = 0.5): 377 | self.predict_fn = predict_fn 378 | self.internal_resolution = internal_resolution 379 | self.output_stride = output_stride 380 | 381 | def get_bodypix_input_size(self, original_size: ImageSize) -> ImageSize: 382 | return ImageSize( 383 | *get_bodypix_input_resolution_height_and_width( 384 | self.internal_resolution, self.output_stride, 385 | original_size.height, original_size.width 386 | ) 387 | ) 388 | 389 | def get_padded_and_resized( 390 | self, image: np.ndarray, model_input_size: ImageSize 391 | ) -> Tuple[np.ndarray, Padding]: 392 | LOGGER.debug( 393 | 'pad_and_resize_to: image.shape=%s (%r), model_input_size=%s', 394 | image.shape, image.dtype, model_input_size 395 | ) 396 | return pad_and_resize_to( 397 | image, 398 | model_input_size.height, 399 | model_input_size.width 400 | ) 401 | 402 | def find_optional_tensor_in_map( 403 | self, 404 | tensor_map: Dict[str, np.ndarray], 405 | name: str 406 | ) -> Optional[np.ndarray]: 407 | if name in tensor_map: 408 | return tensor_map[name] 409 | for key, value in tensor_map.items(): 410 | if name in key: 411 | return value 412 | return None 413 | 414 | def find_required_tensor_in_map( 415 | self, 416 | tensor_map: Dict[str, np.ndarray], 417 | name: str 418 | ) -> np.ndarray: 419 | value = self.find_optional_tensor_in_map(tensor_map, name) 420 | if value is not None: 421 | return value 422 | raise ValueError('tensor with name %r not found in %s' % ( 423 | name, tensor_map.keys() 424 | )) 425 | 426 | def predict_single(self, image: np.ndarray) -> BodyPixResultWrapper: 427 | original_size = ImageSize(*image.shape[:2]) 428 | LOGGER.debug('original_size: %r (%r)', original_size, image.dtype) 429 | model_input_size = self.get_bodypix_input_size(original_size) 430 | LOGGER.debug('model_input_size: %r', model_input_size) 431 | model_input_image, padding = self.get_padded_and_resized(image, model_input_size) 432 | LOGGER.debug( 433 | 'model_input_image: %r (%r)', model_input_image.shape, model_input_image.dtype 434 | ) 435 | LOGGER.debug('predict_fn: %r', self.predict_fn) 436 | 437 | tensor_map = self.predict_fn(model_input_image) 438 | 439 | LOGGER.debug('tensor_map type: %s', type(tensor_map)) 440 | LOGGER.debug('tensor_map keys: %s', tensor_map.keys()) 441 | LOGGER.debug('tensor_map shapes: %s', { 442 | k: t.shape 443 | for k, t in tensor_map.items() 444 | }) 445 | 446 | return BodyPixResultWrapper( 447 | segments_logits=self.find_required_tensor_in_map( 448 | tensor_map, 'float_segments' 449 | ), 450 | part_heatmap_logits=self.find_required_tensor_in_map( 451 | tensor_map, 'float_part_heatmaps' 452 | ), 453 | heatmap_logits=self.find_required_tensor_in_map( 454 | tensor_map, 'float_heatmaps' 455 | ), 456 | short_offsets=self.find_required_tensor_in_map( 457 | tensor_map, 'float_short_offsets' 458 | ), 459 | long_offsets=self.find_required_tensor_in_map( 460 | tensor_map, 'float_long_offsets' 461 | ), 462 | part_offsets=self.find_required_tensor_in_map( 463 | tensor_map, 'float_part_offsets' 464 | ), 465 | displacement_fwd=self.find_required_tensor_in_map( 466 | tensor_map, 'displacement_fwd' 467 | ), 468 | displacement_bwd=self.find_required_tensor_in_map( 469 | tensor_map, 'displacement_bwd' 470 | ), 471 | original_size=original_size, 472 | model_input_size=model_input_size, 473 | output_stride=self.output_stride, 474 | padding=padding 475 | ) 476 | 477 | 478 | def get_structured_output_names(structured_outputs: List['tf.Tensor']) -> List[str]: 479 | return [ 480 | tensor.name.replace(':0', '') 481 | for tensor in structured_outputs 482 | ] 483 | 484 | 485 | def to_number_of_dimensions(data: np.ndarray, dimension_count: int) -> np.ndarray: 486 | while len(data.shape) > dimension_count: 487 | data = data[0] 488 | while len(data.shape) < dimension_count: 489 | data = np.expand_dims(data, axis=0) 490 | return data 491 | 492 | 493 | def load_tflite_model(model_path: str): 494 | # Load TFLite model and allocate tensors. 495 | interpreter = tflite.Interpreter(model_path=model_path) 496 | interpreter.allocate_tensors() 497 | 498 | input_details = interpreter.get_input_details() 499 | LOGGER.debug('input_details: %s', input_details) 500 | input_names = [item['name'] for item in input_details] 501 | LOGGER.debug('input_names: %s', input_names) 502 | input_details_map = dict(zip(input_names, input_details)) 503 | 504 | output_details = interpreter.get_output_details() 505 | LOGGER.debug('output_details: %s', output_details) 506 | output_names = [item['name'] for item in output_details] 507 | LOGGER.debug('output_names: %s', output_names) 508 | 509 | try: 510 | image_input = input_details_map['image'] 511 | except KeyError: 512 | assert len(input_details_map) == 1 513 | image_input = list(input_details_map.values())[0] 514 | input_shape = image_input['shape'] 515 | LOGGER.debug('input_shape: %s', input_shape) 516 | 517 | def predict(image_data: np.ndarray): 518 | nonlocal input_shape 519 | LOGGER.debug( 520 | 'tflite predict, original image_data.shape=%s (%s)', 521 | image_data.shape, image_data.dtype 522 | ) 523 | image_data = to_number_of_dimensions(image_data, len(input_shape)) 524 | LOGGER.debug('tflite predict, image_data.shape=%s (%s)', image_data.shape, image_data.dtype) 525 | height, width, *_ = image_data.shape 526 | if tuple(image_data.shape) != tuple(input_shape): 527 | LOGGER.info('resizing input tensor: %s -> %s', tuple(input_shape), image_data.shape) 528 | interpreter.resize_tensor_input(image_input['index'], list(image_data.shape)) 529 | interpreter.allocate_tensors() 530 | input_shape = image_data.shape 531 | interpreter.set_tensor(image_input['index'], image_data) 532 | if 'image_size' in input_details_map: 533 | interpreter.set_tensor( 534 | input_details_map['image_size']['index'], 535 | np.array([height, width], dtype=np.float_) 536 | ) 537 | 538 | interpreter.invoke() 539 | 540 | # The function `get_tensor()` returns a copy of the tensor data. 541 | # Use `tensor()` in order to get a pointer to the tensor. 542 | return { 543 | item['name']: interpreter.get_tensor(item['index']) 544 | for item in output_details 545 | } 546 | return predict 547 | 548 | 549 | def load_using_saved_model_and_get_predict_function(model_path): 550 | loaded = tf.saved_model.load(model_path) 551 | LOGGER.debug('loaded: %s', loaded) 552 | LOGGER.debug('signature keys: %s', list(loaded.signatures.keys())) 553 | infer = loaded.signatures["serving_default"] 554 | LOGGER.info('structured_outputs: %s', infer.structured_outputs) 555 | return infer 556 | 557 | 558 | def load_using_tfjs_graph_converter_and_get_predict_function( 559 | model_path: str 560 | ) -> Callable[[np.ndarray], dict]: 561 | if tfjs_graph_converter is None: 562 | raise ImportError('tfjs_graph_converter required') 563 | graph = tfjs_graph_converter.api.load_graph_model(model_path) 564 | tf_fn = tfjs_graph_converter.api.graph_to_function_v2(graph) 565 | return DictPredictWrapper( 566 | tf_fn, 567 | get_structured_output_names(tf_fn.structured_outputs) 568 | ) 569 | 570 | 571 | def load_model_and_get_predict_function( 572 | model_path: str 573 | ) -> Callable[[np.ndarray], dict]: 574 | if model_path.endswith('.tflite'): 575 | return load_tflite_model(model_path) 576 | try: 577 | return load_using_saved_model_and_get_predict_function(model_path) 578 | except OSError: 579 | return load_using_tfjs_graph_converter_and_get_predict_function(model_path) 580 | 581 | 582 | def get_output_stride_from_model_path(model_path: str) -> int: 583 | match = re.search(r'stride(\d+)|_(\d+)_quant', model_path) 584 | if not match: 585 | raise ValueError('cannot extract output stride from model path: %r' % model_path) 586 | return int(match.group(1) or match.group(2)) 587 | 588 | 589 | def get_architecture_from_model_path(model_path: str) -> str: 590 | model_path_lower = model_path.lower() 591 | if 'mobilenet' in model_path_lower: 592 | return ModelArchitectureNames.MOBILENET_V1 593 | if 'resnet' in model_path_lower: 594 | return ModelArchitectureNames.RESNET_50 595 | raise ValueError('cannot extract model architecture from model path: %r' % model_path) 596 | 597 | 598 | def load_model( 599 | model_path: str, 600 | output_stride: Optional[int] = None, 601 | architecture_name: Optional[str] = None, 602 | **kwargs 603 | ): 604 | if not output_stride: 605 | output_stride = get_output_stride_from_model_path(model_path) 606 | if not architecture_name: 607 | architecture_name = get_architecture_from_model_path(model_path) 608 | predict_fn = load_model_and_get_predict_function(model_path) 609 | architecture_wrapper: BodyPixArchitecture 610 | if architecture_name == ModelArchitectureNames.MOBILENET_V1: 611 | architecture_wrapper = MobileNetBodyPixPredictWrapper(predict_fn) 612 | elif architecture_name == ModelArchitectureNames.RESNET_50: 613 | architecture_wrapper = ResNet50BodyPixPredictWrapper(predict_fn) 614 | else: 615 | ValueError('unsupported architecture: %s' % architecture_name) 616 | return BodyPixModelWrapper( 617 | architecture_wrapper, 618 | output_stride=output_stride, 619 | **kwargs 620 | ) 621 | -------------------------------------------------------------------------------- /tf_bodypix/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import re 5 | from abc import ABC, abstractmethod 6 | from contextlib import ExitStack 7 | from itertools import cycle 8 | from pathlib import Path 9 | from time import time, sleep 10 | from typing import ContextManager, Dict, List, Optional, Sequence 11 | 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" 13 | 14 | # pylint: disable=wrong-import-position 15 | # flake8: noqa: E402 16 | 17 | try: 18 | import tensorflow as tf 19 | except ImportError: 20 | tf = None 21 | import numpy as np 22 | 23 | from tf_bodypix.utils.timer import LoggingTimer 24 | from tf_bodypix.utils.image import ( 25 | ImageSize, 26 | resize_image_to, 27 | get_image_size, 28 | box_blur_image 29 | ) 30 | from tf_bodypix.utils.s3 import iter_s3_file_urls 31 | from tf_bodypix.download import ( 32 | ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS, 33 | BodyPixModelPaths, 34 | TensorFlowLiteBodyPixModelPaths, 35 | download_model 36 | ) 37 | from tf_bodypix.tflite import get_tflite_converter_for_model_path 38 | from tf_bodypix.model import ( 39 | load_model, 40 | VALID_MODEL_ARCHITECTURE_NAMES, 41 | PART_CHANNELS, 42 | BodyPixModelWrapper, 43 | BodyPixResultWrapper 44 | ) 45 | from tf_bodypix.source import get_image_source, get_threaded_image_source, T_ImageSource 46 | from tf_bodypix.sink import ( 47 | T_OutputSink, 48 | get_image_output_sink_for_path, 49 | get_show_image_output_sink 50 | ) 51 | try: 52 | from tf_bodypix.draw import draw_poses 53 | except ImportError as exc: 54 | _draw_import_exc = exc 55 | def draw_poses(*_, **__): # type: ignore 56 | raise _draw_import_exc 57 | 58 | 59 | LOGGER = logging.getLogger(__name__) 60 | 61 | 62 | DEFAULT_MODEL_TF_PATH = BodyPixModelPaths.MOBILENET_FLOAT_50_STRIDE_16 63 | 64 | 65 | DEFAULT_MODEL_TFLITE_PATH = TensorFlowLiteBodyPixModelPaths.MOBILENET_FLOAT_75_STRIDE_16_FLOAT16 66 | 67 | 68 | DEFAULT_MODEL_PATH = ( 69 | DEFAULT_MODEL_TF_PATH if tf is not None 70 | else DEFAULT_MODEL_TFLITE_PATH 71 | ) 72 | 73 | 74 | class SubCommand(ABC): 75 | def __init__(self, name, description): 76 | self.name = name 77 | self.description = description 78 | 79 | @abstractmethod 80 | def add_arguments(self, parser: argparse.ArgumentParser): 81 | pass 82 | 83 | @abstractmethod 84 | def run(self, args: argparse.Namespace): 85 | pass 86 | 87 | 88 | def add_common_arguments(parser: argparse.ArgumentParser): 89 | parser.add_argument( 90 | "--debug", 91 | action="store_true", 92 | help="Enable debug logging" 93 | ) 94 | 95 | 96 | def add_model_arguments(parser: argparse.ArgumentParser): 97 | parser.add_argument( 98 | "--model-path", 99 | default=DEFAULT_MODEL_PATH, 100 | help="The path or URL to the bodypix model." 101 | ) 102 | parser.add_argument( 103 | "--model-architecture", 104 | choices=VALID_MODEL_ARCHITECTURE_NAMES, 105 | help=( 106 | "The model architecture." 107 | " It will be guessed from the model path if not specified." 108 | ) 109 | ) 110 | parser.add_argument( 111 | "--output-stride", 112 | type=int, 113 | help=( 114 | "The output stride to use." 115 | " It will be guessed from the model path if not specified." 116 | ) 117 | ) 118 | parser.add_argument( 119 | "--internal-resolution", 120 | type=float, 121 | default=0.5, 122 | help=( 123 | "The internal resolution factor to resize the input image to" 124 | " before passing it the model." 125 | ) 126 | ) 127 | parser.add_argument( 128 | "--threshold", 129 | type=float, 130 | default=0.75, 131 | help="The mask threshold." 132 | ) 133 | parser.add_argument( 134 | "--mask-blur", 135 | type=int, 136 | default=0, 137 | help="The blur radius for the mask." 138 | ) 139 | parser.add_argument( 140 | "--mask-mean-count", 141 | type=int, 142 | default=0, 143 | help="The number of masks to average to smooth the results." 144 | ) 145 | parser.add_argument( 146 | "--mask-cache-time", 147 | type=float, 148 | default=0, 149 | help=( 150 | "For how long, in seconds, the mask model result should be cached." 151 | " e.g. if the model is very slow, you could let it calculate every second only." 152 | " of course that would be visible when moving quickly" 153 | ) 154 | ) 155 | 156 | 157 | def _fourcc_type(text: str) -> str: 158 | if not text: 159 | return text 160 | if len(text) != 4: 161 | raise TypeError( 162 | 'fourcc code must have exactly four characters, e.g. MJPG; but was: %r' % text 163 | ) 164 | return text 165 | 166 | 167 | def add_source_arguments(parser: argparse.ArgumentParser): 168 | source_group = parser.add_argument_group('source') 169 | source_group.add_argument( 170 | "--source", 171 | required=True, 172 | help="The path or URL to the source image or webcam source." 173 | ) 174 | image_size_help = ( 175 | "If width and height are specified, the source will be resized." 176 | "In the case of the webcam, it will be asked to produce that resolution if possible" 177 | ) 178 | source_group.add_argument( 179 | "--source-width", 180 | type=int, 181 | help=image_size_help 182 | ) 183 | source_group.add_argument( 184 | "--source-height", 185 | type=int, 186 | help=image_size_help 187 | ) 188 | source_group.add_argument( 189 | "--source-fourcc", 190 | type=_fourcc_type, 191 | default="MJPG", 192 | help="The fourcc code to select the source to, e.g. MJPG" 193 | ) 194 | source_group.add_argument( 195 | "--source-fps", 196 | type=int, 197 | default=None, 198 | help=( 199 | "Limit the source frame rate to desired FPS." 200 | " If provided, it will attempt to set the frame rate on the source device if supported." 201 | " Otherwise it will slow down the frame rate." 202 | " Use '0' for a fast as possible fps." 203 | ) 204 | ) 205 | source_group.add_argument( 206 | "--source-threaded", 207 | action='store_true', 208 | help="if set, will read from the source in a thread (experimental)." 209 | ) 210 | 211 | 212 | def add_output_arguments(parser: argparse.ArgumentParser): 213 | output_group = parser.add_mutually_exclusive_group(required=True) 214 | output_group.add_argument( 215 | "--show-output", 216 | action="store_true", 217 | help="Shows the output in a window." 218 | ) 219 | output_group.add_argument( 220 | "--output", 221 | help="The path to the output file." 222 | ) 223 | 224 | 225 | def get_image_source_for_args(args: argparse.Namespace) -> T_ImageSource: 226 | image_size = None 227 | if args.source_width and args.source_height: 228 | image_size = ImageSize(height=args.source_height, width=args.source_width) 229 | image_source = get_image_source( 230 | args.source, 231 | image_size=image_size, 232 | fourcc=args.source_fourcc, 233 | fps=args.source_fps 234 | ) 235 | if args.source_threaded: 236 | return get_threaded_image_source(image_source) 237 | return image_source 238 | 239 | 240 | def get_output_sink(args: argparse.Namespace) -> ContextManager[T_OutputSink]: 241 | if args.show_output: 242 | return get_show_image_output_sink() 243 | if args.output: 244 | return get_image_output_sink_for_path(args.output) 245 | raise RuntimeError('no output sink') 246 | 247 | 248 | def load_bodypix_model(args: argparse.Namespace) -> BodyPixModelWrapper: 249 | local_model_path = download_model(args.model_path) 250 | if args.model_path != local_model_path: 251 | LOGGER.info('loading model: %r (downloaded from %r)', local_model_path, args.model_path) 252 | else: 253 | LOGGER.info('loading model: %r', local_model_path) 254 | return load_model( 255 | local_model_path, 256 | internal_resolution=args.internal_resolution, 257 | output_stride=args.output_stride, 258 | architecture_name=args.model_architecture 259 | ) 260 | 261 | 262 | def get_mask( 263 | bodypix_result: BodyPixResultWrapper, 264 | masks: List[np.ndarray], 265 | timer: LoggingTimer, 266 | args: argparse.Namespace, 267 | resize_method: Optional[str] = None 268 | ) -> np.ndarray: 269 | mask = bodypix_result.get_mask(args.threshold, dtype=np.float32, resize_method=resize_method) 270 | if args.mask_blur: 271 | timer.on_step_start('mblur') 272 | mask = box_blur_image(mask, args.mask_blur) 273 | if args.mask_mean_count >= 2: 274 | timer.on_step_start('mmean') 275 | masks.append(mask) 276 | if len(masks) > args.mask_mean_count: 277 | masks.pop(0) 278 | if len(masks) >= 2: 279 | mask = np.mean(masks, axis=0) 280 | LOGGER.debug('mask.shape: %s (%s)', mask.shape, mask.dtype) 281 | return mask 282 | 283 | 284 | class ListModelsSubCommand(SubCommand): 285 | def __init__(self): 286 | super().__init__("list-models", "Lists available bodypix models (original models)") 287 | 288 | def add_arguments(self, parser: argparse.ArgumentParser): 289 | add_common_arguments(parser) 290 | parser.add_argument( 291 | "--storage-url", 292 | default="https://storage.googleapis.com/tfjs-models", 293 | help="The base URL for the storage containing the models" 294 | ) 295 | 296 | def get_model_paths(self, storage_url: str) -> Sequence[str]: 297 | return [ 298 | file_url 299 | for file_url in iter_s3_file_urls(storage_url) 300 | if re.match(r'.*/bodypix/.*/model.*\.json', file_url) 301 | ] 302 | 303 | def run(self, args: argparse.Namespace): # pylint: disable=unused-argument 304 | print('\n'.join(self.get_model_paths(storage_url=args.storage_url))) 305 | 306 | 307 | class ListTensorFlowLiteModelsSubCommand(SubCommand): 308 | def __init__(self): 309 | super().__init__("list-tflite-models", "Lists available tflite bodypix models") 310 | 311 | def add_arguments(self, parser: argparse.ArgumentParser): 312 | add_common_arguments(parser) 313 | 314 | def get_model_paths(self) -> Sequence[str]: 315 | return ALL_TENSORFLOW_LITE_BODYPIX_MODEL_PATHS 316 | 317 | def run(self, args: argparse.Namespace): # pylint: disable=unused-argument 318 | print('\n'.join(self.get_model_paths())) 319 | 320 | 321 | class ConvertToTFLiteSubCommand(SubCommand): 322 | def __init__(self): 323 | super().__init__("convert-to-tflite", "Converts the model to a tflite model") 324 | 325 | def add_arguments(self, parser: argparse.ArgumentParser): 326 | add_common_arguments(parser) 327 | parser.add_argument( 328 | "--model-path", 329 | default=DEFAULT_MODEL_TF_PATH, 330 | help="The path or URL to the bodypix model." 331 | ) 332 | parser.add_argument( 333 | "--output-model-file", 334 | required=True, 335 | help="The path to the output file (tflite model)." 336 | ) 337 | parser.add_argument( 338 | "--optimize", 339 | action='store_true', 340 | help="Enable optimization (quantization)." 341 | ) 342 | parser.add_argument( 343 | "--quantization-type", 344 | choices=['float16', 'float32', 'int8'], 345 | help="The quantization type to use." 346 | ) 347 | 348 | def run(self, args: argparse.Namespace): # pylint: disable=unused-argument 349 | LOGGER.info('converting model: %s', args.model_path) 350 | converter = get_tflite_converter_for_model_path(download_model( 351 | args.model_path 352 | )) 353 | tflite_model = converter.convert() 354 | if args.optimize: 355 | LOGGER.info('enabled optimization') 356 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 357 | if args.quantization_type: 358 | LOGGER.info('quanization type: %s', args.quantization_type) 359 | quantization_type = getattr(tf, args.quantization_type) 360 | converter.target_spec.supported_types = [quantization_type] 361 | converter.inference_input_type = quantization_type 362 | converter.inference_output_type = quantization_type 363 | LOGGER.info('saving tflite model to: %s', args.output_model_file) 364 | Path(args.output_model_file).write_bytes(tflite_model) 365 | 366 | 367 | class AbstractWebcamFilterApp(ABC): 368 | def __init__(self, args: argparse.Namespace): 369 | self.args = args 370 | self.bodypix_model = None 371 | self.output_sink = None 372 | self.image_source = None 373 | self.image_iterator = None 374 | self.timer = LoggingTimer() 375 | self.masks: List[np.ndarray] = [] 376 | self.exit_stack = ExitStack() 377 | self.bodypix_result_cache_time = None 378 | self.bodypix_result_cache = None 379 | 380 | @abstractmethod 381 | def get_output_image(self, image_array: np.ndarray) -> np.ndarray: 382 | pass 383 | 384 | def get_mask(self, *args, **kwargs): 385 | return get_mask( 386 | *args, masks=self.masks, timer=self.timer, args=self.args, **kwargs 387 | ) 388 | 389 | def get_bodypix_result(self, image_array: np.ndarray) -> BodyPixResultWrapper: 390 | assert self.bodypix_model is not None 391 | current_time = time() 392 | if ( 393 | self.bodypix_result_cache is not None 394 | and current_time < self.bodypix_result_cache_time + self.args.mask_cache_time 395 | ): 396 | return self.bodypix_result_cache 397 | self.bodypix_result_cache = self.bodypix_model.predict_single(image_array) 398 | self.bodypix_result_cache_time = current_time 399 | return self.bodypix_result_cache 400 | 401 | def __enter__(self): 402 | self.exit_stack.__enter__() 403 | self.bodypix_model = load_bodypix_model(self.args) 404 | self.output_sink = self.exit_stack.enter_context(get_output_sink(self.args)) 405 | self.image_source = self.exit_stack.enter_context(get_image_source_for_args(self.args)) 406 | self.image_iterator = iter(self.image_source) 407 | return self 408 | 409 | def __exit__(self, *args, **kwargs): 410 | self.exit_stack.__exit__(*args, **kwargs) 411 | 412 | def next_frame(self): 413 | self.timer.on_frame_start(initial_step_name='in') 414 | try: 415 | image_array = next(self.image_iterator) 416 | except StopIteration: 417 | return False 418 | LOGGER.debug('image_array: %r (%r)', image_array.shape, image_array.dtype) 419 | self.timer.on_step_start('model') 420 | output_image = self.get_output_image(image_array) 421 | self.timer.on_step_start('out') 422 | self.output_sink(output_image) 423 | self.timer.on_frame_end() 424 | return True 425 | 426 | def run(self): 427 | try: 428 | self.timer.start() 429 | while self.next_frame(): 430 | pass 431 | if self.args.show_output: 432 | LOGGER.info('waiting for window to be closed') 433 | while not self.output_sink.is_closed: 434 | sleep(0.5) 435 | except KeyboardInterrupt: 436 | LOGGER.info('exiting') 437 | 438 | 439 | class AbstractWebcamFilterSubCommand(SubCommand): 440 | def add_arguments(self, parser: argparse.ArgumentParser): 441 | add_common_arguments(parser) 442 | add_model_arguments(parser) 443 | add_source_arguments(parser) 444 | add_output_arguments(parser) 445 | 446 | @abstractmethod 447 | def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: 448 | pass 449 | 450 | def run(self, args: argparse.Namespace): 451 | with self.get_app(args) as app: 452 | app.run() 453 | 454 | 455 | class DrawMaskApp(AbstractWebcamFilterApp): 456 | def get_output_image(self, image_array: np.ndarray) -> np.ndarray: 457 | resize_method = None 458 | result = self.get_bodypix_result(image_array) 459 | self.timer.on_step_start('get_mask') 460 | mask = self.get_mask(result, resize_method=resize_method) 461 | if self.args.colored: 462 | self.timer.on_step_start('get_cpart_mask') 463 | mask_image = result.get_colored_part_mask( 464 | mask, part_names=self.args.parts, resize_method=resize_method 465 | ) 466 | elif self.args.parts: 467 | self.timer.on_step_start('get_part_mask') 468 | mask_image = result.get_part_mask( 469 | mask, part_names=self.args.parts, resize_method=resize_method 470 | ) * 255 471 | else: 472 | if LOGGER.isEnabledFor(logging.DEBUG): 473 | LOGGER.debug( 474 | 'mask: %r (%r, %r) (%s)', 475 | mask.shape, np.min(mask), np.max(mask), mask.dtype 476 | ) 477 | mask_image = mask * 255.0 478 | if self.args.mask_alpha is not None: 479 | self.timer.on_step_start('overlay') 480 | LOGGER.debug('mask.shape: %s (%s)', mask.shape, mask.dtype) 481 | alpha = self.args.mask_alpha 482 | try: 483 | if tf is not None: 484 | if mask_image.dtype == tf.int32: 485 | mask_image = tf.cast(mask_image, tf.float32) 486 | else: 487 | image_array = np.asarray(image_array).astype(np.float32) 488 | if LOGGER.isEnabledFor(logging.DEBUG): 489 | LOGGER.debug( 490 | 'mask_image: %r (%r, %r) (%s)', 491 | mask_image.shape, np.min(mask_image), np.max(mask_image), mask_image.dtype 492 | ) 493 | except TypeError: 494 | pass 495 | output = np.clip( 496 | image_array * (1 - alpha) + mask_image * alpha, 497 | 0.0, 255.0 498 | ) 499 | return output 500 | return mask_image 501 | 502 | 503 | class DrawMaskSubCommand(AbstractWebcamFilterSubCommand): 504 | def __init__(self): 505 | super().__init__("draw-mask", "Draws the mask for the input") 506 | 507 | def add_arguments(self, parser: argparse.ArgumentParser): 508 | super().add_arguments(parser) 509 | parser.add_argument( 510 | "--mask-alpha", 511 | type=float, 512 | help="The opacity of mask overlay to add." 513 | ) 514 | parser.add_argument( 515 | "--add-overlay-alpha", 516 | dest='mask_alpha', 517 | type=float, 518 | help="Deprecated, please use --mask-alpha instead." 519 | ) 520 | parser.add_argument( 521 | "--colored", 522 | action="store_true", 523 | help="Enable generating the colored part mask" 524 | ) 525 | parser.add_argument( 526 | "--parts", 527 | nargs="*", 528 | choices=PART_CHANNELS, 529 | help="Select the parts to output" 530 | ) 531 | 532 | def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: 533 | return DrawMaskApp(args) 534 | 535 | 536 | class DrawPoseApp(AbstractWebcamFilterApp): 537 | def get_output_image(self, image_array: np.ndarray) -> np.ndarray: 538 | result = self.get_bodypix_result(image_array) 539 | self.timer.on_step_start('get_pose') 540 | poses = result.get_poses() 541 | LOGGER.debug('number of poses: %d', len(poses)) 542 | output_image = draw_poses( 543 | image_array.copy(), poses, 544 | keypoints_color=(255, 100, 100), 545 | skeleton_color=(100, 100, 255) 546 | ) 547 | return output_image 548 | 549 | 550 | class DrawPoseSubCommand(AbstractWebcamFilterSubCommand): 551 | def __init__(self): 552 | super().__init__("draw-pose", "Draws the pose estimation") 553 | 554 | def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: 555 | return DrawPoseApp(args) 556 | 557 | 558 | class BlurBackgroundApp(AbstractWebcamFilterApp): 559 | def get_output_image(self, image_array: np.ndarray) -> np.ndarray: 560 | result = self.get_bodypix_result(image_array) 561 | self.timer.on_step_start('get_mask') 562 | mask = self.get_mask(result) 563 | self.timer.on_step_start('bblur') 564 | background_image_array = box_blur_image(image_array, self.args.background_blur) 565 | self.timer.on_step_start('compose') 566 | output = np.clip( 567 | background_image_array * (1 - mask) 568 | + image_array * mask, 569 | 0.0, 255.0 570 | ) 571 | return output 572 | 573 | 574 | class BlurBackgroundSubCommand(AbstractWebcamFilterSubCommand): 575 | def __init__(self): 576 | super().__init__("blur-background", "Blurs the background of the webcam image") 577 | 578 | def add_arguments(self, parser: argparse.ArgumentParser): 579 | super().add_arguments(parser) 580 | parser.add_argument( 581 | "--background-blur", 582 | type=int, 583 | default=15, 584 | help="The blur radius for the background." 585 | ) 586 | 587 | def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: 588 | return BlurBackgroundApp(args) 589 | 590 | 591 | class ReplaceBackgroundApp(AbstractWebcamFilterApp): 592 | def __init__(self, *args, **kwargs): 593 | self.background_image_iterator = None 594 | super().__init__(*args, **kwargs) 595 | 596 | def get_next_background_image(self, image_array: np.ndarray) -> np.ndarray: 597 | if self.background_image_iterator is None: 598 | background_image_source = self.exit_stack.enter_context(get_image_source( 599 | self.args.background, 600 | image_size=get_image_size(image_array) 601 | )) 602 | self.background_image_iterator = iter(cycle(background_image_source)) 603 | return next(self.background_image_iterator) 604 | 605 | def get_output_image(self, image_array: np.ndarray) -> np.ndarray: 606 | background_image_array = self.get_next_background_image(image_array) 607 | result = self.get_bodypix_result(image_array) 608 | self.timer.on_step_start('get_mask') 609 | mask = self.get_mask(result) 610 | self.timer.on_step_start('compose') 611 | background_image_array = resize_image_to( 612 | background_image_array, get_image_size(image_array) 613 | ) 614 | output = np.clip( 615 | background_image_array * (1 - mask) 616 | + image_array * mask, 617 | 0.0, 255.0 618 | ) 619 | return output 620 | 621 | 622 | class ReplaceBackgroundSubCommand(AbstractWebcamFilterSubCommand): 623 | def __init__(self): 624 | super().__init__("replace-background", "Replaces the background of a person") 625 | 626 | def add_arguments(self, parser: argparse.ArgumentParser): 627 | add_common_arguments(parser) 628 | add_model_arguments(parser) 629 | add_source_arguments(parser) 630 | 631 | parser.add_argument( 632 | "--background", 633 | required=True, 634 | help="The path or URL to the background image." 635 | ) 636 | 637 | add_output_arguments(parser) 638 | 639 | def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: 640 | return ReplaceBackgroundApp(args) 641 | 642 | 643 | SUB_COMMANDS: List[SubCommand] = [ 644 | ListModelsSubCommand(), 645 | ListTensorFlowLiteModelsSubCommand(), 646 | ConvertToTFLiteSubCommand(), 647 | DrawMaskSubCommand(), 648 | DrawPoseSubCommand(), 649 | BlurBackgroundSubCommand(), 650 | ReplaceBackgroundSubCommand() 651 | ] 652 | 653 | SUB_COMMAND_BY_NAME: Dict[str, SubCommand] = { 654 | sub_command.name: sub_command for sub_command in SUB_COMMANDS 655 | } 656 | 657 | 658 | def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: 659 | parser = argparse.ArgumentParser( 660 | 'TensorFlow BodyPix (TF BodyPix)', 661 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 662 | ) 663 | subparsers = parser.add_subparsers(dest="command") 664 | subparsers.required = True 665 | for sub_command in SUB_COMMANDS: 666 | sub_parser = subparsers.add_parser( 667 | sub_command.name, help=sub_command.description, 668 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 669 | ) 670 | sub_command.add_arguments(sub_parser) 671 | 672 | args = parser.parse_args(argv) 673 | return args 674 | 675 | 676 | def run(args: argparse.Namespace): 677 | sub_command = SUB_COMMAND_BY_NAME[args.command] 678 | sub_command.run(args) 679 | 680 | 681 | def main(argv: Optional[List[str]] = None): 682 | args = parse_args(argv) 683 | if args.debug: 684 | logging.getLogger().setLevel(logging.DEBUG) 685 | LOGGER.debug("args: %s", args) 686 | run(args) 687 | 688 | 689 | if __name__ == '__main__': 690 | logging.basicConfig(level='INFO') 691 | main() 692 | --------------------------------------------------------------------------------