├── DensePose ├── configs │ ├── Base-DensePose-RCNN-FPN.yaml │ └── densepose_rcnn_R_101_FPN_s1x.yaml └── densepose │ ├── __init__.py │ ├── config.py │ ├── densepose_head.py │ ├── roi_head.py │ └── structures.py ├── LICENSE ├── PointRend ├── configs │ ├── Base-RCNN-FPN.yaml │ └── InstanceSegmentation │ │ ├── Base-PointRend-RCNN-FPN.yaml │ │ └── pointrend_rcnn_R_50_FPN_3x_coco.yaml └── point_rend │ ├── __init__.py │ ├── coarse_mask_head.py │ ├── config.py │ ├── point_features.py │ ├── point_head.py │ └── roi_heads.py ├── README.md ├── additional └── README.md ├── checkpoints └── README.md ├── config.py ├── models ├── __init__.py ├── ief_module.py ├── regressor.py ├── resnet.py └── smpl_official.py ├── pipeline.png ├── predict ├── __init__.py ├── predict_3D.py ├── predict_densepose.py ├── predict_joints2D.py └── predict_silhouette_pointrend.py ├── renderers ├── __init__.py ├── nmr_renderer.py └── weak_perspective_pyrender_renderer.py ├── requirements.txt ├── run_predict.py └── utils ├── __init__.py ├── cam_utils.py ├── checkpoint_utils.py ├── eval_utils.py ├── image_utils.py ├── joints2d_utils.py ├── label_conversions.py ├── model_utils.py └── rigid_transform_utils.py /DensePose/configs/Base-DensePose-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | 22 | DENSEPOSE_ON: True 23 | ROI_HEADS: 24 | NAME: "DensePoseROIHeads" 25 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 26 | NUM_CLASSES: 1 27 | ROI_BOX_HEAD: 28 | NAME: "FastRCNNConvFCHead" 29 | NUM_FC: 2 30 | POOLER_RESOLUTION: 7 31 | POOLER_SAMPLING_RATIO: 2 32 | POOLER_TYPE: "ROIAlign" 33 | ROI_DENSEPOSE_HEAD: 34 | NAME: "DensePoseV1ConvXHead" 35 | POOLER_TYPE: "ROIAlign" 36 | NUM_COARSE_SEGM_CHANNELS: 15 37 | DATASETS: 38 | TRAIN: ("densepose_coco_2014_train", "densepose_coco_2014_valminusminival") 39 | TEST: ("densepose_coco_2014_minival",) 40 | SOLVER: 41 | IMS_PER_BATCH: 16 42 | BASE_LR: 0.002 43 | STEPS: (60000, 80000) 44 | MAX_ITER: 90000 45 | WARMUP_FACTOR: 0.1 46 | INPUT: 47 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 48 | -------------------------------------------------------------------------------- /DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-DensePose-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | RESNETS: 5 | DEPTH: 101 6 | SOLVER: 7 | MAX_ITER: 130000 8 | STEPS: (100000, 120000) 9 | -------------------------------------------------------------------------------- /DensePose/densepose/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # from . import dataset # just to register data 3 | from .config import add_densepose_config 4 | from .densepose_head import ROI_DENSEPOSE_HEAD_REGISTRY 5 | from .roi_head import DensePoseROIHeads 6 | from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData 7 | -------------------------------------------------------------------------------- /DensePose/densepose/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_densepose_config(cfg): 8 | """ 9 | Add config for densepose head. 10 | """ 11 | _C = cfg 12 | 13 | _C.MODEL.DENSEPOSE_ON = True 14 | 15 | _C.MODEL.ROI_DENSEPOSE_HEAD = CN() 16 | _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = "" 17 | _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8 18 | # Number of parts used for point labels 19 | _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24 20 | _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4 21 | _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512 22 | _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3 23 | _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2 24 | _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 56 25 | _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2" 26 | _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 14 27 | _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2 28 | _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 15 # 15 or 2 29 | # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD) 30 | _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7 31 | # Loss weights for annotation masks.(14 Parts) 32 | _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 2.0 33 | # Loss weights for surface parts. (24 Parts) 34 | _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 0.3 35 | # Loss weights for UV regression. 36 | _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.1 37 | # For DeepLab head 38 | _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN() 39 | _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN" 40 | _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0 41 | -------------------------------------------------------------------------------- /DensePose/densepose/densepose_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import Conv2d, ConvTranspose2d, interpolate 8 | from detectron2.structures.boxes import matched_boxlist_iou 9 | from detectron2.utils.registry import Registry 10 | 11 | from .structures import DensePoseOutput 12 | 13 | ROI_DENSEPOSE_HEAD_REGISTRY = Registry("ROI_DENSEPOSE_HEAD") 14 | 15 | 16 | def initialize_module_params(module): 17 | for name, param in module.named_parameters(): 18 | if "bias" in name: 19 | nn.init.constant_(param, 0) 20 | elif "weight" in name: 21 | nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") 22 | 23 | 24 | @ROI_DENSEPOSE_HEAD_REGISTRY.register() 25 | class DensePoseDeepLabHead(nn.Module): 26 | def __init__(self, cfg, input_channels): 27 | super(DensePoseDeepLabHead, self).__init__() 28 | # fmt: off 29 | hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM 30 | kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL 31 | norm = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM 32 | self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS 33 | self.use_nonlocal = cfg.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON 34 | # fmt: on 35 | pad_size = kernel_size // 2 36 | n_channels = input_channels 37 | 38 | self.ASPP = ASPP(input_channels, [6, 12, 56], n_channels) # 6, 12, 56 39 | self.add_module("ASPP", self.ASPP) 40 | 41 | if self.use_nonlocal: 42 | self.NLBlock = NONLocalBlock2D(input_channels, bn_layer=True) 43 | self.add_module("NLBlock", self.NLBlock) 44 | # weight_init.c2_msra_fill(self.ASPP) 45 | 46 | for i in range(self.n_stacked_convs): 47 | norm_module = nn.GroupNorm(32, hidden_dim) if norm == "GN" else None 48 | layer = Conv2d( 49 | n_channels, 50 | hidden_dim, 51 | kernel_size, 52 | stride=1, 53 | padding=pad_size, 54 | bias=not norm, 55 | norm=norm_module, 56 | ) 57 | weight_init.c2_msra_fill(layer) 58 | n_channels = hidden_dim 59 | layer_name = self._get_layer_name(i) 60 | self.add_module(layer_name, layer) 61 | self.n_out_channels = hidden_dim 62 | # initialize_module_params(self) 63 | 64 | def forward(self, features): 65 | x0 = features 66 | x = self.ASPP(x0) 67 | if self.use_nonlocal: 68 | x = self.NLBlock(x) 69 | output = x 70 | for i in range(self.n_stacked_convs): 71 | layer_name = self._get_layer_name(i) 72 | x = getattr(self, layer_name)(x) 73 | x = F.relu(x) 74 | output = x 75 | return output 76 | 77 | def _get_layer_name(self, i): 78 | layer_name = "body_conv_fcn{}".format(i + 1) 79 | return layer_name 80 | 81 | 82 | # Copied from 83 | # https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py 84 | # See https://arxiv.org/pdf/1706.05587.pdf for details 85 | class ASPPConv(nn.Sequential): 86 | def __init__(self, in_channels, out_channels, dilation): 87 | modules = [ 88 | nn.Conv2d( 89 | in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False 90 | ), 91 | nn.GroupNorm(32, out_channels), 92 | nn.ReLU(), 93 | ] 94 | super(ASPPConv, self).__init__(*modules) 95 | 96 | 97 | class ASPPPooling(nn.Sequential): 98 | def __init__(self, in_channels, out_channels): 99 | super(ASPPPooling, self).__init__( 100 | nn.AdaptiveAvgPool2d(1), 101 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 102 | nn.GroupNorm(32, out_channels), 103 | nn.ReLU(), 104 | ) 105 | 106 | def forward(self, x): 107 | size = x.shape[-2:] 108 | x = super(ASPPPooling, self).forward(x) 109 | return F.interpolate(x, size=size, mode="bilinear", align_corners=False) 110 | 111 | 112 | class ASPP(nn.Module): 113 | def __init__(self, in_channels, atrous_rates, out_channels): 114 | super(ASPP, self).__init__() 115 | modules = [] 116 | modules.append( 117 | nn.Sequential( 118 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 119 | nn.GroupNorm(32, out_channels), 120 | nn.ReLU(), 121 | ) 122 | ) 123 | 124 | rate1, rate2, rate3 = tuple(atrous_rates) 125 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 126 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 127 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 128 | modules.append(ASPPPooling(in_channels, out_channels)) 129 | 130 | self.convs = nn.ModuleList(modules) 131 | 132 | self.project = nn.Sequential( 133 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 134 | # nn.BatchNorm2d(out_channels), 135 | nn.ReLU() 136 | # nn.Dropout(0.5) 137 | ) 138 | 139 | def forward(self, x): 140 | res = [] 141 | for conv in self.convs: 142 | res.append(conv(x)) 143 | res = torch.cat(res, dim=1) 144 | return self.project(res) 145 | 146 | 147 | # copied from 148 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_embedded_gaussian.py 149 | # See https://arxiv.org/abs/1711.07971 for details 150 | class _NonLocalBlockND(nn.Module): 151 | def __init__( 152 | self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True 153 | ): 154 | super(_NonLocalBlockND, self).__init__() 155 | 156 | assert dimension in [1, 2, 3] 157 | 158 | self.dimension = dimension 159 | self.sub_sample = sub_sample 160 | 161 | self.in_channels = in_channels 162 | self.inter_channels = inter_channels 163 | 164 | if self.inter_channels is None: 165 | self.inter_channels = in_channels // 2 166 | if self.inter_channels == 0: 167 | self.inter_channels = 1 168 | 169 | if dimension == 3: 170 | conv_nd = nn.Conv3d 171 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 172 | bn = nn.GroupNorm # (32, hidden_dim) #nn.BatchNorm3d 173 | elif dimension == 2: 174 | conv_nd = nn.Conv2d 175 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 176 | bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm2d 177 | else: 178 | conv_nd = nn.Conv1d 179 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 180 | bn = nn.GroupNorm # (32, hidden_dim)nn.BatchNorm1d 181 | 182 | self.g = conv_nd( 183 | in_channels=self.in_channels, 184 | out_channels=self.inter_channels, 185 | kernel_size=1, 186 | stride=1, 187 | padding=0, 188 | ) 189 | 190 | if bn_layer: 191 | self.W = nn.Sequential( 192 | conv_nd( 193 | in_channels=self.inter_channels, 194 | out_channels=self.in_channels, 195 | kernel_size=1, 196 | stride=1, 197 | padding=0, 198 | ), 199 | bn(32, self.in_channels), 200 | ) 201 | nn.init.constant_(self.W[1].weight, 0) 202 | nn.init.constant_(self.W[1].bias, 0) 203 | else: 204 | self.W = conv_nd( 205 | in_channels=self.inter_channels, 206 | out_channels=self.in_channels, 207 | kernel_size=1, 208 | stride=1, 209 | padding=0, 210 | ) 211 | nn.init.constant_(self.W.weight, 0) 212 | nn.init.constant_(self.W.bias, 0) 213 | 214 | self.theta = conv_nd( 215 | in_channels=self.in_channels, 216 | out_channels=self.inter_channels, 217 | kernel_size=1, 218 | stride=1, 219 | padding=0, 220 | ) 221 | self.phi = conv_nd( 222 | in_channels=self.in_channels, 223 | out_channels=self.inter_channels, 224 | kernel_size=1, 225 | stride=1, 226 | padding=0, 227 | ) 228 | 229 | if sub_sample: 230 | self.g = nn.Sequential(self.g, max_pool_layer) 231 | self.phi = nn.Sequential(self.phi, max_pool_layer) 232 | 233 | def forward(self, x): 234 | """ 235 | :param x: (b, c, t, h, w) 236 | :return: 237 | """ 238 | 239 | batch_size = x.size(0) 240 | 241 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 242 | g_x = g_x.permute(0, 2, 1) 243 | 244 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 245 | theta_x = theta_x.permute(0, 2, 1) 246 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 247 | f = torch.matmul(theta_x, phi_x) 248 | f_div_C = F.softmax(f, dim=-1) 249 | 250 | y = torch.matmul(f_div_C, g_x) 251 | y = y.permute(0, 2, 1).contiguous() 252 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 253 | W_y = self.W(y) 254 | z = W_y + x 255 | 256 | return z 257 | 258 | 259 | class NONLocalBlock2D(_NonLocalBlockND): 260 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 261 | super(NONLocalBlock2D, self).__init__( 262 | in_channels, 263 | inter_channels=inter_channels, 264 | dimension=2, 265 | sub_sample=sub_sample, 266 | bn_layer=bn_layer, 267 | ) 268 | 269 | 270 | @ROI_DENSEPOSE_HEAD_REGISTRY.register() 271 | class DensePoseV1ConvXHead(nn.Module): 272 | def __init__(self, cfg, input_channels): 273 | super(DensePoseV1ConvXHead, self).__init__() 274 | # fmt: off 275 | hidden_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM 276 | kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL 277 | self.n_stacked_convs = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS 278 | # fmt: on 279 | pad_size = kernel_size // 2 280 | n_channels = input_channels 281 | for i in range(self.n_stacked_convs): 282 | layer = Conv2d(n_channels, hidden_dim, kernel_size, stride=1, padding=pad_size) 283 | layer_name = self._get_layer_name(i) 284 | self.add_module(layer_name, layer) 285 | n_channels = hidden_dim 286 | self.n_out_channels = n_channels 287 | initialize_module_params(self) 288 | 289 | def forward(self, features): 290 | x = features 291 | output = x 292 | for i in range(self.n_stacked_convs): 293 | layer_name = self._get_layer_name(i) 294 | x = getattr(self, layer_name)(x) 295 | x = F.relu(x) 296 | output = x 297 | return output 298 | 299 | def _get_layer_name(self, i): 300 | layer_name = "body_conv_fcn{}".format(i + 1) 301 | return layer_name 302 | 303 | 304 | class DensePosePredictor(nn.Module): 305 | def __init__(self, cfg, input_channels): 306 | 307 | super(DensePosePredictor, self).__init__() 308 | dim_in = input_channels 309 | n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS 310 | dim_out_patches = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1 311 | kernel_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL 312 | self.ann_index_lowres = ConvTranspose2d( 313 | dim_in, n_segm_chan, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) 314 | ) 315 | self.index_uv_lowres = ConvTranspose2d( 316 | dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) 317 | ) 318 | self.u_lowres = ConvTranspose2d( 319 | dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) 320 | ) 321 | self.v_lowres = ConvTranspose2d( 322 | dim_in, dim_out_patches, kernel_size, stride=2, padding=int(kernel_size / 2 - 1) 323 | ) 324 | self.scale_factor = cfg.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE 325 | initialize_module_params(self) 326 | 327 | def forward(self, head_outputs): 328 | ann_index_lowres = self.ann_index_lowres(head_outputs) 329 | index_uv_lowres = self.index_uv_lowres(head_outputs) 330 | u_lowres = self.u_lowres(head_outputs) 331 | v_lowres = self.v_lowres(head_outputs) 332 | 333 | def interp2d(input): 334 | return interpolate( 335 | input, scale_factor=self.scale_factor, mode="bilinear", align_corners=False 336 | ) 337 | 338 | ann_index = interp2d(ann_index_lowres) 339 | index_uv = interp2d(index_uv_lowres) 340 | u = interp2d(u_lowres) 341 | v = interp2d(v_lowres) 342 | return ( 343 | (ann_index, index_uv, u, v), 344 | (ann_index_lowres, index_uv_lowres, u_lowres, v_lowres), 345 | ) 346 | 347 | 348 | class DensePoseDataFilter(object): 349 | def __init__(self, cfg): 350 | self.iou_threshold = cfg.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD 351 | 352 | @torch.no_grad() 353 | def __call__(self, proposals_with_targets): 354 | """ 355 | Filters proposals with targets to keep only the ones relevant for 356 | DensePose training 357 | proposals: list(Instances), each element of the list corresponds to 358 | various instances (proposals, GT for boxes and densepose) for one 359 | image 360 | """ 361 | proposals_filtered = [] 362 | for proposals_per_image in proposals_with_targets: 363 | if not hasattr(proposals_per_image, "gt_densepose"): 364 | continue 365 | assert hasattr(proposals_per_image, "gt_boxes") 366 | assert hasattr(proposals_per_image, "proposal_boxes") 367 | gt_boxes = proposals_per_image.gt_boxes 368 | est_boxes = proposals_per_image.proposal_boxes 369 | # apply match threshold for densepose head 370 | iou = matched_boxlist_iou(gt_boxes, est_boxes) 371 | iou_select = iou > self.iou_threshold 372 | proposals_per_image = proposals_per_image[iou_select] 373 | assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes) 374 | # filter out any target without densepose annotation 375 | gt_densepose = proposals_per_image.gt_densepose 376 | assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.gt_densepose) 377 | selected_indices = [ 378 | i for i, dp_target in enumerate(gt_densepose) if dp_target is not None 379 | ] 380 | if len(selected_indices) != len(gt_densepose): 381 | proposals_per_image = proposals_per_image[selected_indices] 382 | assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes) 383 | assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.gt_densepose) 384 | proposals_filtered.append(proposals_per_image) 385 | return proposals_filtered 386 | 387 | 388 | def build_densepose_head(cfg, input_channels): 389 | head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME 390 | return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels) 391 | 392 | 393 | def build_densepose_predictor(cfg, input_channels): 394 | predictor = DensePosePredictor(cfg, input_channels) 395 | return predictor 396 | 397 | 398 | def build_densepose_data_filter(cfg): 399 | dp_filter = DensePoseDataFilter(cfg) 400 | return dp_filter 401 | 402 | 403 | def densepose_inference(densepose_outputs, detections): 404 | """ 405 | Infer dense pose estimate based on outputs from the DensePose head 406 | and detections. The estimate for each detection instance is stored in its 407 | "pred_densepose" attribute. 408 | 409 | Args: 410 | densepose_outputs (tuple(`torch.Tensor`)): iterable containing 4 elements: 411 | - s (:obj: `torch.Tensor`): segmentation tensor of size (N, A, H, W), 412 | - i (:obj: `torch.Tensor`): classification tensor of size (N, C, H, W), 413 | - u (:obj: `torch.Tensor`): U coordinates for each class of size (N, C, H, W), 414 | - v (:obj: `torch.Tensor`): V coordinates for each class of size (N, C, H, W), 415 | where N is the total number of detections in a batch, 416 | A is the number of segmentations classes (e.g. 15 for coarse body parts), 417 | C is the number of labels (e.g. 25 for fine body parts), 418 | W is the resolution along the X axis 419 | H is the resolution along the Y axis 420 | detections (list[Instances]): A list of N Instances, where N is the number of images 421 | in the batch. Instances are modified by this method: "pred_densepose" attribute 422 | is added to each instance, the attribute contains the corresponding 423 | DensePoseOutput object. 424 | """ 425 | 426 | # DensePose outputs: segmentation, body part indices, U, V 427 | s, index_uv, u, v = densepose_outputs 428 | k = 0 429 | for detection in detections: 430 | n_i = len(detection) 431 | s_i = s[k : k + n_i] 432 | index_uv_i = index_uv[k : k + n_i] 433 | u_i = u[k : k + n_i] 434 | v_i = v[k : k + n_i] 435 | densepose_output_i = DensePoseOutput(s_i, index_uv_i, u_i, v_i) 436 | detection.pred_densepose = densepose_output_i 437 | k += n_i 438 | 439 | 440 | def _linear_interpolation_utilities(v_norm, v0_src, size_src, v0_dst, size_dst, size_z): 441 | """ 442 | Computes utility values for linear interpolation at points v. 443 | The points are given as normalized offsets in the source interval 444 | (v0_src, v0_src + size_src), more precisely: 445 | v = v0_src + v_norm * size_src / 256.0 446 | The computed utilities include lower points v_lo, upper points v_hi, 447 | interpolation weights v_w and flags j_valid indicating whether the 448 | points falls into the destination interval (v0_dst, v0_dst + size_dst). 449 | 450 | Args: 451 | v_norm (:obj: `torch.Tensor`): tensor of size N containing 452 | normalized point offsets 453 | v0_src (:obj: `torch.Tensor`): tensor of size N containing 454 | left bounds of source intervals for normalized points 455 | size_src (:obj: `torch.Tensor`): tensor of size N containing 456 | source interval sizes for normalized points 457 | v0_dst (:obj: `torch.Tensor`): tensor of size N containing 458 | left bounds of destination intervals 459 | size_dst (:obj: `torch.Tensor`): tensor of size N containing 460 | destination interval sizes 461 | size_z (int): interval size for data to be interpolated 462 | 463 | Returns: 464 | v_lo (:obj: `torch.Tensor`): int tensor of size N containing 465 | indices of lower values used for interpolation, all values are 466 | integers from [0, size_z - 1] 467 | v_hi (:obj: `torch.Tensor`): int tensor of size N containing 468 | indices of upper values used for interpolation, all values are 469 | integers from [0, size_z - 1] 470 | v_w (:obj: `torch.Tensor`): float tensor of size N containing 471 | interpolation weights 472 | j_valid (:obj: `torch.Tensor`): uint8 tensor of size N containing 473 | 0 for points outside the estimation interval 474 | (v0_est, v0_est + size_est) and 1 otherwise 475 | """ 476 | v = v0_src + v_norm * size_src / 256.0 477 | j_valid = (v - v0_dst >= 0) * (v - v0_dst < size_dst) 478 | v_grid = (v - v0_dst) * size_z / size_dst 479 | v_lo = v_grid.floor().long().clamp(min=0, max=size_z - 1) 480 | v_hi = (v_lo + 1).clamp(max=size_z - 1) 481 | v_grid = torch.min(v_hi.float(), v_grid) 482 | v_w = v_grid - v_lo.float() 483 | return v_lo, v_hi, v_w, j_valid 484 | 485 | 486 | def _grid_sampling_utilities( 487 | zh, zw, bbox_xywh_est, bbox_xywh_gt, index_gt, x_norm, y_norm, index_bbox 488 | ): 489 | """ 490 | Prepare tensors used in grid sampling. 491 | 492 | Args: 493 | z_est (:obj: `torch.Tensor`): tensor of size (N,C,H,W) with estimated 494 | values of Z to be extracted for the points X, Y and channel 495 | indices I 496 | bbox_xywh_est (:obj: `torch.Tensor`): tensor of size (N, 4) containing 497 | estimated bounding boxes in format XYWH 498 | bbox_xywh_gt (:obj: `torch.Tensor`): tensor of size (N, 4) containing 499 | matched ground truth bounding boxes in format XYWH 500 | index_gt (:obj: `torch.Tensor`): tensor of size K with point labels for 501 | ground truth points 502 | x_norm (:obj: `torch.Tensor`): tensor of size K with X normalized 503 | coordinates of ground truth points. Image X coordinates can be 504 | obtained as X = Xbbox + x_norm * Wbbox / 255 505 | y_norm (:obj: `torch.Tensor`): tensor of size K with Y normalized 506 | coordinates of ground truth points. Image Y coordinates can be 507 | obtained as Y = Ybbox + y_norm * Hbbox / 255 508 | index_bbox (:obj: `torch.Tensor`): tensor of size K with bounding box 509 | indices for each ground truth point. The values are thus in 510 | [0, N-1] 511 | 512 | Returns: 513 | j_valid (:obj: `torch.Tensor`): uint8 tensor of size M containing 514 | 0 for points to be discarded and 1 for points to be selected 515 | y_lo (:obj: `torch.Tensor`): int tensor of indices of upper values 516 | in z_est for each point 517 | y_hi (:obj: `torch.Tensor`): int tensor of indices of lower values 518 | in z_est for each point 519 | x_lo (:obj: `torch.Tensor`): int tensor of indices of left values 520 | in z_est for each point 521 | x_hi (:obj: `torch.Tensor`): int tensor of indices of right values 522 | in z_est for each point 523 | w_ylo_xlo (:obj: `torch.Tensor`): float tensor of size M; 524 | contains upper-left value weight for each point 525 | w_ylo_xhi (:obj: `torch.Tensor`): float tensor of size M; 526 | contains upper-right value weight for each point 527 | w_yhi_xlo (:obj: `torch.Tensor`): float tensor of size M; 528 | contains lower-left value weight for each point 529 | w_yhi_xhi (:obj: `torch.Tensor`): float tensor of size M; 530 | contains lower-right value weight for each point 531 | """ 532 | 533 | x0_gt, y0_gt, w_gt, h_gt = bbox_xywh_gt[index_bbox].unbind(dim=1) 534 | x0_est, y0_est, w_est, h_est = bbox_xywh_est[index_bbox].unbind(dim=1) 535 | x_lo, x_hi, x_w, jx_valid = _linear_interpolation_utilities( 536 | x_norm, x0_gt, w_gt, x0_est, w_est, zw 537 | ) 538 | y_lo, y_hi, y_w, jy_valid = _linear_interpolation_utilities( 539 | y_norm, y0_gt, h_gt, y0_est, h_est, zh 540 | ) 541 | j_valid = jx_valid * jy_valid 542 | 543 | w_ylo_xlo = (1.0 - x_w) * (1.0 - y_w) 544 | w_ylo_xhi = x_w * (1.0 - y_w) 545 | w_yhi_xlo = (1.0 - x_w) * y_w 546 | w_yhi_xhi = x_w * y_w 547 | 548 | return j_valid, y_lo, y_hi, x_lo, x_hi, w_ylo_xlo, w_ylo_xhi, w_yhi_xlo, w_yhi_xhi 549 | 550 | 551 | def _extract_at_points_packed( 552 | z_est, 553 | index_bbox_valid, 554 | slice_index_uv, 555 | y_lo, 556 | y_hi, 557 | x_lo, 558 | x_hi, 559 | w_ylo_xlo, 560 | w_ylo_xhi, 561 | w_yhi_xlo, 562 | w_yhi_xhi, 563 | ): 564 | """ 565 | Extract ground truth values z_gt for valid point indices and estimated 566 | values z_est using bilinear interpolation over top-left (y_lo, x_lo), 567 | top-right (y_lo, x_hi), bottom-left (y_hi, x_lo) and bottom-right 568 | (y_hi, x_hi) values in z_est with corresponding weights: 569 | w_ylo_xlo, w_ylo_xhi, w_yhi_xlo and w_yhi_xhi. 570 | Use slice_index_uv to slice dim=1 in z_est 571 | """ 572 | z_est_sampled = ( 573 | z_est[index_bbox_valid, slice_index_uv, y_lo, x_lo] * w_ylo_xlo 574 | + z_est[index_bbox_valid, slice_index_uv, y_lo, x_hi] * w_ylo_xhi 575 | + z_est[index_bbox_valid, slice_index_uv, y_hi, x_lo] * w_yhi_xlo 576 | + z_est[index_bbox_valid, slice_index_uv, y_hi, x_hi] * w_yhi_xhi 577 | ) 578 | return z_est_sampled 579 | 580 | 581 | def _resample_data( 582 | z, bbox_xywh_src, bbox_xywh_dst, wout, hout, mode="nearest", padding_mode="zeros" 583 | ): 584 | """ 585 | Args: 586 | z (:obj: `torch.Tensor`): tensor of size (N,C,H,W) with data to be 587 | resampled 588 | bbox_xywh_src (:obj: `torch.Tensor`): tensor of size (N,4) containing 589 | source bounding boxes in format XYWH 590 | bbox_xywh_dst (:obj: `torch.Tensor`): tensor of size (N,4) containing 591 | destination bounding boxes in format XYWH 592 | Return: 593 | zresampled (:obj: `torch.Tensor`): tensor of size (N, C, Hout, Wout) 594 | with resampled values of z, where D is the discretization size 595 | """ 596 | n = bbox_xywh_src.size(0) 597 | assert n == bbox_xywh_dst.size(0), ( 598 | "The number of " 599 | "source ROIs for resampling ({}) should be equal to the number " 600 | "of destination ROIs ({})".format(bbox_xywh_src.size(0), bbox_xywh_dst.size(0)) 601 | ) 602 | x0src, y0src, wsrc, hsrc = bbox_xywh_src.unbind(dim=1) 603 | x0dst, y0dst, wdst, hdst = bbox_xywh_dst.unbind(dim=1) 604 | x0dst_norm = 2 * (x0dst - x0src) / wsrc - 1 605 | y0dst_norm = 2 * (y0dst - y0src) / hsrc - 1 606 | x1dst_norm = 2 * (x0dst + wdst - x0src) / wsrc - 1 607 | y1dst_norm = 2 * (y0dst + hdst - y0src) / hsrc - 1 608 | grid_w = torch.arange(wout, device=z.device, dtype=torch.float) / wout 609 | grid_h = torch.arange(hout, device=z.device, dtype=torch.float) / hout 610 | grid_w_expanded = grid_w[None, None, :].expand(n, hout, wout) 611 | grid_h_expanded = grid_h[None, :, None].expand(n, hout, wout) 612 | dx_expanded = (x1dst_norm - x0dst_norm)[:, None, None].expand(n, hout, wout) 613 | dy_expanded = (y1dst_norm - y0dst_norm)[:, None, None].expand(n, hout, wout) 614 | x0_expanded = x0dst_norm[:, None, None].expand(n, hout, wout) 615 | y0_expanded = y0dst_norm[:, None, None].expand(n, hout, wout) 616 | grid_x = grid_w_expanded * dx_expanded + x0_expanded 617 | grid_y = grid_h_expanded * dy_expanded + y0_expanded 618 | grid = torch.stack((grid_x, grid_y), dim=3) 619 | # resample Z from (N, C, H, W) into (N, C, Hout, Wout) 620 | zresampled = F.grid_sample(z, grid, mode=mode, padding_mode=padding_mode, align_corners=True) 621 | return zresampled 622 | 623 | 624 | def _extract_single_tensors_from_matches_one_image( 625 | proposals_targets, bbox_with_dp_offset, bbox_global_offset 626 | ): 627 | i_gt_all = [] 628 | x_norm_all = [] 629 | y_norm_all = [] 630 | u_gt_all = [] 631 | v_gt_all = [] 632 | s_gt_all = [] 633 | bbox_xywh_gt_all = [] 634 | bbox_xywh_est_all = [] 635 | # Ibbox_all == k should be true for all data that corresponds 636 | # to bbox_xywh_gt[k] and bbox_xywh_est[k] 637 | # index k here is global wrt images 638 | i_bbox_all = [] 639 | # at offset k (k is global) contains index of bounding box data 640 | # within densepose output tensor 641 | i_with_dp = [] 642 | 643 | boxes_xywh_est = proposals_targets.proposal_boxes.clone() 644 | boxes_xywh_gt = proposals_targets.gt_boxes.clone() 645 | n_i = len(boxes_xywh_est) 646 | assert n_i == len(boxes_xywh_gt) 647 | 648 | if n_i: 649 | boxes_xywh_est.tensor[:, 2] -= boxes_xywh_est.tensor[:, 0] 650 | boxes_xywh_est.tensor[:, 3] -= boxes_xywh_est.tensor[:, 1] 651 | boxes_xywh_gt.tensor[:, 2] -= boxes_xywh_gt.tensor[:, 0] 652 | boxes_xywh_gt.tensor[:, 3] -= boxes_xywh_gt.tensor[:, 1] 653 | if hasattr(proposals_targets, "gt_densepose"): 654 | densepose_gt = proposals_targets.gt_densepose 655 | for k, box_xywh_est, box_xywh_gt, dp_gt in zip( 656 | range(n_i), boxes_xywh_est.tensor, boxes_xywh_gt.tensor, densepose_gt 657 | ): 658 | if (dp_gt is not None) and (len(dp_gt.x) > 0): 659 | i_gt_all.append(dp_gt.i) 660 | x_norm_all.append(dp_gt.x) 661 | y_norm_all.append(dp_gt.y) 662 | u_gt_all.append(dp_gt.u) 663 | v_gt_all.append(dp_gt.v) 664 | s_gt_all.append(dp_gt.segm.unsqueeze(0)) 665 | bbox_xywh_gt_all.append(box_xywh_gt.view(-1, 4)) 666 | bbox_xywh_est_all.append(box_xywh_est.view(-1, 4)) 667 | i_bbox_k = torch.full_like(dp_gt.i, bbox_with_dp_offset + len(i_with_dp)) 668 | i_bbox_all.append(i_bbox_k) 669 | i_with_dp.append(bbox_global_offset + k) 670 | return ( 671 | i_gt_all, 672 | x_norm_all, 673 | y_norm_all, 674 | u_gt_all, 675 | v_gt_all, 676 | s_gt_all, 677 | bbox_xywh_gt_all, 678 | bbox_xywh_est_all, 679 | i_bbox_all, 680 | i_with_dp, 681 | ) 682 | 683 | 684 | def _extract_single_tensors_from_matches(proposals_with_targets): 685 | i_img = [] 686 | i_gt_all = [] 687 | x_norm_all = [] 688 | y_norm_all = [] 689 | u_gt_all = [] 690 | v_gt_all = [] 691 | s_gt_all = [] 692 | bbox_xywh_gt_all = [] 693 | bbox_xywh_est_all = [] 694 | i_bbox_all = [] 695 | i_with_dp_all = [] 696 | n = 0 697 | for i, proposals_targets_per_image in enumerate(proposals_with_targets): 698 | n_i = proposals_targets_per_image.proposal_boxes.tensor.size(0) 699 | if not n_i: 700 | continue 701 | i_gt_img, x_norm_img, y_norm_img, u_gt_img, v_gt_img, s_gt_img, bbox_xywh_gt_img, bbox_xywh_est_img, i_bbox_img, i_with_dp_img = _extract_single_tensors_from_matches_one_image( # noqa 702 | proposals_targets_per_image, len(i_with_dp_all), n 703 | ) 704 | i_gt_all.extend(i_gt_img) 705 | x_norm_all.extend(x_norm_img) 706 | y_norm_all.extend(y_norm_img) 707 | u_gt_all.extend(u_gt_img) 708 | v_gt_all.extend(v_gt_img) 709 | s_gt_all.extend(s_gt_img) 710 | bbox_xywh_gt_all.extend(bbox_xywh_gt_img) 711 | bbox_xywh_est_all.extend(bbox_xywh_est_img) 712 | i_bbox_all.extend(i_bbox_img) 713 | i_with_dp_all.extend(i_with_dp_img) 714 | i_img.extend([i] * len(i_with_dp_img)) 715 | n += n_i 716 | # concatenate all data into a single tensor 717 | if (n > 0) and (len(i_with_dp_all) > 0): 718 | i_gt = torch.cat(i_gt_all, 0).long() 719 | x_norm = torch.cat(x_norm_all, 0) 720 | y_norm = torch.cat(y_norm_all, 0) 721 | u_gt = torch.cat(u_gt_all, 0) 722 | v_gt = torch.cat(v_gt_all, 0) 723 | s_gt = torch.cat(s_gt_all, 0) 724 | bbox_xywh_gt = torch.cat(bbox_xywh_gt_all, 0) 725 | bbox_xywh_est = torch.cat(bbox_xywh_est_all, 0) 726 | i_bbox = torch.cat(i_bbox_all, 0).long() 727 | else: 728 | i_gt = None 729 | x_norm = None 730 | y_norm = None 731 | u_gt = None 732 | v_gt = None 733 | s_gt = None 734 | bbox_xywh_gt = None 735 | bbox_xywh_est = None 736 | i_bbox = None 737 | return ( 738 | i_img, 739 | i_with_dp_all, 740 | bbox_xywh_est, 741 | bbox_xywh_gt, 742 | i_gt, 743 | x_norm, 744 | y_norm, 745 | u_gt, 746 | v_gt, 747 | s_gt, 748 | i_bbox, 749 | ) 750 | 751 | 752 | class DensePoseLosses(object): 753 | def __init__(self, cfg): 754 | # fmt: off 755 | self.heatmap_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE 756 | self.w_points = cfg.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS 757 | self.w_part = cfg.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS 758 | self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS 759 | self.n_segm_chan = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS 760 | # fmt: on 761 | 762 | def __call__(self, proposals_with_gt, densepose_outputs): 763 | losses = {} 764 | # densepose outputs are computed for all images and all bounding boxes; 765 | # i.e. if a batch has 4 images with (3, 1, 2, 1) proposals respectively, 766 | # the outputs will have size(0) == 3+1+2+1 == 7 767 | s, index_uv, u, v = densepose_outputs 768 | assert u.size(2) == v.size(2) 769 | assert u.size(3) == v.size(3) 770 | assert u.size(2) == index_uv.size(2) 771 | assert u.size(3) == index_uv.size(3) 772 | 773 | with torch.no_grad(): 774 | index_uv_img, i_with_dp, bbox_xywh_est, bbox_xywh_gt, index_gt_all, x_norm, y_norm, u_gt_all, v_gt_all, s_gt, index_bbox = _extract_single_tensors_from_matches( # noqa 775 | proposals_with_gt 776 | ) 777 | n_batch = len(i_with_dp) 778 | 779 | # NOTE: we need to keep the same computation graph on all the GPUs to 780 | # perform reduction properly. Hence even if we have no data on one 781 | # of the GPUs, we still need to generate the computation graph. 782 | # Add fake (zero) loss in the form Tensor.sum() * 0 783 | if not n_batch: 784 | losses["loss_densepose_U"] = u.sum() * 0 785 | losses["loss_densepose_V"] = v.sum() * 0 786 | losses["loss_densepose_I"] = index_uv.sum() * 0 787 | losses["loss_densepose_S"] = s.sum() * 0 788 | return losses 789 | 790 | zh = u.size(2) 791 | zw = u.size(3) 792 | 793 | j_valid, y_lo, y_hi, x_lo, x_hi, w_ylo_xlo, w_ylo_xhi, w_yhi_xlo, w_yhi_xhi = _grid_sampling_utilities( # noqa 794 | zh, zw, bbox_xywh_est, bbox_xywh_gt, index_gt_all, x_norm, y_norm, index_bbox 795 | ) 796 | 797 | j_valid_fg = j_valid * (index_gt_all > 0) 798 | 799 | u_gt = u_gt_all[j_valid_fg] 800 | u_est_all = _extract_at_points_packed( 801 | u[i_with_dp], 802 | index_bbox, 803 | index_gt_all, 804 | y_lo, 805 | y_hi, 806 | x_lo, 807 | x_hi, 808 | w_ylo_xlo, 809 | w_ylo_xhi, 810 | w_yhi_xlo, 811 | w_yhi_xhi, 812 | ) 813 | u_est = u_est_all[j_valid_fg] 814 | 815 | v_gt = v_gt_all[j_valid_fg] 816 | v_est_all = _extract_at_points_packed( 817 | v[i_with_dp], 818 | index_bbox, 819 | index_gt_all, 820 | y_lo, 821 | y_hi, 822 | x_lo, 823 | x_hi, 824 | w_ylo_xlo, 825 | w_ylo_xhi, 826 | w_yhi_xlo, 827 | w_yhi_xhi, 828 | ) 829 | v_est = v_est_all[j_valid_fg] 830 | 831 | index_uv_gt = index_gt_all[j_valid] 832 | index_uv_est_all = _extract_at_points_packed( 833 | index_uv[i_with_dp], 834 | index_bbox, 835 | slice(None), 836 | y_lo, 837 | y_hi, 838 | x_lo, 839 | x_hi, 840 | w_ylo_xlo[:, None], 841 | w_ylo_xhi[:, None], 842 | w_yhi_xlo[:, None], 843 | w_yhi_xhi[:, None], 844 | ) 845 | index_uv_est = index_uv_est_all[j_valid, :] 846 | 847 | # Resample everything to the estimated data size, no need to resample 848 | # S_est then: 849 | s_est = s[i_with_dp] 850 | with torch.no_grad(): 851 | s_gt = _resample_data( 852 | s_gt.unsqueeze(1), 853 | bbox_xywh_gt, 854 | bbox_xywh_est, 855 | self.heatmap_size, 856 | self.heatmap_size, 857 | mode="nearest", 858 | padding_mode="zeros", 859 | ).squeeze(1) 860 | 861 | # add point-based losses: 862 | u_loss = F.smooth_l1_loss(u_est, u_gt, reduction="sum") * self.w_points 863 | losses["loss_densepose_U"] = u_loss 864 | v_loss = F.smooth_l1_loss(v_est, v_gt, reduction="sum") * self.w_points 865 | losses["loss_densepose_V"] = v_loss 866 | index_uv_loss = F.cross_entropy(index_uv_est, index_uv_gt.long()) * self.w_part 867 | losses["loss_densepose_I"] = index_uv_loss 868 | 869 | if self.n_segm_chan == 2: 870 | s_gt = s_gt > 0 871 | s_loss = F.cross_entropy(s_est, s_gt.long()) * self.w_segm 872 | losses["loss_densepose_S"] = s_loss 873 | return losses 874 | 875 | 876 | def build_densepose_losses(cfg): 877 | losses = DensePoseLosses(cfg) 878 | return losses 879 | -------------------------------------------------------------------------------- /DensePose/densepose/roi_head.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | 6 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 7 | from detectron2.modeling.poolers import ROIPooler 8 | from detectron2.modeling.roi_heads import select_foreground_proposals 9 | 10 | from .densepose_head import ( 11 | build_densepose_data_filter, 12 | build_densepose_head, 13 | build_densepose_losses, 14 | build_densepose_predictor, 15 | densepose_inference, 16 | ) 17 | 18 | 19 | @ROI_HEADS_REGISTRY.register() 20 | class DensePoseROIHeads(StandardROIHeads): 21 | """ 22 | A Standard ROIHeads which contains an addition of DensePose head. 23 | """ 24 | 25 | def __init__(self, cfg, input_shape): 26 | super().__init__(cfg, input_shape) 27 | self._init_densepose_head(cfg) 28 | 29 | def _init_densepose_head(self, cfg): 30 | # fmt: off 31 | self.densepose_on = cfg.MODEL.DENSEPOSE_ON 32 | if not self.densepose_on: 33 | return 34 | self.densepose_data_filter = build_densepose_data_filter(cfg) 35 | dp_pooler_resolution = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION 36 | dp_pooler_scales = tuple(1.0 / self.feature_strides[k] for k in self.in_features) 37 | dp_pooler_sampling_ratio = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO 38 | dp_pooler_type = cfg.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE 39 | # fmt: on 40 | in_channels = [self.feature_channels[f] for f in self.in_features][0] 41 | self.densepose_pooler = ROIPooler( 42 | output_size=dp_pooler_resolution, 43 | scales=dp_pooler_scales, 44 | sampling_ratio=dp_pooler_sampling_ratio, 45 | pooler_type=dp_pooler_type, 46 | ) 47 | self.densepose_head = build_densepose_head(cfg, in_channels) 48 | self.densepose_predictor = build_densepose_predictor( 49 | cfg, self.densepose_head.n_out_channels 50 | ) 51 | self.densepose_losses = build_densepose_losses(cfg) 52 | 53 | def _forward_densepose(self, features, instances): 54 | """ 55 | Forward logic of the densepose prediction branch. 56 | 57 | Args: 58 | features (list[Tensor]): #level input features for densepose prediction 59 | instances (list[Instances]): the per-image instances to train/predict densepose. 60 | In training, they can be the proposals. 61 | In inference, they can be the predicted boxes. 62 | 63 | Returns: 64 | In training, a dict of losses. 65 | In inference, update `instances` with new fields "densepose" and return it. 66 | """ 67 | if not self.densepose_on: 68 | return {} if self.training else instances 69 | 70 | if self.training: 71 | proposals, _ = select_foreground_proposals(instances, self.num_classes) 72 | proposals_dp = self.densepose_data_filter(proposals) 73 | if len(proposals_dp) > 0: 74 | proposal_boxes = [x.proposal_boxes for x in proposals_dp] 75 | features_dp = self.densepose_pooler(features, proposal_boxes) 76 | densepose_head_outputs = self.densepose_head(features_dp) 77 | densepose_outputs, _ = self.densepose_predictor(densepose_head_outputs) 78 | densepose_loss_dict = self.densepose_losses(proposals_dp, densepose_outputs) 79 | return densepose_loss_dict 80 | else: 81 | pred_boxes = [x.pred_boxes for x in instances] 82 | features_dp = self.densepose_pooler(features, pred_boxes) 83 | if len(features_dp) > 0: 84 | densepose_head_outputs = self.densepose_head(features_dp) 85 | densepose_outputs, _ = self.densepose_predictor(densepose_head_outputs) 86 | else: 87 | # If no detection occurred instances 88 | # set densepose_outputs to empty tensors 89 | empty_tensor = torch.zeros(size=(0, 0, 0, 0), device=features_dp.device) 90 | densepose_outputs = tuple([empty_tensor] * 4) 91 | 92 | densepose_inference(densepose_outputs, instances) 93 | return instances 94 | 95 | def forward(self, images, features, proposals, targets=None): 96 | features_list = [features[f] for f in self.in_features] 97 | 98 | instances, losses = super().forward(images, features, proposals, targets) 99 | del targets, images 100 | 101 | if self.training: 102 | losses.update(self._forward_densepose(features_list, instances)) 103 | else: 104 | instances = self._forward_densepose(features_list, instances) 105 | return instances, losses 106 | -------------------------------------------------------------------------------- /DensePose/densepose/structures.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import base64 3 | import numpy as np 4 | from io import BytesIO 5 | import torch 6 | from PIL import Image 7 | from torch.nn import functional as F 8 | 9 | 10 | class DensePoseTransformData(object): 11 | 12 | # Horizontal symmetry label transforms used for horizontal flip 13 | MASK_LABEL_SYMMETRIES = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14] 14 | # fmt: off 15 | POINT_LABEL_SYMMETRIES = [ 0, 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23] # noqa 16 | # fmt: on 17 | 18 | def __init__(self, uv_symmetries): 19 | self.mask_label_symmetries = DensePoseTransformData.MASK_LABEL_SYMMETRIES 20 | self.point_label_symmetries = DensePoseTransformData.POINT_LABEL_SYMMETRIES 21 | self.uv_symmetries = uv_symmetries 22 | 23 | @staticmethod 24 | def load(fpath): 25 | import scipy.io 26 | 27 | uv_symmetry_map = scipy.io.loadmat(fpath) 28 | uv_symmetry_map_torch = {} 29 | for key in ["U_transforms", "V_transforms"]: 30 | map_src = uv_symmetry_map[key] 31 | uv_symmetry_map_torch[key] = [] 32 | for i in range(uv_symmetry_map[key].shape[1]): 33 | uv_symmetry_map_torch[key].append( 34 | torch.from_numpy(map_src[0, i]).to(dtype=torch.float) 35 | ) 36 | transform_data = DensePoseTransformData(uv_symmetry_map_torch) 37 | return transform_data 38 | 39 | 40 | class DensePoseDataRelative(object): 41 | """ 42 | Dense pose relative annotations that can be applied to any bounding box: 43 | x - normalized X coordinates [0, 255] of annotated points 44 | y - normalized Y coordinates [0, 255] of annotated points 45 | i - body part labels 0,...,24 for annotated points 46 | u - body part U coordinates [0, 1] for annotated points 47 | v - body part V coordinates [0, 1] for annotated points 48 | segm - 256x256 segmentation mask with values 0,...,14 49 | To obtain absolute x and y data wrt some bounding box one needs to first 50 | divide the data by 256, multiply by the respective bounding box size 51 | and add bounding box offset: 52 | x_img = x0 + x_norm * w / 256.0 53 | y_img = y0 + y_norm * h / 256.0 54 | Segmentation masks are typically sampled to get image-based masks. 55 | """ 56 | 57 | # Key for normalized X coordinates in annotation dict 58 | X_KEY = "dp_x" 59 | # Key for normalized Y coordinates in annotation dict 60 | Y_KEY = "dp_y" 61 | # Key for U part coordinates in annotation dict 62 | U_KEY = "dp_U" 63 | # Key for V part coordinates in annotation dict 64 | V_KEY = "dp_V" 65 | # Key for I point labels in annotation dict 66 | I_KEY = "dp_I" 67 | # Key for segmentation mask in annotation dict 68 | S_KEY = "dp_masks" 69 | # Number of body parts in segmentation masks 70 | N_BODY_PARTS = 14 71 | # Number of parts in point labels 72 | N_PART_LABELS = 24 73 | MASK_SIZE = 256 74 | 75 | def __init__(self, annotation, cleanup=False): 76 | is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation) 77 | assert is_valid, "Invalid DensePose annotations: {}".format(reason_not_valid) 78 | self.x = torch.as_tensor(annotation[DensePoseDataRelative.X_KEY]) 79 | self.y = torch.as_tensor(annotation[DensePoseDataRelative.Y_KEY]) 80 | self.i = torch.as_tensor(annotation[DensePoseDataRelative.I_KEY]) 81 | self.u = torch.as_tensor(annotation[DensePoseDataRelative.U_KEY]) 82 | self.v = torch.as_tensor(annotation[DensePoseDataRelative.V_KEY]) 83 | self.segm = DensePoseDataRelative.extract_segmentation_mask(annotation) 84 | self.device = torch.device("cpu") 85 | if cleanup: 86 | DensePoseDataRelative.cleanup_annotation(annotation) 87 | 88 | def to(self, device): 89 | if self.device == device: 90 | return self 91 | new_data = DensePoseDataRelative.__new__(DensePoseDataRelative) 92 | new_data.x = self.x 93 | new_data.x = self.x.to(device) 94 | new_data.y = self.y.to(device) 95 | new_data.i = self.i.to(device) 96 | new_data.u = self.u.to(device) 97 | new_data.v = self.v.to(device) 98 | new_data.segm = self.segm.to(device) 99 | new_data.device = device 100 | return new_data 101 | 102 | @staticmethod 103 | def extract_segmentation_mask(annotation): 104 | import pycocotools.mask as mask_utils 105 | 106 | poly_specs = annotation[DensePoseDataRelative.S_KEY] 107 | segm = torch.zeros((DensePoseDataRelative.MASK_SIZE,) * 2, dtype=torch.float32) 108 | for i in range(DensePoseDataRelative.N_BODY_PARTS): 109 | poly_i = poly_specs[i] 110 | if poly_i: 111 | mask_i = mask_utils.decode(poly_i) 112 | segm[mask_i > 0] = i + 1 113 | return segm 114 | 115 | @staticmethod 116 | def validate_annotation(annotation): 117 | for key in [ 118 | DensePoseDataRelative.X_KEY, 119 | DensePoseDataRelative.Y_KEY, 120 | DensePoseDataRelative.I_KEY, 121 | DensePoseDataRelative.U_KEY, 122 | DensePoseDataRelative.V_KEY, 123 | DensePoseDataRelative.S_KEY, 124 | ]: 125 | if key not in annotation: 126 | return False, "no {key} data in the annotation".format(key=key) 127 | return True, None 128 | 129 | @staticmethod 130 | def cleanup_annotation(annotation): 131 | for key in [ 132 | DensePoseDataRelative.X_KEY, 133 | DensePoseDataRelative.Y_KEY, 134 | DensePoseDataRelative.I_KEY, 135 | DensePoseDataRelative.U_KEY, 136 | DensePoseDataRelative.V_KEY, 137 | DensePoseDataRelative.S_KEY, 138 | ]: 139 | if key in annotation: 140 | del annotation[key] 141 | 142 | def apply_transform(self, transforms, densepose_transform_data): 143 | self._transform_pts(transforms, densepose_transform_data) 144 | self._transform_segm(transforms, densepose_transform_data) 145 | 146 | def _transform_pts(self, transforms, dp_transform_data): 147 | import detectron2.data.transforms as T 148 | 149 | # NOTE: This assumes that HorizFlipTransform is the only one that does flip 150 | do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 151 | if do_hflip: 152 | self.x = self.segm.size(1) - self.x 153 | self._flip_iuv_semantics(dp_transform_data) 154 | 155 | def _flip_iuv_semantics(self, dp_transform_data: DensePoseTransformData) -> None: 156 | i_old = self.i.clone() 157 | uv_symmetries = dp_transform_data.uv_symmetries 158 | pt_label_symmetries = dp_transform_data.point_label_symmetries 159 | for i in range(self.N_PART_LABELS): 160 | if i + 1 in i_old: 161 | annot_indices_i = i_old == i + 1 162 | if pt_label_symmetries[i + 1] != i + 1: 163 | self.i[annot_indices_i] = pt_label_symmetries[i + 1] 164 | u_loc = (self.u[annot_indices_i] * 255).long() 165 | v_loc = (self.v[annot_indices_i] * 255).long() 166 | self.u[annot_indices_i] = uv_symmetries["U_transforms"][i][v_loc, u_loc] 167 | self.v[annot_indices_i] = uv_symmetries["V_transforms"][i][v_loc, u_loc] 168 | 169 | def _transform_segm(self, transforms, dp_transform_data): 170 | import detectron2.data.transforms as T 171 | 172 | # NOTE: This assumes that HorizFlipTransform is the only one that does flip 173 | do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 174 | if do_hflip: 175 | self.segm = torch.flip(self.segm, [1]) 176 | self._flip_segm_semantics(dp_transform_data) 177 | 178 | def _flip_segm_semantics(self, dp_transform_data): 179 | old_segm = self.segm.clone() 180 | mask_label_symmetries = dp_transform_data.mask_label_symmetries 181 | for i in range(self.N_BODY_PARTS): 182 | if mask_label_symmetries[i + 1] != i + 1: 183 | self.segm[old_segm == i + 1] = mask_label_symmetries[i + 1] 184 | 185 | 186 | def normalized_coords_transform(x0, y0, w, h): 187 | """ 188 | Coordinates transform that maps top left corner to (-1, -1) and bottom 189 | right corner to (1, 1). Used for torch.grid_sample to initialize the 190 | grid 191 | """ 192 | 193 | def f(p): 194 | return (2 * (p[0] - x0) / w - 1, 2 * (p[1] - y0) / h - 1) 195 | 196 | return f 197 | 198 | 199 | class DensePoseOutput(object): 200 | def __init__(self, S, I, U, V): 201 | self.S = S 202 | self.I = I # noqa: E741 203 | self.U = U 204 | self.V = V 205 | self._check_output_dims(S, I, U, V) 206 | 207 | def _check_output_dims(self, S, I, U, V): 208 | assert ( 209 | len(S.size()) == 4 210 | ), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format( 211 | S.size() 212 | ) 213 | assert ( 214 | len(I.size()) == 4 215 | ), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format( 216 | S.size() 217 | ) 218 | assert ( 219 | len(U.size()) == 4 220 | ), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format( 221 | S.size() 222 | ) 223 | assert ( 224 | len(V.size()) == 4 225 | ), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format( 226 | S.size() 227 | ) 228 | assert len(S) == len(I), ( 229 | "Number of output segmentation planes {} " 230 | "should be equal to the number of output part index " 231 | "planes {}".format(len(S), len(I)) 232 | ) 233 | assert S.size()[2:] == I.size()[2:], ( 234 | "Output segmentation plane size {} " 235 | "should be equal to the output part index " 236 | "plane size {}".format(S.size()[2:], I.size()[2:]) 237 | ) 238 | assert I.size() == U.size(), ( 239 | "Part index output shape {} " 240 | "should be the same as U coordinates output shape {}".format(I.size(), U.size()) 241 | ) 242 | assert I.size() == V.size(), ( 243 | "Part index output shape {} " 244 | "should be the same as V coordinates output shape {}".format(I.size(), V.size()) 245 | ) 246 | 247 | def resize(self, image_size_hw): 248 | # do nothing - outputs are invariant to resize 249 | pass 250 | 251 | def _crop(self, S, I, U, V, bbox_old_xywh, bbox_new_xywh): 252 | """ 253 | Resample S, I, U, V from bbox_old to the cropped bbox_new 254 | """ 255 | x0old, y0old, wold, hold = bbox_old_xywh 256 | x0new, y0new, wnew, hnew = bbox_new_xywh 257 | tr_coords = normalized_coords_transform(x0old, y0old, wold, hold) 258 | topleft = (x0new, y0new) 259 | bottomright = (x0new + wnew, y0new + hnew) 260 | topleft_norm = tr_coords(topleft) 261 | bottomright_norm = tr_coords(bottomright) 262 | hsize = S.size(1) 263 | wsize = S.size(2) 264 | grid = torch.meshgrid( 265 | torch.arange( 266 | topleft_norm[1], 267 | bottomright_norm[1], 268 | (bottomright_norm[1] - topleft_norm[1]) / hsize, 269 | )[:hsize], 270 | torch.arange( 271 | topleft_norm[0], 272 | bottomright_norm[0], 273 | (bottomright_norm[0] - topleft_norm[0]) / wsize, 274 | )[:wsize], 275 | ) 276 | grid = torch.stack(grid, dim=2).to(S.device) 277 | assert ( 278 | grid.size(0) == hsize 279 | ), "Resampled grid expected " "height={}, actual height={}".format(hsize, grid.size(0)) 280 | assert grid.size(1) == wsize, "Resampled grid expected " "width={}, actual width={}".format( 281 | wsize, grid.size(1) 282 | ) 283 | S_new = F.grid_sample( 284 | S.unsqueeze(0), 285 | torch.unsqueeze(grid, 0), 286 | mode="bilinear", 287 | padding_mode="border", 288 | align_corners=True, 289 | ).squeeze(0) 290 | I_new = F.grid_sample( 291 | I.unsqueeze(0), 292 | torch.unsqueeze(grid, 0), 293 | mode="bilinear", 294 | padding_mode="border", 295 | align_corners=True, 296 | ).squeeze(0) 297 | U_new = F.grid_sample( 298 | U.unsqueeze(0), 299 | torch.unsqueeze(grid, 0), 300 | mode="bilinear", 301 | padding_mode="border", 302 | align_corners=True, 303 | ).squeeze(0) 304 | V_new = F.grid_sample( 305 | V.unsqueeze(0), 306 | torch.unsqueeze(grid, 0), 307 | mode="bilinear", 308 | padding_mode="border", 309 | align_corners=True, 310 | ).squeeze(0) 311 | return S_new, I_new, U_new, V_new 312 | 313 | def crop(self, indices_cropped, bboxes_old, bboxes_new): 314 | """ 315 | Crop outputs for selected bounding boxes to the new bounding boxes. 316 | """ 317 | # VK: cropping is ignored for now 318 | # for i, ic in enumerate(indices_cropped): 319 | # self.S[ic], self.I[ic], self.U[ic], self.V[ic] = \ 320 | # self._crop(self.S[ic], self.I[ic], self.U[ic], self.V[ic], 321 | # bboxes_old[i], bboxes_new[i]) 322 | pass 323 | 324 | def to_result(self, boxes_xywh): 325 | """ 326 | Convert DensePose outputs to results format. Results are more compact, 327 | but cannot be resampled any more 328 | """ 329 | result = DensePoseResult(boxes_xywh, self.S, self.I, self.U, self.V) 330 | return result 331 | 332 | def __getitem__(self, item): 333 | if isinstance(item, int): 334 | S_selected = self.S[item].unsqueeze(0) 335 | I_selected = self.I[item].unsqueeze(0) 336 | U_selected = self.U[item].unsqueeze(0) 337 | V_selected = self.V[item].unsqueeze(0) 338 | else: 339 | S_selected = self.S[item] 340 | I_selected = self.I[item] 341 | U_selected = self.U[item] 342 | V_selected = self.V[item] 343 | return DensePoseOutput(S_selected, I_selected, U_selected, V_selected) 344 | 345 | def __str__(self): 346 | s = "DensePoseOutput S {}, I {}, U {}, V {}".format( 347 | list(self.S.size()), list(self.I.size()), list(self.U.size()), list(self.V.size()) 348 | ) 349 | return s 350 | 351 | def __len__(self): 352 | return self.S.size(0) 353 | 354 | 355 | class DensePoseResult(object): 356 | def __init__(self, boxes_xywh, S, I, U, V): 357 | self.results = [] 358 | self.boxes_xywh = boxes_xywh.cpu().tolist() 359 | assert len(boxes_xywh.size()) == 2 360 | assert boxes_xywh.size(1) == 4 361 | for i, box_xywh in enumerate(boxes_xywh): 362 | result_i = self._output_to_result(box_xywh, S[[i]], I[[i]], U[[i]], V[[i]]) 363 | result_numpy_i = result_i.cpu().numpy() 364 | result_encoded_i = DensePoseResult.encode_png_data(result_numpy_i) 365 | result_encoded_with_shape_i = (result_numpy_i.shape, result_encoded_i) 366 | self.results.append(result_encoded_with_shape_i) 367 | 368 | def __str__(self): 369 | s = "DensePoseResult: N={} [{}]".format( 370 | len(self.results), ", ".join([str(list(r[0])) for r in self.results]) 371 | ) 372 | return s 373 | 374 | def _output_to_result(self, box_xywh, S, I, U, V): 375 | x, y, w, h = box_xywh 376 | w = max(int(w), 1) 377 | h = max(int(h), 1) 378 | result = torch.zeros([3, h, w], dtype=torch.uint8, device=U.device) 379 | assert ( 380 | len(S.size()) == 4 381 | ), "AnnIndex tensor size should have {} " "dimensions but has {}".format(4, len(S.size())) 382 | s_bbox = F.interpolate(S, (h, w), mode="bilinear", align_corners=False).argmax(dim=1) 383 | assert ( 384 | len(I.size()) == 4 385 | ), "IndexUV tensor size should have {} " "dimensions but has {}".format(4, len(S.size())) 386 | i_bbox = ( 387 | F.interpolate(I, (h, w), mode="bilinear", align_corners=False).argmax(dim=1) 388 | * (s_bbox > 0).long() 389 | ).squeeze(0) 390 | assert len(U.size()) == 4, "U tensor size should have {} " "dimensions but has {}".format( 391 | 4, len(U.size()) 392 | ) 393 | u_bbox = F.interpolate(U, (h, w), mode="bilinear", align_corners=False) 394 | assert len(V.size()) == 4, "V tensor size should have {} " "dimensions but has {}".format( 395 | 4, len(V.size()) 396 | ) 397 | v_bbox = F.interpolate(V, (h, w), mode="bilinear", align_corners=False) 398 | result[0] = i_bbox 399 | for part_id in range(1, u_bbox.size(1)): 400 | result[1][i_bbox == part_id] = ( 401 | (u_bbox[0, part_id][i_bbox == part_id] * 255).clamp(0, 255).to(torch.uint8) 402 | ) 403 | result[2][i_bbox == part_id] = ( 404 | (v_bbox[0, part_id][i_bbox == part_id] * 255).clamp(0, 255).to(torch.uint8) 405 | ) 406 | assert ( 407 | result.size(1) == h 408 | ), "Results height {} should be equal" "to bounding box height {}".format(result.size(1), h) 409 | assert ( 410 | result.size(2) == w 411 | ), "Results width {} should be equal" "to bounding box width {}".format(result.size(2), w) 412 | return result 413 | 414 | @staticmethod 415 | def encode_png_data(arr): 416 | """ 417 | Encode array data as a PNG image using the highest compression rate 418 | @param arr [in] Data stored in an array of size (3, M, N) of type uint8 419 | @return Base64-encoded string containing PNG-compressed data 420 | """ 421 | assert len(arr.shape) == 3, "Expected a 3D array as an input," " got a {0}D array".format( 422 | len(arr.shape) 423 | ) 424 | assert arr.shape[0] == 3, "Expected first array dimension of size 3," " got {0}".format( 425 | arr.shape[0] 426 | ) 427 | assert arr.dtype == np.uint8, "Expected an array of type np.uint8, " " got {0}".format( 428 | arr.dtype 429 | ) 430 | data = np.moveaxis(arr, 0, -1) 431 | im = Image.fromarray(data) 432 | fstream = BytesIO() 433 | im.save(fstream, format="png", optimize=True) 434 | s = base64.encodebytes(fstream.getvalue()).decode() 435 | return s 436 | 437 | @staticmethod 438 | def decode_png_data(shape, s): 439 | """ 440 | Decode array data from a string that contains PNG-compressed data 441 | @param Base64-encoded string containing PNG-compressed data 442 | @return Data stored in an array of size (3, M, N) of type uint8 443 | """ 444 | fstream = BytesIO(base64.decodebytes(s.encode())) 445 | im = Image.open(fstream) 446 | data = np.moveaxis(np.array(im.getdata(), dtype=np.uint8), -1, 0) 447 | return data.reshape(shape) 448 | 449 | def __len__(self): 450 | return len(self.results) 451 | 452 | def __getitem__(self, item): 453 | result_encoded = self.results[item] 454 | bbox_xywh = self.boxes_xywh[item] 455 | return result_encoded, bbox_xywh 456 | 457 | 458 | class DensePoseList(object): 459 | 460 | _TORCH_DEVICE_CPU = torch.device("cpu") 461 | 462 | def __init__(self, densepose_datas, boxes_xyxy_abs, image_size_hw, device=_TORCH_DEVICE_CPU): 463 | assert len(densepose_datas) == len(boxes_xyxy_abs), ( 464 | "Attempt to initialize DensePoseList with {} DensePose datas " 465 | "and {} boxes".format(len(densepose_datas), len(boxes_xyxy_abs)) 466 | ) 467 | self.densepose_datas = [] 468 | for densepose_data in densepose_datas: 469 | assert isinstance(densepose_data, DensePoseDataRelative) or densepose_data is None, ( 470 | "Attempt to initialize DensePoseList with DensePose datas " 471 | "of type {}, expected DensePoseDataRelative".format(type(densepose_data)) 472 | ) 473 | densepose_data_ondevice = ( 474 | densepose_data.to(device) if densepose_data is not None else None 475 | ) 476 | self.densepose_datas.append(densepose_data_ondevice) 477 | self.boxes_xyxy_abs = boxes_xyxy_abs.to(device) 478 | self.image_size_hw = image_size_hw 479 | self.device = device 480 | 481 | def to(self, device): 482 | if self.device == device: 483 | return self 484 | return DensePoseList(self.densepose_datas, self.boxes_xyxy_abs, self.image_size_hw, device) 485 | 486 | def __iter__(self): 487 | return iter(self.densepose_datas) 488 | 489 | def __len__(self): 490 | return len(self.densepose_datas) 491 | 492 | def __repr__(self): 493 | s = self.__class__.__name__ + "(" 494 | s += "num_instances={}, ".format(len(self.densepose_datas)) 495 | s += "image_width={}, ".format(self.image_size_hw[1]) 496 | s += "image_height={})".format(self.image_size_hw[0]) 497 | return s 498 | 499 | def __getitem__(self, item): 500 | if isinstance(item, int): 501 | densepose_data_rel = self.densepose_datas[item] 502 | return densepose_data_rel 503 | elif isinstance(item, slice): 504 | densepose_datas_rel = self.densepose_datas[item] 505 | boxes_xyxy_abs = self.boxes_xyxy_abs[item] 506 | return DensePoseList( 507 | densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device 508 | ) 509 | elif isinstance(item, torch.Tensor) and (item.dtype == torch.bool): 510 | densepose_datas_rel = [self.densepose_datas[i] for i, x in enumerate(item) if x > 0] 511 | boxes_xyxy_abs = self.boxes_xyxy_abs[item] 512 | return DensePoseList( 513 | densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device 514 | ) 515 | else: 516 | densepose_datas_rel = [self.densepose_datas[i] for i in item] 517 | boxes_xyxy_abs = self.boxes_xyxy_abs[item] 518 | return DensePoseList( 519 | densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device 520 | ) 521 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Akash Sengupta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PointRend/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 43 | -------------------------------------------------------------------------------- /PointRend/configs/InstanceSegmentation/Base-PointRend-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | ROI_HEADS: 4 | NAME: "PointRendROIHeads" 5 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 6 | ROI_BOX_HEAD: 7 | TRAIN_ON_PRED_BOXES: True 8 | ROI_MASK_HEAD: 9 | NAME: "CoarseMaskHead" 10 | FC_DIM: 1024 11 | NUM_FC: 2 12 | OUTPUT_SIDE_RESOLUTION: 7 13 | IN_FEATURES: ["p2"] 14 | POINT_HEAD_ON: True 15 | POINT_HEAD: 16 | FC_DIM: 256 17 | NUM_FC: 3 18 | IN_FEATURES: ["p2"] 19 | INPUT: 20 | MASK_FORMAT: "bitmask" 21 | -------------------------------------------------------------------------------- /PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-PointRend-RCNN-FPN.yaml 2 | MODEL: 3 | WEIGHTS: detectron2://ImageNetPretrained/MSRA/R-50.pkl 4 | MASK_ON: true 5 | RESNETS: 6 | DEPTH: 50 7 | SOLVER: 8 | STEPS: (210000, 250000) 9 | MAX_ITER: 270000 10 | -------------------------------------------------------------------------------- /PointRend/point_rend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .config import add_pointrend_config 3 | from .coarse_mask_head import CoarseMaskHead 4 | from .roi_heads import PointRendROIHeads 5 | -------------------------------------------------------------------------------- /PointRend/point_rend/coarse_mask_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import Conv2d, ShapeSpec 8 | from detectron2.modeling import ROI_MASK_HEAD_REGISTRY 9 | 10 | 11 | @ROI_MASK_HEAD_REGISTRY.register() 12 | class CoarseMaskHead(nn.Module): 13 | """ 14 | A mask head with fully connected layers. Given pooled features it first reduces channels and 15 | spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously 16 | to the standard box head. 17 | """ 18 | 19 | def __init__(self, cfg, input_shape: ShapeSpec): 20 | """ 21 | The following attributes are parsed from config: 22 | conv_dim: the output dimension of the conv layers 23 | fc_dim: the feature dimenstion of the FC layers 24 | num_fc: the number of FC layers 25 | output_side_resolution: side resolution of the output square mask prediction 26 | """ 27 | super(CoarseMaskHead, self).__init__() 28 | 29 | # fmt: off 30 | self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 31 | conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM 32 | self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM 33 | num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC 34 | self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION 35 | self.input_channels = input_shape.channels 36 | self.input_h = input_shape.height 37 | self.input_w = input_shape.width 38 | # fmt: on 39 | 40 | self.conv_layers = [] 41 | if self.input_channels > conv_dim: 42 | self.reduce_channel_dim_conv = Conv2d( 43 | self.input_channels, 44 | conv_dim, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=True, 49 | activation=F.relu, 50 | ) 51 | self.conv_layers.append(self.reduce_channel_dim_conv) 52 | 53 | self.reduce_spatial_dim_conv = Conv2d( 54 | conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu 55 | ) 56 | self.conv_layers.append(self.reduce_spatial_dim_conv) 57 | 58 | input_dim = conv_dim * self.input_h * self.input_w 59 | input_dim //= 4 60 | 61 | self.fcs = [] 62 | for k in range(num_fc): 63 | fc = nn.Linear(input_dim, self.fc_dim) 64 | self.add_module("coarse_mask_fc{}".format(k + 1), fc) 65 | self.fcs.append(fc) 66 | input_dim = self.fc_dim 67 | 68 | output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution 69 | 70 | self.prediction = nn.Linear(self.fc_dim, output_dim) 71 | # use normal distribution initialization for mask prediction layer 72 | nn.init.normal_(self.prediction.weight, std=0.001) 73 | nn.init.constant_(self.prediction.bias, 0) 74 | 75 | for layer in self.conv_layers: 76 | weight_init.c2_msra_fill(layer) 77 | for layer in self.fcs: 78 | weight_init.c2_xavier_fill(layer) 79 | 80 | def forward(self, x): 81 | N = x.shape[0] 82 | x = x.view(N, self.input_channels, self.input_h, self.input_w) 83 | for layer in self.conv_layers: 84 | x = layer(x) 85 | x = torch.flatten(x, start_dim=1) 86 | for layer in self.fcs: 87 | x = F.relu(layer(x)) 88 | return self.prediction(x).view( 89 | N, self.num_classes, self.output_side_resolution, self.output_side_resolution 90 | ) 91 | -------------------------------------------------------------------------------- /PointRend/point_rend/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_pointrend_config(cfg): 8 | """ 9 | Add config for PointRend. 10 | """ 11 | # Names of the input feature maps to be used by a coarse mask head. 12 | cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",) 13 | cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 2 14 | cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 1024 15 | # The side size of a coarse mask head prediction. 16 | cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7 17 | # True if point head is used. 18 | cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False 19 | 20 | cfg.MODEL.POINT_HEAD = CN() 21 | cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead" 22 | cfg.MODEL.POINT_HEAD.NUM_CLASSES = 80 23 | # Names of the input feature maps to be used by a mask point head. 24 | cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",) 25 | # Number of points sampled during training for a mask point head. 26 | cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14 27 | # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the 28 | # original paper. 29 | cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3 30 | # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in 31 | # the original paper. 32 | cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75 33 | # Number of subdivision steps during inference. 34 | cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5 35 | # Maximum number of points selected at each subdivision step (N). 36 | cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28 37 | cfg.MODEL.POINT_HEAD.FC_DIM = 256 38 | cfg.MODEL.POINT_HEAD.NUM_FC = 3 39 | cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False 40 | # If True, then coarse prediction features are used as inout for each layer in PointRend's MLP. 41 | cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True 42 | -------------------------------------------------------------------------------- /PointRend/point_rend/point_features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from detectron2.layers import cat 6 | from detectron2.structures import Boxes 7 | 8 | 9 | """ 10 | Shape shorthand in this module: 11 | 12 | N: minibatch dimension size, i.e. the number of RoIs for instance segmenation or the 13 | number of images for semantic segmenation. 14 | R: number of ROIs, combined over all images, in the minibatch 15 | P: number of points 16 | """ 17 | 18 | 19 | def point_sample(input, point_coords, **kwargs): 20 | """ 21 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. 22 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside 23 | [0, 1] x [0, 1] square. 24 | 25 | Args: 26 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. 27 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains 28 | [0, 1] x [0, 1] normalized point coordinates. 29 | 30 | Returns: 31 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains 32 | features for points in `point_coords`. The features are obtained via bilinear 33 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. 34 | """ 35 | add_dim = False 36 | if point_coords.dim() == 3: 37 | add_dim = True 38 | point_coords = point_coords.unsqueeze(2) 39 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) 40 | if add_dim: 41 | output = output.squeeze(3) 42 | return output 43 | 44 | 45 | def generate_regular_grid_point_coords(R, side_size, device): 46 | """ 47 | Generate regular square grid of points in [0, 1] x [0, 1] coordinate space. 48 | 49 | Args: 50 | R (int): The number of grids to sample, one for each region. 51 | side_size (int): The side size of the regular grid. 52 | device (torch.device): Desired device of returned tensor. 53 | 54 | Returns: 55 | (Tensor): A tensor of shape (R, side_size^2, 2) that contains coordinates 56 | for the regular grids. 57 | """ 58 | aff = torch.tensor([[[0.5, 0, 0.5], [0, 0.5, 0.5]]], device=device) 59 | r = F.affine_grid(aff, torch.Size((1, 1, side_size, side_size)), align_corners=False) 60 | return r.view(1, -1, 2).expand(R, -1, -1) 61 | 62 | 63 | def get_uncertain_point_coords_with_randomness( 64 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio 65 | ): 66 | """ 67 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties 68 | are calculated for each point using 'uncertainty_func' function that takes point's logit 69 | prediction as input. 70 | See PointRend paper for details. 71 | 72 | Args: 73 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for 74 | class-specific or class-agnostic prediction. 75 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that 76 | contains logit predictions for P points and returns their uncertainties as a Tensor of 77 | shape (N, 1, P). 78 | num_points (int): The number of points P to sample. 79 | oversample_ratio (int): Oversampling parameter. 80 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. 81 | 82 | Returns: 83 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P 84 | sampled points. 85 | """ 86 | assert oversample_ratio >= 1 87 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 88 | num_boxes = coarse_logits.shape[0] 89 | num_sampled = int(num_points * oversample_ratio) 90 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device) 91 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False) 92 | # It is crucial to calculate uncertanty based on the sampled prediction value for the points. 93 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads 94 | # to worse results. To illustrate the difference: a sampled point between two coarse predictions 95 | # with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one 96 | # calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the 97 | # center point, they will get -1 unceratinty. 98 | point_uncertainties = uncertainty_func(point_logits) 99 | num_uncertain_points = int(importance_sample_ratio * num_points) 100 | num_random_points = num_points - num_uncertain_points 101 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] 102 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device) 103 | idx += shift[:, None] 104 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( 105 | num_boxes, num_uncertain_points, 2 106 | ) 107 | if num_random_points > 0: 108 | point_coords = cat( 109 | [ 110 | point_coords, 111 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device), 112 | ], 113 | dim=1, 114 | ) 115 | return point_coords 116 | 117 | 118 | def get_uncertain_point_coords_on_grid(uncertainty_map, num_points): 119 | """ 120 | Find `num_points` most uncertain points from `uncertainty_map` grid. 121 | 122 | Args: 123 | uncertainty_map (Tensor): A tensor of shape (N, 1, H, W) that contains uncertainty 124 | values for a set of points on a regular H x W grid. 125 | num_points (int): The number of points P to select. 126 | 127 | Returns: 128 | point_indices (Tensor): A tensor of shape (N, P) that contains indices from 129 | [0, H x W) of the most uncertain points. 130 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains [0, 1] x [0, 1] normalized 131 | coordinates of the most uncertain points from the H x W grid. 132 | """ 133 | R, _, H, W = uncertainty_map.shape 134 | h_step = 1.0 / float(H) 135 | w_step = 1.0 / float(W) 136 | 137 | num_points = min(H * W, num_points) 138 | point_indices = torch.topk(uncertainty_map.view(R, H * W), k=num_points, dim=1)[1] 139 | point_coords = torch.zeros(R, num_points, 2, dtype=torch.float, device=uncertainty_map.device) 140 | point_coords[:, :, 0] = w_step / 2.0 + (point_indices % W).to(torch.float) * w_step 141 | point_coords[:, :, 1] = h_step / 2.0 + (point_indices // W).to(torch.float) * h_step 142 | return point_indices, point_coords 143 | 144 | 145 | def point_sample_fine_grained_features(features_list, feature_scales, boxes, point_coords): 146 | """ 147 | Get features from feature maps in `features_list` that correspond to specific point coordinates 148 | inside each bounding box from `boxes`. 149 | 150 | Args: 151 | features_list (list[Tensor]): A list of feature map tensors to get features from. 152 | feature_scales (list[float]): A list of scales for tensors in `features_list`. 153 | boxes (list[Boxes]): A list of I Boxes objects that contain R_1 + ... + R_I = R boxes all 154 | together. 155 | point_coords (Tensor): A tensor of shape (R, P, 2) that contains 156 | [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. 157 | 158 | Returns: 159 | point_features (Tensor): A tensor of shape (R, C, P) that contains features sampled 160 | from all features maps in feature_list for P sampled points for all R boxes in `boxes`. 161 | point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains image-level 162 | coordinates of P points. 163 | """ 164 | cat_boxes = Boxes.cat(boxes) 165 | num_boxes = [len(b) for b in boxes] 166 | 167 | point_coords_wrt_image = get_point_coords_wrt_image(cat_boxes.tensor, point_coords) 168 | split_point_coords_wrt_image = torch.split(point_coords_wrt_image, num_boxes) 169 | 170 | point_features = [] 171 | for idx_img, point_coords_wrt_image_per_image in enumerate(split_point_coords_wrt_image): 172 | point_features_per_image = [] 173 | for idx_feature, feature_map in enumerate(features_list): 174 | h, w = feature_map.shape[-2:] 175 | scale = torch.tensor([w, h], device=feature_map.device) / feature_scales[idx_feature] 176 | point_coords_scaled = point_coords_wrt_image_per_image / scale 177 | point_features_per_image.append( 178 | point_sample( 179 | feature_map[idx_img].unsqueeze(0), 180 | point_coords_scaled.unsqueeze(0), 181 | align_corners=False, 182 | ) 183 | .squeeze(0) 184 | .transpose(1, 0) 185 | ) 186 | point_features.append(cat(point_features_per_image, dim=1)) 187 | 188 | return cat(point_features, dim=0), point_coords_wrt_image 189 | 190 | 191 | def get_point_coords_wrt_image(boxes_coords, point_coords): 192 | """ 193 | Convert box-normalized [0, 1] x [0, 1] point cooordinates to image-level coordinates. 194 | 195 | Args: 196 | boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. 197 | coordinates. 198 | point_coords (Tensor): A tensor of shape (R, P, 2) that contains 199 | [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. 200 | 201 | Returns: 202 | point_coords_wrt_image (Tensor): A tensor of shape (R, P, 2) that contains 203 | image-normalized coordinates of P sampled points. 204 | """ 205 | with torch.no_grad(): 206 | point_coords_wrt_image = point_coords.clone() 207 | point_coords_wrt_image[:, :, 0] = point_coords_wrt_image[:, :, 0] * ( 208 | boxes_coords[:, None, 2] - boxes_coords[:, None, 0] 209 | ) 210 | point_coords_wrt_image[:, :, 1] = point_coords_wrt_image[:, :, 1] * ( 211 | boxes_coords[:, None, 3] - boxes_coords[:, None, 1] 212 | ) 213 | point_coords_wrt_image[:, :, 0] += boxes_coords[:, None, 0] 214 | point_coords_wrt_image[:, :, 1] += boxes_coords[:, None, 1] 215 | return point_coords_wrt_image 216 | -------------------------------------------------------------------------------- /PointRend/point_rend/point_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from detectron2.layers import ShapeSpec, cat 8 | from detectron2.structures import BitMasks 9 | from detectron2.utils.events import get_event_storage 10 | from detectron2.utils.registry import Registry 11 | 12 | from .point_features import point_sample 13 | 14 | POINT_HEAD_REGISTRY = Registry("POINT_HEAD") 15 | POINT_HEAD_REGISTRY.__doc__ = """ 16 | Registry for point heads, which makes prediction for a given set of per-point features. 17 | 18 | The registered object will be called with `obj(cfg, input_shape)`. 19 | """ 20 | 21 | 22 | def roi_mask_point_loss(mask_logits, instances, points_coord): 23 | """ 24 | Compute the point-based loss for instance segmentation mask predictions. 25 | 26 | Args: 27 | mask_logits (Tensor): A tensor of shape (R, C, P) or (R, 1, P) for class-specific or 28 | class-agnostic, where R is the total number of predicted masks in all images, C is the 29 | number of foreground classes, and P is the number of points sampled for each mask. 30 | The values are logits. 31 | instances (list[Instances]): A list of N Instances, where N is the number of images 32 | in the batch. These instances are in 1:1 correspondence with the `mask_logits`. So, i_th 33 | elememt of the list contains R_i objects and R_1 + ... + R_N is equal to R. 34 | The ground-truth labels (class, box, mask, ...) associated with each instance are stored 35 | in fields. 36 | points_coords (Tensor): A tensor of shape (R, P, 2), where R is the total number of 37 | predicted masks and P is the number of points for each mask. The coordinates are in 38 | the image pixel coordinate space, i.e. [0, H] x [0, W]. 39 | Returns: 40 | point_loss (Tensor): A scalar tensor containing the loss. 41 | """ 42 | assert len(instances) == 0 or isinstance( 43 | instances[0].gt_masks, BitMasks 44 | ), "Point head works with GT in 'bitmask' format only. Set INPUT.MASK_FORMAT to 'bitmask'." 45 | with torch.no_grad(): 46 | cls_agnostic_mask = mask_logits.size(1) == 1 47 | total_num_masks = mask_logits.size(0) 48 | 49 | gt_classes = [] 50 | gt_mask_logits = [] 51 | idx = 0 52 | for instances_per_image in instances: 53 | if not cls_agnostic_mask: 54 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 55 | gt_classes.append(gt_classes_per_image) 56 | 57 | gt_bit_masks = instances_per_image.gt_masks.tensor 58 | h, w = instances_per_image.gt_masks.image_size 59 | scale = torch.tensor([w, h], dtype=torch.float, device=gt_bit_masks.device) 60 | points_coord_grid_sample_format = ( 61 | points_coord[idx : idx + len(instances_per_image)] / scale 62 | ) 63 | idx += len(instances_per_image) 64 | gt_mask_logits.append( 65 | point_sample( 66 | gt_bit_masks.to(torch.float32).unsqueeze(1), 67 | points_coord_grid_sample_format, 68 | align_corners=False, 69 | ).squeeze(1) 70 | ) 71 | gt_mask_logits = cat(gt_mask_logits) 72 | 73 | # torch.mean (in binary_cross_entropy_with_logits) doesn't 74 | # accept empty tensors, so handle it separately 75 | if gt_mask_logits.numel() == 0: 76 | return mask_logits.sum() * 0 77 | 78 | if cls_agnostic_mask: 79 | mask_logits = mask_logits[:, 0] 80 | else: 81 | indices = torch.arange(total_num_masks) 82 | gt_classes = cat(gt_classes, dim=0) 83 | mask_logits = mask_logits[indices, gt_classes] 84 | 85 | # Log the training accuracy (using gt classes and 0.0 threshold for the logits) 86 | mask_accurate = (mask_logits > 0.0) == gt_mask_logits.to(dtype=torch.uint8) 87 | mask_accuracy = mask_accurate.nonzero().size(0) / mask_accurate.numel() 88 | get_event_storage().put_scalar("point_rend/accuracy", mask_accuracy) 89 | 90 | point_loss = F.binary_cross_entropy_with_logits( 91 | mask_logits, gt_mask_logits.to(dtype=torch.float32), reduction="mean" 92 | ) 93 | return point_loss 94 | 95 | 96 | @POINT_HEAD_REGISTRY.register() 97 | class StandardPointHead(nn.Module): 98 | """ 99 | A point head multi-layer perceptron which we model with conv1d layers with kernel 1. The head 100 | takes both fine-grained and coarse prediction features as its input. 101 | """ 102 | 103 | def __init__(self, cfg, input_shape: ShapeSpec): 104 | """ 105 | The following attributes are parsed from config: 106 | fc_dim: the output dimension of each FC layers 107 | num_fc: the number of FC layers 108 | coarse_pred_each_layer: if True, coarse prediction features are concatenated to each 109 | layer's input 110 | """ 111 | super(StandardPointHead, self).__init__() 112 | # fmt: off 113 | num_classes = cfg.MODEL.POINT_HEAD.NUM_CLASSES 114 | fc_dim = cfg.MODEL.POINT_HEAD.FC_DIM 115 | num_fc = cfg.MODEL.POINT_HEAD.NUM_FC 116 | cls_agnostic_mask = cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK 117 | self.coarse_pred_each_layer = cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER 118 | input_channels = input_shape.channels 119 | # fmt: on 120 | 121 | fc_dim_in = input_channels + num_classes 122 | self.fc_layers = [] 123 | for k in range(num_fc): 124 | fc = nn.Conv1d(fc_dim_in, fc_dim, kernel_size=1, stride=1, padding=0, bias=True) 125 | self.add_module("fc{}".format(k + 1), fc) 126 | self.fc_layers.append(fc) 127 | fc_dim_in = fc_dim 128 | fc_dim_in += num_classes if self.coarse_pred_each_layer else 0 129 | 130 | num_mask_classes = 1 if cls_agnostic_mask else num_classes 131 | self.predictor = nn.Conv1d(fc_dim_in, num_mask_classes, kernel_size=1, stride=1, padding=0) 132 | 133 | for layer in self.fc_layers: 134 | weight_init.c2_msra_fill(layer) 135 | # use normal distribution initialization for mask prediction layer 136 | nn.init.normal_(self.predictor.weight, std=0.001) 137 | if self.predictor.bias is not None: 138 | nn.init.constant_(self.predictor.bias, 0) 139 | 140 | def forward(self, fine_grained_features, coarse_features): 141 | x = torch.cat((fine_grained_features, coarse_features), dim=1) 142 | for layer in self.fc_layers: 143 | x = F.relu(layer(x)) 144 | if self.coarse_pred_each_layer: 145 | x = cat((x, coarse_features), dim=1) 146 | return self.predictor(x) 147 | 148 | 149 | def build_point_head(cfg, input_channels): 150 | """ 151 | Build a point head defined by `cfg.MODEL.POINT_HEAD.NAME`. 152 | """ 153 | head_name = cfg.MODEL.POINT_HEAD.NAME 154 | return POINT_HEAD_REGISTRY.get(head_name)(cfg, input_channels) 155 | -------------------------------------------------------------------------------- /PointRend/point_rend/roi_heads.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # File: 4 | import numpy as np 5 | import torch 6 | 7 | from detectron2.layers import ShapeSpec, cat, interpolate 8 | from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads 9 | from detectron2.modeling.roi_heads.mask_head import ( 10 | build_mask_head, 11 | mask_rcnn_inference, 12 | mask_rcnn_loss, 13 | ) 14 | from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals 15 | 16 | from .point_features import ( 17 | generate_regular_grid_point_coords, 18 | get_uncertain_point_coords_on_grid, 19 | get_uncertain_point_coords_with_randomness, 20 | point_sample, 21 | point_sample_fine_grained_features, 22 | ) 23 | from .point_head import build_point_head, roi_mask_point_loss 24 | 25 | 26 | def calculate_uncertainty(logits, classes): 27 | """ 28 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 29 | foreground class in `classes`. 30 | 31 | Args: 32 | logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or 33 | class-agnostic, where R is the total number of predicted masks in all images and C is 34 | the number of foreground classes. The values are logits. 35 | classes (list): A list of length R that contains either predicted of ground truth class 36 | for eash predicted mask. 37 | 38 | Returns: 39 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 40 | the most uncertain locations having the highest uncertainty score. 41 | """ 42 | if logits.shape[1] == 1: 43 | gt_class_logits = logits.clone() 44 | else: 45 | gt_class_logits = logits[ 46 | torch.arange(logits.shape[0], device=logits.device), classes 47 | ].unsqueeze(1) 48 | return -torch.abs(gt_class_logits) 49 | 50 | 51 | @ROI_HEADS_REGISTRY.register() 52 | class PointRendROIHeads(StandardROIHeads): 53 | """ 54 | The RoI heads class for PointRend instance segmentation models. 55 | 56 | In this class we redefine the mask head of `StandardROIHeads` leaving all other heads intact. 57 | To avoid namespace conflict with other heads we use names starting from `mask_` for all 58 | variables that correspond to the mask head in the class's namespace. 59 | """ 60 | 61 | def _init_mask_head(self, cfg): 62 | # fmt: off 63 | self.mask_on = cfg.MODEL.MASK_ON 64 | if not self.mask_on: 65 | return 66 | self.mask_coarse_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES 67 | self.mask_coarse_side_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION 68 | # fmt: on 69 | 70 | in_channels = np.sum([self.feature_channels[f] for f in self.mask_coarse_in_features]) 71 | self.mask_coarse_head = build_mask_head( 72 | cfg, 73 | ShapeSpec( 74 | channels=in_channels, 75 | width=self.mask_coarse_side_size, 76 | height=self.mask_coarse_side_size, 77 | ), 78 | ) 79 | self._init_point_head(cfg) 80 | 81 | def _init_point_head(self, cfg): 82 | # fmt: off 83 | self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON 84 | if not self.mask_point_on: 85 | return 86 | assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES 87 | self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES 88 | self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS 89 | self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO 90 | self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO 91 | # next two parameters are use in the adaptive subdivions inference procedure 92 | self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS 93 | self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS 94 | # fmt: on 95 | 96 | in_channels = np.sum([self.feature_channels[f] for f in self.mask_point_in_features]) 97 | self.mask_point_head = build_point_head( 98 | cfg, ShapeSpec(channels=in_channels, width=1, height=1) 99 | ) 100 | 101 | def _forward_mask(self, features, instances): 102 | """ 103 | Forward logic of the mask prediction branch. 104 | 105 | Args: 106 | features (list[Tensor]): #level input features for mask prediction 107 | instances (list[Instances]): the per-image instances to train/predict masks. 108 | In training, they can be the proposals. 109 | In inference, they can be the predicted boxes. 110 | 111 | Returns: 112 | In training, a dict of losses. 113 | In inference, update `instances` with new fields "pred_masks" and return it. 114 | """ 115 | if not self.mask_on: 116 | return {} if self.training else instances 117 | 118 | if self.training: 119 | proposals, _ = select_foreground_proposals(instances, self.num_classes) 120 | proposal_boxes = [x.proposal_boxes for x in proposals] 121 | mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes) 122 | 123 | losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)} 124 | losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals)) 125 | return losses 126 | else: 127 | pred_boxes = [x.pred_boxes for x in instances] 128 | mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes) 129 | 130 | mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances) 131 | mask_rcnn_inference(mask_logits, instances) 132 | return instances 133 | 134 | def _forward_mask_coarse(self, features, boxes): 135 | """ 136 | Forward logic of the coarse mask head. 137 | """ 138 | point_coords = generate_regular_grid_point_coords( 139 | np.sum(len(x) for x in boxes), self.mask_coarse_side_size, boxes[0].device 140 | ) 141 | mask_coarse_features_list = [ 142 | features[self.in_features.index(k)] for k in self.mask_coarse_in_features 143 | ] 144 | features_scales = [1.0 / self.feature_strides[k] for k in self.mask_coarse_in_features] 145 | # For regular grids of points, this function is equivalent to `len(features_list)' calls 146 | # of `ROIAlign` (with `SAMPLING_RATIO=2`), and concat the results. 147 | mask_features, _ = point_sample_fine_grained_features( 148 | mask_coarse_features_list, features_scales, boxes, point_coords 149 | ) 150 | return self.mask_coarse_head(mask_features) 151 | 152 | def _forward_mask_point(self, features, mask_coarse_logits, instances): 153 | """ 154 | Forward logic of the mask point head. 155 | """ 156 | if not self.mask_point_on: 157 | return {} if self.training else mask_coarse_logits 158 | 159 | mask_features_list = [ 160 | features[self.in_features.index(k)] for k in self.mask_point_in_features 161 | ] 162 | features_scales = [1.0 / self.feature_strides[k] for k in self.mask_point_in_features] 163 | 164 | if self.training: 165 | proposal_boxes = [x.proposal_boxes for x in instances] 166 | gt_classes = cat([x.gt_classes for x in instances]) 167 | with torch.no_grad(): 168 | point_coords = get_uncertain_point_coords_with_randomness( 169 | mask_coarse_logits, 170 | lambda logits: calculate_uncertainty(logits, gt_classes), 171 | self.mask_point_train_num_points, 172 | self.mask_point_oversample_ratio, 173 | self.mask_point_importance_sample_ratio, 174 | ) 175 | 176 | fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features( 177 | mask_features_list, features_scales, proposal_boxes, point_coords 178 | ) 179 | coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False) 180 | point_logits = self.mask_point_head(fine_grained_features, coarse_features) 181 | return { 182 | "loss_mask_point": roi_mask_point_loss( 183 | point_logits, instances, point_coords_wrt_image 184 | ) 185 | } 186 | else: 187 | pred_boxes = [x.pred_boxes for x in instances] 188 | pred_classes = cat([x.pred_classes for x in instances]) 189 | # The subdivision code will fail with the empty list of boxes 190 | if len(pred_classes) == 0: 191 | return mask_coarse_logits 192 | 193 | mask_logits = mask_coarse_logits.clone() 194 | for _ in range(self.mask_point_subdivision_steps): 195 | mask_logits = interpolate( 196 | mask_logits, scale_factor=2, mode="bilinear", align_corners=False 197 | ) 198 | uncertainty_map = calculate_uncertainty(mask_logits, pred_classes) 199 | point_indices, point_coords = get_uncertain_point_coords_on_grid( 200 | uncertainty_map, self.mask_point_subdivision_num_points 201 | ) 202 | fine_grained_features, _ = point_sample_fine_grained_features( 203 | mask_features_list, features_scales, pred_boxes, point_coords 204 | ) 205 | coarse_features = point_sample( 206 | mask_coarse_logits, point_coords, align_corners=False 207 | ) 208 | point_logits = self.mask_point_head(fine_grained_features, coarse_features) 209 | 210 | # put mask point predictions to the right places on the upsampled grid. 211 | R, C, H, W = mask_logits.shape 212 | point_indices = point_indices.unsqueeze(1).expand(-1, C, -1) 213 | mask_logits = ( 214 | mask_logits.reshape(R, C, H * W) 215 | .scatter_(2, point_indices, point_logits) 216 | .view(R, C, H, W) 217 | ) 218 | return mask_logits 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LASOR: Learning Accurate 3D Human Pose and Shape Via Synthetic Occlusion-Aware Data and Neural Mesh Rendering 2 | 3 | Code repository for the paper: 4 | **LASOR: Learning Accurate 3D Human Pose and Shape Via Synthetic Occlusion-Aware Data and Neural Mesh Rendering** 5 | Kaibing Yang; Renshu Gu; Maoyu Wang; Masahiro Toyoura; Gang Xu 6 | 7 | IEEE Transactions on Image Processing 2022 8 | [[paper](https://ieeexplore.ieee.org/document/9709705)] 9 | 10 | ![pipeline](pipeline.png) 11 | 12 | 13 | ### Requirements 14 | - Linux or macOS 15 | - Python ≥ 3.6 16 | 17 | ### Instructions 18 | We recommend using a virtual environment to install relevant dependencies: 19 | ``` 20 | python3 -m venv LASOR 21 | source LASOR/bin/activate 22 | ``` 23 | After creating a virtual environment, first install torch and torchvision: `pip install torch==1.4.0 torchvision==0.5.0` 24 | 25 | Then install detectron2 and its dependencies (cython and pycocotools): 26 | ``` 27 | pip install cython 28 | pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 29 | pip install 'git+https://github.com/akashsengupta1997/detectron2.git' 30 | ``` 31 | 32 | The remaining dependencies can be installed by simply running: `pip install -r requirements.txt`. This will be sufficient for inference. If you wish run model training, you will require the PyTorch port of Neural Mesh Renderer: `pip install neural_renderer_pytorch==1.1.3`. 33 | 34 | ### Additional files 35 | You will need to download the SMPL model. The [neutral model](http://smplify.is.tue.mpg.de) is required for training and running the demo code. If you want to evaluate the model on datasets with gendered SMPL labels (such as 3DPW and SSP-3D), the male and female models are available [here](http://smpl.is.tue.mpg.de). You will need to convert the SMPL model files to be compatible with python3 by removing any chumpy objects. To do so, please follow the instructions [here](https://github.com/vchoutas/smplx/tree/master/tools). 36 | 37 | Download the required additional files [here](https://drive.google.com/drive/folders/1phJix1Fp-AbJgoLImb19eXCWEK7ZnAp_?usp=sharing). Place both the SMPL model and the additional files in the `additional` directory such that they have the following structure and filenames: 38 | 39 | LASOR 40 | ├── additional # Folder with additional files 41 | │ ├── smpl 42 | │ ├── SMPL_NEUTRAL.pkl # Gender-neutral SMPL model 43 | │ ├── cocoplus_regressor.npy # Cocoplus joints regressor 44 | │ ├── J_regressor_h36m.npy # Human3.6M joints regressor 45 | │ ├── J_regressor_extra.npy # Extra joints regressor 46 | │ ├── neutral_smpl_mean_params_6dpose.npz # Mean gender-neutral SMPL parameters 47 | │ ├── smpl_faces.npy # SMPL mesh faces 48 | │ ├── cube_parts.npy 49 | │ └── vertex_texture.npy 50 | └── ... 51 | 52 | ### Model checkpoints 53 | Download pre-trained model checkpoints [here](https://drive.google.com/file/d/1xKM8PRhc3TbZVKig9kULJ01-VVWUkICi/view?usp=sharing) for our SMPL regressor, as well as for PointRend and DensePose (via detectron2) from [here](https://drive.google.com/drive/folders/1QX5NBR6GgmfP206bMHN9ZgK8_1QEKSdg?usp=sharing). Place these files in the `checkpoints` directory. 54 | 55 | ### Training data 56 | We use the training data from [here](https://drive.google.com/drive/folders/1CLOqQBrTos7vhohjFcU2OFkNYmyvQf6t?usp=sharing) provided by STRAPS. Place these files in the `data` directory. 57 | 58 | ## Inference 59 | `run_predict.py` is used to run inference on a given folder of input images. For example, to run inference on the demo folder, do: 60 | ``` 61 | python run_predict.py --input ./demo --checkpoint checkpoints/LASOR.tar --silh_from pointrend 62 | ``` 63 | You can choose between using `--silh_from pointrend` and `--silh_from densepose` to obtain human silhouettes. PointRend provides more accurate silhouettes for easy body poses but DensePose is more robust to challenging body poses. Best results are achieved when the image is roughly cropped and centred around the subject person. 64 | 65 | If you are running inference on a remote machine, you might run into problems getting `pyrender` to run off-screen/headless rendering. If you have EGL installed, uncommenting the appropriate line in `run_predict.py` might work. If not, simply disable pyrender rendering during inference. 66 | 67 | 68 | ## Citations 69 | 70 | If you find this code useful in your research, please cite the following publication: 71 | ``` 72 | @ARTICLE{9709705, 73 | author={Yang, Kaibing and Gu, Renshu and Wang, Maoyu and Toyoura, Masahiro and Xu, Gang}, 74 | journal={IEEE Transactions on Image Processing}, 75 | title={LASOR: Learning Accurate 3D Human Pose and Shape via Synthetic Occlusion-Aware Data and Neural Mesh Rendering}, 76 | year={2022}, 77 | volume={31}, 78 | number={}, 79 | pages={1938-1948}, 80 | doi={10.1109/TIP.2022.3149229} 81 | } 82 | ``` 83 | 84 | 85 | ## Acknowledgments 86 | Code was adapted from/influenced by the following repos - thanks to the authors! 87 | - [STRAPS](https://github.com/akashsengupta1997/STRAPS-3DHumanShapePose) 88 | - [HMR](https://github.com/akanazawa/hmr) 89 | - [SPIN](https://github.com/nkolot/SPIN) 90 | - [VIBE](https://github.com/mkocabas/VIBE) 91 | - [detectron2](https://github.com/facebookresearch/detectron2) 92 | - [NMR PyTorch](https://github.com/daniilidis-group/neural_renderer) 93 | -------------------------------------------------------------------------------- /additional/README.md: -------------------------------------------------------------------------------- 1 | # Additional files directory 2 | Place SMPL model files and additional files specified in the installation instructions here. -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | # Checkpoints directory 2 | Place pre-trained SMPL regressor checkpoint, as well as PointRend and DensePose checkpoints, in this directory (as specified in the installation instructions). -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ------------------------ Paths ------------------------ 2 | # Additional files 3 | SMPL_MODEL_DIR = 'additional/smpl' 4 | SMPL_FACES_PATH = 'additional/smpl_faces.npy' 5 | SMPL_MEAN_PARAMS_PATH = 'additional/neutral_smpl_mean_params_6dpose.npz' 6 | J_REGRESSOR_EXTRA_PATH = 'additional/J_regressor_extra.npy' 7 | COCOPLUS_REGRESSOR_PATH = 'additional/cocoplus_regressor.npy' 8 | H36M_REGRESSOR_PATH = 'additional/J_regressor_h36m.npy' 9 | VERTEX_TEXTURE_PATH = 'additional/vertex_texture.npy' 10 | CUBE_PARTS_PATH = 'additional/cube_parts.npy' 11 | 12 | # ------------------------ Constants ------------------------ 13 | FOCAL_LENGTH = 5000. 14 | REGRESSOR_IMG_WH = 256 15 | 16 | # ------------------------ Joint label conventions ------------------------ 17 | # The SMPL model (im smpl_official.py) returns a large superset of joints. 18 | # Different subsets are used during training - e.g. H36M 3D joints convention and COCO 2D joints convention. 19 | # You may wish to use different subsets in accordance with your training data/inference needs. 20 | 21 | # The joints superset is broken down into: 45 SMPL joints (24 standard + additional fingers/toes/face), 22 | # 9 extra joints, 19 cocoplus joints and 17 H36M joints. 23 | # The 45 SMPL joints are converted to COCO joints with the map below. 24 | # (Not really sure how coco and cocoplus are related.) 25 | 26 | # Indices to get 17 COCO joints and 17 H36M joints from joints superset. 27 | ALL_JOINTS_TO_COCO_MAP = [24, 26, 25, 28, 27, 16, 17, 18, 19, 20, 21, 1, 2, 4, 5, 7, 8] 28 | ALL_JOINTS_TO_H36M_MAP = list(range(73, 90)) 29 | 30 | # Indices to get the 14 LSP joints from the 17 H36M joints 31 | H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] 32 | H36M_TO_J14 = H36M_TO_J17[:14] 33 | 34 | 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iGame-Lab/LASOR/7627d1311402ff803e64fbc2b21ad75dfe5d0451/models/__init__.py -------------------------------------------------------------------------------- /models/ief_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import config 6 | 7 | 8 | class IEFModule(nn.Module): 9 | """ 10 | Iterative error feedback module that regresses SMPL body model parameters (and 11 | weak-perspective camera parameters) given input features. 12 | """ 13 | def __init__(self, fc_layers_neurons, in_features, num_output_params, iterations=3): 14 | super(IEFModule, self).__init__() 15 | 16 | self.fc1 = nn.Linear(in_features + num_output_params, fc_layers_neurons[0]) 17 | self.fc2 = nn.Linear(fc_layers_neurons[0], fc_layers_neurons[1]) 18 | self.fc3 = nn.Linear(fc_layers_neurons[1], num_output_params) 19 | self.relu = nn.ReLU(inplace=True) 20 | torch.nn.init.zeros_(self.fc1.bias) 21 | torch.nn.init.zeros_(self.fc2.bias) 22 | torch.nn.init.zeros_(self.fc3.bias) 23 | 24 | self.ief_layers = nn.Sequential(self.fc1, 25 | self.relu, 26 | self.fc2, 27 | self.relu, 28 | self.fc3) 29 | 30 | self.iterations = iterations 31 | self.initial_params_estimate = self.load_mean_params_6d_pose(config.SMPL_MEAN_PARAMS_PATH) 32 | 33 | def load_mean_params_6d_pose(self, mean_params_path): 34 | mean_smpl = np.load(mean_params_path) 35 | mean_pose = mean_smpl['pose'] 36 | mean_shape = mean_smpl['shape'] 37 | 38 | mean_params = np.zeros(3 + 24*6 + 10) 39 | mean_params[3:] = np.concatenate((mean_pose, mean_shape)) 40 | 41 | # Set initial weak-perspective camera parameters - [s, tx, ty] 42 | mean_params[0] = 0.9 # Initialise scale to 0.9 43 | mean_params[1] = 0.0 44 | mean_params[2] = 0.0 45 | 46 | return torch.from_numpy(mean_params.astype(np.float32)).float() 47 | 48 | def forward(self, img_features): 49 | batch_size = img_features.size(0) 50 | 51 | params_estimate = self.initial_params_estimate.repeat([batch_size, 1]) 52 | params_estimate = params_estimate.to(img_features.device) 53 | 54 | state = torch.cat([img_features, params_estimate], dim=1) 55 | for i in range(self.iterations): 56 | delta = self.ief_layers(state) 57 | params_estimate += delta 58 | state = torch.cat([img_features, params_estimate], dim=1) 59 | 60 | cam_params = params_estimate[:, :3] 61 | pose_params = params_estimate[:, 3:3 + 24*6] 62 | shape_params = params_estimate[:, 3 + 24*6:] 63 | 64 | return cam_params, pose_params, shape_params 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /models/regressor.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.resnet import resnet18, resnet50 4 | from models.ief_module import IEFModule 5 | 6 | 7 | class SingleInputRegressor(nn.Module): 8 | """ 9 | Combined encoder + regressor model that takes proxy representation input (e.g. 10 | silhouettes + 2D joints) and outputs SMPL body model parameters + weak-perspective 11 | camera. 12 | """ 13 | def __init__(self, 14 | resnet_in_channels=1, 15 | resnet_layers=18, 16 | ief_iters=3): 17 | """ 18 | :param resnet_in_channels: 1 if input silhouette/segmentation, 1 + num_joints if 19 | input silhouette/segmentation + joints. 20 | :param resnet_layers: number of layers in ResNet backbone (18 or 50) 21 | :param ief_iters: number of IEF iterations. 22 | """ 23 | super(SingleInputRegressor, self).__init__() 24 | 25 | num_pose_params = 24*6 26 | num_output_params = 3 + num_pose_params + 10 27 | 28 | if resnet_layers == 18: 29 | self.image_encoder = resnet18(in_channels=resnet_in_channels, 30 | pretrained=False) 31 | self.ief_module = IEFModule([512, 512], 32 | 512, 33 | num_output_params, 34 | iterations=ief_iters) 35 | elif resnet_layers == 50: 36 | self.image_encoder = resnet50(in_channels=resnet_in_channels, 37 | pretrained=False) 38 | self.ief_module = IEFModule([1024, 1024], 39 | 2048, 40 | num_output_params, 41 | iterations=ief_iters) 42 | 43 | def forward(self, input): 44 | input_feats = self.image_encoder(input) 45 | cam_params, pose_params, shape_params = self.ief_module(input_feats) 46 | 47 | return cam_params, pose_params, shape_params 48 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from pytorch source code (I've just removed the final FC layer). 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 12 | 'wide_resnet50_2', 'wide_resnet101_2'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | 28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation) 32 | 33 | 34 | def conv1x1(in_planes, out_planes, stride=1): 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | __constants__ = ['downsample'] 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 44 | base_width=64, dilation=1, norm_layer=None): 45 | super(BasicBlock, self).__init__() 46 | if norm_layer is None: 47 | norm_layer = nn.BatchNorm2d 48 | if groups != 1 or base_width != 64: 49 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 50 | if dilation > 1: 51 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 53 | self.conv1 = conv3x3(inplanes, planes, stride) 54 | self.bn1 = norm_layer(planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes) 57 | self.bn2 = norm_layer(planes) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | 71 | if self.downsample is not None: 72 | identity = self.downsample(x) 73 | 74 | out += identity 75 | out = self.relu(out) 76 | 77 | return out 78 | 79 | 80 | class Bottleneck(nn.Module): 81 | expansion = 4 82 | __constants__ = ['downsample'] 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, in_channels, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | # self.fc = nn.Linear(512 * block.expansion, num_classes) - don't need final FC layer 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | 206 | x = self.maxpool(x) 207 | 208 | x = self.layer1(x) 209 | x = self.layer2(x) 210 | x = self.layer3(x) 211 | x = self.layer4(x) 212 | 213 | x = self.avgpool(x) 214 | x = torch.flatten(x, 1) 215 | 216 | return x 217 | 218 | 219 | def _resnet(arch, block, layers, in_channels, pretrained, progress, **kwargs): 220 | model = ResNet(block, layers, in_channels, **kwargs) 221 | if pretrained: 222 | state_dict = model_zoo.load_url(model_urls[arch], 223 | progress=progress) 224 | model.load_state_dict(state_dict, strict=False) # not using final FC layer 225 | return model 226 | 227 | 228 | def resnet18(in_channels, pretrained=False, progress=True, **kwargs): 229 | r"""ResNet-18 model from 230 | `"Deep Residual Learning for Image Recognition" `_ 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | progress (bool): If True, displays a progress bar of the download to stderr 234 | """ 235 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], in_channels, pretrained, progress, 236 | **kwargs) 237 | 238 | 239 | def resnet34(in_channels, pretrained=False, progress=True, **kwargs): 240 | r"""ResNet-34 model from 241 | `"Deep Residual Learning for Image Recognition" `_ 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | progress (bool): If True, displays a progress bar of the download to stderr 245 | """ 246 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], in_channels, pretrained, progress, 247 | **kwargs) 248 | 249 | 250 | def resnet50(in_channels, pretrained=False, progress=True, **kwargs): 251 | r"""ResNet-50 model from 252 | `"Deep Residual Learning for Image Recognition" `_ 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | progress (bool): If True, displays a progress bar of the download to stderr 256 | """ 257 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], in_channels, pretrained, progress, 258 | **kwargs) 259 | 260 | 261 | def resnet101(in_channels, pretrained=False, progress=True, **kwargs): 262 | r"""ResNet-101 model from 263 | `"Deep Residual Learning for Image Recognition" `_ 264 | Args: 265 | pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | progress (bool): If True, displays a progress bar of the download to stderr 267 | """ 268 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], in_channels, pretrained, progress, 269 | **kwargs) 270 | 271 | 272 | def resnet152(in_channels, pretrained=False, progress=True, **kwargs): 273 | r"""ResNet-152 model from 274 | `"Deep Residual Learning for Image Recognition" `_ 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], in_channels, pretrained, progress, 280 | **kwargs) 281 | 282 | 283 | def resnext50_32x4d(in_channels, pretrained=False, progress=True, **kwargs): 284 | r"""ResNeXt-50 32x4d model from 285 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | kwargs['groups'] = 32 291 | kwargs['width_per_group'] = 4 292 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], in_channels, 293 | pretrained, progress, **kwargs) 294 | 295 | 296 | def resnext101_32x8d(in_channels, pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-101 32x8d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | Args: 300 | pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | progress (bool): If True, displays a progress bar of the download to stderr 302 | """ 303 | kwargs['groups'] = 32 304 | kwargs['width_per_group'] = 8 305 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], in_channels, 306 | pretrained, progress, **kwargs) 307 | 308 | 309 | def wide_resnet50_2(in_channels, pretrained=False, progress=True, **kwargs): 310 | r"""Wide ResNet-50-2 model from 311 | `"Wide Residual Networks" `_ 312 | The model is the same as ResNet except for the bottleneck number of channels 313 | which is twice larger in every block. The number of channels in outer 1x1 314 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 315 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on ImageNet 318 | progress (bool): If True, displays a progress bar of the download to stderr 319 | """ 320 | kwargs['width_per_group'] = 64 * 2 321 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 322 | pretrained, progress, **kwargs) 323 | 324 | 325 | def wide_resnet101_2(in_channels, pretrained=False, progress=True, **kwargs): 326 | r"""Wide ResNet-101-2 model from 327 | `"Wide Residual Networks" `_ 328 | The model is the same as ResNet except for the bottleneck number of channels 329 | which is twice larger in every block. The number of channels in outer 1x1 330 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 331 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['width_per_group'] = 64 * 2 337 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 338 | pretrained, progress, **kwargs) 339 | -------------------------------------------------------------------------------- /models/smpl_official.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from smplx import SMPL as _SMPL 4 | from smplx.body_models import ModelOutput 5 | from smplx.lbs import vertices2joints 6 | 7 | import config 8 | 9 | 10 | class SMPL(_SMPL): 11 | """ 12 | Extension of the official SMPL (from the smplx python package) implementation to 13 | support more joints. 14 | """ 15 | def __init__(self, *args, **kwargs): 16 | super(SMPL, self).__init__(*args, **kwargs) 17 | J_regressor_extra = np.load(config.J_REGRESSOR_EXTRA_PATH) 18 | J_regressor_cocoplus = np.load(config.COCOPLUS_REGRESSOR_PATH) 19 | J_regressor_h36m = np.load(config.H36M_REGRESSOR_PATH) 20 | self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, 21 | dtype=torch.float32)) 22 | self.register_buffer('J_regressor_cocoplus', torch.tensor(J_regressor_cocoplus, 23 | dtype=torch.float32)) 24 | self.register_buffer('J_regressor_h36m', torch.tensor(J_regressor_h36m, 25 | dtype=torch.float32)) 26 | 27 | def forward(self, *args, **kwargs): 28 | kwargs['get_skin'] = True 29 | smpl_output = super(SMPL, self).forward(*args, **kwargs) 30 | extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices) 31 | cocoplus_joints = vertices2joints(self.J_regressor_cocoplus, smpl_output.vertices) 32 | h36m_joints = vertices2joints(self.J_regressor_h36m, smpl_output.vertices) 33 | all_joints = torch.cat([smpl_output.joints, extra_joints, cocoplus_joints, 34 | h36m_joints], dim=1) 35 | output = ModelOutput(vertices=smpl_output.vertices, 36 | global_orient=smpl_output.global_orient, 37 | body_pose=smpl_output.body_pose, 38 | joints=all_joints, 39 | betas=smpl_output.betas, 40 | full_pose=smpl_output.full_pose) 41 | return output 42 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iGame-Lab/LASOR/7627d1311402ff803e64fbc2b21ad75dfe5d0451/pipeline.png -------------------------------------------------------------------------------- /predict/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iGame-Lab/LASOR/7627d1311402ff803e64fbc2b21ad75dfe5d0451/predict/__init__.py -------------------------------------------------------------------------------- /predict/predict_3D.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from smplx.lbs import batch_rodrigues 6 | 7 | from detectron2.config import get_cfg 8 | from detectron2 import model_zoo 9 | from detectron2.engine import DefaultPredictor 10 | 11 | from PointRend.point_rend import add_pointrend_config 12 | from DensePose.densepose import add_densepose_config 13 | 14 | import config 15 | 16 | from predict.predict_joints2D import predict_joints2D 17 | from predict.predict_silhouette_pointrend import predict_silhouette_pointrend 18 | from predict.predict_densepose import predict_densepose 19 | 20 | from models.smpl_official import SMPL 21 | from renderers.weak_perspective_pyrender_renderer import Renderer 22 | 23 | from utils.image_utils import pad_to_square, crop_and_resize_silhouette_joints 24 | from utils.cam_utils import orthographic_project_torch 25 | from utils.joints2d_utils import undo_keypoint_normalisation 26 | from utils.label_conversions import convert_multiclass_to_binary_labels, \ 27 | convert_2Djoints_to_gaussian_heatmaps 28 | from utils.rigid_transform_utils import rot6d_to_rotmat 29 | 30 | import matplotlib 31 | matplotlib.use('agg') 32 | import matplotlib.pyplot as plt 33 | 34 | 35 | def setup_detectron2_predictors(silhouettes_from='densepose'): 36 | # Keypoint-RCNN 37 | kprcnn_config_file = "COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml" 38 | kprcnn_cfg = get_cfg() 39 | kprcnn_cfg.merge_from_file(model_zoo.get_config_file(kprcnn_config_file)) 40 | kprcnn_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set threshold for this model 41 | kprcnn_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(kprcnn_config_file) 42 | kprcnn_cfg.freeze() 43 | joints2D_predictor = DefaultPredictor(kprcnn_cfg) 44 | 45 | if silhouettes_from == 'pointrend': 46 | # PointRend-RCNN-R50-FPN 47 | pointrend_config_file = "PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" 48 | pointrend_cfg = get_cfg() 49 | add_pointrend_config(pointrend_cfg) 50 | pointrend_cfg.merge_from_file(pointrend_config_file) 51 | pointrend_cfg.MODEL.WEIGHTS = "checkpoints/pointrend_rcnn_R_50_fpn.pkl" 52 | pointrend_cfg.freeze() 53 | silhouette_predictor = DefaultPredictor(pointrend_cfg) 54 | elif silhouettes_from == 'densepose': 55 | # DensePose-RCNN-R101-FPN 56 | densepose_config_file = "DensePose/configs/densepose_rcnn_R_101_FPN_s1x.yaml" 57 | densepose_cfg = get_cfg() 58 | add_densepose_config(densepose_cfg) 59 | densepose_cfg.merge_from_file(densepose_config_file) 60 | densepose_cfg.MODEL.WEIGHTS = "checkpoints/densepose_rcnn_R_101_fpn_s1x.pkl" 61 | densepose_cfg.freeze() 62 | silhouette_predictor = DefaultPredictor(densepose_cfg) 63 | 64 | return joints2D_predictor, silhouette_predictor 65 | 66 | 67 | def create_proxy_representation(silhouette, 68 | joints2D, 69 | out_wh): 70 | 71 | heatmaps = convert_2Djoints_to_gaussian_heatmaps(joints2D.astype(np.int16), 72 | out_wh) 73 | proxy_rep = np.concatenate([silhouette[:, :, None], heatmaps], axis=-1) 74 | proxy_rep = np.transpose(proxy_rep, [2, 0, 1]) # (C, out_wh, out_WH) 75 | 76 | return proxy_rep 77 | 78 | 79 | def predict_3D(input, 80 | regressor, 81 | device, 82 | silhouettes_from='densepose', 83 | proxy_rep_input_wh=512, 84 | save_proxy_vis=True, 85 | render_vis=True): 86 | 87 | # Set-up proxy representation predictors. 88 | joints2D_predictor, silhouette_predictor = setup_detectron2_predictors(silhouettes_from=silhouettes_from) 89 | 90 | # Set-up SMPL model. 91 | smpl = SMPL(config.SMPL_MODEL_DIR, batch_size=1).to(device) 92 | 93 | if render_vis: 94 | # Set-up renderer for visualisation. 95 | wp_renderer = Renderer(resolution=(proxy_rep_input_wh, proxy_rep_input_wh)) 96 | 97 | if os.path.isdir(input): 98 | image_fnames = [f for f in sorted(os.listdir(input)) if f.endswith('.png') or 99 | f.endswith('.jpg')] 100 | for fname in image_fnames: 101 | print("Predicting on:", fname) 102 | image = cv2.imread(os.path.join(input, fname)) 103 | # Pre-process for 2D detectors 104 | image = pad_to_square(image) 105 | image = cv2.resize(image, (proxy_rep_input_wh, proxy_rep_input_wh), 106 | interpolation=cv2.INTER_LINEAR) 107 | # Predict 2D 108 | joints2D, joints2D_vis = predict_joints2D(image, joints2D_predictor) 109 | if silhouettes_from == 'pointrend': 110 | silhouette, silhouette_vis = predict_silhouette_pointrend(image, 111 | silhouette_predictor) 112 | elif silhouettes_from == 'densepose': 113 | silhouette, silhouette_vis = predict_densepose(image, silhouette_predictor) 114 | silhouette = convert_multiclass_to_binary_labels(silhouette) 115 | # Crop around silhouette 116 | silhouette, joints2D, image = crop_and_resize_silhouette_joints(silhouette, 117 | joints2D, 118 | out_wh=config.REGRESSOR_IMG_WH, 119 | image=image, 120 | image_out_wh=proxy_rep_input_wh, 121 | bbox_scale_factor=1.2) 122 | # Create proxy representation 123 | proxy_rep = create_proxy_representation(silhouette, joints2D, 124 | out_wh=config.REGRESSOR_IMG_WH) 125 | proxy_rep = proxy_rep[None, :, :, :] # add batch dimension 126 | proxy_rep = torch.from_numpy(proxy_rep).float().to(device) 127 | 128 | # Predict 3D 129 | regressor.eval() 130 | with torch.no_grad(): 131 | pred_cam_wp, pred_pose, pred_shape = regressor(proxy_rep) 132 | # Convert pred pose to rotation matrices 133 | if pred_pose.shape[-1] == 24 * 3: 134 | pred_pose_rotmats = batch_rodrigues(pred_pose.contiguous().view(-1, 3)) 135 | pred_pose_rotmats = pred_pose_rotmats.view(-1, 24, 3, 3) 136 | elif pred_pose.shape[-1] == 24 * 6: 137 | pred_pose_rotmats = rot6d_to_rotmat(pred_pose.contiguous()).view(-1, 24, 3, 3) 138 | 139 | pred_smpl_output = smpl(body_pose=pred_pose_rotmats[:, 1:], 140 | global_orient=pred_pose_rotmats[:, 0].unsqueeze(1), 141 | betas=pred_shape, 142 | pose2rot=False) 143 | pred_vertices = pred_smpl_output.vertices 144 | pred_vertices2d = orthographic_project_torch(pred_vertices, pred_cam_wp) 145 | pred_vertices2d = undo_keypoint_normalisation(pred_vertices2d, 146 | proxy_rep_input_wh) 147 | 148 | pred_reposed_smpl_output = smpl(betas=pred_shape) 149 | pred_reposed_vertices = pred_reposed_smpl_output.vertices 150 | 151 | # Numpy-fying 152 | pred_vertices = pred_vertices.cpu().detach().numpy()[0] 153 | pred_vertices2d = pred_vertices2d.cpu().detach().numpy()[0] 154 | pred_reposed_vertices = pred_reposed_vertices.cpu().detach().numpy()[0] 155 | pred_cam_wp = pred_cam_wp.cpu().detach().numpy()[0] 156 | 157 | if not os.path.isdir(os.path.join(input, 'verts_vis')): 158 | os.makedirs(os.path.join(input, 'verts_vis')) 159 | plt.figure() 160 | plt.imshow(image[:,:,::-1]) 161 | plt.scatter(pred_vertices2d[:, 0], pred_vertices2d[:, 1], s=0.3) 162 | plt.gca().set_axis_off() 163 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 164 | plt.margins(0, 0) 165 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 166 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 167 | plt.savefig(os.path.join(input, 'verts_vis', 'verts_'+fname)) 168 | 169 | if render_vis: 170 | rend_img = wp_renderer.render(verts=pred_vertices, cam=pred_cam_wp, img=image) 171 | rend_reposed_img = wp_renderer.render(verts=pred_reposed_vertices, 172 | cam=np.array([0.8, 0., -0.2]), 173 | angle=180, 174 | axis=[1, 0, 0]) 175 | if not os.path.isdir(os.path.join(input, 'rend_vis')): 176 | os.makedirs(os.path.join(input, 'rend_vis')) 177 | cv2.imwrite(os.path.join(input, 'rend_vis', 'rend_'+fname), rend_img) 178 | cv2.imwrite(os.path.join(input, 'rend_vis', 'reposed_'+fname), rend_reposed_img) 179 | if save_proxy_vis: 180 | if not os.path.isdir(os.path.join(input, 'proxy_vis')): 181 | os.makedirs(os.path.join(input, 'proxy_vis')) 182 | cv2.imwrite(os.path.join(input, 'proxy_vis', 'silhouette_'+fname), silhouette_vis) 183 | cv2.imwrite(os.path.join(input, 'proxy_vis', 'joints2D_'+fname), joints2D_vis) 184 | -------------------------------------------------------------------------------- /predict/predict_densepose.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cv2 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('TkAgg') 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | from detectron2.config import get_cfg 10 | from detectron2.engine.defaults import DefaultPredictor 11 | from detectron2.structures.boxes import BoxMode 12 | 13 | from DensePose.densepose import add_densepose_config 14 | from DensePose.densepose.structures import DensePoseResult 15 | 16 | 17 | 18 | def apply_colormap(image, vmin=None, vmax=None, cmap='viridis', cmap_seed=1): 19 | """ 20 | Apply a matplotlib colormap to an image. 21 | 22 | This method will preserve the exact image size. `cmap` can be either a 23 | matplotlib colormap name, a discrete number, or a colormap instance. If it 24 | is a number, a discrete colormap will be generated based on the HSV 25 | colorspace. The permutation of colors is random and can be controlled with 26 | the `cmap_seed`. The state of the RNG is preserved. 27 | """ 28 | image = image.astype("float64") # Returns a copy. 29 | # Normalization. 30 | if vmin is not None: 31 | imin = float(vmin) 32 | image = np.clip(image, vmin, sys.float_info.max) 33 | else: 34 | imin = np.min(image) 35 | if vmax is not None: 36 | imax = float(vmax) 37 | image = np.clip(image, -sys.float_info.max, vmax) 38 | else: 39 | imax = np.max(image) 40 | image -= imin 41 | image /= (imax - imin) 42 | # Visualization. 43 | cmap_ = plt.get_cmap(cmap) 44 | vis = cmap_(image, bytes=True) 45 | return vis 46 | 47 | 48 | def get_largest_centred_bounding_box(bboxes, orig_w, orig_h): 49 | """ 50 | Given an array of bounding boxes, return the index of the largest + roughly-centred 51 | bounding box. 52 | :param bboxes: (N, 4) array of [x1 y1 x2 y2] bounding boxes 53 | :param orig_w: original image width 54 | :param orig_h: original image height 55 | """ 56 | bboxes_area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) 57 | sorted_bbox_indices = np.argsort(bboxes_area)[::-1] # Indices of bboxes sorted by area. 58 | bbox_found = False 59 | i = 0 60 | while not bbox_found and i < sorted_bbox_indices.shape[0]: 61 | bbox_index = sorted_bbox_indices[i] 62 | bbox = bboxes[bbox_index] 63 | bbox_centre = ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0) # Centre (width, height) 64 | if abs(bbox_centre[0] - orig_w / 2.0) < orig_w/5.0 and abs(bbox_centre[1] - orig_h / 2.0) < orig_w/5.0: 65 | largest_centred_bbox_index = bbox_index 66 | bbox_found = True 67 | i += 1 68 | 69 | # If can't find bbox sufficiently close to centre, just use biggest bbox as prediction 70 | if not bbox_found: 71 | largest_centred_bbox_index = sorted_bbox_indices[0] 72 | 73 | return largest_centred_bbox_index 74 | 75 | 76 | def predict_densepose(input_image, predictor): 77 | """ 78 | Predicts densepose output given a cropped and centred input image. 79 | :param input_images: (wh, wh) 80 | :param predictor: instance of detectron2 DefaultPredictor class, created with the 81 | appropriate config file. 82 | """ 83 | orig_h, orig_w = input_image.shape[:2] 84 | outputs = predictor(input_image)["instances"] 85 | bboxes = outputs.pred_boxes.tensor.cpu() # Multiple densepose predictions if there are multiple people in the image 86 | bboxes_XYWH = BoxMode.convert(bboxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) 87 | bboxes = bboxes.cpu().detach().numpy() 88 | largest_centred_bbox_index = get_largest_centred_bounding_box(bboxes, orig_w, orig_h) # Picks out centred person that is largest in the image. 89 | 90 | pred_densepose = outputs.pred_densepose.to_result(bboxes_XYWH) 91 | iuv_arr = DensePoseResult.decode_png_data(*pred_densepose.results[largest_centred_bbox_index]) 92 | 93 | # Round bbox to int 94 | largest_bbox = bboxes[largest_centred_bbox_index] 95 | w1 = largest_bbox[0] 96 | w2 = largest_bbox[0] + iuv_arr.shape[2] 97 | h1 = largest_bbox[1] 98 | h2 = largest_bbox[1] + iuv_arr.shape[1] 99 | 100 | I_image = np.zeros((orig_h, orig_w)) 101 | I_image[int(h1):int(h2), int(w1):int(w2)] = iuv_arr[0, :, :] 102 | # U_image = np.zeros((orig_h, orig_w)) 103 | # U_image[int(h1):int(h2), int(w1):int(w2)] = iuv_arr[1, :, :] 104 | # V_image = np.zeros((orig_h, orig_w)) 105 | # V_image[int(h1):int(h2), int(w1):int(w2)] = iuv_arr[2, :, :] 106 | 107 | vis_I_image = apply_colormap(I_image, vmin=0, vmax=24) 108 | vis_I_image = vis_I_image[:, :, :3] 109 | vis_I_image[I_image == 0, :] = np.zeros(3, dtype=np.uint8) 110 | overlay_vis = cv2.addWeighted(input_image, 111 | 0.6, 112 | vis_I_image, 113 | 0.4, 114 | gamma=0) 115 | 116 | return I_image, overlay_vis 117 | 118 | -------------------------------------------------------------------------------- /predict/predict_joints2D.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def get_largest_centred_bounding_box(bboxes, orig_w, orig_h): 6 | """ 7 | Given an array of bounding boxes, return the index of the largest + roughly-centred 8 | bounding box. 9 | :param bboxes: (N, 4) array of [x1 y1 x2 y2] bounding boxes 10 | :param orig_w: original image width 11 | :param orig_h: original image height 12 | """ 13 | bboxes_area = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) 14 | sorted_bbox_indices = np.argsort(bboxes_area)[::-1] # Indices of bboxes sorted by area. 15 | bbox_found = False 16 | i = 0 17 | while not bbox_found and i < sorted_bbox_indices.shape[0]: 18 | bbox_index = sorted_bbox_indices[i] 19 | bbox = bboxes[bbox_index] 20 | bbox_centre = ((bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0) # Centre (width, height) 21 | if abs(bbox_centre[0] - orig_w / 2.0) < orig_w/6.0 and abs(bbox_centre[1] - orig_h / 2.0) < orig_w/6.0: 22 | largest_centred_bbox_index = bbox_index 23 | bbox_found = True 24 | i += 1 25 | 26 | # If can't find bbox sufficiently close to centre, just use biggest bbox as prediction 27 | if not bbox_found: 28 | largest_centred_bbox_index = sorted_bbox_indices[0] 29 | 30 | return largest_centred_bbox_index 31 | 32 | 33 | def predict_joints2D(input_image, predictor): 34 | """ 35 | Predicts 2D joints (17 2D joints in COCO convention along with prediction confidence) 36 | given a cropped and centred input image. 37 | :param input_images: (wh, wh) 38 | :param predictor: instance of detectron2 DefaultPredictor class, created with the 39 | appropriate config file. 40 | """ 41 | image = np.copy(input_image) 42 | orig_h, orig_w = image.shape[:2] 43 | outputs = predictor(image) # Multiple bboxes + keypoints predictions if there are multiple people in the image 44 | bboxes = outputs['instances'].pred_boxes.tensor.cpu().numpy() 45 | if bboxes.shape[0] == 0: # Can't find any people in image 46 | keypoints = np.zeros((17, 3)) 47 | else: 48 | largest_centred_bbox_index = get_largest_centred_bounding_box(bboxes, orig_w, orig_h) # Picks out centred person that is largest in the image. 49 | keypoints = outputs['instances'].pred_keypoints.cpu().numpy() 50 | keypoints = keypoints[largest_centred_bbox_index] 51 | 52 | for j in range(keypoints.shape[0]): 53 | cv2.circle(image, (keypoints[j, 0], keypoints[j, 1]), 5, (0, 255, 0), -1) 54 | font = cv2.FONT_HERSHEY_SIMPLEX 55 | fontScale = 0.5 56 | fontColor = (0, 0, 255) 57 | cv2.putText(image, str(j), (keypoints[j, 0], keypoints[j, 1]), 58 | font, fontScale, fontColor, lineType=2) 59 | 60 | return keypoints, image 61 | 62 | -------------------------------------------------------------------------------- /predict/predict_silhouette_pointrend.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def get_largest_centred_mask(human_masks, orig_w, orig_h): 8 | """ 9 | Given an array of human segmentation masks, return the index of the largest + 10 | roughly-centred mask. 11 | :param human_masks: (N, img_wh, img_wh) human segmentation masks. 12 | :param orig_w: original image width 13 | :param orig_h: original image height 14 | """ 15 | mask_areas = np.sum(human_masks, axis=(1, 2)) 16 | sorted_mask_indices = np.argsort(mask_areas)[::-1] # Indices of masks sorted by area. 17 | mask_found = False 18 | i = 0 19 | while not mask_found and i < sorted_mask_indices.shape[0]: 20 | mask_index = sorted_mask_indices[i] 21 | mask = human_masks[mask_index, :, :] 22 | mask_pixels = np.argwhere(mask != 0) 23 | bbox_corners = np.amin(mask_pixels, axis=0), np.amax(mask_pixels, axis=0) # (row_min, col_min), (row_max, col_max) 24 | bbox_centre = ((bbox_corners[0][0] + bbox_corners[1][0]) / 2.0, 25 | (bbox_corners[0][1] + bbox_corners[1][1]) / 2.0) # Centre in rows, columns (i.e. height, width) 26 | 27 | if abs(bbox_centre[0] - orig_h / 2.0) < orig_w/4.0 and abs(bbox_centre[1] - orig_w / 2.0) < orig_w/6.0: 28 | largest_centred_mask_index = mask_index 29 | mask_found = True 30 | i += 1 31 | 32 | # If can't find mask sufficiently close to centre, just use biggest mask as prediction 33 | if not mask_found: 34 | largest_centred_mask_index = sorted_mask_indices[0] 35 | 36 | return largest_centred_mask_index 37 | 38 | 39 | def predict_silhouette_pointrend(input_image, predictor): 40 | """ 41 | Predicts human silhouette (binary segmetnation) given a cropped and centred input image. 42 | :param input_images: (wh, wh) 43 | :param predictor: instance of detectron2 DefaultPredictor class, created with the 44 | appropriate config file. 45 | """ 46 | orig_h, orig_w = input_image.shape[:2] 47 | outputs = predictor(input_image)['instances'] # Multiple silhouette predictions if there are multiple people in the image 48 | classes = outputs.pred_classes 49 | masks = outputs.pred_masks 50 | human_masks = masks[classes == 0] 51 | human_masks = human_masks.cpu().detach().numpy() 52 | largest_centred_mask_index = get_largest_centred_mask(human_masks, orig_w, orig_h) # Picks out centred person that is largest in the image. 53 | human_mask = human_masks[largest_centred_mask_index, :, :].astype(np.uint8) 54 | overlay_vis = cv2.addWeighted(input_image, 1.0, 55 | 255 * np.tile(human_mask[:, :, None], [1, 1, 3]), 56 | 0.5, gamma=0) 57 | 58 | return human_mask, overlay_vis 59 | -------------------------------------------------------------------------------- /renderers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iGame-Lab/LASOR/7627d1311402ff803e64fbc2b21ad75dfe5d0451/renderers/__init__.py -------------------------------------------------------------------------------- /renderers/nmr_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import neural_renderer as nr 6 | import config 7 | 8 | 9 | class NMRRenderer(nn.Module): 10 | """ 11 | Neural mesh renderer module - renders 6 body-part segmentations or RGB images. 12 | Code adapted from https://github.com/nkolot/SPIN/blob/master/utils/part_utils.py 13 | 6 body-part convention: 14 | 0 - background 15 | 1 - left arm 16 | 2 - right arm 17 | 3 - head 18 | 4 - left leg 19 | 5 - right leg 20 | 6 - torso 21 | """ 22 | def __init__(self, 23 | batch_size, 24 | cam_K, 25 | cam_R, 26 | img_wh=256, 27 | rend_parts_seg=False): 28 | """ 29 | :param batch_size 30 | :param cam_K: (bs, 3, 3) camera intrinsics matrix 31 | :param cam_R: (bs, 3, 3) camera rotation matrix (usually identity). 32 | :param img_wh: output render width/height 33 | :param rend_parts_seg: if True, render 6 part segmentation, else render RGB. 34 | """ 35 | super(NMRRenderer, self).__init__() 36 | 37 | faces = np.load(config.SMPL_FACES_PATH) 38 | faces = torch.from_numpy(faces.astype(np.int32)) 39 | faces = faces[None, :].expand(batch_size, -1, -1) 40 | self.register_buffer('faces', faces) 41 | 42 | if rend_parts_seg: 43 | textures = np.load(config.VERTEX_TEXTURE_PATH) 44 | textures = torch.from_numpy(textures).float() 45 | textures = textures.expand(batch_size, -1, -1, -1, -1, -1) 46 | self.register_buffer('textures', textures) 47 | 48 | cube_parts = np.load(config.CUBE_PARTS_PATH) 49 | cube_parts = torch.from_numpy(cube_parts).float() 50 | self.register_buffer('cube_parts', cube_parts) 51 | else: 52 | texture_size = 2 53 | textures = torch.ones(batch_size, self.faces.shape[1], texture_size, texture_size, 54 | texture_size, 3, dtype=torch.float32) 55 | self.register_buffer('textures', textures) 56 | 57 | # Setup renderer 58 | if cam_K.ndim != 3: 59 | print("Expanding cam_K and cam_R by batch size.") 60 | cam_K = cam_K[None, :, :].expand(batch_size, -1, -1) 61 | cam_R = cam_R[None, :, :].expand(batch_size, -1, -1) 62 | renderer = nr.Renderer(camera_mode='projection', 63 | K=cam_K, 64 | R=cam_R, 65 | image_size=img_wh, 66 | orig_size=img_wh, 67 | light_direction=[0, 0, 1]) 68 | if rend_parts_seg: 69 | renderer.light_intensity_ambient = 1 70 | renderer.anti_aliasing = False 71 | renderer.light_intensity_directional = 0 72 | self.renderer = renderer 73 | 74 | self.rend_parts_seg = rend_parts_seg 75 | 76 | def forward(self, vertices, cam_ts): 77 | """ 78 | :param vertices: (B, N, 3) 79 | :param cam_ts: (B, 1, 3) 80 | """ 81 | if cam_ts.ndim == 2: 82 | cam_ts = cam_ts.unsqueeze(1) 83 | if self.rend_parts_seg: 84 | parts, _, mask = self.renderer(vertices, self.faces, self.textures, 85 | t=cam_ts) 86 | parts = self.get_parts(parts, mask) 87 | return parts 88 | else: 89 | rend_image, depth, _ = self.renderer(vertices, self.faces, self.textures, 90 | t=cam_ts) 91 | return rend_image, depth 92 | 93 | def get_parts(self, parts, mask): 94 | """Process renderer part image to get body part indices.""" 95 | bn,c,h,w = parts.shape 96 | mask = mask.view(-1,1) 97 | parts_index = torch.floor(100*parts.permute(0,2,3,1).contiguous().view(-1,3)).long() 98 | parts = self.cube_parts[parts_index[:,0], parts_index[:,1], parts_index[:,2], None] 99 | parts *= mask 100 | parts = parts.view(bn,h,w).long() 101 | return parts -------------------------------------------------------------------------------- /renderers/weak_perspective_pyrender_renderer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Code from https://github.com/mkocabas/VIBE 3 | 4 | import math 5 | import trimesh 6 | import pyrender 7 | import numpy as np 8 | from pyrender.constants import RenderFlags 9 | 10 | import config 11 | 12 | 13 | class WeakPerspectiveCamera(pyrender.Camera): 14 | def __init__(self, 15 | scale, 16 | translation, 17 | znear=pyrender.camera.DEFAULT_Z_NEAR, 18 | zfar=None, 19 | name=None): 20 | super(WeakPerspectiveCamera, self).__init__( 21 | znear=znear, 22 | zfar=zfar, 23 | name=name, 24 | ) 25 | self.scale = scale 26 | self.translation = translation 27 | 28 | def get_projection_matrix(self, width=None, height=None): 29 | P = np.eye(4) 30 | P[0, 0] = self.scale[0] 31 | P[1, 1] = self.scale[1] 32 | P[0, 3] = self.translation[0] * self.scale[0] 33 | P[1, 3] = -self.translation[1] * self.scale[1] 34 | P[2, 2] = -1 35 | return P 36 | 37 | 38 | class Renderer(): 39 | def __init__(self, resolution=(256, 256)): 40 | self.resolution = resolution 41 | 42 | self.faces = np.load(config.SMPL_FACES_PATH) 43 | self.renderer = pyrender.OffscreenRenderer( 44 | viewport_width=self.resolution[0], 45 | viewport_height=self.resolution[1], 46 | point_size=1.0 47 | ) 48 | 49 | # set the scene 50 | self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) 51 | 52 | light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1.) 53 | 54 | light_pose = np.eye(4) 55 | light_pose[:3, 3] = [0, -1, 1] 56 | self.scene.add(light, pose=light_pose) 57 | 58 | light_pose[:3, 3] = [0, 1, 1] 59 | self.scene.add(light, pose=light_pose) 60 | 61 | # light_pose[:3, 3] = [1, 1, 2] 62 | # self.scene.add(light, pose=light_pose) 63 | 64 | def render(self, verts, cam, img=None, angle=None, axis=None, mesh_filename=None, color=[0.8, 0.3, 0.3], 65 | return_mask=False): 66 | 67 | mesh = trimesh.Trimesh(vertices=verts, faces=self.faces) 68 | 69 | Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0]) 70 | mesh.apply_transform(Rx) 71 | 72 | if mesh_filename is not None: 73 | mesh.export(mesh_filename) 74 | 75 | if angle and axis: 76 | R = trimesh.transformations.rotation_matrix(math.radians(angle), axis) 77 | mesh.apply_transform(R) 78 | 79 | if cam.shape[-1] == 4: 80 | sx, sy, tx, ty = cam 81 | elif cam.shape[-1] == 3: 82 | s, tx, ty = cam 83 | sx = sy = s 84 | 85 | camera = WeakPerspectiveCamera( 86 | scale=[sx, sy], 87 | translation=[tx, ty], 88 | zfar=1000. 89 | ) 90 | 91 | material = pyrender.MetallicRoughnessMaterial( 92 | metallicFactor=0.2, 93 | alphaMode='OPAQUE', 94 | baseColorFactor=(color[0], color[1], color[2], 1.0) 95 | ) 96 | 97 | mesh = pyrender.Mesh.from_trimesh(mesh, material=material) 98 | 99 | mesh_node = self.scene.add(mesh, 'mesh') 100 | 101 | camera_pose = np.eye(4) 102 | cam_node = self.scene.add(camera, pose=camera_pose) 103 | 104 | rgb, rend_depth = self.renderer.render(self.scene, flags=RenderFlags.RGBA) 105 | valid_mask = (rend_depth > 0) 106 | if return_mask: 107 | return valid_mask 108 | else: 109 | if img is None: 110 | img = np.zeros((self.resolution[0], self.resolution[1], 3)) 111 | valid_mask = valid_mask[:, :, None] 112 | output_img = rgb[:, :, :-1] * valid_mask + (1 - valid_mask) * img 113 | image = output_img.astype(np.uint8) 114 | 115 | self.scene.remove_node(mesh_node) 116 | self.scene.remove_node(cam_node) 117 | 118 | return image 119 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | matplotlib 3 | pyrender 4 | tqdm 5 | trimesh 6 | smplx 7 | scikit-image==0.16.2 8 | -------------------------------------------------------------------------------- /run_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | 5 | from models.regressor import SingleInputRegressor 6 | from predict.predict_3D import predict_3D 7 | 8 | 9 | def main(input_path, checkpoint_path, device, silhouettes_from): 10 | regressor = SingleInputRegressor(resnet_in_channels=18, 11 | resnet_layers=50, 12 | ief_iters=3) 13 | 14 | print("Regressor loaded. Weights from:", checkpoint_path) 15 | regressor.to(device) 16 | checkpoint = torch.load(checkpoint_path, map_location=device) 17 | regressor.load_state_dict(checkpoint['best_model_state_dict']) 18 | 19 | predict_3D(input_path, regressor, device, silhouettes_from=silhouettes_from, 20 | save_proxy_vis=True, render_vis=True) 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--input', type=str, help='Path to input image/folder of images.') 26 | parser.add_argument('--checkpoint', type=str, help='Path to model checkpoint') 27 | parser.add_argument('--silh_from', choices=['densepose', 'pointrend']) 28 | parser.add_argument('--gpu', default='0') 29 | args = parser.parse_args() 30 | 31 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 33 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 34 | 35 | # Regarding body mesh visualisation using pyrender: 36 | # If you are running this script on a remote machine via ssh, you might need to use EGL 37 | # to create an OpenGL context. If EGL is installed on the remote machine, uncommenting the 38 | # following line should work. 39 | # os.environ['PYOPENGL_PLATFORM'] = 'egl' 40 | # If this still doesn't work, just disable rendering visualisation by setting render_vis 41 | # argument in predict_3D to False. 42 | 43 | main(args.input, args.checkpoint, device, args.silh_from) 44 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iGame-Lab/LASOR/7627d1311402ff803e64fbc2b21ad75dfe5d0451/utils/__init__.py -------------------------------------------------------------------------------- /utils/cam_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def orthographic_project_torch(points3D, cam_params): 6 | """ 7 | Scaled orthographic projection (i.e. weak perspective projection). 8 | Should be going from SMPL 3D coords to [-1, 1] scaled image coords. 9 | cam_params are [s, tx, ty] - i.e. scaling and 2D translation. 10 | """ 11 | x = points3D[:, :, 0] 12 | y = points3D[:, :, 1] 13 | 14 | # Scaling 15 | s = torch.unsqueeze(cam_params[:, 0], dim=1) 16 | 17 | # Translation 18 | t_x = torch.unsqueeze(cam_params[:, 1], dim=1) 19 | t_y = torch.unsqueeze(cam_params[:, 2], dim=1) 20 | 21 | u = s * (x + t_x) 22 | v = s * (y + t_y) 23 | 24 | proj_points = torch.stack([u, v], dim=-1) 25 | 26 | return proj_points 27 | 28 | 29 | def get_intrinsics_matrix(img_width, img_height, focal_length): 30 | """ 31 | Camera intrinsic matrix (calibration matrix) given focal length and img_width and 32 | img_height. Assumes that principal point is at (width/2, height/2). 33 | """ 34 | K = np.array([[focal_length, 0., img_width/2.0], 35 | [0., focal_length, img_height/2.0], 36 | [0., 0., 1.]]) 37 | return K 38 | 39 | 40 | def perspective_project_torch(points, rotation, translation, cam_K=None, 41 | focal_length=None, img_wh=None): 42 | """ 43 | This function computes the perspective projection of a set of points in torch. 44 | Input: 45 | points (bs, N, 3): 3D points 46 | rotation (bs, 3, 3): Camera rotation 47 | translation (bs, 3): Camera translation 48 | Either 49 | cam_K (bs, 3, 3): Camera intrinsics matrix 50 | Or 51 | focal_length (bs,) or scalar: Focal length 52 | camera_center (bs, 2): Camera center 53 | """ 54 | batch_size = points.shape[0] 55 | if cam_K is None: 56 | cam_K = torch.from_numpy(get_intrinsics_matrix(img_wh, img_wh, focal_length).astype( 57 | np.float32)) 58 | cam_K = torch.cat(batch_size * [cam_K[None, :, :]], dim=0) 59 | cam_K = cam_K.to(points.device) 60 | 61 | # Transform points 62 | points = torch.einsum('bij,bkj->bki', rotation, points) 63 | points = points + translation.unsqueeze(1) 64 | 65 | # Apply perspective distortion 66 | projected_points = points / points[:, :, -1].unsqueeze(-1) 67 | 68 | # Apply camera intrinsics 69 | projected_points = torch.einsum('bij,bkj->bki', cam_K, projected_points) 70 | 71 | return projected_points[:, :, :-1] 72 | 73 | 74 | def convert_weak_perspective_to_camera_translation(cam_wp, focal_length, resolution): 75 | cam_t = np.array([cam_wp[1], cam_wp[2], 2 * focal_length / (resolution * cam_wp[0] + 1e-9)]) 76 | return cam_t 77 | 78 | 79 | def batch_convert_weak_perspective_to_camera_translation(wp_cams, focal_length, resolution): 80 | num = wp_cams.shape[0] 81 | cam_ts = np.zeros((num, 3), dtype=np.float32) 82 | for i in range(num): 83 | cam_t = convert_weak_perspective_to_camera_translation(wp_cams[i], 84 | focal_length, 85 | resolution) 86 | cam_ts[i] = cam_t.astype(np.float32) 87 | return cam_ts 88 | -------------------------------------------------------------------------------- /utils/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def load_training_info_from_checkpoint(checkpoint, save_val_metrics): 5 | current_epoch = checkpoint['epoch'] + 1 6 | best_epoch = checkpoint['best_epoch'] 7 | best_model_wts = checkpoint['best_model_state_dict'] 8 | best_epoch_val_metrics = checkpoint['best_epoch_val_metrics'] 9 | # ^ best val metrics, happened at best_epoch 10 | 11 | # If different save_val_metrics used upon re-starting training, set best values for those 12 | # metrics to infinity. 13 | for metric in save_val_metrics: 14 | if metric not in best_epoch_val_metrics.keys(): 15 | best_epoch_val_metrics[metric] = np.inf 16 | metrics_to_del = [metric for metric in best_epoch_val_metrics.keys() if 17 | metric not in save_val_metrics] 18 | for metric in metrics_to_del: 19 | del best_epoch_val_metrics[metric] 20 | 21 | print('\nTraining information loaded from checkpoint.') 22 | print('Current epoch:', current_epoch) 23 | print('Best epoch val metrics from last training run:', best_epoch_val_metrics, 24 | ' - achieved in epoch:', best_epoch) 25 | 26 | return current_epoch, best_epoch, best_model_wts, best_epoch_val_metrics 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parts of the code are adapted from https://github.com/akanazawa/hmr 3 | """ 4 | import numpy as np 5 | 6 | 7 | def compute_similarity_transform(S1, S2): 8 | """ 9 | Computes a similarity transform (sR, t) that takes 10 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2, 11 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale. 12 | i.e. solves the orthogonal Procrutes problem. 13 | """ 14 | transposed = False 15 | if S1.shape[0] != 3 and S1.shape[0] != 2: 16 | S1 = S1.T 17 | S2 = S2.T 18 | transposed = True 19 | assert(S2.shape[1] == S1.shape[1]) 20 | 21 | # 1. Remove mean. 22 | mu1 = S1.mean(axis=1, keepdims=True) 23 | mu2 = S2.mean(axis=1, keepdims=True) 24 | X1 = S1 - mu1 25 | X2 = S2 - mu2 26 | 27 | # 2. Compute variance of X1 used for scale. 28 | var1 = np.sum(X1**2) 29 | 30 | # 3. The outer product of X1 and X2. 31 | K = X1.dot(X2.T) 32 | 33 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are 34 | # singular vectors of K. 35 | U, s, Vh = np.linalg.svd(K) 36 | V = Vh.T 37 | # Construct Z that fixes the orientation of R to get det(R)=1. 38 | Z = np.eye(U.shape[0]) 39 | Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) 40 | # Construct R. 41 | R = V.dot(Z.dot(U.T)) 42 | 43 | # 5. Recover scale. 44 | scale = np.trace(R.dot(K)) / var1 45 | 46 | # 6. Recover translation. 47 | t = mu2 - scale*(R.dot(mu1)) 48 | 49 | # 7. Error: 50 | S1_hat = scale*R.dot(S1) + t 51 | 52 | if transposed: 53 | S1_hat = S1_hat.T 54 | 55 | return S1_hat 56 | 57 | 58 | def procrustes_analysis_batch(S1, S2): 59 | """Batched version of compute_similarity_transform.""" 60 | S1_hat = np.zeros_like(S1) 61 | for i in range(S1.shape[0]): 62 | S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) 63 | return S1_hat 64 | 65 | 66 | def scale_and_translation_transform_batch(P, T): 67 | """ 68 | First Normalises batch of input 3D meshes P such that each mesh has mean (0, 0, 0) and 69 | RMS distance from mean = 1. 70 | Then transforms P such that it has the same mean and RMSD as T. 71 | :param P: (batch_size, N, 3) batch of N 3D meshes to transform. 72 | :param T: (batch_size, N, 3) batch of N reference 3D meshes. 73 | :return: P transformed 74 | """ 75 | P_mean = np.mean(P, axis=1, keepdims=True) 76 | P_trans = P - P_mean 77 | P_scale = np.sqrt(np.sum(P_trans ** 2, axis=(1, 2), keepdims=True) / P.shape[1]) 78 | P_normalised = P_trans / P_scale 79 | 80 | T_mean = np.mean(T, axis=1, keepdims=True) 81 | T_scale = np.sqrt(np.sum((T - T_mean) ** 2, axis=(1, 2), keepdims=True) / T.shape[1]) 82 | 83 | P_transformed = P_normalised * T_scale + T_mean 84 | 85 | return P_transformed 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def pad_to_square(image): 7 | """ 8 | Pad image to square shape. 9 | """ 10 | height, width = image.shape[:2] 11 | 12 | if width < height: 13 | border_width = (height - width) // 2 14 | image = cv2.copyMakeBorder(image, 0, 0, border_width, border_width, 15 | cv2.BORDER_CONSTANT, value=0) 16 | else: 17 | border_width = (width - height) // 2 18 | image = cv2.copyMakeBorder(image, border_width, border_width, 0, 0, 19 | cv2.BORDER_CONSTANT, value=0) 20 | 21 | return image 22 | 23 | def convert_bbox_corners_to_centre_hw(bbox_corners): 24 | """ 25 | Converst bbox coordinates from x1, y1, x2, y2 to centre, height, width. 26 | """ 27 | x1, y1, x2, y2 = bbox_corners 28 | centre = np.array([(x1+x2)/2.0, (y1+y2)/2.0]) 29 | height = x2 - x1 30 | width = y2 - y1 31 | 32 | return centre, height, width 33 | 34 | 35 | def convert_bbox_centre_hw_to_corners(centre, height, width): 36 | x1 = centre[0] - height/2.0 37 | x2 = centre[0] + height/2.0 38 | y1 = centre[1] - width/2.0 39 | y2 = centre[1] + width/2.0 40 | 41 | return np.array([x1, y1, x2, y2]) 42 | 43 | 44 | def batch_crop_seg_to_bounding_box(seg, joints2D, orig_scale_factor=1.2, delta_scale_range=None, delta_centre_range=None): 45 | """ 46 | seg: (bs, wh, wh) 47 | joints2D: (bs, num joints, 2) 48 | scale: bbox expansion scale 49 | """ 50 | all_cropped_segs = [] 51 | all_cropped_joints2D = [] 52 | for i in range(seg.shape[0]): 53 | body_pixels = np.argwhere(seg[i] != 0) 54 | bbox_corners = np.amin(body_pixels, axis=0), np.amax(body_pixels, axis=0) 55 | bbox_corners = np.concatenate(bbox_corners) 56 | centre, height, width = convert_bbox_corners_to_centre_hw(bbox_corners) 57 | if delta_scale_range is not None: 58 | l, h = delta_scale_range 59 | delta_scale = (h - l) * np.random.rand() + l 60 | scale_factor = orig_scale_factor + delta_scale 61 | else: 62 | scale_factor = orig_scale_factor 63 | 64 | if delta_centre_range is not None: 65 | l, h = delta_centre_range 66 | delta_centre = (h - l) * np.random.rand(2) + l 67 | centre = centre + delta_centre 68 | 69 | wh = max(height, width) * scale_factor 70 | 71 | bbox_corners = convert_bbox_centre_hw_to_corners(centre, wh, wh) 72 | 73 | top_left = bbox_corners[:2].astype(np.int16) 74 | bottom_right = bbox_corners[2:].astype(np.int16) 75 | top_left[top_left < 0] = 0 76 | bottom_right[bottom_right < 0] = 0 77 | 78 | cropped_joints2d = joints2D[i] - top_left[::-1] 79 | cropped_seg = seg[i, top_left[0]: bottom_right[0], top_left[1]: bottom_right[1]] 80 | all_cropped_joints2D.append(cropped_joints2d) 81 | all_cropped_segs.append(cropped_seg) 82 | return all_cropped_segs, all_cropped_joints2D 83 | 84 | 85 | def batch_resize(all_cropped_segs, all_cropped_joints2D, img_wh): 86 | """ 87 | all_cropped_seg: list of cropped segs with len = batch size 88 | """ 89 | all_resized_segs = [] 90 | all_resized_joints2D = [] 91 | for i in range(len(all_cropped_segs)): 92 | seg = all_cropped_segs[i] 93 | orig_height, orig_width = seg.shape[:2] 94 | resized_seg = cv2.resize(seg, (img_wh, img_wh), interpolation=cv2.INTER_NEAREST) 95 | all_resized_segs.append(resized_seg) 96 | 97 | joints2D = all_cropped_joints2D[i] 98 | resized_joints2D = joints2D * np.array([img_wh / float(orig_width), 99 | img_wh / float(orig_height)]) 100 | all_resized_joints2D.append(resized_joints2D) 101 | 102 | all_resized_segs = np.stack(all_resized_segs, axis=0) 103 | all_resized_joints2D = np.stack(all_resized_joints2D, axis=0) 104 | 105 | return all_resized_segs, all_resized_joints2D 106 | 107 | 108 | def crop_and_resize_silhouette_joints(silhouette, 109 | joints2D, 110 | out_wh, 111 | image=None, 112 | image_out_wh=None, 113 | bbox_scale_factor=1.2): 114 | # Find bounding box around silhouette 115 | body_pixels = np.argwhere(silhouette != 0) 116 | bbox_centre, height, width = convert_bbox_corners_to_centre_hw(np.concatenate([np.amin(body_pixels, axis=0), 117 | np.amax(body_pixels, axis=0)])) 118 | wh = max(height, width) * bbox_scale_factor # Make bounding box square with sides = wh 119 | bbox_corners = convert_bbox_centre_hw_to_corners(bbox_centre, wh, wh) 120 | top_left = bbox_corners[:2].astype(np.int16) 121 | bottom_right = bbox_corners[2:].astype(np.int16) 122 | top_left_orig = top_left.copy() 123 | bottom_right_orig = bottom_right.copy() 124 | top_left[top_left < 0] = 0 125 | bottom_right[bottom_right < 0] = 0 126 | # Crop silhouette 127 | orig_height, orig_width = silhouette.shape[:2] 128 | silhouette = silhouette[top_left[0]: bottom_right[0], top_left[1]: bottom_right[1]] 129 | # Pad silhouette if crop not square 130 | silhouette = cv2.copyMakeBorder(src=silhouette, 131 | top=max(0, -top_left_orig[0]), 132 | bottom=max(0, bottom_right_orig[0] - orig_height), 133 | left=max(0, -top_left_orig[1]), 134 | right=max(0, bottom_right_orig[1] - orig_width), 135 | borderType=cv2.BORDER_CONSTANT, 136 | value=0) 137 | crop_height, crop_width = silhouette.shape[:2] 138 | # Resize silhouette 139 | silhouette = cv2.resize(silhouette, (out_wh, out_wh), 140 | interpolation=cv2.INTER_NEAREST) 141 | 142 | # Translate and resize joints2D 143 | joints2D = joints2D[:, :2] - top_left_orig[::-1] 144 | joints2D = joints2D * np.array([out_wh / float(crop_width), 145 | out_wh / float(crop_height)]) 146 | 147 | if image is not None: 148 | # Crop image 149 | orig_height, orig_width = image.shape[:2] 150 | image = image[top_left[0]: bottom_right[0], top_left[1]: bottom_right[1]] 151 | # Pad image if crop not square 152 | image = cv2.copyMakeBorder(src=image, 153 | top=max(0, -top_left_orig[0]), 154 | bottom=max(0, bottom_right_orig[0] - orig_height), 155 | left=max(0, -top_left_orig[1]), 156 | right=max(0, bottom_right_orig[1] - orig_width), 157 | borderType=cv2.BORDER_CONSTANT, 158 | value=0) 159 | # Resize silhouette 160 | image = cv2.resize(image, (image_out_wh, image_out_wh), 161 | interpolation=cv2.INTER_LINEAR) 162 | 163 | return silhouette, joints2D, image 164 | 165 | -------------------------------------------------------------------------------- /utils/joints2d_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def undo_keypoint_normalisation(normalised_keypoints, img_wh): 6 | """ 7 | Converts normalised keypoints from [-1, 1] space to pixel space i.e. [0, img_wh] 8 | """ 9 | keypoints = (normalised_keypoints + 1) * (img_wh/2.0) 10 | return keypoints 11 | 12 | 13 | def check_joints2d_visibility(joints2d, img_wh): 14 | vis = np.ones(joints2d.shape[1]) 15 | vis[joints2d[0] > img_wh] = 0 16 | vis[joints2d[1] > img_wh] = 0 17 | vis[joints2d[0] < 0] = 0 18 | vis[joints2d[1] < 0] = 0 19 | 20 | return vis 21 | 22 | 23 | def check_joints2d_visibility_torch(joints2d, img_wh): 24 | """ 25 | Checks if 2D joints are within the image dimensions. 26 | """ 27 | vis = torch.ones(joints2d.shape[:2], device=joints2d.device, dtype=torch.bool) 28 | vis[joints2d[:, :, 0] > img_wh] = 0 29 | vis[joints2d[:, :, 1] > img_wh] = 0 30 | vis[joints2d[:, :, 0] < 0] = 0 31 | vis[joints2d[:, :, 1] < 0] = 0 32 | 33 | return vis -------------------------------------------------------------------------------- /utils/label_conversions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains functions for label conversions. 3 | """ 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def convert_densepose_to_6part_lsp_labels(densepose_seg): 9 | lsp_6part_seg = np.zeros_like(densepose_seg) 10 | 11 | lsp_6part_seg[densepose_seg == 1] = 6 12 | lsp_6part_seg[densepose_seg == 2] = 6 13 | lsp_6part_seg[densepose_seg == 3] = 2 14 | lsp_6part_seg[densepose_seg == 4] = 1 15 | lsp_6part_seg[densepose_seg == 5] = 4 16 | lsp_6part_seg[densepose_seg == 6] = 5 17 | lsp_6part_seg[densepose_seg == 7] = 5 18 | lsp_6part_seg[densepose_seg == 8] = 4 19 | lsp_6part_seg[densepose_seg == 9] = 5 20 | lsp_6part_seg[densepose_seg == 10] = 4 21 | lsp_6part_seg[densepose_seg == 11] = 5 22 | lsp_6part_seg[densepose_seg == 12] = 4 23 | lsp_6part_seg[densepose_seg == 13] = 5 24 | lsp_6part_seg[densepose_seg == 14] = 4 25 | lsp_6part_seg[densepose_seg == 15] = 1 26 | lsp_6part_seg[densepose_seg == 16] = 2 27 | lsp_6part_seg[densepose_seg == 17] = 1 28 | lsp_6part_seg[densepose_seg == 18] = 2 29 | lsp_6part_seg[densepose_seg == 19] = 1 30 | lsp_6part_seg[densepose_seg == 20] = 2 31 | lsp_6part_seg[densepose_seg == 21] = 1 32 | lsp_6part_seg[densepose_seg == 22] = 2 33 | lsp_6part_seg[densepose_seg == 23] = 3 34 | lsp_6part_seg[densepose_seg == 24] = 3 35 | 36 | return lsp_6part_seg 37 | 38 | 39 | def convert_multiclass_to_binary_labels(multiclass_labels): 40 | """ 41 | Converts multiclass segmentation labels into a binary mask. 42 | """ 43 | binary_labels = np.zeros_like(multiclass_labels) 44 | binary_labels[multiclass_labels != 0] = 1 45 | 46 | return binary_labels 47 | 48 | def convert_multiclass_to_binary_labels_torch(multiclass_labels): 49 | """ 50 | Converts multiclass segmentation labels into a binary mask. 51 | """ 52 | binary_labels = torch.zeros_like(multiclass_labels) 53 | binary_labels[multiclass_labels != 0] = 1 54 | 55 | return binary_labels 56 | 57 | 58 | def convert_2Djoints_to_gaussian_heatmaps(joints2D, img_wh, std=4): 59 | """ 60 | Converts 2D joints locations to img_wh x img_wh x num_joints gaussian heatmaps with given 61 | standard deviation var. 62 | """ 63 | num_joints = joints2D.shape[0] 64 | size = 2 * std # Truncate gaussian at 2 std from joint location. 65 | heatmaps = np.zeros((img_wh, img_wh, num_joints), dtype=np.float32) 66 | for i in range(joints2D.shape[0]): 67 | if np.all(joints2D[i] > -size) and np.all(joints2D[i] < img_wh-1+size): 68 | x, y = np.meshgrid(np.linspace(-size, size, 2*size), 69 | np.linspace(-size, size, 2*size)) 70 | d = np.sqrt(x * x + y * y) 71 | gaussian = np.exp(-(d ** 2 / (2.0 * std ** 2))) 72 | 73 | joint_centre = joints2D[i] 74 | hmap_start_x = max(0, joint_centre[0] - size) 75 | hmap_end_x = min(img_wh-1, joint_centre[0] + size) 76 | hmap_start_y = max(0, joint_centre[1] - size) 77 | hmap_end_y = min(img_wh-1, joint_centre[1] + size) 78 | 79 | g_start_x = max(0, size - joint_centre[0]) 80 | g_end_x = min(2*size, 2*size - (size + joint_centre[0] - (img_wh-1))) 81 | g_start_y = max(0, size - joint_centre[1]) 82 | g_end_y = min(2 * size, 2 * size - (size + joint_centre[1] - (img_wh-1))) 83 | 84 | heatmaps[hmap_start_y:hmap_end_y, 85 | hmap_start_x:hmap_end_x, i] = gaussian[g_start_y:g_end_y, g_start_x:g_end_x] 86 | 87 | return heatmaps 88 | 89 | 90 | def convert_2Djoints_to_gaussian_heatmaps_torch(joints2D, img_wh, std=4): 91 | """ 92 | Converts 2D joints locations to img_wh x img_wh x num_joints gaussian heatmaps with given 93 | standard deviation var. 94 | :param joints2D: (B, N, 2) tensor - batch of 2D joints. 95 | :return heatmaps: (B, N, img_wh, img_wh) - batch of 2D joint heatmaps. 96 | """ 97 | joints2D_rounded = joints2D.int() 98 | batch_size = joints2D_rounded.shape[0] 99 | num_joints = joints2D_rounded.shape[1] 100 | device = joints2D_rounded.device 101 | heatmaps = torch.zeros((batch_size, num_joints, img_wh, img_wh), device=device).float() 102 | 103 | size = 2 * std # Truncate gaussian at 2 std from joint location. 104 | x, y = torch.meshgrid(torch.linspace(-size, size, 2 * size), 105 | torch.linspace(-size, size, 2 * size)) 106 | x = x.to(device) 107 | y = y.to(device) 108 | d = torch.sqrt(x * x + y * y) 109 | gaussian = torch.exp(-(d ** 2 / (2.0 * std ** 2))) 110 | 111 | for i in range(batch_size): 112 | for j in range(num_joints): 113 | if torch.all(joints2D_rounded[i, j] > -size) and torch.all(joints2D_rounded[i, j] < img_wh-1+size): 114 | joint_centre = joints2D_rounded[i, j] 115 | hmap_start_x = max(0, joint_centre[0].item() - size) 116 | hmap_end_x = min(img_wh-1, joint_centre[0].item() + size) 117 | hmap_start_y = max(0, joint_centre[1].item() - size) 118 | hmap_end_y = min(img_wh-1, joint_centre[1].item() + size) 119 | 120 | g_start_x = max(0, size - joint_centre[0].item()) 121 | g_end_x = min(2*size, 2*size - (size + joint_centre[0].item() - (img_wh-1))) 122 | g_start_y = max(0, size - joint_centre[1].item()) 123 | g_end_y = min(2 * size, 2 * size - (size + joint_centre[1].item() - (img_wh-1))) 124 | 125 | heatmaps[i, j, hmap_start_y:hmap_end_y, hmap_start_x:hmap_end_x] = gaussian[g_start_y:g_end_y, g_start_x:g_end_x] 126 | 127 | return heatmaps 128 | 129 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def count_parameters(model): 4 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 5 | -------------------------------------------------------------------------------- /utils/rigid_transform_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from torch.nn import functional as F 5 | 6 | 7 | def rotate_translate_verts_torch(vertices, axis, angle, trans): 8 | """ 9 | Rotates and translates batch of vertices. 10 | :param vertices: B, N, 3 11 | :param axis: 3, 12 | :param angle: angle in radians 13 | :param trans: 3, 14 | :return: 15 | """ 16 | r = angle * axis 17 | R = cv2.Rodrigues(r)[0] 18 | R = torch.from_numpy(R.astype(np.float32)).to(vertices.device) 19 | trans = torch.from_numpy(trans.astype(np.float32)).to(vertices.device) 20 | 21 | vertices = torch.einsum('ij,bkj->bki', R, vertices) 22 | vertices = vertices + trans 23 | 24 | return vertices 25 | 26 | 27 | def rot6d_to_rotmat(x): 28 | """Convert 6D rotation representation to 3x3 rotation matrix. 29 | Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 30 | Input: 31 | (B,6) Batch of 6-D rotation representations 32 | Output: 33 | (B,3,3) Batch of corresponding rotation matrices 34 | """ 35 | x = x.view(-1,3,2) 36 | a1 = x[:, :, 0] 37 | a2 = x[:, :, 1] 38 | b1 = F.normalize(a1) # Ensuring columns are unit vectors 39 | b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) # Ensuring column 1 and column 2 are orthogonal 40 | b3 = torch.cross(b1, b2) 41 | return torch.stack((b1, b2, b3), dim=-1) --------------------------------------------------------------------------------