├── data └── .gitignore ├── mediapipe_crop_estimate ├── __init__.py ├── mlp │ ├── center.pt │ ├── size.pt │ └── rotation.pt ├── evaluation_utils.py ├── estimate_poses.py ├── train_kan.py ├── train_dataset.py ├── train_mlp.py ├── mediapipe_utils.py ├── evaluate.py └── collect_hands.py ├── assets ├── cropped │ ├── bad.jpg │ ├── ok.jpg │ ├── best.jpg │ ├── good.jpg │ └── worst.jpg └── original │ ├── ok.jpg │ ├── bad.jpg │ ├── best.jpg │ ├── good.jpg │ └── worst.jpg ├── .gitignore ├── .github └── workflows │ ├── test.yaml │ └── lint.yaml ├── LICENSE ├── pyproject.toml └── README.md /data/.gitignore: -------------------------------------------------------------------------------- 1 | hand_labels/ 2 | temp/ -------------------------------------------------------------------------------- /mediapipe_crop_estimate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/cropped/bad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/cropped/bad.jpg -------------------------------------------------------------------------------- /assets/cropped/ok.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/cropped/ok.jpg -------------------------------------------------------------------------------- /assets/original/ok.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/original/ok.jpg -------------------------------------------------------------------------------- /assets/cropped/best.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/cropped/best.jpg -------------------------------------------------------------------------------- /assets/cropped/good.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/cropped/good.jpg -------------------------------------------------------------------------------- /assets/cropped/worst.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/cropped/worst.jpg -------------------------------------------------------------------------------- /assets/original/bad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/original/bad.jpg -------------------------------------------------------------------------------- /assets/original/best.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/original/best.jpg -------------------------------------------------------------------------------- /assets/original/good.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/original/good.jpg -------------------------------------------------------------------------------- /assets/original/worst.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/assets/original/worst.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | build/ 3 | mediapipe_crop_estimate.egg-info/ 4 | mediapipe_crop_estimate/lightning_logs 5 | mediapipe_crop_estimate/figures -------------------------------------------------------------------------------- /mediapipe_crop_estimate/mlp/center.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/mediapipe_crop_estimate/mlp/center.pt -------------------------------------------------------------------------------- /mediapipe_crop_estimate/mlp/size.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/mediapipe_crop_estimate/mlp/size.pt -------------------------------------------------------------------------------- /mediapipe_crop_estimate/mlp/rotation.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sign-language-processing/mediapipe-hand-crop-fix/main/mediapipe_crop_estimate/mlp/rotation.pt -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | 4 | on: 5 | push: 6 | branches: [ master, main ] 7 | pull_request: 8 | branches: [ master, main ] 9 | 10 | 11 | jobs: 12 | test: 13 | name: Test 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Install Requirements 23 | run: pip install .[dev] 24 | 25 | - name: Test Code 26 | run: pytest pose_anonymization 27 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | 4 | on: 5 | push: 6 | branches: [ master, main ] 7 | pull_request: 8 | branches: [ master, main ] 9 | 10 | 11 | jobs: 12 | test: 13 | name: Lint 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v3 18 | - uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Install Requirements 23 | run: pip install .[dev] 24 | 25 | - name: Lint Code 26 | run: pylint mediapipe_crop_estimate 27 | -------------------------------------------------------------------------------- /mediapipe_crop_estimate/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | from shapely.geometry import Polygon 2 | 3 | from mediapipe_crop_estimate.mediapipe_utils import Rect 4 | 5 | 6 | def intersection_over_union(rect1: Rect, rect2: Rect): # IoU including rotation 7 | box1 = rect1.get_corners() 8 | box2 = rect2.get_corners() 9 | 10 | poly1 = Polygon(box1) 11 | poly2 = Polygon(box2) 12 | 13 | # Intersection polygon 14 | inter_poly = poly1.intersection(poly2) 15 | 16 | if inter_poly.is_empty: 17 | return 0.0 # No intersection 18 | 19 | # Areas of the intersection and the union 20 | inter_area = inter_poly.area 21 | area1 = poly1.area 22 | area2 = poly2.area 23 | union_area = area1 + area2 - inter_area 24 | 25 | # Compute IoU 26 | return inter_area / union_area 27 | -------------------------------------------------------------------------------- /mediapipe_crop_estimate/estimate_poses.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import cv2 4 | from pose_format.utils.holistic import load_holistic 5 | from tqdm import tqdm 6 | 7 | hand_labels_dir = Path(__file__).parent.parent / "data" / "hand_labels" 8 | panopticdb_dir = Path(__file__).parent.parent / "data" / "hand143_panopticdb" 9 | 10 | dataset_dirs = [ 11 | hand_labels_dir / "manual_test", 12 | hand_labels_dir / "manual_train", 13 | panopticdb_dir / "imgs" 14 | ] 15 | 16 | for dataset_dir in dataset_dirs: 17 | jpg_files = list(dataset_dir.glob("*.jpg")) 18 | for file in tqdm(jpg_files): 19 | pose_file = file.with_suffix(".pose") 20 | if pose_file.exists(): 21 | continue 22 | image = cv2.cvtColor(cv2.imread(str(file)), cv2.COLOR_BGR2RGB) 23 | pose = load_holistic([image], fps=30, width=image.shape[1], height=image.shape[0]) 24 | with open(pose_file, "wb") as f: 25 | pose.write(f) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sign Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mediapipe-hand-crop-fix" 3 | description = "Code for Optimizing Hand Region Detection in MediaPipe Holistic Full-Body Pose Estimation to Improve Accuracy and Avoid Downstream Errors" 4 | version = "0.0.1" 5 | authors = [ 6 | { name = "Amit Moryossef", email = "amitmoryossef@gmail.com" }, 7 | ] 8 | readme = "README.md" 9 | dependencies = [ 10 | "pose-format", 11 | "mediapipe", 12 | "tqdm", 13 | "opencv-python", 14 | "pykan" 15 | ] 16 | 17 | [project.optional-dependencies] 18 | dev = [ 19 | "pytest", 20 | "pylint" 21 | ] 22 | 23 | [tool.yapf] 24 | based_on_style = "google" 25 | column_limit = 120 26 | 27 | [tool.pylint] 28 | max-line-length = 120 29 | disable = [ 30 | "C0114", # Missing module docstring 31 | "C0115", # Missing class docstring 32 | "C0116", # Missing function or method docstring 33 | ] 34 | good-names = ["i", "f", "x", "y", "p1", "p2"] 35 | 36 | [tool.pylint.typecheck] 37 | generated-members = ["cv2.*", "torch.*"] 38 | 39 | [tool.setuptools] 40 | packages = [ 41 | "mediapipe_crop_estimate", 42 | ] 43 | 44 | [tool.pytest.ini_options] 45 | addopts = "-v" 46 | testpaths = ["mediapipe_crop_estimate"] 47 | -------------------------------------------------------------------------------- /mediapipe_crop_estimate/train_kan.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/KindXiaoming/pykan/blob/master/hellokan.ipynb 2 | 3 | from kan import KAN 4 | from matplotlib import pyplot as plt 5 | 6 | from mediapipe_crop_estimate.train_dataset import get_dataset 7 | 8 | dataset = get_dataset() 9 | 10 | train_features = dataset["train_input"].shape[-1] 11 | label_features = dataset["train_label"].shape[-1] 12 | 13 | for i in range(label_features): 14 | feature_dataset = { 15 | "train_input": dataset["train_input"], 16 | "train_label": dataset["train_label"][:, [i]], 17 | "test_input": dataset["test_input"], 18 | "test_label": dataset["test_label"][:, [i]] 19 | } 20 | 21 | train_losses = [] 22 | test_losses = [] 23 | 24 | last_model = None 25 | 26 | # for grid in [5, 10, 20]: # Train and refine the grid 27 | for grid in [5]: # Train and refine the grid 28 | model = KAN(width=[train_features, 5, 1], # Features input, 5 hidden neurons, 1D output 29 | grid=grid, # 5 grid intervals 30 | k=3) # cubic spline 31 | if last_model is not None: 32 | model = model.initialize_from_another_model(last_model, feature_dataset["train_input"]) 33 | last_model = model 34 | 35 | results = model.train(dataset, opt="LBFGS", steps=50, stop_grid_update_step=30) 36 | train_losses += results['train_loss'] 37 | test_losses += results['test_loss'] 38 | 39 | plt.plot(train_losses) 40 | plt.plot(test_losses) 41 | plt.legend(['train', 'test']) 42 | plt.ylabel('RMSE') 43 | plt.xlabel('step') 44 | plt.yscale('log') 45 | plt.show() 46 | 47 | model = model.prune() 48 | model(dataset['train_input']) 49 | model.plot() 50 | plt.show() 51 | 52 | lib = ['x', 'x^2', 'exp', 'log', 'sqrt', 'sin'] 53 | model.auto_symbolic(lib=lib) 54 | 55 | model.train(dataset, opt="LBFGS", steps=50) 56 | 57 | print(model.symbolic_formula()[0][0]) 58 | -------------------------------------------------------------------------------- /mediapipe_crop_estimate/train_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def get_dataset(augment=True): 11 | dataset_dict = {} 12 | for split_name in ["train", "test"]: 13 | annotations_path = Path(__file__).parent.parent / "data" / "processed" / f"{split_name}.json" 14 | with open(annotations_path, "r") as f: 15 | annotations = json.load(f) 16 | 17 | inputs = [] 18 | labels = [] 19 | 20 | for annotation in annotations: 21 | source = annotation["source"] 22 | target = annotation["target"] 23 | 24 | aspect_ratio = source["aspect_ratio"] 25 | points = list(chain.from_iterable(source["pose"])) 26 | inputs.append([aspect_ratio] + points) 27 | 28 | labels.append([ 29 | target["center"][0], 30 | target["center"][1], 31 | target["size"], 32 | target["rotation"] / 360 33 | ]) 34 | 35 | if augment and split_name == "train": 36 | for i in range(10): 37 | x_shift = random.gauss(mu=0.0, sigma=.2) 38 | y_shift = random.gauss(mu=0.0, sigma=.2) 39 | points = np.array(annotation["source"]["pose"]) + np.array([x_shift, y_shift, 0]) 40 | inputs.append([aspect_ratio] + list(chain.from_iterable(points))) 41 | labels.append([ 42 | target["center"][0] + x_shift, 43 | target["center"][1] + y_shift, 44 | target["size"], 45 | target["rotation"] / 360 46 | ]) 47 | 48 | dataset_dict[f"{split_name}_input"] = torch.tensor(inputs, dtype=torch.float32) 49 | dataset_dict[f"{split_name}_label"] = torch.tensor(labels, dtype=torch.float32) 50 | 51 | return dataset_dict 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MediaPipe Hand Crop Fix 2 | 3 | Code for "Optimizing Hand Region Detection in MediaPipe Holistic Full-Body Pose Estimation to Improve Accuracy and Avoid 4 | Downstream Errors". 5 | 6 | Fixing https://github.com/google/mediapipe/issues/5373 7 | 8 | ## Motivation 9 | 10 | The MediaPipe ROI estimation can be not so great... Here are a few examples: 11 | 12 | Worst and Best are the edge cases as seen in the data. The rest were picked manually. 13 | We only look at the right hand. If the data is for the left hand, we flip the image. 14 | In the following table, for each image we show the hand keypoints in green lines, 15 | and two bounding boxes in green - gold ROI and red - predicted ROI. 16 | To indicate the orientation of the bounding box, we draw a blue line on the bottom edge of each box. 17 | 18 | | Worst | Bad | OK | Good | Best | 19 | |------------------------------------|--------------------------------|------------------------------|----------------------------------|----------------------------------| 20 | | ROI: 0.08% | --- | --- | --- | ROI: 93.7% | 21 | | ![Worst](assets/cropped/worst.jpg) | ![Bad](assets/cropped/bad.jpg) | ![OK](assets/cropped/ok.jpg) | ![Good](assets/cropped/good.jpg) | ![Best](assets/cropped/best.jpg) | 22 | 23 | ## Usage 24 | 25 | ```bash 26 | git clone https://github.com/sign-language-processing/mediapipe-hand-crop-fix.git 27 | cd mediapipe-hand-crop-fix 28 | ``` 29 | 30 | Download and extract the Panoptic Hand Pose Dataset: 31 | 32 | ```bash 33 | cd data 34 | wget http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels.zip 35 | ``` 36 | 37 | Estimate full body poses using MediaPipe Holistic: 38 | 39 | ```bash 40 | python -m mediapipe_crop_estimate.estimate_poses 41 | ``` 42 | 43 | Collect the annotations as well as estimated regions of interest: 44 | ```bash 45 | python -m mediapipe_crop_estimate.collect_hands 46 | ``` 47 | 48 | Then, train an MLP using the annotations and estimated regions of interest: 49 | ```bash 50 | python -m mediapipe_crop_estimate.train_mlp 51 | ``` 52 | 53 | Finally, evaluate: 54 | ```bash 55 | python -m mediapipe_crop_estimate.evaluate 56 | ``` -------------------------------------------------------------------------------- /mediapipe_crop_estimate/train_mlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import DataLoader, TensorDataset 7 | 8 | from mediapipe_crop_estimate.train_dataset import get_dataset 9 | 10 | 11 | class SimpleMLP(pl.LightningModule): 12 | def __init__(self, input_size, hidden_size, output_size): 13 | super().__init__() 14 | self.layers = nn.Sequential( 15 | nn.Linear(input_size, hidden_size), 16 | nn.ReLU(), 17 | nn.Linear(hidden_size, hidden_size), 18 | nn.ReLU(), 19 | nn.Linear(hidden_size, output_size) 20 | ) 21 | 22 | def forward(self, x): 23 | return self.layers(x) 24 | 25 | def configure_optimizers(self): 26 | return torch.optim.Adam(self.parameters()) 27 | 28 | def training_step(self, batch): 29 | inputs, labels = batch 30 | outputs = self(inputs) 31 | loss = nn.MSELoss()(outputs, labels) 32 | self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) 33 | return loss 34 | 35 | def validation_step(self, batch): 36 | inputs, labels = batch 37 | outputs = self(inputs) 38 | loss = nn.MSELoss()(outputs, labels) 39 | self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) 40 | return loss 41 | 42 | 43 | class DataModule(pl.LightningDataModule): 44 | # pylint: disable=redefined-outer-name 45 | def __init__(self, train_dataset, val_dataset, batch_size=32): 46 | super().__init__() 47 | self.train_dataset = train_dataset 48 | self.val_dataset = val_dataset 49 | self.batch_size = batch_size 50 | 51 | def train_dataloader(self): 52 | return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) 53 | 54 | def val_dataloader(self): 55 | return DataLoader(self.val_dataset, batch_size=self.batch_size) 56 | 57 | 58 | if __name__ == "__main__": 59 | dataset = get_dataset() 60 | 61 | train_features = dataset["train_input"].shape[-1] 62 | label_features = dataset["train_label"].shape[-1] 63 | 64 | models = { 65 | "center": [0, 1], 66 | "size": [2], 67 | "rotation": [3] 68 | } 69 | 70 | os.makedirs("mlp", exist_ok=True) 71 | 72 | for model_name, label_indices in models.items(): 73 | train_dataset = TensorDataset(dataset['train_input'], dataset['train_label'][:, label_indices]) 74 | val_dataset = TensorDataset(dataset['test_input'][:10], dataset['test_label'][:, label_indices][:10]) 75 | data_module = DataModule(train_dataset, val_dataset) 76 | 77 | model = SimpleMLP(input_size=train_features, hidden_size=10, output_size=len(label_indices)) 78 | trainer = pl.Trainer(max_epochs=100, progress_bar_refresh_rate=20) 79 | trainer.fit(model, datamodule=data_module) 80 | 81 | # save model jit 82 | model.eval() 83 | example_input = dataset["train_input"][0] 84 | traced_model = torch.jit.trace(model, example_input) 85 | traced_model.save(f"mlp/{model_name}.pt") 86 | -------------------------------------------------------------------------------- /mediapipe_crop_estimate/mediapipe_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def normalize_radians(angle): 6 | return angle - 2 * np.pi * np.floor((angle + np.pi) / (2 * np.pi)) 7 | 8 | 9 | def compute_rotation(landmarks): 10 | wrist = landmarks[0] 11 | index = landmarks[5] 12 | middle = landmarks[9] 13 | ring = landmarks[13] 14 | 15 | fingers_center = (index + ring) / 2 16 | fingers_center = (fingers_center + middle) / 2 17 | 18 | rotation = normalize_radians(np.pi / 2 - np.arctan2(-(fingers_center[1] - wrist[1]), fingers_center[0] - wrist[0])) 19 | return rotation 20 | 21 | 22 | class Rect: 23 | def __init__(self, x_center, y_center, width, height, rotation): 24 | self.x_center = x_center 25 | self.y_center = y_center 26 | self.width = width 27 | self.height = height 28 | self.rotation = rotation 29 | 30 | def postprocess(self, scale, shift_y): 31 | size = max(self.width, self.height) # Square long 32 | 33 | # shift y by 0.1 * size, taking rotation into account 34 | xShift = shift_y * size * np.sin(self.rotation) 35 | yShift = shift_y * size * np.cos(self.rotation) 36 | self.x_center += xShift 37 | self.y_center += yShift 38 | 39 | # scale size 40 | size *= scale 41 | self.width = self.height = size 42 | 43 | def reflect_horizontal(self, width): 44 | self.x_center = width - self.x_center 45 | self.rotation = (-self.rotation) % 360 46 | 47 | def draw(self, image, color=(0, 255, 0)): 48 | box = cv2.boxPoints(((self.x_center, self.y_center), (self.width, self.height), self.rotation)) 49 | cv2.drawContours(image, [np.int0(box)], 0, color, 2) 50 | # draw a line at the bottom of the box 51 | cv2.line(image, (int(box[0][0]), int(box[0][1])), (int(box[-1][0]), int(box[-1][1])), (255, 0, 0), 2) 52 | 53 | def contains(self, x: float, y: float): 54 | return self.x_center - self.width / 2 < x < self.x_center + self.width / 2 and \ 55 | self.y_center - self.height / 2 < y < self.y_center + self.height / 2 56 | 57 | def get_corners(self): 58 | return cv2.boxPoints(((self.x_center, self.y_center), (self.width, self.height), self.rotation)) 59 | 60 | def __str__(self): 61 | return f"x_center: {self.x_center}, y_center: {self.y_center}, width: {self.width}, height: {self.height}, rotation: {self.rotation}" 62 | 63 | 64 | def landmarks_to_rect(landmarks): 65 | rotation = compute_rotation(landmarks) 66 | reverse_angle = normalize_radians(-rotation) 67 | 68 | min_coords = landmarks.min(axis=0) 69 | max_coords = landmarks.max(axis=0) 70 | center_coords = (max_coords + min_coords) / 2 71 | 72 | rotated_min = np.array([float('inf'), float('inf')]) 73 | rotated_max = np.array([-float('inf'), -float('inf')]) 74 | for x, y in (landmarks - center_coords): 75 | rotated_coords = np.array([ 76 | x * np.cos(reverse_angle) - y * np.sin(reverse_angle), 77 | x * np.sin(reverse_angle) + y * np.cos(reverse_angle) 78 | ]) 79 | rotated_min = np.minimum(rotated_min, rotated_coords) 80 | rotated_max = np.maximum(rotated_max, rotated_coords) 81 | 82 | rotated_center = (rotated_max + rotated_min) / 2 83 | final_center = ( 84 | rotated_center[0] * np.cos(rotation) - rotated_center[1] * np.sin(rotation) + center_coords[0], 85 | rotated_center[0] * np.sin(rotation) + rotated_center[1] * np.cos(rotation) + center_coords[1] 86 | ) 87 | width = (rotated_max[0] - rotated_min[0]) 88 | height = (rotated_max[1] - rotated_min[1]) 89 | 90 | return Rect(x_center=final_center[0], 91 | y_center=final_center[1], 92 | width=width, height=height, 93 | rotation=np.degrees(rotation)) 94 | 95 | 96 | def holistic_body_landmarks_to_rect(wrist, index, pinky): 97 | wrist = wrist[:2] 98 | index = index[:2] 99 | pinky = pinky[:2] 100 | 101 | # Estimate middle finger position 102 | center = (2 * index + pinky) / 3 103 | # Estimate hand size 104 | size = 2 * np.linalg.norm(center - wrist) 105 | # Estimate hand 2D rotation 106 | rotation = 90 + np.degrees(np.arctan2(center[1] - wrist[1], center[0] - wrist[0])) 107 | # Shift center 108 | rect = Rect(center[0], center[1], size, size, rotation) 109 | rect.postprocess(2.7, -0.1) 110 | return rect 111 | -------------------------------------------------------------------------------- /mediapipe_crop_estimate/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import lru_cache 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import matplotlib.pyplot as plt 8 | 9 | from mediapipe_crop_estimate.evaluation_utils import intersection_over_union 10 | from mediapipe_crop_estimate.mediapipe_utils import holistic_body_landmarks_to_rect, Rect 11 | 12 | annotations_path = Path(__file__).parent.parent / "data" / "processed" / "test.json" 13 | 14 | 15 | def original_method(aspect_ratio: float, points): 16 | height = 1 # just a number 17 | width = height * aspect_ratio # Important for angle calculation 18 | 19 | points *= np.array([width, height, 1]) 20 | wrist, index, pinky = points[[2, 4, 5]] 21 | rect = holistic_body_landmarks_to_rect(wrist, index, pinky) 22 | 23 | size = rect.width / width 24 | return Rect(x_center=rect.x_center / width, 25 | y_center=rect.y_center / height, 26 | width=size, 27 | height=size, 28 | rotation=rect.rotation) 29 | 30 | 31 | @lru_cache(maxsize=None) 32 | def load_mlp(model_name: str): 33 | model = torch.jit.load(Path(__file__).parent / "mlp" / f"{model_name}.pt") 34 | model.eval() 35 | return model 36 | 37 | 38 | def mlp_method(aspect_ratio: float, points: np.ndarray): 39 | center_mlp = load_mlp("center") 40 | size_mlp = load_mlp("size") 41 | rotation_mlp = load_mlp("rotation") 42 | with torch.no_grad(): 43 | input_vector = torch.tensor([aspect_ratio] + points.flatten().tolist(), dtype=torch.float32) 44 | center = center_mlp(input_vector).numpy() 45 | size = float(size_mlp(input_vector).numpy()) 46 | rotation = float(rotation_mlp(input_vector).numpy()) * 360 47 | 48 | return Rect(x_center=float(center[0]), y_center=float(center[1]), 49 | width=size, height=size, rotation=rotation) 50 | 51 | 52 | methods = { 53 | "original": original_method, 54 | "mlp": mlp_method, 55 | } 56 | 57 | with open(annotations_path, "r") as f: 58 | annotations = json.load(f) 59 | 60 | # plot a histogram 61 | fig = plt.figure() 62 | 63 | methods_ious = {} 64 | for method_name, method in methods.items(): 65 | center_error = 0 66 | size_error = 0 67 | rotation_error = 0 68 | iou_total = 0 69 | 70 | min_iou = 1 71 | 72 | method_ious = [] 73 | 74 | for annotation in annotations: 75 | # if annotation["file"] not in ["data/hand_labels/manual_train/036362775_01_l.jpg", 76 | # "data/hand_labels/manual_train/ex1_3.flv_000006_r.jpg"]: 77 | # continue 78 | 79 | gold_center = annotation["target"]["center"] 80 | gold_rect = Rect(x_center=gold_center[0], y_center=gold_center[1], 81 | width=annotation["target"]["size"], height=annotation["target"]["size"], 82 | rotation=annotation["target"]["rotation"]) 83 | 84 | aspect_ratio = annotation["source"]["aspect_ratio"] 85 | points = np.array(annotation["source"]["pose"]) 86 | rect = method(aspect_ratio, points) 87 | 88 | iou = intersection_over_union(gold_rect, rect) 89 | if iou < min_iou: 90 | min_iou = iou 91 | 92 | method_ious.append(iou) 93 | iou_total += iou 94 | center_error += np.linalg.norm(np.array([rect.x_center * aspect_ratio, rect.y_center]) - 95 | np.array([gold_rect.x_center * aspect_ratio, gold_rect.y_center])) 96 | size_error += abs(gold_rect.width - rect.width) / gold_rect.width 97 | # rotation error is circular 98 | rotation_error += min(abs(gold_rect.rotation - rect.rotation), abs(gold_rect.rotation - rect.rotation + 360)) 99 | 100 | print(f"Method: {method_name}") 101 | print(f"IOU: {iou_total / len(annotations):.2f}") 102 | print(f"Center error: {center_error / len(annotations) * 100:.2f}%") 103 | print(f"Size error: {size_error / len(annotations) * 100:.2f}%") 104 | print(f"Rotation error: {rotation_error / len(annotations):.2f}") 105 | print(f"Min IOU: {min_iou:.2f}") 106 | print() 107 | 108 | plt.hist(method_ious, bins=20, alpha=0.5, label=method_name, density=True) 109 | methods_ious[method_name] = method_ious 110 | 111 | plt.legend(loc='upper left') 112 | plt.gca().axes.get_yaxis().set_visible(False) 113 | plt.tight_layout() 114 | plt.savefig("histogram.pdf") 115 | 116 | # count how many times each method wins 117 | wins = {method_name: 0 for method_name in methods} 118 | num_annotations = len(methods_ious["original"]) 119 | for i in range(num_annotations): 120 | ious = {method_name: method_ious[i] for method_name, method_ious in methods_ious.items()} 121 | best_method = max(ious, key=lambda k: ious[k]) 122 | wins[best_method] += 1 123 | print("Wins", {k: v / num_annotations for k, v in wins.items()}) -------------------------------------------------------------------------------- /mediapipe_crop_estimate/collect_hands.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | from pose_format import Pose 7 | from tqdm import tqdm 8 | 9 | from mediapipe_crop_estimate.evaluation_utils import intersection_over_union 10 | from mediapipe_crop_estimate.mediapipe_utils import landmarks_to_rect, holistic_body_landmarks_to_rect 11 | 12 | if __name__ == "__main__": 13 | data_dir = Path(__file__).parent.parent / "data" 14 | 15 | temp_dir = data_dir / "temp" 16 | temp_dir.mkdir(exist_ok=True) 17 | 18 | processed_dir = data_dir / "processed" 19 | processed_dir.mkdir(exist_ok=True) 20 | 21 | splits = { 22 | "test": data_dir / "hand_labels" / "manual_test", 23 | "train": data_dir / "hand_labels" / "manual_train", 24 | } 25 | 26 | min_iou = 1 27 | min_iou_file = "" 28 | max_iou = 0 29 | max_iou_file = "" 30 | 31 | for split_name, directory in splits.items(): 32 | total_files = 0 33 | missing_poses = 0 34 | pose_out_of_rect = 0 35 | 36 | data = [] 37 | 38 | # find every json and matching pose file 39 | json_files = list(directory.glob("*.json")) 40 | for file in tqdm(json_files): 41 | total_files += 1 42 | 43 | pose_file = file.with_suffix(".pose") 44 | if not pose_file.exists(): 45 | print(f"Pose file {pose_file} does not exist") 46 | continue 47 | 48 | with open(file, "r") as f: 49 | content = json.load(f) 50 | 51 | with open(pose_file, "rb") as f: 52 | pose = Pose.read(f.read()) 53 | 54 | pose_w = pose.header.dimensions.width 55 | pose_h = pose.header.dimensions.height 56 | pose_max = max(pose_w, pose_h) 57 | aspect_ratio = pose_w / pose_h 58 | 59 | handedness = "LEFT" if content["is_left"] == 1 else "RIGHT" 60 | 61 | pose_points = [f"{handedness}_SHOULDER", f"{handedness}_ELBOW", f"{handedness}_WRIST", 62 | f"{handedness}_THUMB", f"{handedness}_INDEX", f"{handedness}_PINKY"] 63 | pose_body_indexes = [pose.header._get_point_index("POSE_LANDMARKS", point) for point in pose_points] 64 | pose_points = pose.body.data[0, 0, pose_body_indexes].filled(0) 65 | 66 | if pose.body.confidence[0, 0, pose_body_indexes].sum() == 0: 67 | missing_poses += 1 68 | continue 69 | 70 | hand_points = np.array(content["hand_pts"])[:, :2] 71 | rect = landmarks_to_rect(hand_points) 72 | 73 | # Estimate of number of points in rect 74 | points_in_rect = [point for point in pose_points if rect.contains(point[0], point[1])] 75 | if len(points_in_rect) < 3: 76 | pose_out_of_rect += 1 77 | continue 78 | 79 | rect.postprocess(scale=2, shift_y=-0.1) 80 | 81 | if handedness == "LEFT": 82 | pose_points = (pose_points - np.array([pose_w, 0, 0])) * np.array([-1, 1, 1]) 83 | rect.reflect_horizontal(pose_w) 84 | 85 | data.append({ 86 | "file": str(file.with_suffix(".jpg").relative_to(Path(__file__).parent.parent)), 87 | "source": { 88 | "aspect_ratio": aspect_ratio, 89 | "pose": (pose_points / np.array([pose_w, pose_h, 1])).tolist() 90 | }, 91 | "target": { 92 | "center": (rect.x_center / pose_w, rect.y_center / pose_h), 93 | "size": rect.width / pose_max, 94 | "rotation": rect.rotation 95 | } 96 | }) 97 | 98 | # Save the image in a temp directory to make sure the rect is correct 99 | temp_file = str(temp_dir / file.with_suffix(".jpg").name) 100 | 101 | image = cv2.imread(str(file.with_suffix(".jpg"))) 102 | if handedness == "LEFT": 103 | image = cv2.flip(image, 1) 104 | rect.draw(image) 105 | 106 | for point in pose_points: 107 | cv2.circle(image, (int(point[0]), int(point[1])), 3, (0, 0, 255), -1) 108 | 109 | 110 | # draw lines between shoulder, elbow, wrist, thumb, index, pinky 111 | def draw_line(p1, p2): 112 | cv2.line(image, (int(p1[0]), int(p1[1])), (int(p2[0]), int(p2[1])), (0, 255, 0), 2) 113 | 114 | 115 | draw_line(pose_points[0], pose_points[1]) 116 | draw_line(pose_points[1], pose_points[2]) 117 | draw_line(pose_points[2], pose_points[3]) 118 | draw_line(pose_points[2], pose_points[4]) 119 | draw_line(pose_points[2], pose_points[5]) 120 | draw_line(pose_points[4], pose_points[5]) 121 | 122 | estimated_rect = holistic_body_landmarks_to_rect(pose_points[2], pose_points[4], pose_points[5]) 123 | estimated_rect.draw(image, color=(0, 0, 255)) 124 | 125 | cv2.imwrite(temp_file, image) 126 | 127 | iou = intersection_over_union(rect, estimated_rect) 128 | if iou < min_iou: 129 | min_iou = iou 130 | min_iou_file = file 131 | if iou > max_iou: 132 | max_iou = iou 133 | max_iou_file = file 134 | 135 | print(f"Total files: {total_files}") 136 | print(f"Missing poses: {missing_poses}") 137 | print(f"Poses out of rect: {pose_out_of_rect}") 138 | with open(processed_dir / f"{split_name}.json", "w") as f: 139 | json.dump(data, f, indent=2) 140 | 141 | print(f"Min IoU: {min_iou} for file {min_iou_file}") 142 | print(f"Max IoU: {max_iou} for file {max_iou_file}") 143 | 144 | # panopticdb_dir = Path(__file__).parent.parent / "data" / "hand143_panopticdb" 145 | --------------------------------------------------------------------------------