├── .gitignore ├── LICENSE ├── README.md ├── anchors_face.npy ├── anchors_face_back.npy ├── anchors_palm.npy ├── anchors_pose.npy ├── blazebase.py ├── blazeface.pth ├── blazeface.py ├── blazeface_landmark.pth ├── blazeface_landmark.py ├── blazefaceback.pth ├── blazehand_landmark.pth ├── blazehand_landmark.py ├── blazepalm.pth ├── blazepalm.py ├── blazepose.pth ├── blazepose.py ├── blazepose_landmark.pth ├── blazepose_landmark.py ├── demo.py ├── demo1.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Conversion of MediaPipe from TFLite to PyTorch done by Zak Murez 2 | in June 2020. Website: https://zak.murez.com 3 | 4 | This work builds upon https://github.com/hollance/BlazeFace-PyTorch 5 | 6 | This work is licensed under the same terms as MediaPipe (Apache License 2.0) 7 | https://github.com/google/mediapipe/blob/master/LICENSE 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MediaPipe in PyTorch 2 | 3 | Port of MediaPipe (https://github.com/google/mediapipe) tflite models to PyTorch 4 | 5 | Builds upon the work of https://github.com/hollance/BlazeFace-PyTorch 6 | 7 | ```python demo.py``` 8 | 9 | ## Models ported so far 10 | 1. Face detector (BlazeFace) 11 | 1. Face landmarks 12 | 1. Palm detector 13 | 1. Hand landmarks 14 | 15 | ## TODO 16 | 1. Add conversion and verification scripts 17 | 1. Verify face landmark pipeline 18 | 1. Improve README and samples 19 | -------------------------------------------------------------------------------- /anchors_face.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/anchors_face.npy -------------------------------------------------------------------------------- /anchors_face_back.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/anchors_face_back.npy -------------------------------------------------------------------------------- /anchors_palm.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/anchors_palm.npy -------------------------------------------------------------------------------- /anchors_pose.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/anchors_pose.npy -------------------------------------------------------------------------------- /blazebase.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize_pad(img): 9 | """ resize and pad images to be input to the detectors 10 | 11 | The face and palm detector networks take 256x256 and 128x128 images 12 | as input. As such the input image is padded and resized to fit the 13 | size while maintaing the aspect ratio. 14 | 15 | Returns: 16 | img1: 256x256 17 | img2: 128x128 18 | scale: scale factor between original image and 256x256 image 19 | pad: pixels of padding in the original image 20 | """ 21 | 22 | size0 = img.shape 23 | if size0[0]>=size0[1]: 24 | h1 = 256 25 | w1 = 256 * size0[1] // size0[0] 26 | padh = 0 27 | padw = 256 - w1 28 | scale = size0[1] / w1 29 | else: 30 | h1 = 256 * size0[0] // size0[1] 31 | w1 = 256 32 | padh = 256 - h1 33 | padw = 0 34 | scale = size0[0] / h1 35 | padh1 = padh//2 36 | padh2 = padh//2 + padh%2 37 | padw1 = padw//2 38 | padw2 = padw//2 + padw%2 39 | img1 = cv2.resize(img, (w1,h1)) 40 | img1 = np.pad(img1, ((padh1, padh2), (padw1, padw2), (0,0))) 41 | pad = (int(padh1 * scale), int(padw1 * scale)) 42 | img2 = cv2.resize(img1, (128,128)) 43 | return img1, img2, scale, pad 44 | 45 | 46 | def denormalize_detections(detections, scale, pad): 47 | """ maps detection coordinates from [0,1] to image coordinates 48 | 49 | The face and palm detector networks take 256x256 and 128x128 images 50 | as input. As such the input image is padded and resized to fit the 51 | size while maintaing the aspect ratio. This function maps the 52 | normalized coordinates back to the original image coordinates. 53 | 54 | Inputs: 55 | detections: nxm tensor. n is the number of detections. 56 | m is 4+2*k where the first 4 valuse are the bounding 57 | box coordinates and k is the number of additional 58 | keypoints output by the detector. 59 | scale: scalar that was used to resize the image 60 | pad: padding in the x and y dimensions 61 | 62 | """ 63 | detections[:, 0] = detections[:, 0] * scale * 256 - pad[0] 64 | detections[:, 1] = detections[:, 1] * scale * 256 - pad[1] 65 | detections[:, 2] = detections[:, 2] * scale * 256 - pad[0] 66 | detections[:, 3] = detections[:, 3] * scale * 256 - pad[1] 67 | 68 | detections[:, 4::2] = detections[:, 4::2] * scale * 256 - pad[1] 69 | detections[:, 5::2] = detections[:, 5::2] * scale * 256 - pad[0] 70 | return detections 71 | 72 | 73 | 74 | 75 | class BlazeBlock(nn.Module): 76 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act='relu', skip_proj=False): 77 | super(BlazeBlock, self).__init__() 78 | 79 | self.stride = stride 80 | self.kernel_size = kernel_size 81 | self.channel_pad = out_channels - in_channels 82 | 83 | # TFLite uses slightly different padding than PyTorch 84 | # on the depthwise conv layer when the stride is 2. 85 | if stride == 2: 86 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 87 | padding = 0 88 | else: 89 | padding = (kernel_size - 1) // 2 90 | 91 | self.convs = nn.Sequential( 92 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, 93 | kernel_size=kernel_size, stride=stride, padding=padding, 94 | groups=in_channels, bias=True), 95 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 96 | kernel_size=1, stride=1, padding=0, bias=True), 97 | ) 98 | 99 | if skip_proj: 100 | self.skip_proj = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 101 | kernel_size=1, stride=1, padding=0, bias=True) 102 | else: 103 | self.skip_proj = None 104 | 105 | if act == 'relu': 106 | self.act = nn.ReLU(inplace=True) 107 | elif act == 'prelu': 108 | self.act = nn.PReLU(out_channels) 109 | else: 110 | raise NotImplementedError("unknown activation %s"%act) 111 | 112 | def forward(self, x): 113 | if self.stride == 2: 114 | if self.kernel_size==3: 115 | h = F.pad(x, (0, 2, 0, 2), "constant", 0) 116 | else: 117 | h = F.pad(x, (1, 2, 1, 2), "constant", 0) 118 | x = self.max_pool(x) 119 | else: 120 | h = x 121 | 122 | if self.skip_proj is not None: 123 | x = self.skip_proj(x) 124 | elif self.channel_pad > 0: 125 | x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0) 126 | 127 | 128 | return self.act(self.convs(h) + x) 129 | 130 | 131 | class FinalBlazeBlock(nn.Module): 132 | def __init__(self, channels, kernel_size=3): 133 | super(FinalBlazeBlock, self).__init__() 134 | 135 | # TFLite uses slightly different padding than PyTorch 136 | # on the depthwise conv layer when the stride is 2. 137 | self.convs = nn.Sequential( 138 | nn.Conv2d(in_channels=channels, out_channels=channels, 139 | kernel_size=kernel_size, stride=2, padding=0, 140 | groups=channels, bias=True), 141 | nn.Conv2d(in_channels=channels, out_channels=channels, 142 | kernel_size=1, stride=1, padding=0, bias=True), 143 | ) 144 | 145 | self.act = nn.ReLU(inplace=True) 146 | 147 | def forward(self, x): 148 | h = F.pad(x, (0, 2, 0, 2), "constant", 0) 149 | 150 | return self.act(self.convs(h)) 151 | 152 | 153 | class BlazeBase(nn.Module): 154 | """ Base class for media pipe models. """ 155 | 156 | def _device(self): 157 | """Which device (CPU or GPU) is being used by this model?""" 158 | return self.classifier_8.weight.device 159 | 160 | def load_weights(self, path): 161 | self.load_state_dict(torch.load(path)) 162 | self.eval() 163 | 164 | 165 | class BlazeLandmark(BlazeBase): 166 | """ Base class for landmark models. """ 167 | 168 | def extract_roi(self, frame, xc, yc, theta, scale): 169 | 170 | # take points on unit square and transform them according to the roi 171 | points = torch.tensor([[-1, -1, 1, 1], 172 | [-1, 1, -1, 1]], device=scale.device).view(1,2,4) 173 | points = points * scale.view(-1,1,1)/2 174 | theta = theta.view(-1, 1, 1) 175 | R = torch.cat(( 176 | torch.cat((torch.cos(theta), -torch.sin(theta)), 2), 177 | torch.cat((torch.sin(theta), torch.cos(theta)), 2), 178 | ), 1) 179 | center = torch.cat((xc.view(-1,1,1), yc.view(-1,1,1)), 1) 180 | points = R @ points + center 181 | 182 | # use the points to compute the affine transform that maps 183 | # these points back to the output square 184 | res = self.resolution 185 | points1 = np.array([[0, 0, res-1], 186 | [0, res-1, 0]], dtype=np.float32).T 187 | affines = [] 188 | imgs = [] 189 | for i in range(points.shape[0]): 190 | pts = points[i, :, :3].cpu().numpy().T 191 | M = cv2.getAffineTransform(pts, points1) 192 | img = cv2.warpAffine(frame, M, (res,res))#, borderValue=127.5) 193 | img = torch.tensor(img, device=scale.device) 194 | imgs.append(img) 195 | affine = cv2.invertAffineTransform(M).astype('float32') 196 | affine = torch.tensor(affine, device=scale.device) 197 | affines.append(affine) 198 | if imgs: 199 | imgs = torch.stack(imgs).permute(0,3,1,2).float() / 255.#/ 127.5 - 1.0 200 | affines = torch.stack(affines) 201 | else: 202 | imgs = torch.zeros((0, 3, res, res), device=scale.device) 203 | affines = torch.zeros((0, 2, 3), device=scale.device) 204 | 205 | return imgs, affines, points 206 | 207 | def denormalize_landmarks(self, landmarks, affines): 208 | landmarks[:,:,:2] *= self.resolution 209 | for i in range(len(landmarks)): 210 | landmark, affine = landmarks[i], affines[i] 211 | landmark = (affine[:,:2] @ landmark[:,:2].T + affine[:,2:]).T 212 | landmarks[i,:,:2] = landmark 213 | return landmarks 214 | 215 | 216 | 217 | class BlazeDetector(BlazeBase): 218 | """ Base class for detector models. 219 | 220 | Based on code from https://github.com/tkat0/PyTorch_BlazeFace/ and 221 | https://github.com/hollance/BlazeFace-PyTorch and 222 | https://github.com/google/mediapipe/ 223 | """ 224 | def load_anchors(self, path): 225 | self.anchors = torch.tensor(np.load(path), dtype=torch.float32, device=self._device()) 226 | assert(self.anchors.ndimension() == 2) 227 | assert(self.anchors.shape[0] == self.num_anchors) 228 | assert(self.anchors.shape[1] == 4) 229 | 230 | def _preprocess(self, x): 231 | """Converts the image pixels to the range [-1, 1].""" 232 | return x.float() / 255.# 127.5 - 1.0 233 | 234 | def predict_on_image(self, img): 235 | """Makes a prediction on a single image. 236 | 237 | Arguments: 238 | img: a NumPy array of shape (H, W, 3) or a PyTorch tensor of 239 | shape (3, H, W). The image's height and width should be 240 | 128 pixels. 241 | 242 | Returns: 243 | A tensor with face detections. 244 | """ 245 | if isinstance(img, np.ndarray): 246 | img = torch.from_numpy(img).permute((2, 0, 1)) 247 | 248 | return self.predict_on_batch(img.unsqueeze(0))[0] 249 | 250 | def predict_on_batch(self, x): 251 | """Makes a prediction on a batch of images. 252 | 253 | Arguments: 254 | x: a NumPy array of shape (b, H, W, 3) or a PyTorch tensor of 255 | shape (b, 3, H, W). The height and width should be 128 pixels. 256 | 257 | Returns: 258 | A list containing a tensor of face detections for each image in 259 | the batch. If no faces are found for an image, returns a tensor 260 | of shape (0, 17). 261 | 262 | Each face detection is a PyTorch tensor consisting of 17 numbers: 263 | - ymin, xmin, ymax, xmax 264 | - x,y-coordinates for the 6 keypoints 265 | - confidence score 266 | """ 267 | if isinstance(x, np.ndarray): 268 | x = torch.from_numpy(x).permute((0, 3, 1, 2)) 269 | 270 | assert x.shape[1] == 3 271 | assert x.shape[2] == self.y_scale 272 | assert x.shape[3] == self.x_scale 273 | 274 | # 1. Preprocess the images into tensors: 275 | x = x.to(self._device()) 276 | x = self._preprocess(x) 277 | 278 | # 2. Run the neural network: 279 | with torch.no_grad(): 280 | out = self.__call__(x) 281 | 282 | # 3. Postprocess the raw predictions: 283 | detections = self._tensors_to_detections(out[0], out[1], self.anchors) 284 | 285 | # 4. Non-maximum suppression to remove overlapping detections: 286 | filtered_detections = [] 287 | for i in range(len(detections)): 288 | faces = self._weighted_non_max_suppression(detections[i]) 289 | faces = torch.stack(faces) if len(faces) > 0 else torch.zeros((0, self.num_coords+1)) 290 | filtered_detections.append(faces) 291 | 292 | return filtered_detections 293 | 294 | 295 | def detection2roi(self, detection): 296 | """ Convert detections from detector to an oriented bounding box. 297 | 298 | Adapted from: 299 | # mediapipe/modules/face_landmark/face_detection_front_detection_to_roi.pbtxt 300 | 301 | The center and size of the box is calculated from the center 302 | of the detected box. Rotation is calcualted from the vector 303 | between kp1 and kp2 relative to theta0. The box is scaled 304 | and shifted by dscale and dy. 305 | 306 | """ 307 | if self.detection2roi_method == 'box': 308 | # compute box center and scale 309 | # use mediapipe/calculators/util/detections_to_rects_calculator.cc 310 | xc = (detection[:,1] + detection[:,3]) / 2 311 | yc = (detection[:,0] + detection[:,2]) / 2 312 | scale = (detection[:,3] - detection[:,1]) # assumes square boxes 313 | 314 | elif self.detection2roi_method == 'alignment': 315 | # compute box center and scale 316 | # use mediapipe/calculators/util/alignment_points_to_rects_calculator.cc 317 | xc = detection[:,4+2*self.kp1] 318 | yc = detection[:,4+2*self.kp1+1] 319 | x1 = detection[:,4+2*self.kp2] 320 | y1 = detection[:,4+2*self.kp2+1] 321 | scale = ((xc-x1)**2 + (yc-y1)**2).sqrt() * 2 322 | else: 323 | raise NotImplementedError( 324 | "detection2roi_method [%s] not supported"%self.detection2roi_method) 325 | 326 | yc += self.dy * scale 327 | scale *= self.dscale 328 | 329 | # compute box rotation 330 | x0 = detection[:,4+2*self.kp1] 331 | y0 = detection[:,4+2*self.kp1+1] 332 | x1 = detection[:,4+2*self.kp2] 333 | y1 = detection[:,4+2*self.kp2+1] 334 | #theta = np.arctan2(y0-y1, x0-x1) - self.theta0 335 | theta = torch.atan2(y0-y1, x0-x1) - self.theta0 336 | return xc, yc, scale, theta 337 | 338 | 339 | def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors): 340 | """The output of the neural network is a tensor of shape (b, 896, 16) 341 | containing the bounding box regressor predictions, as well as a tensor 342 | of shape (b, 896, 1) with the classification confidences. 343 | 344 | This function converts these two "raw" tensors into proper detections. 345 | Returns a list of (num_detections, 17) tensors, one for each image in 346 | the batch. 347 | 348 | This is based on the source code from: 349 | mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc 350 | mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto 351 | """ 352 | assert raw_box_tensor.ndimension() == 3 353 | assert raw_box_tensor.shape[1] == self.num_anchors 354 | assert raw_box_tensor.shape[2] == self.num_coords 355 | 356 | assert raw_score_tensor.ndimension() == 3 357 | assert raw_score_tensor.shape[1] == self.num_anchors 358 | assert raw_score_tensor.shape[2] == self.num_classes 359 | 360 | assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0] 361 | 362 | detection_boxes = self._decode_boxes(raw_box_tensor, anchors) 363 | 364 | thresh = self.score_clipping_thresh 365 | raw_score_tensor = raw_score_tensor.clamp(-thresh, thresh) 366 | detection_scores = raw_score_tensor.sigmoid().squeeze(dim=-1) 367 | 368 | # Note: we stripped off the last dimension from the scores tensor 369 | # because there is only has one class. Now we can simply use a mask 370 | # to filter out the boxes with too low confidence. 371 | mask = detection_scores >= self.min_score_thresh 372 | 373 | # Because each image from the batch can have a different number of 374 | # detections, process them one at a time using a loop. 375 | output_detections = [] 376 | for i in range(raw_box_tensor.shape[0]): 377 | boxes = detection_boxes[i, mask[i]] 378 | scores = detection_scores[i, mask[i]].unsqueeze(dim=-1) 379 | output_detections.append(torch.cat((boxes, scores), dim=-1)) 380 | 381 | return output_detections 382 | 383 | def _decode_boxes(self, raw_boxes, anchors): 384 | """Converts the predictions into actual coordinates using 385 | the anchor boxes. Processes the entire batch at once. 386 | """ 387 | boxes = torch.zeros_like(raw_boxes) 388 | 389 | x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0] 390 | y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1] 391 | 392 | w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2] 393 | h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3] 394 | 395 | boxes[..., 0] = y_center - h / 2. # ymin 396 | boxes[..., 1] = x_center - w / 2. # xmin 397 | boxes[..., 2] = y_center + h / 2. # ymax 398 | boxes[..., 3] = x_center + w / 2. # xmax 399 | 400 | for k in range(self.num_keypoints): 401 | offset = 4 + k*2 402 | keypoint_x = raw_boxes[..., offset ] / self.x_scale * anchors[:, 2] + anchors[:, 0] 403 | keypoint_y = raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3] + anchors[:, 1] 404 | boxes[..., offset ] = keypoint_x 405 | boxes[..., offset + 1] = keypoint_y 406 | 407 | return boxes 408 | 409 | def _weighted_non_max_suppression(self, detections): 410 | """The alternative NMS method as mentioned in the BlazeFace paper: 411 | 412 | "We replace the suppression algorithm with a blending strategy that 413 | estimates the regression parameters of a bounding box as a weighted 414 | mean between the overlapping predictions." 415 | 416 | The original MediaPipe code assigns the score of the most confident 417 | detection to the weighted detection, but we take the average score 418 | of the overlapping detections. 419 | 420 | The input detections should be a Tensor of shape (count, 17). 421 | 422 | Returns a list of PyTorch tensors, one for each detected face. 423 | 424 | This is based on the source code from: 425 | mediapipe/calculators/util/non_max_suppression_calculator.cc 426 | mediapipe/calculators/util/non_max_suppression_calculator.proto 427 | """ 428 | if len(detections) == 0: return [] 429 | 430 | output_detections = [] 431 | 432 | # Sort the detections from highest to lowest score. 433 | remaining = torch.argsort(detections[:, self.num_coords], descending=True) 434 | 435 | while len(remaining) > 0: 436 | detection = detections[remaining[0]] 437 | 438 | # Compute the overlap between the first box and the other 439 | # remaining boxes. (Note that the other_boxes also include 440 | # the first_box.) 441 | first_box = detection[:4] 442 | other_boxes = detections[remaining, :4] 443 | ious = overlap_similarity(first_box, other_boxes) 444 | 445 | # If two detections don't overlap enough, they are considered 446 | # to be from different faces. 447 | mask = ious > self.min_suppression_threshold 448 | overlapping = remaining[mask] 449 | remaining = remaining[~mask] 450 | 451 | # Take an average of the coordinates from the overlapping 452 | # detections, weighted by their confidence scores. 453 | weighted_detection = detection.clone() 454 | if len(overlapping) > 1: 455 | coordinates = detections[overlapping, :self.num_coords] 456 | scores = detections[overlapping, self.num_coords:self.num_coords+1] 457 | total_score = scores.sum() 458 | weighted = (coordinates * scores).sum(dim=0) / total_score 459 | weighted_detection[:self.num_coords] = weighted 460 | weighted_detection[self.num_coords] = total_score / len(overlapping) 461 | 462 | output_detections.append(weighted_detection) 463 | 464 | return output_detections 465 | 466 | 467 | # IOU code from https://github.com/amdegroot/ssd.pytorch/blob/master/layers/box_utils.py 468 | 469 | def intersect(box_a, box_b): 470 | """ We resize both tensors to [A,B,2] without new malloc: 471 | [A,2] -> [A,1,2] -> [A,B,2] 472 | [B,2] -> [1,B,2] -> [A,B,2] 473 | Then we compute the area of intersect between box_a and box_b. 474 | Args: 475 | box_a: (tensor) bounding boxes, Shape: [A,4]. 476 | box_b: (tensor) bounding boxes, Shape: [B,4]. 477 | Return: 478 | (tensor) intersection area, Shape: [A,B]. 479 | """ 480 | A = box_a.size(0) 481 | B = box_b.size(0) 482 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 483 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 484 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 485 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 486 | inter = torch.clamp((max_xy - min_xy), min=0) 487 | return inter[:, :, 0] * inter[:, :, 1] 488 | 489 | 490 | def jaccard(box_a, box_b): 491 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 492 | is simply the intersection over union of two boxes. Here we operate on 493 | ground truth boxes and default boxes. 494 | E.g.: 495 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 496 | Args: 497 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 498 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 499 | Return: 500 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 501 | """ 502 | inter = intersect(box_a, box_b) 503 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 504 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 505 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 506 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 507 | union = area_a + area_b - inter 508 | return inter / union # [A,B] 509 | 510 | 511 | def overlap_similarity(box, other_boxes): 512 | """Computes the IOU between a bounding box and set of other boxes.""" 513 | return jaccard(box.unsqueeze(0), other_boxes).squeeze(0) 514 | -------------------------------------------------------------------------------- /blazeface.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazeface.pth -------------------------------------------------------------------------------- /blazeface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from blazebase import BlazeDetector, BlazeBlock, FinalBlazeBlock 7 | 8 | 9 | class BlazeFace(BlazeDetector): 10 | """The BlazeFace face detection model from MediaPipe. 11 | 12 | The version from MediaPipe is simpler than the one in the paper; 13 | it does not use the "double" BlazeBlocks. 14 | 15 | Because we won't be training this model, it doesn't need to have 16 | batchnorm layers. These have already been "folded" into the conv 17 | weights by TFLite. 18 | 19 | The conversion to PyTorch is fairly straightforward, but there are 20 | some small differences between TFLite and PyTorch in how they handle 21 | padding on conv layers with stride 2. 22 | 23 | This version works on batches, while the MediaPipe version can only 24 | handle a single image at a time. 25 | 26 | Based on code from https://github.com/tkat0/PyTorch_BlazeFace/ and 27 | https://github.com/hollance/BlazeFace-PyTorch and 28 | https://github.com/google/mediapipe/ 29 | 30 | """ 31 | def __init__(self, back_model=False): 32 | super(BlazeFace, self).__init__() 33 | 34 | # These are the settings from the MediaPipe example graph 35 | # mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt 36 | self.num_classes = 1 37 | self.num_anchors = 896 38 | self.num_coords = 16 39 | self.score_clipping_thresh = 100.0 40 | self.back_model = back_model 41 | if back_model: 42 | self.x_scale = 256.0 43 | self.y_scale = 256.0 44 | self.h_scale = 256.0 45 | self.w_scale = 256.0 46 | self.min_score_thresh = 0.65 47 | else: 48 | self.x_scale = 128.0 49 | self.y_scale = 128.0 50 | self.h_scale = 128.0 51 | self.w_scale = 128.0 52 | self.min_score_thresh = 0.75 53 | self.min_suppression_threshold = 0.3 54 | self.num_keypoints = 6 55 | 56 | # These settings are for converting detections to ROIs which can then 57 | # be extracted and feed into the landmark network 58 | # use mediapipe/calculators/util/detections_to_rects_calculator.cc 59 | self.detection2roi_method = 'box' 60 | # mediapipe/modules/face_landmark/face_detection_front_detection_to_roi.pbtxt 61 | self.kp1 = 1 62 | self.kp2 = 0 63 | self.theta0 = 0. 64 | self.dscale = 1.5 65 | self.dy = 0. 66 | 67 | self._define_layers() 68 | 69 | def _define_layers(self): 70 | if self.back_model: 71 | self.backbone = nn.Sequential( 72 | nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True), 73 | nn.ReLU(inplace=True), 74 | 75 | BlazeBlock(24, 24), 76 | BlazeBlock(24, 24), 77 | BlazeBlock(24, 24), 78 | BlazeBlock(24, 24), 79 | BlazeBlock(24, 24), 80 | BlazeBlock(24, 24), 81 | BlazeBlock(24, 24), 82 | BlazeBlock(24, 24, stride=2), 83 | BlazeBlock(24, 24), 84 | BlazeBlock(24, 24), 85 | BlazeBlock(24, 24), 86 | BlazeBlock(24, 24), 87 | BlazeBlock(24, 24), 88 | BlazeBlock(24, 24), 89 | BlazeBlock(24, 24), 90 | BlazeBlock(24, 48, stride=2), 91 | BlazeBlock(48, 48), 92 | BlazeBlock(48, 48), 93 | BlazeBlock(48, 48), 94 | BlazeBlock(48, 48), 95 | BlazeBlock(48, 48), 96 | BlazeBlock(48, 48), 97 | BlazeBlock(48, 48), 98 | BlazeBlock(48, 96, stride=2), 99 | BlazeBlock(96, 96), 100 | BlazeBlock(96, 96), 101 | BlazeBlock(96, 96), 102 | BlazeBlock(96, 96), 103 | BlazeBlock(96, 96), 104 | BlazeBlock(96, 96), 105 | BlazeBlock(96, 96), 106 | ) 107 | self.final = FinalBlazeBlock(96) 108 | self.classifier_8 = nn.Conv2d(96, 2, 1, bias=True) 109 | self.classifier_16 = nn.Conv2d(96, 6, 1, bias=True) 110 | 111 | self.regressor_8 = nn.Conv2d(96, 32, 1, bias=True) 112 | self.regressor_16 = nn.Conv2d(96, 96, 1, bias=True) 113 | else: 114 | self.backbone1 = nn.Sequential( 115 | nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True), 116 | nn.ReLU(inplace=True), 117 | 118 | BlazeBlock(24, 24), 119 | BlazeBlock(24, 28), 120 | BlazeBlock(28, 32, stride=2), 121 | BlazeBlock(32, 36), 122 | BlazeBlock(36, 42), 123 | BlazeBlock(42, 48, stride=2), 124 | BlazeBlock(48, 56), 125 | BlazeBlock(56, 64), 126 | BlazeBlock(64, 72), 127 | BlazeBlock(72, 80), 128 | BlazeBlock(80, 88), 129 | ) 130 | 131 | self.backbone2 = nn.Sequential( 132 | BlazeBlock(88, 96, stride=2), 133 | BlazeBlock(96, 96), 134 | BlazeBlock(96, 96), 135 | BlazeBlock(96, 96), 136 | BlazeBlock(96, 96), 137 | ) 138 | 139 | self.classifier_8 = nn.Conv2d(88, 2, 1, bias=True) 140 | self.classifier_16 = nn.Conv2d(96, 6, 1, bias=True) 141 | 142 | self.regressor_8 = nn.Conv2d(88, 32, 1, bias=True) 143 | self.regressor_16 = nn.Conv2d(96, 96, 1, bias=True) 144 | 145 | def forward(self, x): 146 | # TFLite uses slightly different padding on the first conv layer 147 | # than PyTorch, so do it manually. 148 | x = F.pad(x, (1, 2, 1, 2), "constant", 0) 149 | 150 | b = x.shape[0] # batch size, needed for reshaping later 151 | 152 | if self.back_model: 153 | x = self.backbone(x) # (b, 16, 16, 96) 154 | h = self.final(x) # (b, 8, 8, 96) 155 | else: 156 | x = self.backbone1(x) # (b, 88, 16, 16) 157 | h = self.backbone2(x) # (b, 96, 8, 8) 158 | 159 | # Note: Because PyTorch is NCHW but TFLite is NHWC, we need to 160 | # permute the output from the conv layers before reshaping it. 161 | 162 | c1 = self.classifier_8(x) # (b, 2, 16, 16) 163 | c1 = c1.permute(0, 2, 3, 1) # (b, 16, 16, 2) 164 | c1 = c1.reshape(b, -1, 1) # (b, 512, 1) 165 | 166 | c2 = self.classifier_16(h) # (b, 6, 8, 8) 167 | c2 = c2.permute(0, 2, 3, 1) # (b, 8, 8, 6) 168 | c2 = c2.reshape(b, -1, 1) # (b, 384, 1) 169 | 170 | c = torch.cat((c1, c2), dim=1) # (b, 896, 1) 171 | 172 | r1 = self.regressor_8(x) # (b, 32, 16, 16) 173 | r1 = r1.permute(0, 2, 3, 1) # (b, 16, 16, 32) 174 | r1 = r1.reshape(b, -1, 16) # (b, 512, 16) 175 | 176 | r2 = self.regressor_16(h) # (b, 96, 8, 8) 177 | r2 = r2.permute(0, 2, 3, 1) # (b, 8, 8, 96) 178 | r2 = r2.reshape(b, -1, 16) # (b, 384, 16) 179 | 180 | r = torch.cat((r1, r2), dim=1) # (b, 896, 16) 181 | return [r, c] 182 | 183 | -------------------------------------------------------------------------------- /blazeface_landmark.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazeface_landmark.pth -------------------------------------------------------------------------------- /blazeface_landmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from blazebase import BlazeLandmark, BlazeBlock 7 | 8 | class BlazeFaceLandmark(BlazeLandmark): 9 | """The face landmark model from MediaPipe. 10 | 11 | """ 12 | def __init__(self): 13 | super(BlazeFaceLandmark, self).__init__() 14 | 15 | # size of ROIs used for input 16 | self.resolution = 192 17 | 18 | self._define_layers() 19 | 20 | def _define_layers(self): 21 | self.backbone1 = nn.Sequential( 22 | nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=0, bias=True), 23 | nn.PReLU(16), 24 | 25 | BlazeBlock(16, 16, 3, act='prelu'), 26 | BlazeBlock(16, 16, 3, act='prelu'), 27 | BlazeBlock(16, 32, 3, 2, act='prelu'), 28 | 29 | BlazeBlock(32, 32, 3, act='prelu'), 30 | BlazeBlock(32, 32, 3, act='prelu'), 31 | BlazeBlock(32, 64, 3, 2, act='prelu'), 32 | 33 | BlazeBlock(64, 64, 3, act='prelu'), 34 | BlazeBlock(64, 64, 3, act='prelu'), 35 | BlazeBlock(64, 128, 3, 2, act='prelu'), 36 | 37 | BlazeBlock(128, 128, 3, act='prelu'), 38 | BlazeBlock(128, 128, 3, act='prelu'), 39 | BlazeBlock(128, 128, 3, 2, act='prelu'), 40 | 41 | BlazeBlock(128, 128, 3, act='prelu'), 42 | BlazeBlock(128, 128, 3, act='prelu'), 43 | ) 44 | 45 | 46 | self.backbone2a = nn.Sequential( 47 | BlazeBlock(128, 128, 3, 2, act='prelu'), 48 | BlazeBlock(128, 128, 3, act='prelu'), 49 | BlazeBlock(128, 128, 3, act='prelu'), 50 | nn.Conv2d(128, 32, 1, padding=0, bias=True), 51 | nn.PReLU(32), 52 | BlazeBlock(32, 32, 3, act='prelu'), 53 | nn.Conv2d(32, 1404, 3, padding=0, bias=True) 54 | ) 55 | 56 | self.backbone2b = nn.Sequential( 57 | BlazeBlock(128, 128, 3, 2, act='prelu'), 58 | nn.Conv2d(128, 32, 1, padding=0, bias=True), 59 | nn.PReLU(32), 60 | BlazeBlock(32, 32, 3, act='prelu'), 61 | nn.Conv2d(32, 1, 3, padding=0, bias=True) 62 | ) 63 | 64 | def forward(self, x): 65 | if x.shape[0] == 0: 66 | return torch.zeros((0,)), torch.zeros((0, 468, 3)) 67 | 68 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 69 | 70 | x = self.backbone1(x) 71 | landmarks = self.backbone2a(x).view(-1, 468, 3) / 192 72 | flag = self.backbone2b(x).sigmoid().view(-1) 73 | 74 | return flag, landmarks -------------------------------------------------------------------------------- /blazefaceback.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazefaceback.pth -------------------------------------------------------------------------------- /blazehand_landmark.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazehand_landmark.pth -------------------------------------------------------------------------------- /blazehand_landmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from blazebase import BlazeLandmark, BlazeBlock 7 | 8 | class BlazeHandLandmark(BlazeLandmark): 9 | """The hand landmark model from MediaPipe. 10 | 11 | """ 12 | def __init__(self): 13 | super(BlazeHandLandmark, self).__init__() 14 | 15 | # size of ROIs used for input 16 | self.resolution = 256 17 | 18 | self._define_layers() 19 | 20 | def _define_layers(self): 21 | self.backbone1 = nn.Sequential( 22 | nn.Conv2d(in_channels=3, out_channels=24, kernel_size=3, stride=2, padding=0, bias=True), 23 | nn.ReLU(inplace=True), 24 | 25 | BlazeBlock(24, 24, 5), 26 | BlazeBlock(24, 24, 5), 27 | BlazeBlock(24, 48, 5, 2), 28 | ) 29 | 30 | self.backbone2 = nn.Sequential( 31 | BlazeBlock(48, 48, 5), 32 | BlazeBlock(48, 48, 5), 33 | BlazeBlock(48, 96, 5, 2), 34 | ) 35 | 36 | self.backbone3 = nn.Sequential( 37 | BlazeBlock(96, 96, 5), 38 | BlazeBlock(96, 96, 5), 39 | BlazeBlock(96, 96, 5, 2), 40 | ) 41 | 42 | self.backbone4 = nn.Sequential( 43 | BlazeBlock(96, 96, 5), 44 | BlazeBlock(96, 96, 5), 45 | BlazeBlock(96, 96, 5, 2), 46 | ) 47 | 48 | self.blaze5 = BlazeBlock(96, 96, 5) 49 | self.blaze6 = BlazeBlock(96, 96, 5) 50 | self.conv7 = nn.Conv2d(96, 48, 1, bias=True) 51 | 52 | self.backbone8 = nn.Sequential( 53 | BlazeBlock(48, 48, 5), 54 | BlazeBlock(48, 48, 5), 55 | BlazeBlock(48, 48, 5), 56 | BlazeBlock(48, 48, 5), 57 | BlazeBlock(48, 96, 5, 2), 58 | BlazeBlock(96, 96, 5), 59 | BlazeBlock(96, 96, 5), 60 | BlazeBlock(96, 96, 5), 61 | BlazeBlock(96, 96, 5), 62 | BlazeBlock(96, 288, 5, 2), 63 | BlazeBlock(288, 288, 5), 64 | BlazeBlock(288, 288, 5), 65 | BlazeBlock(288, 288, 5), 66 | BlazeBlock(288, 288, 5), 67 | BlazeBlock(288, 288, 5, 2), 68 | BlazeBlock(288, 288, 5), 69 | BlazeBlock(288, 288, 5), 70 | BlazeBlock(288, 288, 5), 71 | BlazeBlock(288, 288, 5), 72 | BlazeBlock(288, 288, 5, 2), 73 | BlazeBlock(288, 288, 5), 74 | BlazeBlock(288, 288, 5), 75 | BlazeBlock(288, 288, 5), 76 | BlazeBlock(288, 288, 5), 77 | BlazeBlock(288, 288, 5, 2), 78 | BlazeBlock(288, 288, 5), 79 | BlazeBlock(288, 288, 5), 80 | BlazeBlock(288, 288, 5), 81 | BlazeBlock(288, 288, 5), 82 | ) 83 | 84 | self.hand_flag = nn.Conv2d(288, 1, 2, bias=True) 85 | self.handed = nn.Conv2d(288, 1, 2, bias=True) 86 | self.landmarks = nn.Conv2d(288, 63, 2, bias=True) 87 | 88 | 89 | def forward(self, x): 90 | if x.shape[0] == 0: 91 | return torch.zeros((0,)), torch.zeros((0,)), torch.zeros((0, 21, 3)) 92 | 93 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 94 | 95 | x = self.backbone1(x) 96 | y = self.backbone2(x) 97 | z = self.backbone3(y) 98 | w = self.backbone4(z) 99 | 100 | z = z + F.interpolate(w, scale_factor=2, mode='bilinear') 101 | z = self.blaze5(z) 102 | 103 | y = y + F.interpolate(z, scale_factor=2, mode='bilinear') 104 | y = self.blaze6(y) 105 | y = self.conv7(y) 106 | 107 | x = x + F.interpolate(y, scale_factor=2, mode='bilinear') 108 | 109 | x = self.backbone8(x) 110 | 111 | hand_flag = self.hand_flag(x).view(-1).sigmoid() 112 | handed = self.handed(x).view(-1).sigmoid() 113 | landmarks = self.landmarks(x).view(-1, 21, 3) / 256 114 | 115 | return hand_flag, handed, landmarks -------------------------------------------------------------------------------- /blazepalm.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazepalm.pth -------------------------------------------------------------------------------- /blazepalm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from blazebase import BlazeDetector, BlazeBlock 7 | 8 | 9 | class BlazePalm(BlazeDetector): 10 | """The palm detection model from MediaPipe. """ 11 | def __init__(self): 12 | super(BlazePalm, self).__init__() 13 | 14 | # These are the settings from the MediaPipe example graph 15 | # mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt 16 | self.num_classes = 1 17 | self.num_anchors = 2944 18 | self.num_coords = 18 19 | self.score_clipping_thresh = 100.0 20 | self.x_scale = 256.0 21 | self.y_scale = 256.0 22 | self.h_scale = 256.0 23 | self.w_scale = 256.0 24 | self.min_score_thresh = 0.5 25 | self.min_suppression_threshold = 0.3 26 | self.num_keypoints = 7 27 | 28 | # These settings are for converting detections to ROIs which can then 29 | # be extracted and feed into the landmark network 30 | # use mediapipe/calculators/util/detections_to_rects_calculator.cc 31 | self.detection2roi_method = 'box' 32 | # mediapipe/graphs/hand_tracking/subgraphs/hand_detection_cpu.pbtxt 33 | self.kp1 = 0 34 | self.kp2 = 2 35 | self.theta0 = np.pi/2 36 | self.dscale = 2.6 37 | self.dy = -0.5 38 | 39 | self._define_layers() 40 | 41 | def _define_layers(self): 42 | self.backbone1 = nn.Sequential( 43 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=0, bias=True), 44 | nn.ReLU(inplace=True), 45 | 46 | BlazeBlock(32, 32), 47 | BlazeBlock(32, 32), 48 | BlazeBlock(32, 32), 49 | BlazeBlock(32, 32), 50 | BlazeBlock(32, 32), 51 | BlazeBlock(32, 32), 52 | BlazeBlock(32, 32), 53 | 54 | BlazeBlock(32, 64, stride=2), 55 | BlazeBlock(64, 64), 56 | BlazeBlock(64, 64), 57 | BlazeBlock(64, 64), 58 | BlazeBlock(64, 64), 59 | BlazeBlock(64, 64), 60 | BlazeBlock(64, 64), 61 | BlazeBlock(64, 64), 62 | 63 | BlazeBlock(64, 128, stride=2), 64 | BlazeBlock(128, 128), 65 | BlazeBlock(128, 128), 66 | BlazeBlock(128, 128), 67 | BlazeBlock(128, 128), 68 | BlazeBlock(128, 128), 69 | BlazeBlock(128, 128), 70 | BlazeBlock(128, 128), 71 | 72 | ) 73 | 74 | self.backbone2 = nn.Sequential( 75 | BlazeBlock(128, 256, stride=2), 76 | BlazeBlock(256, 256), 77 | BlazeBlock(256, 256), 78 | BlazeBlock(256, 256), 79 | BlazeBlock(256, 256), 80 | BlazeBlock(256, 256), 81 | BlazeBlock(256, 256), 82 | BlazeBlock(256, 256), 83 | ) 84 | 85 | self.backbone3 = nn.Sequential( 86 | BlazeBlock(256, 256, stride=2), 87 | BlazeBlock(256, 256), 88 | BlazeBlock(256, 256), 89 | BlazeBlock(256, 256), 90 | BlazeBlock(256, 256), 91 | BlazeBlock(256, 256), 92 | BlazeBlock(256, 256), 93 | BlazeBlock(256, 256), 94 | ) 95 | 96 | self.conv_transpose1 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2, padding=0, bias=True) 97 | self.blaze1 = BlazeBlock(256, 256) 98 | 99 | self.conv_transpose2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0, bias=True) 100 | self.blaze2 = BlazeBlock(128, 128) 101 | 102 | self.classifier_32 = nn.Conv2d(128, 2, 1, bias=True) 103 | self.classifier_16 = nn.Conv2d(256, 2, 1, bias=True) 104 | self.classifier_8 = nn.Conv2d(256, 6, 1, bias=True) 105 | 106 | self.regressor_32 = nn.Conv2d(128, 36, 1, bias=True) 107 | self.regressor_16 = nn.Conv2d(256, 36, 1, bias=True) 108 | self.regressor_8 = nn.Conv2d(256, 108, 1, bias=True) 109 | 110 | def forward(self, x): 111 | b = x.shape[0] # batch size, needed for reshaping later 112 | 113 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 114 | 115 | x = self.backbone1(x) # (b, 128, 32, 32) 116 | y = self.backbone2(x) # (b, 256, 16, 16) 117 | z = self.backbone3(y) # (b, 256, 8, 8) 118 | 119 | y = y + F.relu(self.conv_transpose1(z), True) 120 | y = self.blaze1(y) 121 | 122 | x = x + F.relu(self.conv_transpose2(y), True) 123 | x = self.blaze2(x) 124 | 125 | 126 | # Note: Because PyTorch is NCHW but TFLite is NHWC, we need to 127 | # permute the output from the conv layers before reshaping it. 128 | 129 | c1 = self.classifier_8(z) # (b, 2, 16, 16) 130 | c1 = c1.permute(0, 2, 3, 1) # (b, 16, 16, 2) 131 | c1 = c1.reshape(b, -1, 1) # (b, 512, 1) 132 | 133 | c2 = self.classifier_16(y) # (b, 6, 8, 8) 134 | c2 = c2.permute(0, 2, 3, 1) # (b, 8, 8, 6) 135 | c2 = c2.reshape(b, -1, 1) # (b, 384, 1) 136 | 137 | c3 = self.classifier_32(x) # (b, 6, 8, 8) 138 | c3 = c3.permute(0, 2, 3, 1) # (b, 8, 8, 6) 139 | c3 = c3.reshape(b, -1, 1) # (b, 384, 1) 140 | 141 | c = torch.cat((c3, c2, c1), dim=1) # (b, 896, 1) 142 | 143 | r1 = self.regressor_8(z) # (b, 32, 16, 16) 144 | r1 = r1.permute(0, 2, 3, 1) # (b, 16, 16, 32) 145 | r1 = r1.reshape(b, -1, 18) # (b, 512, 16) 146 | 147 | r2 = self.regressor_16(y) # (b, 96, 8, 8) 148 | r2 = r2.permute(0, 2, 3, 1) # (b, 8, 8, 96) 149 | r2 = r2.reshape(b, -1, 18) # (b, 384, 16) 150 | 151 | r3 = self.regressor_32(x) # (b, 96, 8, 8) 152 | r3 = r3.permute(0, 2, 3, 1) # (b, 8, 8, 96) 153 | r3 = r3.reshape(b, -1, 18) # (b, 384, 16) 154 | 155 | r = torch.cat((r3, r2, r1), dim=1) # (b, 896, 16) 156 | 157 | return [r, c] 158 | -------------------------------------------------------------------------------- /blazepose.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazepose.pth -------------------------------------------------------------------------------- /blazepose.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from blazebase import BlazeDetector, BlazeBlock 7 | 8 | 9 | class BlazePose(BlazeDetector): 10 | """The BlazePose pose detection model from MediaPipe. 11 | 12 | Because we won't be training this model, it doesn't need to have 13 | batchnorm layers. These have already been "folded" into the conv 14 | weights by TFLite. 15 | 16 | The conversion to PyTorch is fairly straightforward, but there are 17 | some small differences between TFLite and PyTorch in how they handle 18 | padding on conv layers with stride 2. 19 | 20 | This version works on batches, while the MediaPipe version can only 21 | handle a single image at a time. 22 | 23 | Based on code from https://github.com/tkat0/PyTorch_BlazeFace/ and 24 | https://github.com/google/mediapipe/ 25 | """ 26 | def __init__(self): 27 | super(BlazePose, self).__init__() 28 | 29 | # These are the settings from the MediaPipe example graph 30 | # mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt 31 | self.num_classes = 1 32 | self.num_anchors = 896 33 | self.num_coords = 12 34 | self.score_clipping_thresh = 100.0 35 | self.x_scale = 128.0 36 | self.y_scale = 128.0 37 | self.h_scale = 128.0 38 | self.w_scale = 128.0 39 | self.min_score_thresh = 0.75 40 | self.min_suppression_threshold = 0.3 41 | self.num_keypoints = 4 42 | 43 | # These settings are for converting detections to ROIs which can then 44 | # be extracted and feed into the landmark network 45 | # use mediapipe/calculators/util/alignment_points_to_rects_calculator.cc 46 | self.detection2roi_method = 'alignment' 47 | # mediapipe/modules/pose_landmark/pose_detection_to_roi.pbtxt 48 | self.kp1 = 2 49 | self.kp2 = 3 50 | self.theta0 = 90 * np.pi / 180 51 | self.dscale = 1.5 52 | self.dy = 0. 53 | 54 | self._define_layers() 55 | 56 | def _define_layers(self): 57 | self.backbone1 = nn.Sequential( 58 | nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, stride=1, padding=2, bias=True), 59 | nn.ReLU(inplace=True), 60 | 61 | BlazeBlock(48, 48, 5), 62 | BlazeBlock(48, 48, 5), 63 | BlazeBlock(48, 48, 5), 64 | BlazeBlock(48, 64, 5, 2, skip_proj=True), 65 | 66 | BlazeBlock(64, 64, 5), 67 | BlazeBlock(64, 64, 5), 68 | BlazeBlock(64, 64, 5), 69 | BlazeBlock(64, 96, 5, 2, skip_proj=True), 70 | 71 | BlazeBlock(96, 96, 5), 72 | BlazeBlock(96, 96, 5), 73 | BlazeBlock(96, 96, 5), 74 | BlazeBlock(96, 96, 5), 75 | BlazeBlock(96, 96, 5), 76 | BlazeBlock(96, 96, 5), 77 | BlazeBlock(96, 128, 5, 2, skip_proj=True), 78 | 79 | BlazeBlock(128, 128, 5), 80 | BlazeBlock(128, 128, 5), 81 | BlazeBlock(128, 128, 5), 82 | BlazeBlock(128, 128, 5), 83 | BlazeBlock(128, 128, 5), 84 | BlazeBlock(128, 128, 5), 85 | BlazeBlock(128, 128, 5), 86 | 87 | ) 88 | 89 | self.backbone2 = nn.Sequential( 90 | BlazeBlock(128, 256, 5, 2, skip_proj=True), 91 | BlazeBlock(256, 256, 5), 92 | BlazeBlock(256, 256, 5), 93 | BlazeBlock(256, 256, 5), 94 | BlazeBlock(256, 256, 5), 95 | BlazeBlock(256, 256, 5), 96 | BlazeBlock(256, 256, 5), 97 | 98 | ) 99 | 100 | self.classifier_8 = nn.Conv2d(128, 2, 1, bias=True) 101 | self.classifier_16 = nn.Conv2d(256, 6, 1, bias=True) 102 | 103 | self.regressor_8 = nn.Conv2d(128, 24, 1, bias=True) 104 | self.regressor_16 = nn.Conv2d(256, 72, 1, bias=True) 105 | 106 | def forward(self, x): 107 | # TFLite uses slightly different padding on the first conv layer 108 | # than PyTorch, so do it manually. 109 | # x = F.pad(x, (1, 2, 1, 2), "constant", 0) 110 | 111 | b = x.shape[0] # batch size, needed for reshaping later 112 | 113 | x = self.backbone1(x) 114 | h = self.backbone2(x) 115 | 116 | # Note: Because PyTorch is NCHW but TFLite is NHWC, we need to 117 | # permute the output from the conv layers before reshaping it. 118 | 119 | c1 = self.classifier_8(x) 120 | # print(c1) 121 | # print(c1.shape) 122 | 123 | c1 = c1.permute(0, 2, 3, 1) 124 | c1 = c1.reshape(b, -1, 1) 125 | 126 | c2 = self.classifier_16(h) 127 | c2 = c2.permute(0, 2, 3, 1) 128 | c2 = c2.reshape(b, -1, 1) 129 | 130 | c = torch.cat((c1, c2), dim=1) 131 | 132 | r1 = self.regressor_8(x) 133 | r1 = r1.permute(0, 2, 3, 1) 134 | r1 = r1.reshape(b, -1, 12) 135 | 136 | r2 = self.regressor_16(h) 137 | r2 = r2.permute(0, 2, 3, 1) 138 | r2 = r2.reshape(b, -1, 12) 139 | 140 | r = torch.cat((r1, r2), dim=1) 141 | return [r, c] -------------------------------------------------------------------------------- /blazepose_landmark.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/MediaPipePyTorch/65f2549ba35cd61dfd29f402f6c21882a32fabb1/blazepose_landmark.pth -------------------------------------------------------------------------------- /blazepose_landmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from blazebase import BlazeLandmark, BlazeBlock 7 | 8 | class BlazePoseLandmark(BlazeLandmark): 9 | """The hand landmark model from MediaPipe. 10 | 11 | """ 12 | def __init__(self): 13 | super(BlazePoseLandmark, self).__init__() 14 | 15 | # size of ROIs used for input 16 | self.resolution = 256 17 | 18 | self._define_layers() 19 | 20 | def _define_layers(self): 21 | self.backbone1 = nn.Sequential( 22 | nn.Conv2d(in_channels=3, out_channels=24, kernel_size=3, stride=2, padding=0, bias=True), 23 | nn.ReLU(inplace=True), 24 | 25 | BlazeBlock(24, 24, 3), 26 | BlazeBlock(24, 24, 3), 27 | ) 28 | 29 | self.backbone2 = nn.Sequential( 30 | BlazeBlock(24, 48, 3, 2), 31 | BlazeBlock(48, 48, 3), 32 | BlazeBlock(48, 48, 3), 33 | BlazeBlock(48, 48, 3), 34 | ) 35 | 36 | self.backbone3 = nn.Sequential( 37 | BlazeBlock(48, 96, 3, 2), 38 | BlazeBlock(96, 96, 3), 39 | BlazeBlock(96, 96, 3), 40 | BlazeBlock(96, 96, 3), 41 | BlazeBlock(96, 96, 3), 42 | ) 43 | 44 | self.backbone4 = nn.Sequential( 45 | BlazeBlock(96, 192, 3, 2), 46 | BlazeBlock(192, 192, 3), 47 | BlazeBlock(192, 192, 3), 48 | BlazeBlock(192, 192, 3), 49 | BlazeBlock(192, 192, 3), 50 | BlazeBlock(192, 192, 3), 51 | ) 52 | 53 | self.backbone5 = nn.Sequential( 54 | BlazeBlock(192, 288, 3, 2), 55 | BlazeBlock(288, 288, 3), 56 | BlazeBlock(288, 288, 3), 57 | BlazeBlock(288, 288, 3), 58 | BlazeBlock(288, 288, 3), 59 | BlazeBlock(288, 288, 3), 60 | BlazeBlock(288, 288, 3), 61 | ) 62 | 63 | self.up1 = nn.Sequential( 64 | nn.Conv2d(288, 288, 3, 1, 1, groups=288, bias=True), 65 | nn.Conv2d(288, 48, 1, bias=True), 66 | nn.ReLU(True), 67 | ) 68 | 69 | self.up2 = nn.Sequential( 70 | nn.Conv2d(192, 192, 3, 1, 1, groups=192, bias=True), 71 | nn.Conv2d(192, 48, 1, bias=True), 72 | nn.ReLU(True), 73 | ) 74 | 75 | self.up3 = nn.Sequential( 76 | nn.Conv2d(96, 96, 3, 1, 1, groups=96, bias=True), 77 | nn.Conv2d(96, 48, 1, bias=True), 78 | nn.ReLU(True), 79 | ) 80 | 81 | self.up4 = nn.Sequential( 82 | nn.Conv2d(48, 48, 3, 1, 1, groups=48, bias=True), 83 | nn.Conv2d(48, 48, 1, bias=True), 84 | nn.ReLU(True), 85 | ) 86 | 87 | self.block1 = nn.Sequential( 88 | BlazeBlock(48, 96, 3, 2), 89 | BlazeBlock(96, 96, 3), 90 | BlazeBlock(96, 96, 3), 91 | BlazeBlock(96, 96, 3), 92 | BlazeBlock(96, 96, 3), 93 | ) 94 | 95 | self.up5 = nn.Sequential( 96 | nn.Conv2d(96, 96, 3, 1, 1, groups=96, bias=True), 97 | nn.Conv2d(96, 96, 1, bias=True), 98 | nn.ReLU(True), 99 | ) 100 | 101 | self.block2 = nn.Sequential( 102 | BlazeBlock(96, 192, 3, 2), 103 | BlazeBlock(192, 192, 3), 104 | BlazeBlock(192, 192, 3), 105 | BlazeBlock(192, 192, 3), 106 | BlazeBlock(192, 192, 3), 107 | BlazeBlock(192, 192, 3), 108 | ) 109 | 110 | self.up6 = nn.Sequential( 111 | nn.Conv2d(192, 192, 3, 1, 1, groups=192, bias=True), 112 | nn.Conv2d(192, 192, 1, bias=True), 113 | nn.ReLU(True), 114 | ) 115 | 116 | self.block3 = nn.Sequential( 117 | BlazeBlock(192, 288, 3, 2), 118 | BlazeBlock(288, 288, 3), 119 | BlazeBlock(288, 288, 3), 120 | BlazeBlock(288, 288, 3), 121 | BlazeBlock(288, 288, 3), 122 | BlazeBlock(288, 288, 3), 123 | BlazeBlock(288, 288, 3), 124 | ) 125 | 126 | self.up7 = nn.Sequential( 127 | nn.Conv2d(288, 288, 3, 1, 1, groups=288, bias=True), 128 | nn.Conv2d(288, 288, 1, bias=True), 129 | nn.ReLU(True), 130 | ) 131 | 132 | self.block4 = nn.Sequential( 133 | BlazeBlock(288, 288, 3, 2), 134 | BlazeBlock(288, 288, 3), 135 | BlazeBlock(288, 288, 3), 136 | BlazeBlock(288, 288, 3), 137 | BlazeBlock(288, 288, 3), 138 | BlazeBlock(288, 288, 3), 139 | BlazeBlock(288, 288, 3), 140 | BlazeBlock(288, 288, 3), 141 | 142 | BlazeBlock(288, 288, 3, 2), 143 | BlazeBlock(288, 288, 3), 144 | BlazeBlock(288, 288, 3), 145 | BlazeBlock(288, 288, 3), 146 | BlazeBlock(288, 288, 3), 147 | BlazeBlock(288, 288, 3), 148 | BlazeBlock(288, 288, 3), 149 | ) 150 | 151 | self.up8 = nn.Sequential( 152 | nn.Conv2d(48, 48, 3, 1, 1, groups=48, bias=True), 153 | nn.Conv2d(48, 8, 1, bias=True), 154 | nn.ReLU(True), 155 | ) 156 | 157 | self.up9 = nn.Sequential( 158 | nn.Conv2d(24, 24, 3, 1, 1, groups=24, bias=True), 159 | nn.Conv2d(24, 8, 1, bias=True), 160 | nn.ReLU(True), 161 | ) 162 | 163 | self.block5 = BlazeBlock(288, 288, 3) 164 | 165 | self.block6 = nn.Sequential( 166 | nn.Conv2d(8, 8, 3, 1, 1, groups=8, bias=True), 167 | nn.Conv2d(8, 8, 1, bias=True), 168 | nn.ReLU(True), 169 | ) 170 | 171 | self.flag = nn.Conv2d(288, 1, 2, bias=True) 172 | self.segmentation = nn.Conv2d(8, 1, 3, padding=1, bias=True) 173 | self.landmarks = nn.Conv2d(288, 124, 2, bias=True) 174 | 175 | def forward(self, x): 176 | batch = x.shape[0] 177 | if batch == 0: 178 | return torch.zeros((0,)), torch.zeros((0,31,4)), torch.zeros((0, 128,128)) 179 | 180 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 181 | 182 | x = self.backbone1(x) 183 | y = self.backbone2(x) 184 | z = self.backbone3(y) 185 | w = self.backbone4(z) 186 | v = self.backbone5(w) 187 | 188 | w1 = self.up2(w) + F.interpolate(self.up1(v), scale_factor=2, mode='bilinear') 189 | z1 = self.up3(z) + F.interpolate(w1, scale_factor=2, mode='bilinear') 190 | y1 = self.up4(y) + F.interpolate(z1, scale_factor=2, mode='bilinear') 191 | 192 | seg = self.up9(x) + F.interpolate(self.up8(y1), scale_factor=2, mode='bilinear') 193 | seg = self.segmentation(self.block6(seg)).squeeze(1) 194 | 195 | out = self.block1(y1) + self.up5(z) 196 | out = self.block2(out) + self.up6(w) 197 | out = self.block3(out) + self.up7(v) 198 | out = self.block4(out) 199 | out = self.block5(out) 200 | flag = self.flag(out).view(-1).sigmoid() 201 | landmarks = self.landmarks(out).view(batch,31,4) / 256 202 | 203 | return flag, landmarks, seg -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import sys 5 | 6 | from blazebase import resize_pad, denormalize_detections 7 | from blazeface import BlazeFace 8 | from blazepalm import BlazePalm 9 | from blazeface_landmark import BlazeFaceLandmark 10 | from blazehand_landmark import BlazeHandLandmark 11 | 12 | from visualization import draw_detections, draw_landmarks, draw_roi, HAND_CONNECTIONS, FACE_CONNECTIONS 13 | 14 | gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | torch.set_grad_enabled(False) 16 | 17 | back_detector = True 18 | 19 | face_detector = BlazeFace(back_model=back_detector).to(gpu) 20 | if back_detector: 21 | face_detector.load_weights("blazefaceback.pth") 22 | face_detector.load_anchors("anchors_face_back.npy") 23 | else: 24 | face_detector.load_weights("blazeface.pth") 25 | face_detector.load_anchors("anchors_face.npy") 26 | 27 | palm_detector = BlazePalm().to(gpu) 28 | palm_detector.load_weights("blazepalm.pth") 29 | palm_detector.load_anchors("anchors_palm.npy") 30 | palm_detector.min_score_thresh = .75 31 | 32 | hand_regressor = BlazeHandLandmark().to(gpu) 33 | hand_regressor.load_weights("blazehand_landmark.pth") 34 | 35 | face_regressor = BlazeFaceLandmark().to(gpu) 36 | face_regressor.load_weights("blazeface_landmark.pth") 37 | 38 | 39 | WINDOW='test' 40 | cv2.namedWindow(WINDOW) 41 | if len(sys.argv) > 1: 42 | capture = cv2.VideoCapture(sys.argv[1]) 43 | mirror_img = False 44 | else: 45 | capture = cv2.VideoCapture(0) 46 | mirror_img = True 47 | 48 | if capture.isOpened(): 49 | hasFrame, frame = capture.read() 50 | frame_ct = 0 51 | else: 52 | hasFrame = False 53 | 54 | while hasFrame: 55 | frame_ct +=1 56 | 57 | if mirror_img: 58 | frame = np.ascontiguousarray(frame[:,::-1,::-1]) 59 | else: 60 | frame = np.ascontiguousarray(frame[:,:,::-1]) 61 | 62 | img1, img2, scale, pad = resize_pad(frame) 63 | 64 | if back_detector: 65 | normalized_face_detections = face_detector.predict_on_image(img1) 66 | else: 67 | normalized_face_detections = face_detector.predict_on_image(img2) 68 | normalized_palm_detections = palm_detector.predict_on_image(img1) 69 | 70 | face_detections = denormalize_detections(normalized_face_detections, scale, pad) 71 | palm_detections = denormalize_detections(normalized_palm_detections, scale, pad) 72 | 73 | 74 | xc, yc, scale, theta = face_detector.detection2roi(face_detections.cpu()) 75 | img, affine, box = face_regressor.extract_roi(frame, xc, yc, theta, scale) 76 | flags, normalized_landmarks = face_regressor(img.to(gpu)) 77 | landmarks = face_regressor.denormalize_landmarks(normalized_landmarks.cpu(), affine) 78 | 79 | 80 | xc, yc, scale, theta = palm_detector.detection2roi(palm_detections.cpu()) 81 | img, affine2, box2 = hand_regressor.extract_roi(frame, xc, yc, theta, scale) 82 | flags2, handed2, normalized_landmarks2 = hand_regressor(img.to(gpu)) 83 | landmarks2 = hand_regressor.denormalize_landmarks(normalized_landmarks2.cpu(), affine2) 84 | 85 | 86 | for i in range(len(flags)): 87 | landmark, flag = landmarks[i], flags[i] 88 | if flag>.5: 89 | draw_landmarks(frame, landmark[:,:2], FACE_CONNECTIONS, size=1) 90 | 91 | 92 | for i in range(len(flags2)): 93 | landmark, flag = landmarks2[i], flags2[i] 94 | if flag>.5: 95 | draw_landmarks(frame, landmark[:,:2], HAND_CONNECTIONS, size=2) 96 | 97 | draw_roi(frame, box) 98 | draw_roi(frame, box2) 99 | draw_detections(frame, face_detections) 100 | draw_detections(frame, palm_detections) 101 | 102 | cv2.imshow(WINDOW, frame[:,:,::-1]) 103 | # cv2.imwrite('sample/%04d.jpg'%frame_ct, frame[:,:,::-1]) 104 | 105 | hasFrame, frame = capture.read() 106 | key = cv2.waitKey(1) 107 | if key == 27: 108 | break 109 | 110 | capture.release() 111 | cv2.destroyAllWindows() 112 | -------------------------------------------------------------------------------- /demo1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | 5 | from blazebase import resize_pad, denormalize_detections 6 | from blazepose import BlazePose 7 | from blazepose_landmark import BlazePoseLandmark 8 | 9 | from visualization import draw_detections, draw_landmarks, draw_roi, POSE_CONNECTIONS 10 | 11 | gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | torch.set_grad_enabled(False) 13 | 14 | 15 | 16 | pose_detector = BlazePose().to(gpu) 17 | pose_detector.load_weights("blazepose.pth") 18 | pose_detector.load_anchors("anchors_pose.npy") 19 | 20 | pose_regressor = BlazePoseLandmark().to(gpu) 21 | pose_regressor.load_weights("blazepose_landmark.pth") 22 | 23 | 24 | WINDOW='test' 25 | cv2.namedWindow(WINDOW) 26 | capture = cv2.VideoCapture(0) 27 | 28 | if capture.isOpened(): 29 | hasFrame, frame = capture.read() 30 | frame_ct = 0 31 | else: 32 | hasFrame = False 33 | 34 | while hasFrame: 35 | frame_ct +=1 36 | 37 | frame = np.ascontiguousarray(frame[:,::-1,::-1]) 38 | 39 | img1, img2, scale, pad = resize_pad(frame) 40 | 41 | normalized_pose_detections = pose_detector.predict_on_image(img2) 42 | pose_detections = denormalize_detections(normalized_pose_detections, scale, pad) 43 | 44 | xc, yc, scale, theta = pose_detector.detection2roi(pose_detections) 45 | img, affine, box = pose_regressor.extract_roi(frame, xc, yc, theta, scale) 46 | flags, normalized_landmarks, mask = pose_regressor(img.to(gpu)) 47 | landmarks = pose_regressor.denormalize_landmarks(normalized_landmarks, affine) 48 | 49 | draw_detections(frame, pose_detections) 50 | draw_roi(frame, box) 51 | 52 | for i in range(len(flags)): 53 | landmark, flag = landmarks[i], flags[i] 54 | if flag>.5: 55 | draw_landmarks(frame, landmark, POSE_CONNECTIONS, size=2) 56 | 57 | cv2.imshow(WINDOW, frame[:,:,::-1]) 58 | # cv2.imwrite('sample/%04d.jpg'%frame_ct, frame[:,:,::-1]) 59 | 60 | hasFrame, frame = capture.read() 61 | key = cv2.waitKey(1) 62 | if key == 27: 63 | break 64 | 65 | capture.release() 66 | cv2.destroyAllWindows() 67 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | def draw_detections(img, detections, with_keypoints=True): 6 | if isinstance(detections, torch.Tensor): 7 | detections = detections.cpu().numpy() 8 | 9 | if detections.ndim == 1: 10 | detections = np.expand_dims(detections, axis=0) 11 | 12 | n_keypoints = detections.shape[1] // 2 - 2 13 | 14 | for i in range(detections.shape[0]): 15 | ymin = detections[i, 0] 16 | xmin = detections[i, 1] 17 | ymax = detections[i, 2] 18 | xmax = detections[i, 3] 19 | 20 | start_point = (int(xmin), int(ymin)) 21 | end_point = (int(xmax), int(ymax)) 22 | img = cv2.rectangle(img, start_point, end_point, (255, 0, 0), 1) 23 | 24 | if with_keypoints: 25 | for k in range(n_keypoints): 26 | kp_x = int(detections[i, 4 + k*2 ]) 27 | kp_y = int(detections[i, 4 + k*2 + 1]) 28 | cv2.circle(img, (kp_x, kp_y), 2, (0, 0, 255), thickness=2) 29 | return img 30 | 31 | 32 | def draw_roi(img, roi): 33 | for i in range(roi.shape[0]): 34 | (x1,x2,x3,x4), (y1,y2,y3,y4) = roi[i] 35 | cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,0), 2) 36 | cv2.line(img, (int(x1), int(y1)), (int(x3), int(y3)), (0,255,0), 2) 37 | cv2.line(img, (int(x2), int(y2)), (int(x4), int(y4)), (0,0,0), 2) 38 | cv2.line(img, (int(x3), int(y3)), (int(x4), int(y4)), (0,0,0), 2) 39 | 40 | 41 | def draw_landmarks(img, points, connections=[], color=(0, 255, 0), size=2): 42 | points = points[:,:2] 43 | for point in points: 44 | x, y = point 45 | x, y = int(x), int(y) 46 | cv2.circle(img, (x, y), size, color, thickness=size) 47 | for connection in connections: 48 | x0, y0 = points[connection[0]] 49 | x1, y1 = points[connection[1]] 50 | x0, y0 = int(x0), int(y0) 51 | x1, y1 = int(x1), int(y1) 52 | cv2.line(img, (x0, y0), (x1, y1), (0,0,0), size) 53 | 54 | 55 | 56 | # https://github.com/metalwhale/hand_tracking/blob/b2a650d61b4ab917a2367a05b85765b81c0564f2/run.py 57 | # 8 12 16 20 58 | # | | | | 59 | # 7 11 15 19 60 | # 4 | | | | 61 | # | 6 10 14 18 62 | # 3 | | | | 63 | # | 5---9---13--17 64 | # 2 \ / 65 | # \ \ / 66 | # 1 \ / 67 | # \ \ / 68 | # ------0- 69 | HAND_CONNECTIONS = [ 70 | (0, 1), (1, 2), (2, 3), (3, 4), 71 | (5, 6), (6, 7), (7, 8), 72 | (9, 10), (10, 11), (11, 12), 73 | (13, 14), (14, 15), (15, 16), 74 | (17, 18), (18, 19), (19, 20), 75 | (0, 5), (5, 9), (9, 13), (13, 17), (0, 17) 76 | ] 77 | 78 | POSE_CONNECTIONS = [ 79 | (0,1), (1,2), (2,3), (3,7), 80 | (0,4), (4,5), (5,6), (6,8), 81 | (9,10), 82 | (11,13), (13,15), (15,17), (17,19), (19,15), (15,21), 83 | (12,14), (14,16), (16,18), (18,20), (20,16), (16,22), 84 | (11,12), (12,24), (24,23), (23,11) 85 | ] 86 | 87 | # Vertex indices can be found in 88 | # github.com/google/mediapipe/modules/face_geometry/data/canonical_face_model_uv_visualisation.png 89 | # Found in github.com/google/mediapipe/python/solutions/face_mesh.py 90 | FACE_CONNECTIONS = [ 91 | # Lips. 92 | (61, 146), (146, 91), (91, 181), (181, 84), (84, 17), 93 | (17, 314), (314, 405), (405, 321), (321, 375), (375, 291), 94 | (61, 185), (185, 40), (40, 39), (39, 37), (37, 0), 95 | (0, 267), (267, 269), (269, 270), (270, 409), (409, 291), 96 | (78, 95), (95, 88), (88, 178), (178, 87), (87, 14), 97 | (14, 317), (317, 402), (402, 318), (318, 324), (324, 308), 98 | (78, 191), (191, 80), (80, 81), (81, 82), (82, 13), 99 | (13, 312), (312, 311), (311, 310), (310, 415), (415, 308), 100 | # Left eye. 101 | (263, 249), (249, 390), (390, 373), (373, 374), (374, 380), 102 | (380, 381), (381, 382), (382, 362), (263, 466), (466, 388), 103 | (388, 387), (387, 386), (386, 385), (385, 384), (384, 398), 104 | (398, 362), 105 | # Left eyebrow. 106 | (276, 283), (283, 282), (282, 295), (295, 285), (300, 293), 107 | (293, 334), (334, 296), (296, 336), 108 | # Right eye. 109 | (33, 7), (7, 163), (163, 144), (144, 145), (145, 153), 110 | (153, 154), (154, 155), (155, 133), (33, 246), (246, 161), 111 | (161, 160), (160, 159), (159, 158), (158, 157), (157, 173), 112 | (173, 133), 113 | # Right eyebrow. 114 | (46, 53), (53, 52), (52, 65), (65, 55), (70, 63), (63, 105), 115 | (105, 66), (66, 107), 116 | # Face oval. 117 | (10, 338), (338, 297), (297, 332), (332, 284), (284, 251), 118 | (251, 389), (389, 356), (356, 454), (454, 323), (323, 361), 119 | (361, 288), (288, 397), (397, 365), (365, 379), (379, 378), 120 | (378, 400), (400, 377), (377, 152), (152, 148), (148, 176), 121 | (176, 149), (149, 150), (150, 136), (136, 172), (172, 58), 122 | (58, 132), (132, 93), (93, 234), (234, 127), (127, 162), 123 | (162, 21), (21, 54), (54, 103), (103, 67), (67, 109), 124 | (109, 10) 125 | ] 126 | --------------------------------------------------------------------------------