├── FaceBoxes ├── .gitignore ├── FaceBoxes.py ├── __init__.py ├── build_cpu_nms.sh ├── models │ └── faceboxes.py ├── readme.md ├── utils │ ├── .gitignore │ ├── __init__.py │ ├── box_utils.py │ ├── build.py │ ├── config.py │ ├── functions.py │ ├── nms │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── cpu_nms.pyx │ │ └── py_cpu_nms.py │ ├── nms_wrapper.py │ ├── prior_box.py │ └── timer.py └── weights │ ├── FaceBoxesProd.pth │ └── readme.md ├── LICENSE ├── README.md ├── Sim3DR ├── .gitignore ├── Sim3DR.py ├── __init__.py ├── _init_paths.py ├── build_sim3dr.sh ├── lib │ ├── rasterize.h │ ├── rasterize.pyx │ └── rasterize_kernel.cpp ├── lighting.py ├── readme.md ├── setup.py └── tests │ ├── .gitignore │ ├── CMakeLists.txt │ ├── io.cpp │ ├── io.h │ └── test.cpp ├── artistic.py ├── backbone_nets ├── ResNeSt │ ├── __init__.py │ ├── ablation.py │ ├── resnest.py │ ├── resnet.py │ └── splat.py ├── ghostnet_backbone.py ├── mobilenetv1_backbone.py ├── mobilenetv2_backbone.py ├── pointnet_backbone.py └── resnet_backbone.py ├── benchmark.py ├── benchmark_aflw2000.py ├── benchmark_validate.py ├── demo ├── 0.png ├── 1.png ├── 10.png ├── 11.png ├── 12.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── AF-1.png ├── AF-2.png ├── alignment.png ├── comparison-deca.png ├── demo.gif ├── multiple.png ├── orientation.png ├── single.png └── teaser.png ├── img ├── sample_1.jpg ├── sample_2.jpg ├── sample_3.jpg └── sample_4.jpg ├── loss_definition.py ├── main_train.py ├── model_building.py ├── pretrained └── __init__.py ├── setup.py ├── singleImage.py ├── singleImage_simple.py ├── synergy3DMM.py ├── synergy_demo.ipynb ├── train_script.sh ├── utils ├── __init__.py ├── ddfa.py ├── inference.py ├── io.py ├── params.py └── render.py └── uv_texture_realFaces.py /FaceBoxes/.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__ 3 | **/__pycache__ -------------------------------------------------------------------------------- /FaceBoxes/FaceBoxes.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os.path as osp 4 | 5 | import torch 6 | import numpy as np 7 | import cv2 8 | 9 | from .utils.prior_box import PriorBox 10 | from .utils.nms_wrapper import nms 11 | from .utils.box_utils import decode 12 | from .utils.timer import Timer 13 | from .utils.functions import check_keys, remove_prefix, load_model 14 | from .utils.config import cfg 15 | from .models.faceboxes import FaceBoxesNet 16 | 17 | # some global configs 18 | confidence_threshold = 0.05 19 | top_k = 5000 20 | keep_top_k = 750 21 | nms_threshold = 0.3 22 | vis_thres = 0.5 #0.5 23 | resize = 1 24 | 25 | scale_flag = True 26 | HEIGHT, WIDTH = 720, 1080 27 | 28 | make_abs_path = lambda fn: osp.join(osp.dirname(osp.realpath(__file__)), fn) 29 | pretrained_path = make_abs_path('weights/FaceBoxesProd.pth') 30 | 31 | 32 | def viz_bbox(img, dets, wfp='out.jpg'): 33 | # show 34 | for b in dets: 35 | if b[4] < vis_thres: 36 | continue 37 | text = "{:.4f}".format(b[4]) 38 | b = list(map(int, b)) 39 | cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) 40 | cx = b[0] 41 | cy = b[1] + 12 42 | cv2.putText(img, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255)) 43 | cv2.imwrite(wfp, img) 44 | print(f'Viz bbox to {wfp}') 45 | 46 | 47 | class FaceBoxes: 48 | def __init__(self, timer_flag=False): 49 | net = FaceBoxesNet(phase='test', size=None, num_classes=2) # initialize detector 50 | self.net = load_model(net, pretrained_path=pretrained_path, load_to_cpu=True) 51 | self.net.eval() 52 | 53 | for p in self.net.parameters(): 54 | p.requires_grad_(False) 55 | 56 | # print('Finished loading model!') 57 | 58 | self.timer_flag = timer_flag 59 | 60 | def __call__(self, img_): 61 | img_raw = img_.copy() 62 | 63 | # scaling to speed up 64 | scale = 1 65 | if scale_flag: 66 | h, w = img_raw.shape[:2] 67 | if h > HEIGHT: 68 | scale = HEIGHT / h 69 | if w * scale > WIDTH: 70 | scale *= WIDTH / (w * scale) 71 | # print(scale) 72 | if scale == 1: 73 | img_raw_scale = img_raw 74 | else: 75 | h_s = int(scale * h) 76 | w_s = int(scale * w) 77 | # print(h_s, w_s) 78 | img_raw_scale = cv2.resize(img_raw, dsize=(w_s, h_s)) 79 | # print(img_raw_scale.shape) 80 | 81 | img = np.float32(img_raw_scale) 82 | else: 83 | img = np.float32(img_raw) 84 | 85 | # forward 86 | _t = {'forward_pass': Timer(), 'misc': Timer()} 87 | im_height, im_width, _ = img.shape 88 | scale_bbox = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]]) 89 | img -= (104, 117, 123) 90 | img = img.transpose(2, 0, 1) 91 | img = torch.from_numpy(img).unsqueeze(0) 92 | 93 | _t['forward_pass'].tic() 94 | loc, conf = self.net(img) # forward pass 95 | _t['forward_pass'].toc() 96 | _t['misc'].tic() 97 | priorbox = PriorBox(image_size=(im_height, im_width)) 98 | priors = priorbox.forward() 99 | prior_data = priors.data 100 | boxes = decode(loc.data.squeeze(0), prior_data, cfg['variance']) 101 | if scale_flag: 102 | boxes = boxes * scale_bbox / scale / resize 103 | else: 104 | boxes = boxes * scale_bbox / resize 105 | 106 | boxes = boxes.cpu().numpy() 107 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1] 108 | 109 | # ignore low scores 110 | inds = np.where(scores > confidence_threshold)[0] 111 | boxes = boxes[inds] 112 | scores = scores[inds] 113 | 114 | # keep top-K before NMS 115 | order = scores.argsort()[::-1][:top_k] 116 | boxes = boxes[order] 117 | scores = scores[order] 118 | 119 | # do NMS 120 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 121 | # keep = py_cpu_nms(dets, args.nms_threshold) 122 | keep = nms(dets, nms_threshold) 123 | dets = dets[keep, :] 124 | 125 | # keep top-K faster NMS 126 | dets = dets[:keep_top_k, :] 127 | _t['misc'].toc() 128 | 129 | if self.timer_flag: 130 | print('Detection: {:d}/{:d} forward_pass_time: {:.4f}s misc: {:.4f}s'.format(1, 1, _t[ 131 | 'forward_pass'].average_time, _t['misc'].average_time)) 132 | 133 | # filter using vis_thres 134 | det_bboxes = [] 135 | for b in dets: 136 | if b[4] > vis_thres: 137 | xmin, ymin, xmax, ymax, score = b[0], b[1], b[2], b[3], b[4] 138 | w = xmax - xmin + 1 139 | h = ymax - ymin + 1 140 | bbox = [xmin, ymin, xmax, ymax, score] 141 | det_bboxes.append(bbox) 142 | 143 | return det_bboxes 144 | 145 | 146 | def main(): 147 | face_boxes = FaceBoxes(timer_flag=True) 148 | 149 | fn = 'trump_hillary.jpg' 150 | img_fp = f'../examples/inputs/{fn}' 151 | img = cv2.imread(img_fp) 152 | dets = face_boxes(img) # xmin, ymin, w, h 153 | # print(dets) 154 | 155 | wfn = fn.replace('.jpg', '_det.jpg') 156 | wfp = osp.join('../examples/results', wfn) 157 | viz_bbox(img, dets, wfp) 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /FaceBoxes/__init__.py: -------------------------------------------------------------------------------- 1 | from .FaceBoxes import FaceBoxes 2 | -------------------------------------------------------------------------------- /FaceBoxes/build_cpu_nms.sh: -------------------------------------------------------------------------------- 1 | cd utils 2 | python3 build.py build_ext --inplace 3 | cd .. -------------------------------------------------------------------------------- /FaceBoxes/models/faceboxes.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicConv2d(nn.Module): 9 | 10 | def __init__(self, in_channels, out_channels, **kwargs): 11 | super(BasicConv2d, self).__init__() 12 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 13 | self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | return F.relu(x, inplace=True) 19 | 20 | 21 | class Inception(nn.Module): 22 | def __init__(self): 23 | super(Inception, self).__init__() 24 | self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0) 25 | self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0) 26 | self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0) 27 | self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1) 28 | self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0) 29 | self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1) 30 | self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 31 | 32 | def forward(self, x): 33 | branch1x1 = self.branch1x1(x) 34 | 35 | branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 36 | branch1x1_2 = self.branch1x1_2(branch1x1_pool) 37 | 38 | branch3x3_reduce = self.branch3x3_reduce(x) 39 | branch3x3 = self.branch3x3(branch3x3_reduce) 40 | 41 | branch3x3_reduce_2 = self.branch3x3_reduce_2(x) 42 | branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2) 43 | branch3x3_3 = self.branch3x3_3(branch3x3_2) 44 | 45 | outputs = [branch1x1, branch1x1_2, branch3x3, branch3x3_3] 46 | return torch.cat(outputs, 1) 47 | 48 | 49 | class CRelu(nn.Module): 50 | 51 | def __init__(self, in_channels, out_channels, **kwargs): 52 | super(CRelu, self).__init__() 53 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 54 | self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) 55 | 56 | def forward(self, x): 57 | x = self.conv(x) 58 | x = self.bn(x) 59 | x = torch.cat([x, -x], 1) 60 | x = F.relu(x, inplace=True) 61 | return x 62 | 63 | 64 | class FaceBoxesNet(nn.Module): 65 | 66 | def __init__(self, phase, size, num_classes): 67 | super(FaceBoxesNet, self).__init__() 68 | self.phase = phase 69 | self.num_classes = num_classes 70 | self.size = size 71 | 72 | self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3) 73 | self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2) 74 | 75 | self.inception1 = Inception() 76 | self.inception2 = Inception() 77 | self.inception3 = Inception() 78 | 79 | self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0) 80 | self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) 81 | 82 | self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0) 83 | self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1) 84 | 85 | self.loc, self.conf = self.multibox(self.num_classes) 86 | 87 | if self.phase == 'test': 88 | self.softmax = nn.Softmax(dim=-1) 89 | 90 | if self.phase == 'train': 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | if m.bias is not None: 94 | nn.init.xavier_normal_(m.weight.data) 95 | m.bias.data.fill_(0.02) 96 | else: 97 | m.weight.data.normal_(0, 0.01) 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | 102 | def multibox(self, num_classes): 103 | loc_layers = [] 104 | conf_layers = [] 105 | loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)] 106 | conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)] 107 | loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] 108 | conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] 109 | loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)] 110 | conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)] 111 | return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers) 112 | 113 | def forward(self, x): 114 | 115 | detection_sources = list() 116 | loc = list() 117 | conf = list() 118 | 119 | x = self.conv1(x) 120 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 121 | x = self.conv2(x) 122 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 123 | x = self.inception1(x) 124 | x = self.inception2(x) 125 | x = self.inception3(x) 126 | detection_sources.append(x) 127 | 128 | x = self.conv3_1(x) 129 | x = self.conv3_2(x) 130 | detection_sources.append(x) 131 | 132 | x = self.conv4_1(x) 133 | x = self.conv4_2(x) 134 | detection_sources.append(x) 135 | 136 | for (x, l, c) in zip(detection_sources, self.loc, self.conf): 137 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 138 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 139 | 140 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 141 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 142 | 143 | if self.phase == "test": 144 | output = (loc.view(loc.size(0), -1, 4), 145 | self.softmax(conf.view(conf.size(0), -1, self.num_classes))) 146 | else: 147 | output = (loc.view(loc.size(0), -1, 4), 148 | conf.view(conf.size(0), -1, self.num_classes)) 149 | 150 | return output 151 | -------------------------------------------------------------------------------- /FaceBoxes/readme.md: -------------------------------------------------------------------------------- 1 | ## How to fun FaceBoxes 2 | 3 | ### Build the cpu version of NMS 4 | ```shell script 5 | cd utils 6 | python3 build.py build_ext --inplace 7 | ``` 8 | 9 | or just run 10 | 11 | ```shell script 12 | sh ./build_cpu_nms.sh 13 | ``` 14 | 15 | ### Run the demo of face detection 16 | ```shell script 17 | python3 FaceBoxes.py 18 | ``` -------------------------------------------------------------------------------- /FaceBoxes/utils/.gitignore: -------------------------------------------------------------------------------- 1 | utils/build 2 | utils/nms/*.so 3 | utils/*.c 4 | build/ 5 | -------------------------------------------------------------------------------- /FaceBoxes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/FaceBoxes/utils/__init__.py -------------------------------------------------------------------------------- /FaceBoxes/utils/box_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def point_form(boxes): 8 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 9 | representation for comparison to point form ground truth data. 10 | Args: 11 | boxes: (tensor) center-size default boxes from priorbox layers. 12 | Return: 13 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 14 | """ 15 | return torch.cat((boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin 16 | boxes[:, :2] + boxes[:, 2:] / 2), 1) # xmax, ymax 17 | 18 | 19 | def center_size(boxes): 20 | """ Convert prior_boxes to (cx, cy, w, h) 21 | representation for comparison to center-size form ground truth data. 22 | Args: 23 | boxes: (tensor) point_form boxes 24 | Return: 25 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 26 | """ 27 | return torch.cat((boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy 28 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 29 | 30 | 31 | def intersect(box_a, box_b): 32 | """ We resize both tensors to [A,B,2] without new malloc: 33 | [A,2] -> [A,1,2] -> [A,B,2] 34 | [B,2] -> [1,B,2] -> [A,B,2] 35 | Then we compute the area of intersect between box_a and box_b. 36 | Args: 37 | box_a: (tensor) bounding boxes, Shape: [A,4]. 38 | box_b: (tensor) bounding boxes, Shape: [B,4]. 39 | Return: 40 | (tensor) intersection area, Shape: [A,B]. 41 | """ 42 | A = box_a.size(0) 43 | B = box_b.size(0) 44 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 45 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 46 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 47 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 48 | inter = torch.clamp((max_xy - min_xy), min=0) 49 | return inter[:, :, 0] * inter[:, :, 1] 50 | 51 | 52 | def jaccard(box_a, box_b): 53 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 54 | is simply the intersection over union of two boxes. Here we operate on 55 | ground truth boxes and default boxes. 56 | E.g.: 57 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 58 | Args: 59 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 60 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 61 | Return: 62 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 63 | """ 64 | inter = intersect(box_a, box_b) 65 | area_a = ((box_a[:, 2] - box_a[:, 0]) * 66 | (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 67 | area_b = ((box_b[:, 2] - box_b[:, 0]) * 68 | (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 69 | union = area_a + area_b - inter 70 | return inter / union # [A,B] 71 | 72 | 73 | def matrix_iou(a, b): 74 | """ 75 | return iou of a and b, numpy version for data augenmentation 76 | """ 77 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) 78 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) 79 | 80 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) 81 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) 82 | area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) 83 | return area_i / (area_a[:, np.newaxis] + area_b - area_i) 84 | 85 | 86 | def matrix_iof(a, b): 87 | """ 88 | return iof of a and b, numpy version for data augenmentation 89 | """ 90 | lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) 91 | rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) 92 | 93 | area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) 94 | area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) 95 | return area_i / np.maximum(area_a[:, np.newaxis], 1) 96 | 97 | 98 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 99 | """Match each prior box with the ground truth box of the highest jaccard 100 | overlap, encode the bounding boxes, then return the matched indices 101 | corresponding to both confidence and location preds. 102 | Args: 103 | threshold: (float) The overlap threshold used when mathing boxes. 104 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 105 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 106 | variances: (tensor) Variances corresponding to each prior coord, 107 | Shape: [num_priors, 4]. 108 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 109 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 110 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 111 | idx: (int) current batch index 112 | Return: 113 | The matched indices corresponding to 1)location and 2)confidence preds. 114 | """ 115 | # jaccard index 116 | overlaps = jaccard( 117 | truths, 118 | point_form(priors) 119 | ) 120 | # (Bipartite Matching) 121 | # [1,num_objects] best prior for each ground truth 122 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 123 | 124 | # ignore hard gt 125 | valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 126 | best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] 127 | if best_prior_idx_filter.shape[0] <= 0: 128 | loc_t[idx] = 0 129 | conf_t[idx] = 0 130 | return 131 | 132 | # [1,num_priors] best ground truth for each prior 133 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 134 | best_truth_idx.squeeze_(0) 135 | best_truth_overlap.squeeze_(0) 136 | best_prior_idx.squeeze_(1) 137 | best_prior_idx_filter.squeeze_(1) 138 | best_prior_overlap.squeeze_(1) 139 | best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior 140 | # TODO refactor: index best_prior_idx with long tensor 141 | # ensure every gt matches with its prior of max overlap 142 | for j in range(best_prior_idx.size(0)): 143 | best_truth_idx[best_prior_idx[j]] = j 144 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 145 | conf = labels[best_truth_idx] # Shape: [num_priors] 146 | conf[best_truth_overlap < threshold] = 0 # label as background 147 | loc = encode(matches, priors, variances) 148 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 149 | conf_t[idx] = conf # [num_priors] top class label for each prior 150 | 151 | 152 | def encode(matched, priors, variances): 153 | """Encode the variances from the priorbox layers into the ground truth boxes 154 | we have matched (based on jaccard overlap) with the prior boxes. 155 | Args: 156 | matched: (tensor) Coords of ground truth for each prior in point-form 157 | Shape: [num_priors, 4]. 158 | priors: (tensor) Prior boxes in center-offset form 159 | Shape: [num_priors,4]. 160 | variances: (list[float]) Variances of priorboxes 161 | Return: 162 | encoded boxes (tensor), Shape: [num_priors, 4] 163 | """ 164 | 165 | # dist b/t match center and prior's center 166 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 167 | # encode variance 168 | g_cxcy /= (variances[0] * priors[:, 2:]) 169 | # match wh / prior wh 170 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 171 | g_wh = torch.log(g_wh) / variances[1] 172 | # return target for smooth_l1_loss 173 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 174 | 175 | 176 | # Adapted from https://github.com/Hakuyume/chainer-ssd 177 | def decode(loc, priors, variances): 178 | """Decode locations from predictions using priors to undo 179 | the encoding we did for offset regression at train time. 180 | Args: 181 | loc (tensor): location predictions for loc layers, 182 | Shape: [num_priors,4] 183 | priors (tensor): Prior boxes in center-offset form. 184 | Shape: [num_priors,4]. 185 | variances: (list[float]) Variances of priorboxes 186 | Return: 187 | decoded bounding box predictions 188 | """ 189 | 190 | boxes = torch.cat(( 191 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 192 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 193 | boxes[:, :2] -= boxes[:, 2:] / 2 194 | boxes[:, 2:] += boxes[:, :2] 195 | return boxes 196 | 197 | 198 | def log_sum_exp(x): 199 | """Utility function for computing log_sum_exp while determining 200 | This will be used to determine unaveraged confidence loss across 201 | all examples in a batch. 202 | Args: 203 | x (Variable(tensor)): conf_preds from conf layers 204 | """ 205 | x_max = x.data.max() 206 | return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max 207 | 208 | 209 | # Original author: Francisco Massa: 210 | # https://github.com/fmassa/object-detection.torch 211 | # Ported to PyTorch by Max deGroot (02/01/2017) 212 | def nms(boxes, scores, overlap=0.5, top_k=200): 213 | """Apply non-maximum suppression at test time to avoid detecting too many 214 | overlapping bounding boxes for a given object. 215 | Args: 216 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 217 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 218 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 219 | top_k: (int) The Maximum number of box preds to consider. 220 | Return: 221 | The indices of the kept boxes with respect to num_priors. 222 | """ 223 | 224 | keep = torch.Tensor(scores.size(0)).fill_(0).long() 225 | if boxes.numel() == 0: 226 | return keep 227 | x1 = boxes[:, 0] 228 | y1 = boxes[:, 1] 229 | x2 = boxes[:, 2] 230 | y2 = boxes[:, 3] 231 | area = torch.mul(x2 - x1, y2 - y1) 232 | v, idx = scores.sort(0) # sort in ascending order 233 | # I = I[v >= 0.01] 234 | idx = idx[-top_k:] # indices of the top-k largest vals 235 | xx1 = boxes.new() 236 | yy1 = boxes.new() 237 | xx2 = boxes.new() 238 | yy2 = boxes.new() 239 | w = boxes.new() 240 | h = boxes.new() 241 | 242 | # keep = torch.Tensor() 243 | count = 0 244 | while idx.numel() > 0: 245 | i = idx[-1] # index of current largest val 246 | # keep.append(i) 247 | keep[count] = i 248 | count += 1 249 | if idx.size(0) == 1: 250 | break 251 | idx = idx[:-1] # remove kept element from view 252 | # load bboxes of next highest vals 253 | torch.index_select(x1, 0, idx, out=xx1) 254 | torch.index_select(y1, 0, idx, out=yy1) 255 | torch.index_select(x2, 0, idx, out=xx2) 256 | torch.index_select(y2, 0, idx, out=yy2) 257 | # store element-wise max with next highest score 258 | xx1 = torch.clamp(xx1, min=x1[i]) 259 | yy1 = torch.clamp(yy1, min=y1[i]) 260 | xx2 = torch.clamp(xx2, max=x2[i]) 261 | yy2 = torch.clamp(yy2, max=y2[i]) 262 | w.resize_as_(xx2) 263 | h.resize_as_(yy2) 264 | w = xx2 - xx1 265 | h = yy2 - yy1 266 | # check sizes of xx1 and xx2.. after each iteration 267 | w = torch.clamp(w, min=0.0) 268 | h = torch.clamp(h, min=0.0) 269 | inter = w * h 270 | # IoU = i / (area(a) + area(b) - i) 271 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 272 | union = (rem_areas - inter) + area[i] 273 | IoU = inter / union # store result in iou 274 | # keep only elements with an IoU <= overlap 275 | idx = idx[IoU.le(overlap)] 276 | return keep, count 277 | -------------------------------------------------------------------------------- /FaceBoxes/utils/build.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # -------------------------------------------------------- 4 | # Fast R-CNN 5 | # Copyright (c) 2015 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ross Girshick 8 | # -------------------------------------------------------- 9 | 10 | import os 11 | from os.path import join as pjoin 12 | import numpy as np 13 | from distutils.core import setup 14 | from distutils.extension import Extension 15 | from Cython.Distutils import build_ext 16 | 17 | 18 | def find_in_path(name, path): 19 | "Find a file in a search path" 20 | # adapted fom http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/ 21 | for dir in path.split(os.pathsep): 22 | binpath = pjoin(dir, name) 23 | if os.path.exists(binpath): 24 | return os.path.abspath(binpath) 25 | return None 26 | 27 | 28 | # Obtain the numpy include directory. This logic works across numpy versions. 29 | try: 30 | numpy_include = np.get_include() 31 | except AttributeError: 32 | numpy_include = np.get_numpy_include() 33 | 34 | 35 | # run the customize_compiler 36 | class custom_build_ext(build_ext): 37 | def build_extensions(self): 38 | # customize_compiler_for_nvcc(self.compiler) 39 | build_ext.build_extensions(self) 40 | 41 | 42 | ext_modules = [ 43 | Extension( 44 | "nms.cpu_nms", 45 | ["nms/cpu_nms.pyx"], 46 | # extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]}, 47 | extra_compile_args=["-Wno-cpp", "-Wno-unused-function"], 48 | include_dirs=[numpy_include] 49 | ) 50 | ] 51 | 52 | setup( 53 | name='mot_utils', 54 | ext_modules=ext_modules, 55 | # inject our custom trigger 56 | cmdclass={'build_ext': custom_build_ext}, 57 | ) 58 | -------------------------------------------------------------------------------- /FaceBoxes/utils/config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | cfg = { 4 | 'name': 'FaceBoxes', 5 | 'min_sizes': [[32, 64, 128], [256], [512]], 6 | 'steps': [32, 64, 128], 7 | 'variance': [0.1, 0.2], 8 | 'clip': False 9 | } 10 | -------------------------------------------------------------------------------- /FaceBoxes/utils/functions.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import sys 4 | import os.path as osp 5 | import torch 6 | 7 | def check_keys(model, pretrained_state_dict): 8 | ckpt_keys = set(pretrained_state_dict.keys()) 9 | model_keys = set(model.state_dict().keys()) 10 | used_pretrained_keys = model_keys & ckpt_keys 11 | unused_pretrained_keys = ckpt_keys - model_keys 12 | missing_keys = model_keys - ckpt_keys 13 | # print('Missing keys:{}'.format(len(missing_keys))) 14 | # print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) 15 | # print('Used keys:{}'.format(len(used_pretrained_keys))) 16 | assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' 17 | return True 18 | 19 | 20 | def remove_prefix(state_dict, prefix): 21 | ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' 22 | # print('remove prefix \'{}\''.format(prefix)) 23 | f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x 24 | return {f(key): value for key, value in state_dict.items()} 25 | 26 | 27 | def load_model(model, pretrained_path, load_to_cpu): 28 | if not osp.isfile(pretrained_path): 29 | print(f'The pre-trained FaceBoxes model {pretrained_path} does not exist') 30 | sys.exit('-1') 31 | # print('Loading pretrained model from {}'.format(pretrained_path)) 32 | if load_to_cpu: 33 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 34 | else: 35 | device = torch.cuda.current_device() 36 | pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) 37 | if "state_dict" in pretrained_dict.keys(): 38 | pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') 39 | else: 40 | pretrained_dict = remove_prefix(pretrained_dict, 'module.') 41 | check_keys(model, pretrained_dict) 42 | model.load_state_dict(pretrained_dict, strict=False) 43 | return model 44 | -------------------------------------------------------------------------------- /FaceBoxes/utils/nms/.gitignore: -------------------------------------------------------------------------------- 1 | *.c 2 | *.so 3 | -------------------------------------------------------------------------------- /FaceBoxes/utils/nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/FaceBoxes/utils/nms/__init__.py -------------------------------------------------------------------------------- /FaceBoxes/utils/nms/cpu_nms.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 12 | return a if a >= b else b 13 | 14 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 15 | return a if a <= b else b 16 | 17 | def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): 18 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 19 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 20 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 21 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 22 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 23 | 24 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 25 | cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] 26 | 27 | cdef int ndets = dets.shape[0] 28 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 29 | np.zeros((ndets), dtype=np.int64) 30 | 31 | # nominal indices 32 | cdef int _i, _j 33 | # sorted indices 34 | cdef int i, j 35 | # temp variables for box i's (the box currently under consideration) 36 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 37 | # variables for computing overlap with box j (lower scoring box) 38 | cdef np.float32_t xx1, yy1, xx2, yy2 39 | cdef np.float32_t w, h 40 | cdef np.float32_t inter, ovr 41 | 42 | keep = [] 43 | for _i in range(ndets): 44 | i = order[_i] 45 | if suppressed[i] == 1: 46 | continue 47 | keep.append(i) 48 | ix1 = x1[i] 49 | iy1 = y1[i] 50 | ix2 = x2[i] 51 | iy2 = y2[i] 52 | iarea = areas[i] 53 | for _j in range(_i + 1, ndets): 54 | j = order[_j] 55 | if suppressed[j] == 1: 56 | continue 57 | xx1 = max(ix1, x1[j]) 58 | yy1 = max(iy1, y1[j]) 59 | xx2 = min(ix2, x2[j]) 60 | yy2 = min(iy2, y2[j]) 61 | w = max(0.0, xx2 - xx1 + 1) 62 | h = max(0.0, yy2 - yy1 + 1) 63 | inter = w * h 64 | ovr = inter / (iarea + areas[j] - inter) 65 | if ovr >= thresh: 66 | suppressed[j] = 1 67 | 68 | return keep 69 | 70 | def cpu_soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): 71 | cdef unsigned int N = boxes.shape[0] 72 | cdef float iw, ih, box_area 73 | cdef float ua 74 | cdef int pos = 0 75 | cdef float maxscore = 0 76 | cdef int maxpos = 0 77 | cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov 78 | 79 | for i in range(N): 80 | maxscore = boxes[i, 4] 81 | maxpos = i 82 | 83 | tx1 = boxes[i,0] 84 | ty1 = boxes[i,1] 85 | tx2 = boxes[i,2] 86 | ty2 = boxes[i,3] 87 | ts = boxes[i,4] 88 | 89 | pos = i + 1 90 | # get max box 91 | while pos < N: 92 | if maxscore < boxes[pos, 4]: 93 | maxscore = boxes[pos, 4] 94 | maxpos = pos 95 | pos = pos + 1 96 | 97 | # add max box as a detection 98 | boxes[i,0] = boxes[maxpos,0] 99 | boxes[i,1] = boxes[maxpos,1] 100 | boxes[i,2] = boxes[maxpos,2] 101 | boxes[i,3] = boxes[maxpos,3] 102 | boxes[i,4] = boxes[maxpos,4] 103 | 104 | # swap ith box with position of max box 105 | boxes[maxpos,0] = tx1 106 | boxes[maxpos,1] = ty1 107 | boxes[maxpos,2] = tx2 108 | boxes[maxpos,3] = ty2 109 | boxes[maxpos,4] = ts 110 | 111 | tx1 = boxes[i,0] 112 | ty1 = boxes[i,1] 113 | tx2 = boxes[i,2] 114 | ty2 = boxes[i,3] 115 | ts = boxes[i,4] 116 | 117 | pos = i + 1 118 | # NMS iterations, note that N changes if detection boxes fall below threshold 119 | while pos < N: 120 | x1 = boxes[pos, 0] 121 | y1 = boxes[pos, 1] 122 | x2 = boxes[pos, 2] 123 | y2 = boxes[pos, 3] 124 | s = boxes[pos, 4] 125 | 126 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 127 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 128 | if iw > 0: 129 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 130 | if ih > 0: 131 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 132 | ov = iw * ih / ua #iou between max box and detection box 133 | 134 | if method == 1: # linear 135 | if ov > Nt: 136 | weight = 1 - ov 137 | else: 138 | weight = 1 139 | elif method == 2: # gaussian 140 | weight = np.exp(-(ov * ov)/sigma) 141 | else: # original NMS 142 | if ov > Nt: 143 | weight = 0 144 | else: 145 | weight = 1 146 | 147 | boxes[pos, 4] = weight*boxes[pos, 4] 148 | 149 | # if box score falls below threshold, discard the box by swapping with last box 150 | # update N 151 | if boxes[pos, 4] < threshold: 152 | boxes[pos,0] = boxes[N-1, 0] 153 | boxes[pos,1] = boxes[N-1, 1] 154 | boxes[pos,2] = boxes[N-1, 2] 155 | boxes[pos,3] = boxes[N-1, 3] 156 | boxes[pos,4] = boxes[N-1, 4] 157 | N = N - 1 158 | pos = pos - 1 159 | 160 | pos = pos + 1 161 | 162 | keep = [i for i in range(N)] 163 | return keep 164 | -------------------------------------------------------------------------------- /FaceBoxes/utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /FaceBoxes/utils/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # -------------------------------------------------------- 4 | # Fast R-CNN 5 | # Copyright (c) 2015 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ross Girshick 8 | # -------------------------------------------------------- 9 | 10 | from .nms.cpu_nms import cpu_nms, cpu_soft_nms 11 | 12 | 13 | def nms(dets, thresh): 14 | """Dispatch to either CPU or GPU NMS implementations.""" 15 | 16 | if dets.shape[0] == 0: 17 | return [] 18 | return cpu_nms(dets, thresh) 19 | # return gpu_nms(dets, thresh) 20 | -------------------------------------------------------------------------------- /FaceBoxes/utils/prior_box.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from .config import cfg 4 | 5 | import torch 6 | from itertools import product as product 7 | from math import ceil 8 | 9 | 10 | class PriorBox(object): 11 | def __init__(self, image_size=None): 12 | super(PriorBox, self).__init__() 13 | # self.aspect_ratios = cfg['aspect_ratios'] 14 | self.min_sizes = cfg['min_sizes'] 15 | self.steps = cfg['steps'] 16 | self.clip = cfg['clip'] 17 | self.image_size = image_size 18 | self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] 19 | 20 | def forward(self): 21 | anchors = [] 22 | for k, f in enumerate(self.feature_maps): 23 | min_sizes = self.min_sizes[k] 24 | for i, j in product(range(f[0]), range(f[1])): 25 | for min_size in min_sizes: 26 | s_kx = min_size / self.image_size[1] 27 | s_ky = min_size / self.image_size[0] 28 | if min_size == 32: 29 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in 30 | [j + 0, j + 0.25, j + 0.5, j + 0.75]] 31 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in 32 | [i + 0, i + 0.25, i + 0.5, i + 0.75]] 33 | for cy, cx in product(dense_cy, dense_cx): 34 | anchors += [cx, cy, s_kx, s_ky] 35 | elif min_size == 64: 36 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0, j + 0.5]] 37 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0, i + 0.5]] 38 | for cy, cx in product(dense_cy, dense_cx): 39 | anchors += [cx, cy, s_kx, s_ky] 40 | else: 41 | cx = (j + 0.5) * self.steps[k] / self.image_size[1] 42 | cy = (i + 0.5) * self.steps[k] / self.image_size[0] 43 | anchors += [cx, cy, s_kx, s_ky] 44 | # back to torch land 45 | output = torch.Tensor(anchors).view(-1, 4) 46 | if self.clip: 47 | output.clamp_(max=1, min=0) 48 | return output 49 | -------------------------------------------------------------------------------- /FaceBoxes/utils/timer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # -------------------------------------------------------- 4 | # Fast R-CNN 5 | # Copyright (c) 2015 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Ross Girshick 8 | # -------------------------------------------------------- 9 | 10 | import time 11 | 12 | 13 | class Timer(object): 14 | """A simple timer.""" 15 | 16 | def __init__(self): 17 | self.total_time = 0. 18 | self.calls = 0 19 | self.start_time = 0. 20 | self.diff = 0. 21 | self.average_time = 0. 22 | 23 | def tic(self): 24 | # using time.time instead of time.clock because time time.clock 25 | # does not normalize for multithreading 26 | self.start_time = time.time() 27 | 28 | def toc(self, average=True): 29 | self.diff = time.time() - self.start_time 30 | self.total_time += self.diff 31 | self.calls += 1 32 | self.average_time = self.total_time / self.calls 33 | if average: 34 | return self.average_time 35 | else: 36 | return self.diff 37 | 38 | def clear(self): 39 | self.total_time = 0. 40 | self.calls = 0 41 | self.start_time = 0. 42 | self.diff = 0. 43 | self.average_time = 0. 44 | -------------------------------------------------------------------------------- /FaceBoxes/weights/FaceBoxesProd.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/FaceBoxes/weights/FaceBoxesProd.pth -------------------------------------------------------------------------------- /FaceBoxes/weights/readme.md: -------------------------------------------------------------------------------- 1 | The pre-trained model `FaceBoxesProd.pth` is downloaded from [Google Drive](https://drive.google.com/file/d/1tRVwOlu0QtjvADQ2H7vqrRwsWEmaqioI). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Cho Ying Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
SynergyNet
2 | 3DV 2021: Synergy between 3DMM and 3D Landmarks for Accurate 3D Facial Geometry 3 | 4 | Cho-Ying Wu, Qiangeng Xu, Ulrich Neumann, CGIT Lab at University of Souther California 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/synergy-between-3dmm-and-3d-landmarks-for/face-alignment-on-aflw)](https://paperswithcode.com/sota/face-alignment-on-aflw?p=synergy-between-3dmm-and-3d-landmarks-for) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/synergy-between-3dmm-and-3d-landmarks-for/head-pose-estimation-on-aflw2000)](https://paperswithcode.com/sota/head-pose-estimation-on-aflw2000?p=synergy-between-3dmm-and-3d-landmarks-for) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/synergy-between-3dmm-and-3d-landmarks-for/face-alignment-on-aflw2000-3d)](https://paperswithcode.com/sota/face-alignment-on-aflw2000-3d?p=synergy-between-3dmm-and-3d-landmarks-for) 9 | 10 | [paper] [video] [project page] 11 | 12 | News [Dec 11, 2024]: Add details of training UV-texture GAN. See the section "Training UV-texture GAN" 13 | 14 | News [Jul 10, 2022]: Add simplified api for getting 3d landmarks, face mesh, and face pose in only one line. See "Simplified API" It's convenient if you simply want to plug in this method in your work. 15 | 16 | News: Add Colab demo 17 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1q9HRLA3wGxz4IFIseZFK1maOyH0wutYk?usp=sharing) 18 | 19 | News: Our new work [Cross-Modal Perceptionist] is accepted to CVPR 2022, which is based on this SynergyNet project.
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | ##
Advantages
30 | 31 | :+1: SOTA on all 3D facial alignment, face orientation estimation, and 3D face modeling.

32 | :+1: Fast inference with 3000fps on a laptop RTX 2080.

33 | :+1: Simple implementation with only widely used operations.

34 | 35 | (This project is built/tested on Python 3.8 and PyTorch 1.9 on a compatible GPU) 36 | 37 | ##
Single Image Inference Demo
38 | 39 | 1. Clone 40 | 41 | ```git clone https://github.com/choyingw/SynergyNet``` 42 | 43 | ```cd SynergyNet ``` 44 | 45 | 2. Use conda 46 | 47 | ```conda create --name SynergyNet``` 48 | 49 | ```conda activate SynergyNet``` 50 | 51 | 3. Install pre-requisite common packages 52 | 53 | ```PyTorch 1.9 (should also be compatiable with 1.0+ versions), Torchvision, Opencv, Scipy, Matplotlib, Cython ``` 54 | 55 | 4. Download data [here] and 56 | [here]. Extract these data under the repo root. 57 | 58 | These data are processed from [3DDFA] and [FSA-Net]. 59 | 60 | Download pretrained weights [here]. Put the model under 'pretrained/' 61 | 62 | 5. Compile Sim3DR and FaceBoxes: 63 | 64 | ```cd Sim3DR``` 65 | 66 | ```./build_sim3dr.sh``` 67 | 68 | ```cd ../FaceBoxes``` 69 | 70 | ```./build_cpu_nms.sh``` 71 | 72 | ```cd ..``` 73 | 74 | 6. Inference 75 | 76 | ```python singleImage.py -f img``` 77 | 78 | The default inference requires a compatible GPU to run. If you would like to run on a CPU, please comment the .cuda() and load the pretrained weights into cpu. 79 | 80 | ##
Simplified API
81 | 82 | We provide a simple API for convenient usage if you want to plug in this method into your work. 83 | 84 | ```python 85 | import cv2 86 | from synergy3DMM import SynergyNet 87 | model = SynergyNet() 88 | I = cv2.imread() 89 | # get landmark [[y, x, z], 68 (points)], mesh [[y, x, z], 53215 (points)], and face pose (Euler angles [yaw, pitch, roll] and translation [y, x, z]) 90 | lmk3d, mesh, pose = model.get_all_outputs(I) 91 | ``` 92 | We provide a simple script in singleImage_simple.py 93 | 94 | We also provide a setup.py file. Run pip install -e . You can do from synergy3DMM import SynergyNet in other directory. Note that [3dmm_data] and [pretrained weight] (Put the model under 'pretrained/') need to be present. 95 | 96 | ##
Benchmark Evaluation
97 | 98 | 1. Follow Single Image Inference Demo: Step 1-4 99 | 100 | 2. Benchmarking 101 | 102 | ```python benchmark.py -w pretrained/best.pth.tar``` 103 | 104 | Print-out results and visualization fo first-50 examples are stored under 'results/' (see 'demo/' for some pre-generated samples as references) are shown. 105 | 106 | Updates: Best head pose estimation [pretrained model] (Mean MAE: 3.31) that is better than number reported in paper (3.35). Use -w to load different pretrained models. 107 | 108 | ##
Training
109 | 110 | 1. Follow Single Image Inference Demo: Step 1-4. 111 | 112 | 2. Download training data from [3DDFA]: train_aug_120x120.zip and extract the zip file under the root folder (Containing about 680K images). 113 | 114 | 3. 115 | ```bash train_script.sh``` 116 | 117 | 4. Please refer to train_script for hyperparameters, such as learning rate, epochs, or GPU device. The default settings take ~19G on a 3090 GPU and about 6 hours for training. If your GPU is less than this size, please decrease the batch size and learning rate proportionally. 118 | 119 | ##
Textured Artistic Face Meshes
120 | 121 | 1. Follow Single Image Inference Demo: Step 1-5. 122 | 123 | 2. Download artistic faces data [here], which are from [AF-Dataset]. Download our predicted UV maps [here] by UV-texture GAN. Extract them under the root folder. 124 | 125 | 3. 126 | ```python artistic.py -f art-all --png```(whole folder) 127 | 128 | ```python artistic.py -f art-all/122.png```(single image) 129 | 130 | 131 | Note that this artistic face dataset contains many different level/style face abstration. If a testing image is close to real, the result is much better than those of highly abstract samples. 132 | 133 | ##
Textured Real Face Renderings
134 | 135 | 1. Follow Single Image Inference Demo: Step 1-5. 136 | 137 | 2. Download our predicted UV maps and real face images for AFLW2000-3D [here] by UV-texture GAN. Extract them under the root folder. 138 | 139 | 3. 140 | ```python uv_texture_realFaces.py -f texture_data/real --png``` (whole folder) 141 | 142 | ```python uv_texture_realFaces.py -f texture_data/real/image00002_real_A.png``` (single image) 143 | 144 | The results (3D meshes and renderings) are stored under 'inference_output' 145 | 146 | ##
Training UV-texture GAN
147 | 148 | 1. Acquire AFLW2000-3D dataset and use [MGC-Net] test pipeline to get UV-texture for the AFLW2000 images. 149 | 150 | 2. Use [Pix2Pix] and train LSGAN with un-paired loss by their training recipe. In the input layer, concat the mean UV-texture and image and also shortcut add the mean texture at the output of generator. 151 | 152 | 3. The mean UV-texture can be got from original BFM set or from [face3d] 153 | 154 | ##
More Results
155 | 156 | We show a comparison with [DECA] using the top-3 largest roll angle samples in AFLW2000-3D. 157 | 158 | 159 | 160 | 161 | Facial alignemnt on AFLW2000-3D (NME of facial landmarks): 162 | 163 | 164 | 165 | Face orientation estimation on AFLW2000-3D (MAE of Euler angles): 166 | 167 | 168 | 169 | Results on artistic faces: 170 | 171 | 172 | 173 | 174 | 175 | **Related Project** 176 | 177 | [Cross-Modal Perceptionist] (analysis on relation for voice and 3D face) 178 | 179 | **Bibtex** 180 | 181 | If you find our work useful, please consider to cite our work 182 | 183 | @INPROCEEDINGS{wu2021synergy, 184 | author={Wu, Cho-Ying and Xu, Qiangeng and Neumann, Ulrich}, 185 | booktitle={2021 International Conference on 3D Vision (3DV)}, 186 | title={Synergy between 3DMM and 3D Landmarks for Accurate 3D Facial Geometry}, 187 | year={2021} 188 | } 189 | 190 | **Acknowledgement** 191 | 192 | The project is developed on [3DDFA] and [FSA-Net]. Thank them for their wonderful work. Thank [3DDFA-V2] for the face detector and rendering codes. 193 | -------------------------------------------------------------------------------- /Sim3DR/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | cmake-build-debug/ 3 | .idea/ 4 | build/ 5 | *.so 6 | data/ 7 | 8 | lib/rasterize.cpp -------------------------------------------------------------------------------- /Sim3DR/Sim3DR.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from . import _init_paths 4 | import numpy as np 5 | import Sim3DR_Cython 6 | 7 | 8 | def get_normal(vertices, triangles): 9 | normal = np.zeros_like(vertices, dtype=np.float32) 10 | Sim3DR_Cython.get_normal(normal, vertices, triangles, vertices.shape[0], triangles.shape[0]) 11 | return normal 12 | 13 | 14 | def rasterize(vertices, triangles, colors, bg=None, 15 | height=None, width=None, channel=None, 16 | reverse=False): 17 | if bg is not None: 18 | height, width, channel = bg.shape 19 | else: 20 | assert height is not None and width is not None and channel is not None 21 | bg = np.zeros((height, width, channel), dtype=np.float32) 22 | 23 | buffer = np.zeros((height, width), dtype=np.float32) - 1e8 24 | 25 | if colors.dtype != np.float32: 26 | colors = colors.astype(np.float32) 27 | Sim3DR_Cython.rasterize(bg, vertices, triangles, colors, buffer, triangles.shape[0], height, width, channel, 28 | reverse=reverse) 29 | return bg 30 | -------------------------------------------------------------------------------- /Sim3DR/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from .Sim3DR import get_normal, rasterize 4 | from .lighting import RenderPipeline 5 | -------------------------------------------------------------------------------- /Sim3DR/_init_paths.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os.path as osp 4 | import sys 5 | 6 | 7 | def add_path(path): 8 | if path not in sys.path: 9 | sys.path.insert(0, path) 10 | 11 | 12 | this_dir = osp.dirname(__file__) 13 | lib_path = osp.join(this_dir, '.') 14 | add_path(lib_path) 15 | -------------------------------------------------------------------------------- /Sim3DR/build_sim3dr.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py build_ext --inplace -------------------------------------------------------------------------------- /Sim3DR/lib/rasterize.h: -------------------------------------------------------------------------------- 1 | #ifndef MESH_CORE_HPP_ 2 | #define MESH_CORE_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace std; 12 | 13 | class Point3D { 14 | public: 15 | float x; 16 | float y; 17 | float z; 18 | 19 | public: 20 | Point3D() : x(0.f), y(0.f), z(0.f) {} 21 | Point3D(float x_, float y_, float z_) : x(x_), y(y_), z(z_) {} 22 | 23 | void initialize(float x_, float y_, float z_){ 24 | this->x = x_; this->y = y_; this->z = z_; 25 | } 26 | 27 | Point3D cross(Point3D &p){ 28 | Point3D c; 29 | c.x = this->y * p.z - this->z * p.y; 30 | c.y = this->z * p.x - this->x * p.z; 31 | c.z = this->x * p.y - this->y * p.x; 32 | return c; 33 | } 34 | 35 | float dot(Point3D &p) { 36 | return this->x * p.x + this->y * p.y + this->z * p.z; 37 | } 38 | 39 | Point3D operator-(const Point3D &p) { 40 | Point3D np; 41 | np.x = this->x - p.x; 42 | np.y = this->y - p.y; 43 | np.z = this->z - p.z; 44 | return np; 45 | } 46 | 47 | }; 48 | 49 | class Point { 50 | public: 51 | float x; 52 | float y; 53 | 54 | public: 55 | Point() : x(0.f), y(0.f) {} 56 | Point(float x_, float y_) : x(x_), y(y_) {} 57 | float dot(Point p) { 58 | return this->x * p.x + this->y * p.y; 59 | } 60 | 61 | Point operator-(const Point &p) { 62 | Point np; 63 | np.x = this->x - p.x; 64 | np.y = this->y - p.y; 65 | return np; 66 | } 67 | 68 | Point operator+(const Point &p) { 69 | Point np; 70 | np.x = this->x + p.x; 71 | np.y = this->y + p.y; 72 | return np; 73 | } 74 | 75 | Point operator*(float s) { 76 | Point np; 77 | np.x = s * this->x; 78 | np.y = s * this->y; 79 | return np; 80 | } 81 | }; 82 | 83 | 84 | bool is_point_in_tri(Point p, Point p0, Point p1, Point p2); 85 | 86 | void get_point_weight(float *weight, Point p, Point p0, Point p1, Point p2); 87 | 88 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int ntri, bool norm_flg); 89 | 90 | void _get_ver_normal(float *ver_normal, float *tri_normal, int *triangles, int nver, int ntri); 91 | 92 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri); 93 | 94 | void _rasterize_triangles( 95 | float *vertices, int *triangles, float *depth_buffer, int *triangle_buffer, float *barycentric_weight, 96 | int ntri, int h, int w); 97 | 98 | void _rasterize( 99 | unsigned char *image, float *vertices, int *triangles, float *colors, 100 | float *depth_buffer, int ntri, int h, int w, int c, float alpha, bool reverse); 101 | 102 | void _render_texture_core( 103 | float *image, float *vertices, int *triangles, 104 | float *texture, float *tex_coords, int *tex_triangles, 105 | float *depth_buffer, 106 | int nver, int tex_nver, int ntri, 107 | int h, int w, int c, 108 | int tex_h, int tex_w, int tex_c, 109 | int mapping_type); 110 | 111 | void _write_obj_with_colors_texture(string filename, string mtl_name, 112 | float *vertices, int *triangles, float *colors, float *uv_coords, 113 | int nver, int ntri, int ntexver); 114 | 115 | #endif -------------------------------------------------------------------------------- /Sim3DR/lib/rasterize.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | # from libcpp.string cimport string 4 | cimport cython 5 | from libcpp cimport bool 6 | 7 | # from cpython import bool 8 | 9 | # use the Numpy-C-API from Cython 10 | np.import_array() 11 | 12 | # cdefine the signature of our c function 13 | cdef extern from "rasterize.h": 14 | void _rasterize_triangles( 15 | float*vertices, int*triangles, float*depth_buffer, int*triangle_buffer, float*barycentric_weight, 16 | int ntri, int h, int w 17 | ) 18 | 19 | void _rasterize( 20 | unsigned char*image, float*vertices, int*triangles, float*colors, float*depth_buffer, 21 | int ntri, int h, int w, int c, float alpha, bool reverse 22 | ) 23 | 24 | # void _render_texture_core( 25 | # float* image, float* vertices, int* triangles, 26 | # float* texture, float* tex_coords, int* tex_triangles, 27 | # float* depth_buffer, 28 | # int nver, int tex_nver, int ntri, 29 | # int h, int w, int c, 30 | # int tex_h, int tex_w, int tex_c, 31 | # int mapping_type) 32 | 33 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int nver, bool norm_flg) 34 | void _get_ver_normal(float *ver_normal, float*tri_normal, int*triangles, int nver, int ntri) 35 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri) 36 | 37 | 38 | # void _write_obj_with_colors_texture(string filename, string mtl_name, 39 | # float* vertices, int* triangles, float* colors, float* uv_coords, 40 | # int nver, int ntri, int ntexver) 41 | 42 | @cython.boundscheck(False) 43 | @cython.wraparound(False) 44 | def get_tri_normal(np.ndarray[float, ndim=2, mode="c"] tri_normal not None, 45 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 46 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 47 | int ntri, bool norm_flg = False): 48 | _get_tri_normal( np.PyArray_DATA(tri_normal), np.PyArray_DATA(vertices), 49 | np.PyArray_DATA(triangles), ntri, norm_flg) 50 | 51 | @cython.boundscheck(False) # turn off bounds-checking for entire function 52 | @cython.wraparound(False) # turn off negative index wrapping for entire function 53 | def get_ver_normal(np.ndarray[float, ndim=2, mode = "c"] ver_normal not None, 54 | np.ndarray[float, ndim=2, mode = "c"] tri_normal not None, 55 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 56 | int nver, int ntri): 57 | _get_ver_normal( 58 | np.PyArray_DATA(ver_normal), np.PyArray_DATA(tri_normal), np.PyArray_DATA(triangles), 59 | nver, ntri) 60 | 61 | @cython.boundscheck(False) # turn off bounds-checking for entire function 62 | @cython.wraparound(False) # turn off negative index wrapping for entire function 63 | def get_normal(np.ndarray[float, ndim=2, mode = "c"] ver_normal not None, 64 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 65 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 66 | int nver, int ntri): 67 | _get_normal( 68 | np.PyArray_DATA(ver_normal), np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), 69 | nver, ntri) 70 | 71 | 72 | @cython.boundscheck(False) # turn off bounds-checking for entire function 73 | @cython.wraparound(False) # turn off negative index wrapping for entire function 74 | def rasterize_triangles( 75 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 76 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 77 | np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None, 78 | np.ndarray[int, ndim=2, mode = "c"] triangle_buffer not None, 79 | np.ndarray[float, ndim=2, mode = "c"] barycentric_weight not None, 80 | int ntri, int h, int w 81 | ): 82 | _rasterize_triangles( 83 | np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), 84 | np.PyArray_DATA(depth_buffer), np.PyArray_DATA(triangle_buffer), 85 | np.PyArray_DATA(barycentric_weight), 86 | ntri, h, w) 87 | 88 | @cython.boundscheck(False) # turn off bounds-checking for entire function 89 | @cython.wraparound(False) # turn off negative index wrapping for entire function 90 | def rasterize(np.ndarray[unsigned char, ndim=3, mode = "c"] image not None, 91 | np.ndarray[float, ndim=2, mode = "c"] vertices not None, 92 | np.ndarray[int, ndim=2, mode="c"] triangles not None, 93 | np.ndarray[float, ndim=2, mode = "c"] colors not None, 94 | np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None, 95 | int ntri, int h, int w, int c, float alpha = 1, bool reverse = False 96 | ): 97 | _rasterize( 98 | np.PyArray_DATA(image), np.PyArray_DATA(vertices), 99 | np.PyArray_DATA(triangles), 100 | np.PyArray_DATA(colors), 101 | np.PyArray_DATA(depth_buffer), 102 | ntri, h, w, c, alpha, reverse) 103 | 104 | # def render_texture_core(np.ndarray[float, ndim=3, mode = "c"] image not None, 105 | # np.ndarray[float, ndim=2, mode = "c"] vertices not None, 106 | # np.ndarray[int, ndim=2, mode="c"] triangles not None, 107 | # np.ndarray[float, ndim=3, mode = "c"] texture not None, 108 | # np.ndarray[float, ndim=2, mode = "c"] tex_coords not None, 109 | # np.ndarray[int, ndim=2, mode="c"] tex_triangles not None, 110 | # np.ndarray[float, ndim=2, mode = "c"] depth_buffer not None, 111 | # int nver, int tex_nver, int ntri, 112 | # int h, int w, int c, 113 | # int tex_h, int tex_w, int tex_c, 114 | # int mapping_type 115 | # ): 116 | # _render_texture_core( 117 | # np.PyArray_DATA(image), np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), 118 | # np.PyArray_DATA(texture), np.PyArray_DATA(tex_coords), np.PyArray_DATA(tex_triangles), 119 | # np.PyArray_DATA(depth_buffer), 120 | # nver, tex_nver, ntri, 121 | # h, w, c, 122 | # tex_h, tex_w, tex_c, 123 | # mapping_type) 124 | # 125 | # def write_obj_with_colors_texture_core(string filename, string mtl_name, 126 | # np.ndarray[float, ndim=2, mode = "c"] vertices not None, 127 | # np.ndarray[int, ndim=2, mode="c"] triangles not None, 128 | # np.ndarray[float, ndim=2, mode = "c"] colors not None, 129 | # np.ndarray[float, ndim=2, mode = "c"] uv_coords not None, 130 | # int nver, int ntri, int ntexver 131 | # ): 132 | # _write_obj_with_colors_texture(filename, mtl_name, 133 | # np.PyArray_DATA(vertices), np.PyArray_DATA(triangles), np.PyArray_DATA(colors), np.PyArray_DATA(uv_coords), 134 | # nver, ntri, ntexver) 135 | -------------------------------------------------------------------------------- /Sim3DR/lighting.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | from .Sim3DR import get_normal, rasterize 5 | 6 | _norm = lambda arr: arr / np.sqrt(np.sum(arr ** 2, axis=1))[:, None] 7 | 8 | 9 | def norm_vertices(vertices): 10 | vertices -= vertices.min(0)[None, :] 11 | vertices /= vertices.max() 12 | vertices *= 2 13 | vertices -= vertices.max(0)[None, :] / 2 14 | return vertices 15 | 16 | 17 | def convert_type(obj): 18 | if isinstance(obj, tuple) or isinstance(obj, list): 19 | return np.array(obj, dtype=np.float32)[None, :] 20 | return obj 21 | 22 | 23 | class RenderPipeline(object): 24 | def __init__(self, **kwargs): 25 | self.intensity_ambient = convert_type(kwargs.get('intensity_ambient', 0.3)) 26 | self.intensity_directional = convert_type(kwargs.get('intensity_directional', 0.6)) 27 | self.intensity_specular = convert_type(kwargs.get('intensity_specular', 0.1)) 28 | self.specular_exp = kwargs.get('specular_exp', 5) 29 | self.color_ambient = convert_type(kwargs.get('color_ambient', (1, 1, 1))) 30 | self.color_directional = convert_type(kwargs.get('color_directional', (1, 1, 1))) 31 | self.light_pos = convert_type(kwargs.get('light_pos', (0, 0, 5))) 32 | self.view_pos = convert_type(kwargs.get('view_pos', (0, 0, 5))) 33 | 34 | def update_light_pos(self, light_pos): 35 | self.light_pos = convert_type(light_pos) 36 | 37 | def __call__(self, vertices, triangles, bg, texture=None): 38 | normal = get_normal(vertices, triangles) 39 | 40 | # 1. lighting 41 | light = np.zeros_like(vertices, dtype=np.float32) 42 | # ambient component 43 | if self.intensity_ambient > 0: 44 | light += self.intensity_ambient * self.color_ambient 45 | 46 | vertices_n = norm_vertices(vertices.copy()) 47 | if self.intensity_directional > 0: 48 | # diffuse component 49 | direction = _norm(self.light_pos - vertices_n) 50 | cos = np.sum(normal * direction, axis=1)[:, None] 51 | # cos = np.clip(cos, 0, 1) 52 | # todo: check below 53 | light += self.intensity_directional * (self.color_directional * np.clip(cos, 0, 1)) 54 | 55 | # specular component 56 | if self.intensity_specular > 0: 57 | v2v = _norm(self.view_pos - vertices_n) 58 | reflection = 2 * cos * normal - direction 59 | spe = np.sum((v2v * reflection) ** self.specular_exp, axis=1)[:, None] 60 | spe = np.where(cos != 0, np.clip(spe, 0, 1), np.zeros_like(spe)) 61 | light += self.intensity_specular * self.color_directional * np.clip(spe, 0, 1) 62 | light = np.clip(light, 0, 1) 63 | 64 | # 2. rasterization, [0, 1] 65 | if texture is None: 66 | render_img = rasterize(vertices, triangles, light, bg=bg) 67 | return render_img 68 | else: 69 | texture *= light 70 | render_img = rasterize(vertices, triangles, texture, bg=bg) 71 | return render_img 72 | 73 | 74 | def main(): 75 | pass 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /Sim3DR/readme.md: -------------------------------------------------------------------------------- 1 | ## Sim3DR 2 | This is a simple 3D render, written by c++ and cython. 3 | 4 | ### Build Sim3DR 5 | 6 | ```shell script 7 | python3 setup.py build_ext --inplace 8 | ``` -------------------------------------------------------------------------------- /Sim3DR/setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | python setup.py build_ext -i 3 | to compile 4 | ''' 5 | 6 | from distutils.core import setup, Extension 7 | from Cython.Build import cythonize 8 | from Cython.Distutils import build_ext 9 | import numpy 10 | 11 | setup( 12 | name='Sim3DR_Cython', # not the package name 13 | cmdclass={'build_ext': build_ext}, 14 | ext_modules=[Extension("Sim3DR_Cython", 15 | sources=["lib/rasterize.pyx", "lib/rasterize_kernel.cpp"], 16 | language='c++', 17 | include_dirs=[numpy.get_include()], 18 | extra_compile_args=["-std=c++11"])], 19 | ) 20 | -------------------------------------------------------------------------------- /Sim3DR/tests/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /Sim3DR/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | 3 | set(TARGET test) 4 | project(${TARGET}) 5 | 6 | #find_package( OpenCV REQUIRED ) 7 | #include_directories( ${OpenCV_INCLUDE_DIRS} ) 8 | 9 | #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC -O3") 10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -std=c++11") 11 | add_executable(${TARGET} test.cpp rasterize_kernel.cpp io.cpp) 12 | target_include_directories(${TARGET} PRIVATE ${PROJECT_SOURCE_DIR}) 13 | -------------------------------------------------------------------------------- /Sim3DR/tests/io.cpp: -------------------------------------------------------------------------------- 1 | #include "io.h" 2 | 3 | //void load_obj(const string obj_fp, float* vertices, float* colors, float* triangles){ 4 | // string line; 5 | // ifstream in(obj_fp); 6 | // 7 | // if(in.is_open()){ 8 | // while (getline(in, line)){ 9 | // stringstream ss(line); 10 | // 11 | // char t; // type: v, f 12 | // ss >> t; 13 | // if (t == 'v'){ 14 | // 15 | // } 16 | // } 17 | // } 18 | //} 19 | 20 | void load_obj(const char *obj_fp, float *vertices, float *colors, int *triangles, int nver, int ntri) { 21 | FILE *fp; 22 | fp = fopen(obj_fp, "r"); 23 | 24 | char t; // type: v or f 25 | if (fp != nullptr) { 26 | for (int i = 0; i < nver; ++i) { 27 | fscanf(fp, "%c", &t); 28 | for (int j = 0; j < 3; ++j) 29 | fscanf(fp, " %f", &vertices[3 * i + j]); 30 | for (int j = 0; j < 3; ++j) 31 | fscanf(fp, " %f", &colors[3 * i + j]); 32 | fscanf(fp, "\n"); 33 | } 34 | // fscanf(fp, "%c", &t); 35 | for (int i = 0; i < ntri; ++i) { 36 | fscanf(fp, "%c", &t); 37 | for (int j = 0; j < 3; ++j) { 38 | fscanf(fp, " %d", &triangles[3 * i + j]); 39 | triangles[3 * i + j] -= 1; 40 | } 41 | fscanf(fp, "\n"); 42 | } 43 | 44 | fclose(fp); 45 | } 46 | } 47 | 48 | void load_ply(const char *ply_fp, float *vertices, int *triangles, int nver, int ntri) { 49 | FILE *fp; 50 | fp = fopen(ply_fp, "r"); 51 | 52 | // char s[256]; 53 | char t; 54 | if (fp != nullptr) { 55 | // for (int i = 0; i < 9; ++i) 56 | // fscanf(fp, "%s", s); 57 | for (int i = 0; i < nver; ++i) 58 | fscanf(fp, "%f %f %f\n", &vertices[3 * i], &vertices[3 * i + 1], &vertices[3 * i + 2]); 59 | 60 | for (int i = 0; i < ntri; ++i) 61 | fscanf(fp, "%c %d %d %d\n", &t, &triangles[3 * i], &triangles[3 * i + 1], &triangles[3 * i + 2]); 62 | 63 | fclose(fp); 64 | } 65 | } 66 | 67 | void write_ppm(const char *filename, unsigned char *img, int h, int w, int c) { 68 | FILE *fp; 69 | //open file for output 70 | fp = fopen(filename, "wb"); 71 | if (!fp) { 72 | fprintf(stderr, "Unable to open file '%s'\n", filename); 73 | exit(1); 74 | } 75 | 76 | //write the header file 77 | //image format 78 | fprintf(fp, "P6\n"); 79 | 80 | //image size 81 | fprintf(fp, "%d %d\n", w, h); 82 | 83 | // rgb component depth 84 | fprintf(fp, "%d\n", MAX_PXL_VALUE); 85 | 86 | // pixel data 87 | fwrite(img, sizeof(unsigned char), size_t(h * w * c), fp); 88 | fclose(fp); 89 | } -------------------------------------------------------------------------------- /Sim3DR/tests/io.h: -------------------------------------------------------------------------------- 1 | #ifndef IO_H_ 2 | #define IO_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | using namespace std; 11 | 12 | #define MAX_PXL_VALUE 255 13 | 14 | void load_obj(const char* obj_fp, float* vertices, float* colors, int* triangles, int nver, int ntri); 15 | void load_ply(const char* ply_fp, float* vertices, int* triangles, int nver, int ntri); 16 | 17 | 18 | void write_ppm(const char *filename, unsigned char *img, int h, int w, int c); 19 | 20 | #endif -------------------------------------------------------------------------------- /Sim3DR/tests/test.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Tesing cases 3 | */ 4 | 5 | #include 6 | #include 7 | #include "rasterize.h" 8 | #include "io.h" 9 | 10 | void test_isPointInTri() { 11 | Point p0(0, 0); 12 | Point p1(1, 0); 13 | Point p2(1, 1); 14 | 15 | Point p(0.2, 0.2); 16 | 17 | if (is_point_in_tri(p, p0, p1, p2)) 18 | std::cout << "In"; 19 | else 20 | std::cout << "Out"; 21 | std::cout << std::endl; 22 | } 23 | 24 | void test_getPointWeight() { 25 | Point p0(0, 0); 26 | Point p1(1, 0); 27 | Point p2(1, 1); 28 | 29 | Point p(0.2, 0.2); 30 | 31 | float weight[3]; 32 | get_point_weight(weight, p, p0, p1, p2); 33 | std::cout << weight[0] << " " << weight[1] << " " << weight[2] << std::endl; 34 | } 35 | 36 | void test_get_tri_normal() { 37 | float tri_normal[3]; 38 | // float vertices[9] = {1, 0, 0, 0, 0, 0, 0, 1, 0}; 39 | float vertices[9] = {1, 1.1, 0, 0, 0, 0, 0, 0.6, 0.7}; 40 | int triangles[3] = {0, 1, 2}; 41 | int ntri = 1; 42 | 43 | _get_tri_normal(tri_normal, vertices, triangles, ntri); 44 | 45 | for (int i = 0; i < 3; ++i) 46 | std::cout << tri_normal[i] << ", "; 47 | std::cout << std::endl; 48 | } 49 | 50 | void test_load_obj() { 51 | const char *fp = "../data/vd005_mesh.obj"; 52 | int nver = 35709; 53 | int ntri = 70789; 54 | 55 | auto *vertices = new float[nver]; 56 | auto *colors = new float[nver]; 57 | auto *triangles = new int[ntri]; 58 | load_obj(fp, vertices, colors, triangles, nver, ntri); 59 | 60 | delete[] vertices; 61 | delete[] colors; 62 | delete[] triangles; 63 | } 64 | 65 | void test_render() { 66 | // 1. loading obj 67 | // const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/vd005_mesh.obj"; 68 | const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/face1.obj"; 69 | int nver = 35709; //53215; //35709; 70 | int ntri = 70789; //105840;//70789; 71 | 72 | auto *vertices = new float[3 * nver]; 73 | auto *colors = new float[3 * nver]; 74 | auto *triangles = new int[3 * ntri]; 75 | load_obj(fp, vertices, colors, triangles, nver, ntri); 76 | 77 | // 2. rendering 78 | int h = 224, w = 224, c = 3; 79 | 80 | // enlarging 81 | int scale = 4; 82 | h *= scale; 83 | w *= scale; 84 | for (int i = 0; i < nver * 3; ++i) vertices[i] *= scale; 85 | 86 | auto *image = new unsigned char[h * w * c](); 87 | auto *depth_buffer = new float[h * w](); 88 | 89 | for (int i = 0; i < h * w; ++i) depth_buffer[i] = -999999; 90 | 91 | clock_t t; 92 | t = clock(); 93 | 94 | _rasterize(image, vertices, triangles, colors, depth_buffer, ntri, h, w, c, true); 95 | t = clock() - t; 96 | double time_taken = ((double) t) / CLOCKS_PER_SEC; // in seconds 97 | printf("Render took %f seconds to execute \n", time_taken); 98 | 99 | 100 | // auto *image_char = new u_char[h * w * c](); 101 | // for (int i = 0; i < h * w * c; ++i) 102 | // image_char[i] = u_char(255 * image[i]); 103 | write_ppm("res.ppm", image, h, w, c); 104 | 105 | // delete[] image_char; 106 | delete[] vertices; 107 | delete[] colors; 108 | delete[] triangles; 109 | delete[] image; 110 | delete[] depth_buffer; 111 | } 112 | 113 | void test_light() { 114 | // 1. loading obj 115 | const char *fp = "/Users/gjz/gjzprojects/Sim3DR/data/emma_input_0_noheader.ply"; 116 | int nver = 53215; //35709; 117 | int ntri = 105840; //70789; 118 | 119 | auto *vertices = new float[3 * nver]; 120 | auto *colors = new float[3 * nver]; 121 | auto *triangles = new int[3 * ntri]; 122 | load_ply(fp, vertices, triangles, nver, ntri); 123 | 124 | // 2. rendering 125 | // int h = 1901, w = 3913, c = 3; 126 | int h = 2000, w = 4000, c = 3; 127 | 128 | // enlarging 129 | // int scale = 1; 130 | // h *= scale; 131 | // w *= scale; 132 | // for (int i = 0; i < nver * 3; ++i) vertices[i] *= scale; 133 | 134 | auto *image = new unsigned char[h * w * c](); 135 | auto *depth_buffer = new float[h * w](); 136 | 137 | for (int i = 0; i < h * w; ++i) depth_buffer[i] = -999999; 138 | for (int i = 0; i < 3 * nver; ++i) colors[i] = 0.8; 139 | 140 | clock_t t; 141 | t = clock(); 142 | 143 | _rasterize(image, vertices, triangles, colors, depth_buffer, ntri, h, w, c, true); 144 | t = clock() - t; 145 | double time_taken = ((double) t) / CLOCKS_PER_SEC; // in seconds 146 | printf("Render took %f seconds to execute \n", time_taken); 147 | 148 | 149 | // auto *image_char = new u_char[h * w * c](); 150 | // for (int i = 0; i < h * w * c; ++i) 151 | // image_char[i] = u_char(255 * image[i]); 152 | write_ppm("emma.ppm", image, h, w, c); 153 | 154 | // delete[] image_char; 155 | delete[] vertices; 156 | delete[] colors; 157 | delete[] triangles; 158 | delete[] image; 159 | delete[] depth_buffer; 160 | } 161 | 162 | int main(int argc, char *argv[]) { 163 | // std::cout << "Hello CMake!" << std::endl; 164 | 165 | // test_isPointInTri(); 166 | // test_getPointWeight(); 167 | // test_get_tri_normal(); 168 | // test_load_obj(); 169 | // test_render(); 170 | test_light(); 171 | return 0; 172 | } -------------------------------------------------------------------------------- /artistic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import cv2 5 | from utils.ddfa import ToTensor, Normalize 6 | from model_building import SynergyNet 7 | from utils.inference import crop_img, predict_denseVert 8 | import argparse 9 | import torch.backends.cudnn as cudnn 10 | cudnn.benchmark = True 11 | import os 12 | import os.path as osp 13 | import glob 14 | from FaceBoxes import FaceBoxes 15 | 16 | # Following 3DDFA-V2, we also use 120x120 resolution 17 | IMG_SIZE = 120 18 | 19 | def write_obj_with_colors(obj_name, vertices, triangles, colors): 20 | triangles = triangles.copy() 21 | 22 | if obj_name.split('.')[-1] != 'obj': 23 | obj_name = obj_name + '.obj' 24 | with open(obj_name, 'w') as f: 25 | for i in range(vertices.shape[1]): 26 | s = 'v {:.4f} {:.4f} {:.4f} {} {} {}\n'.format(vertices[0, i], vertices[1, i], vertices[2, i], colors[i, 2], 27 | colors[i, 1], colors[i, 0]) 28 | f.write(s) 29 | for i in range(triangles.shape[1]): 30 | s = 'f {} {} {}\n'.format(triangles[0, i], triangles[1, i], triangles[2, i]) 31 | f.write(s) 32 | 33 | def main(args): 34 | # load pre-tained model 35 | checkpoint_fp = 'pretrained/best.pth.tar' 36 | args.arch = 'mobilenet_v2' 37 | args.devices_id = [0] 38 | 39 | checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict'] 40 | 41 | model = SynergyNet(args) 42 | model_dict = model.state_dict() 43 | 44 | # load BFM_UV mapping and kept indicies and deleted triangles 45 | uv_vert=np.load('3dmm_data/BFM_UV.npy') 46 | coord_u = (uv_vert[:,1]*255.0).astype(np.int32) 47 | coord_v = (uv_vert[:,0]*255.0).astype(np.int32) 48 | keep_ind = np.load('3dmm_data/keptInd.npy') 49 | tri_deletion = np.load('3dmm_data/deletedTri.npy') 50 | 51 | # because the model is trained by multiple gpus, prefix 'module' should be removed 52 | for k in checkpoint.keys(): 53 | model_dict[k.replace('module.', '')] = checkpoint[k] 54 | 55 | model.load_state_dict(model_dict, strict=False) 56 | model = model.cuda() 57 | model.eval() 58 | 59 | # face detector 60 | face_boxes = FaceBoxes() 61 | 62 | # preparation 63 | transform = transforms.Compose([ToTensor(), Normalize(mean=127.5, std=128)]) 64 | if osp.isdir(args.files): 65 | if not args.files[-1] == '/': 66 | args.files = args.files + '/' 67 | if not args.png: 68 | files = sorted(glob.glob(args.files+'*.jpg')) 69 | else: 70 | files = sorted(glob.glob(args.files+'*.png')) 71 | else: 72 | files = [args.files] 73 | 74 | for img_fp in files: 75 | print("Process the image: ", img_fp) 76 | 77 | img_ori = cv2.imread(img_fp) 78 | 79 | # crop faces 80 | rects = face_boxes(img_ori) 81 | 82 | # storage 83 | vertices_lst = [] 84 | for rect in rects: 85 | roi_box = rect 86 | 87 | # enlarge the bbox a little and do a square crop 88 | HCenter = (rect[1] + rect[3])/2 89 | WCenter = (rect[0] + rect[2])/2 90 | side_len = roi_box[3]-roi_box[1] 91 | margin = side_len * 1.2 // 2 92 | roi_box[0], roi_box[1], roi_box[2], roi_box[3] = WCenter-margin, HCenter-margin, WCenter+margin, HCenter+margin 93 | 94 | img = crop_img(img_ori, roi_box) 95 | img = cv2.resize(img, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) 96 | 97 | input = transform(img).unsqueeze(0) 98 | with torch.no_grad(): 99 | input = input.cuda() 100 | param = model.forward_test(input) 101 | param = param.squeeze().cpu().numpy().flatten().astype(np.float32) 102 | 103 | # dense pts 104 | vertices = predict_denseVert(param, roi_box, transform=True) 105 | vertices_lst.append(vertices) 106 | 107 | # textured obj file output 108 | if not osp.exists(f'inference_output/obj/'): 109 | os.makedirs(f'inference_output/obj/') 110 | 111 | name = img_fp.rsplit('/',1)[-1][:-4] # drop off the extension 112 | colors = cv2.imread(f'uv_art/{name}_fake_B.png',-1) 113 | colors = np.flip(colors,axis=0) 114 | colors_uv = (colors[coord_u, coord_v,:]) 115 | 116 | wfp = f'inference_output/obj/{name}.obj' 117 | write_obj_with_colors(wfp, vertices[:,keep_ind], tri_deletion, colors_uv[keep_ind,:].astype(np.float32)) 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument('-f', '--files', default='', help='path to a single image or path to a folder containing multiple images') 122 | parser.add_argument("--png", action="store_true", help="if images are with .png extension") 123 | parser.add_argument('--img_size', default=120, type=int) 124 | parser.add_argument('-b', '--batch-size', default=1, type=int) 125 | 126 | args = parser.parse_args() 127 | main(args) -------------------------------------------------------------------------------- /backbone_nets/ResNeSt/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnest import * 2 | from .ablation import * 3 | -------------------------------------------------------------------------------- /backbone_nets/ResNeSt/ablation.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt ablation study models""" 9 | 10 | import torch 11 | from .resnet import ResNet, Bottleneck 12 | 13 | __all__ = ['resnest50_fast_1s1x64d', 'resnest50_fast_2s1x64d', 'resnest50_fast_4s1x64d', 14 | 'resnest50_fast_1s2x40d', 'resnest50_fast_2s2x40d', 'resnest50_fast_4s2x40d', 15 | 'resnest50_fast_1s4x24d'] 16 | 17 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 18 | 19 | _model_sha256 = {name: checksum for checksum, name in [ 20 | ('d8fbf808', 'resnest50_fast_1s1x64d'), 21 | ('44938639', 'resnest50_fast_2s1x64d'), 22 | ('f74f3fc3', 'resnest50_fast_4s1x64d'), 23 | ('32830b84', 'resnest50_fast_1s2x40d'), 24 | ('9d126481', 'resnest50_fast_2s2x40d'), 25 | ('41d14ed0', 'resnest50_fast_4s2x40d'), 26 | ('d4a4f76f', 'resnest50_fast_1s4x24d'), 27 | ]} 28 | 29 | def short_hash(name): 30 | if name not in _model_sha256: 31 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 32 | return _model_sha256[name][:8] 33 | 34 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 35 | name in _model_sha256.keys() 36 | } 37 | 38 | def resnest50_fast_1s1x64d(pretrained=False, root='~/.encoding/models', **kwargs): 39 | model = ResNet(Bottleneck, [3, 4, 6, 3], 40 | radix=1, groups=1, bottleneck_width=64, 41 | deep_stem=True, stem_width=32, avg_down=True, 42 | avd=True, avd_first=True, **kwargs) 43 | if pretrained: 44 | model.load_state_dict(torch.hub.load_state_dict_from_url( 45 | resnest_model_urls['resnest50_fast_1s1x64d'], progress=True, check_hash=True)) 46 | return model 47 | 48 | def resnest50_fast_2s1x64d(pretrained=False, root='~/.encoding/models', **kwargs): 49 | model = ResNet(Bottleneck, [3, 4, 6, 3], 50 | radix=2, groups=1, bottleneck_width=64, 51 | deep_stem=True, stem_width=32, avg_down=True, 52 | avd=True, avd_first=True, **kwargs) 53 | if pretrained: 54 | model.load_state_dict(torch.hub.load_state_dict_from_url( 55 | resnest_model_urls['resnest50_fast_2s1x64d'], progress=True, check_hash=True)) 56 | return model 57 | 58 | def resnest50_fast_4s1x64d(pretrained=False, root='~/.encoding/models', **kwargs): 59 | model = ResNet(Bottleneck, [3, 4, 6, 3], 60 | radix=4, groups=1, bottleneck_width=64, 61 | deep_stem=True, stem_width=32, avg_down=True, 62 | avd=True, avd_first=True, **kwargs) 63 | if pretrained: 64 | model.load_state_dict(torch.hub.load_state_dict_from_url( 65 | resnest_model_urls['resnest50_fast_4s1x64d'], progress=True, check_hash=True)) 66 | return model 67 | 68 | def resnest50_fast_1s2x40d(pretrained=False, root='~/.encoding/models', **kwargs): 69 | model = ResNet(Bottleneck, [3, 4, 6, 3], 70 | radix=1, groups=2, bottleneck_width=40, 71 | deep_stem=True, stem_width=32, avg_down=True, 72 | avd=True, avd_first=True, **kwargs) 73 | if pretrained: 74 | model.load_state_dict(torch.hub.load_state_dict_from_url( 75 | resnest_model_urls['resnest50_fast_1s2x40d'], progress=True, check_hash=True)) 76 | return model 77 | 78 | def resnest50_fast_2s2x40d(pretrained=False, root='~/.encoding/models', **kwargs): 79 | model = ResNet(Bottleneck, [3, 4, 6, 3], 80 | radix=2, groups=2, bottleneck_width=40, 81 | deep_stem=True, stem_width=32, avg_down=True, 82 | avd=True, avd_first=True, **kwargs) 83 | if pretrained: 84 | model.load_state_dict(torch.hub.load_state_dict_from_url( 85 | resnest_model_urls['resnest50_fast_2s2x40d'], progress=True, check_hash=True)) 86 | return model 87 | 88 | def resnest50_fast_4s2x40d(pretrained=False, root='~/.encoding/models', **kwargs): 89 | model = ResNet(Bottleneck, [3, 4, 6, 3], 90 | radix=4, groups=2, bottleneck_width=40, 91 | deep_stem=True, stem_width=32, avg_down=True, 92 | avd=True, avd_first=True, **kwargs) 93 | if pretrained: 94 | model.load_state_dict(torch.hub.load_state_dict_from_url( 95 | resnest_model_urls['resnest50_fast_4s2x40d'], progress=True, check_hash=True)) 96 | return model 97 | 98 | def resnest50_fast_1s4x24d(pretrained=False, root='~/.encoding/models', **kwargs): 99 | model = ResNet(Bottleneck, [3, 4, 6, 3], 100 | radix=1, groups=4, bottleneck_width=24, 101 | deep_stem=True, stem_width=32, avg_down=True, 102 | avd=True, avd_first=True, **kwargs) 103 | if pretrained: 104 | model.load_state_dict(torch.hub.load_state_dict_from_url( 105 | resnest_model_urls['resnest50_fast_1s4x24d'], progress=True, check_hash=True)) 106 | return model 107 | -------------------------------------------------------------------------------- /backbone_nets/ResNeSt/resnest.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2020 5 | ## 6 | ## LICENSE file in the root directory of this source tree 7 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | """ResNeSt models""" 9 | 10 | import torch 11 | from .resnet import ResNet, Bottleneck 12 | 13 | __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] 14 | 15 | _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' 16 | 17 | _model_sha256 = {name: checksum for checksum, name in [ 18 | ('528c19ca', 'resnest50'), 19 | ('22405ba7', 'resnest101'), 20 | ('75117900', 'resnest200'), 21 | ('0cc87c48', 'resnest269'), 22 | ]} 23 | 24 | def short_hash(name): 25 | if name not in _model_sha256: 26 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 27 | return _model_sha256[name][:8] 28 | 29 | resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for 30 | name in _model_sha256.keys() 31 | } 32 | 33 | def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): 34 | model = ResNet(Bottleneck, [3, 4, 6, 3], 35 | radix=2, groups=1, bottleneck_width=64, 36 | deep_stem=True, stem_width=32, avg_down=True, 37 | avd=True, avd_first=False, **kwargs) 38 | if pretrained: 39 | model.load_state_dict(torch.hub.load_state_dict_from_url( 40 | resnest_model_urls['resnest50'], progress=True, check_hash=True), strict=False) 41 | return model 42 | 43 | def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): 44 | model = ResNet(Bottleneck, [3, 4, 23, 3], 45 | radix=2, groups=1, bottleneck_width=64, 46 | deep_stem=True, stem_width=64, avg_down=True, 47 | avd=True, avd_first=False, **kwargs) 48 | if pretrained: 49 | model.load_state_dict(torch.hub.load_state_dict_from_url( 50 | resnest_model_urls['resnest101'], progress=True, check_hash=True), strict=False) 51 | return model 52 | 53 | def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): 54 | model = ResNet(Bottleneck, [3, 24, 36, 3], 55 | radix=2, groups=1, bottleneck_width=64, 56 | deep_stem=True, stem_width=64, avg_down=True, 57 | avd=True, avd_first=False, **kwargs) 58 | if pretrained: 59 | model.load_state_dict(torch.hub.load_state_dict_from_url( 60 | resnest_model_urls['resnest200'], progress=True, check_hash=True)) 61 | return model 62 | 63 | def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): 64 | model = ResNet(Bottleneck, [3, 30, 48, 8], 65 | radix=2, groups=1, bottleneck_width=64, 66 | deep_stem=True, stem_width=64, avg_down=True, 67 | avd=True, avd_first=False, **kwargs) 68 | if pretrained: 69 | model.load_state_dict(torch.hub.load_state_dict_from_url( 70 | resnest_model_urls['resnest269'], progress=True, check_hash=True)) 71 | return model 72 | -------------------------------------------------------------------------------- /backbone_nets/ResNeSt/splat.py: -------------------------------------------------------------------------------- 1 | """Split-Attention""" 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU 7 | from torch.nn.modules.utils import _pair 8 | 9 | __all__ = ['SplAtConv2d'] 10 | 11 | class SplAtConv2d(Module): 12 | """Split-Attention Conv2d 13 | """ 14 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 15 | dilation=(1, 1), groups=1, bias=True, 16 | radix=2, reduction_factor=4, 17 | rectify=False, rectify_avg=False, norm_layer=None, 18 | dropblock_prob=0.0, **kwargs): 19 | super(SplAtConv2d, self).__init__() 20 | padding = _pair(padding) 21 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 22 | self.rectify_avg = rectify_avg 23 | inter_channels = max(in_channels*radix//reduction_factor, 32) 24 | self.radix = radix 25 | self.cardinality = groups 26 | self.channels = channels 27 | self.dropblock_prob = dropblock_prob 28 | if self.rectify: 29 | from rfconv import RFConv2d 30 | self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 31 | groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) 32 | else: 33 | self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, 34 | groups=groups*radix, bias=bias, **kwargs) 35 | self.use_bn = norm_layer is not None 36 | if self.use_bn: 37 | self.bn0 = norm_layer(channels*radix) 38 | self.relu = ReLU(inplace=True) 39 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 40 | if self.use_bn: 41 | self.bn1 = norm_layer(inter_channels) 42 | self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) 43 | if dropblock_prob > 0.0: 44 | self.dropblock = DropBlock2D(dropblock_prob, 3) 45 | self.rsoftmax = rSoftMax(radix, groups) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | if self.use_bn: 50 | x = self.bn0(x) 51 | if self.dropblock_prob > 0.0: 52 | x = self.dropblock(x) 53 | x = self.relu(x) 54 | 55 | batch, rchannel = x.shape[:2] 56 | if self.radix > 1: 57 | if torch.__version__ < '1.5': 58 | splited = torch.split(x, int(rchannel//self.radix), dim=1) 59 | else: 60 | splited = torch.split(x, rchannel//self.radix, dim=1) 61 | gap = sum(splited) 62 | else: 63 | gap = x 64 | gap = F.adaptive_avg_pool2d(gap, 1) 65 | gap = self.fc1(gap) 66 | 67 | if self.use_bn: 68 | gap = self.bn1(gap) 69 | gap = self.relu(gap) 70 | 71 | atten = self.fc2(gap) 72 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 73 | 74 | if self.radix > 1: 75 | if torch.__version__ < '1.5': 76 | attens = torch.split(atten, int(rchannel//self.radix), dim=1) 77 | else: 78 | attens = torch.split(atten, rchannel//self.radix, dim=1) 79 | out = sum([att*split for (att, split) in zip(attens, splited)]) 80 | else: 81 | out = atten * x 82 | return out.contiguous() 83 | 84 | class rSoftMax(nn.Module): 85 | def __init__(self, radix, cardinality): 86 | super().__init__() 87 | self.radix = radix 88 | self.cardinality = cardinality 89 | 90 | def forward(self, x): 91 | batch = x.size(0) 92 | if self.radix > 1: 93 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 94 | x = F.softmax(x, dim=1) 95 | x = x.reshape(batch, -1) 96 | else: 97 | x = torch.sigmoid(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /backbone_nets/ghostnet_backbone.py: -------------------------------------------------------------------------------- 1 | # 2020.06.09-Changed for building GhostNet 2 | # Huawei Technologies Co., Ltd. 3 | """ 4 | Creates a GhostNet Model as defined in: 5 | GhostNet: More Features from Cheap Operations By Kai Han, Yunhe Wang, Qi Tian, Jianyuan Guo, Chunjing Xu, Chang Xu. 6 | https://arxiv.org/abs/1911.11907 7 | Modified from https://github.com/d-li14/mobilenetv3.pytorch and https://github.com/rwightman/pytorch-image-models 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | 15 | __all__ = ['ghostnet'] 16 | 17 | 18 | def _make_divisible(v, divisor, min_value=None): 19 | """ 20 | This function is taken from the original tf repo. 21 | It ensures that all layers have a channel number that is divisible by 8 22 | It can be seen here: 23 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 24 | """ 25 | if min_value is None: 26 | min_value = divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def hard_sigmoid(x, inplace: bool = False): 35 | if inplace: 36 | return x.add_(3.).clamp_(0., 6.).div_(6.) 37 | else: 38 | return F.relu6(x + 3.) / 6. 39 | 40 | 41 | class SqueezeExcite(nn.Module): 42 | def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, 43 | act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): 44 | super(SqueezeExcite, self).__init__() 45 | self.gate_fn = gate_fn 46 | reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) 47 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 48 | self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) 49 | self.act1 = act_layer(inplace=True) 50 | self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) 51 | 52 | def forward(self, x): 53 | x_se = self.avg_pool(x) 54 | x_se = self.conv_reduce(x_se) 55 | x_se = self.act1(x_se) 56 | x_se = self.conv_expand(x_se) 57 | x = x * self.gate_fn(x_se) 58 | return x 59 | 60 | 61 | class ConvBnAct(nn.Module): 62 | def __init__(self, in_chs, out_chs, kernel_size, 63 | stride=1, act_layer=nn.ReLU): 64 | super(ConvBnAct, self).__init__() 65 | self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) 66 | self.bn1 = nn.BatchNorm2d(out_chs) 67 | self.act1 = act_layer(inplace=True) 68 | 69 | def forward(self, x): 70 | x = self.conv(x) 71 | x = self.bn1(x) 72 | x = self.act1(x) 73 | return x 74 | 75 | 76 | class GhostModule(nn.Module): 77 | def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): 78 | super(GhostModule, self).__init__() 79 | self.oup = oup 80 | init_channels = math.ceil(oup / ratio) 81 | new_channels = init_channels*(ratio-1) 82 | 83 | self.primary_conv = nn.Sequential( 84 | nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), 85 | nn.BatchNorm2d(init_channels), 86 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 87 | ) 88 | 89 | self.cheap_operation = nn.Sequential( 90 | nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), 91 | nn.BatchNorm2d(new_channels), 92 | nn.ReLU(inplace=True) if relu else nn.Sequential(), 93 | ) 94 | 95 | def forward(self, x): 96 | x1 = self.primary_conv(x) 97 | x2 = self.cheap_operation(x1) 98 | out = torch.cat([x1,x2], dim=1) 99 | return out[:,:self.oup,:,:] 100 | 101 | 102 | class GhostBottleneck(nn.Module): 103 | """ Ghost bottleneck w/ optional SE""" 104 | 105 | def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, 106 | stride=1, act_layer=nn.ReLU, se_ratio=0.): 107 | super(GhostBottleneck, self).__init__() 108 | has_se = se_ratio is not None and se_ratio > 0. 109 | self.stride = stride 110 | 111 | # Point-wise expansion 112 | self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) 113 | 114 | # Depth-wise convolution 115 | if self.stride > 1: 116 | self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, 117 | padding=(dw_kernel_size-1)//2, 118 | groups=mid_chs, bias=False) 119 | self.bn_dw = nn.BatchNorm2d(mid_chs) 120 | 121 | # Squeeze-and-excitation 122 | if has_se: 123 | self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) 124 | else: 125 | self.se = None 126 | 127 | # Point-wise linear projection 128 | self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) 129 | 130 | # shortcut 131 | if (in_chs == out_chs and self.stride == 1): 132 | self.shortcut = nn.Sequential() 133 | else: 134 | self.shortcut = nn.Sequential( 135 | nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, 136 | padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), 137 | nn.BatchNorm2d(in_chs), 138 | nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), 139 | nn.BatchNorm2d(out_chs), 140 | ) 141 | 142 | 143 | def forward(self, x): 144 | residual = x 145 | 146 | # 1st ghost bottleneck 147 | x = self.ghost1(x) 148 | 149 | # Depth-wise convolution 150 | if self.stride > 1: 151 | x = self.conv_dw(x) 152 | x = self.bn_dw(x) 153 | 154 | # Squeeze-and-excitation 155 | if self.se is not None: 156 | x = self.se(x) 157 | 158 | # 2nd ghost bottleneck 159 | x = self.ghost2(x) 160 | 161 | x += self.shortcut(residual) 162 | return x 163 | 164 | 165 | class GhostNet(nn.Module): 166 | def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2): 167 | super(GhostNet, self).__init__() 168 | # setting of inverted residual blocks 169 | self.cfgs = cfgs 170 | self.dropout = dropout 171 | 172 | # building first layer 173 | output_channel = _make_divisible(16 * width, 4) 174 | self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) 175 | self.bn1 = nn.BatchNorm2d(output_channel) 176 | self.act1 = nn.ReLU(inplace=True) 177 | input_channel = output_channel 178 | 179 | # building inverted residual blocks 180 | stages = [] 181 | block = GhostBottleneck 182 | for cfg in self.cfgs: 183 | layers = [] 184 | for k, exp_size, c, se_ratio, s in cfg: 185 | output_channel = _make_divisible(c * width, 4) 186 | hidden_channel = _make_divisible(exp_size * width, 4) 187 | layers.append(block(input_channel, hidden_channel, output_channel, k, s, 188 | se_ratio=se_ratio)) 189 | input_channel = output_channel 190 | stages.append(nn.Sequential(*layers)) 191 | 192 | output_channel = _make_divisible(exp_size * width, 4) 193 | stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) 194 | input_channel = output_channel 195 | 196 | self.blocks = nn.Sequential(*stages) 197 | 198 | # building last several layers 199 | output_channel = 1280 200 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 201 | self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) 202 | self.act2 = nn.ReLU(inplace=True) 203 | 204 | self.num_ori = 12 205 | self.num_shape = 40 206 | self.num_exp = 10 207 | self.num_texture = 40 208 | 209 | self.classifier_ori = nn.Linear(output_channel, self.num_ori) 210 | self.classifier_shape = nn.Linear(output_channel, self.num_shape) 211 | self.classifier_exp = nn.Linear(output_channel, self.num_exp) 212 | self.classifier_texture = nn.Linear(output_channel, self.num_texture) 213 | 214 | def forward(self, x): 215 | x = self.conv_stem(x) 216 | x = self.bn1(x) 217 | x = self.act1(x) 218 | x = self.blocks(x) 219 | x = self.global_pool(x) 220 | x = self.conv_head(x) 221 | x = self.act2(x) 222 | x = x.view(x.size(0), -1) 223 | if self.dropout > 0.: 224 | x = F.dropout(x, p=self.dropout, training=self.training) 225 | #x = self.classifier(x) 226 | 227 | x_ori = self.classifier_ori(x) 228 | x_shape = self.classifier_shape(x) 229 | x_exp = self.classifier_exp(x) 230 | x_tex = self.classifier_texture(x) 231 | x = torch.cat((x_ori, x_shape, x_exp, x_tex), dim=1) 232 | 233 | return x 234 | 235 | 236 | def ghostnet(**kwargs): 237 | """ 238 | Constructs a GhostNet model 239 | """ 240 | cfgs = [ 241 | # k, t, c, SE, s 242 | # stage1 243 | [[3, 16, 16, 0, 1]], 244 | # stage2 245 | [[3, 48, 24, 0, 2]], 246 | [[3, 72, 24, 0, 1]], 247 | # stage3 248 | [[5, 72, 40, 0.25, 2]], 249 | [[5, 120, 40, 0.25, 1]], 250 | # stage4 251 | [[3, 240, 80, 0, 2]], 252 | [[3, 200, 80, 0, 1], 253 | [3, 184, 80, 0, 1], 254 | [3, 184, 80, 0, 1], 255 | [3, 480, 112, 0.25, 1], 256 | [3, 672, 112, 0.25, 1] 257 | ], 258 | # stage5 259 | [[5, 672, 160, 0.25, 2]], 260 | [[5, 960, 160, 0, 1], 261 | [5, 960, 160, 0.25, 1], 262 | [5, 960, 160, 0, 1], 263 | [5, 960, 160, 0.25, 1] 264 | ] 265 | ] 266 | return GhostNet(cfgs, **kwargs) 267 | 268 | 269 | if __name__=='__main__': 270 | model = ghostnet() 271 | model.eval() 272 | print(model) 273 | input = torch.randn(32,3,320,256) 274 | y = model(input) 275 | print(y.size()) -------------------------------------------------------------------------------- /backbone_nets/mobilenetv1_backbone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | from __future__ import division 5 | 6 | """ 7 | Creates a MobileNet Model as defined in: 8 | Andrew G. Howard Menglong Zhu Bo Chen, et.al. (2017). 9 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications. 10 | Copyright (c) Yang Lu, 2017 11 | 12 | Modified By cleardusk 13 | """ 14 | import math 15 | import torch 16 | import torch.nn as nn 17 | 18 | __all__ = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025'] 19 | 20 | 21 | class DepthWiseBlock(nn.Module): 22 | def __init__(self, inplanes, planes, stride=1, prelu=False): 23 | super(DepthWiseBlock, self).__init__() 24 | inplanes, planes = int(inplanes), int(planes) 25 | self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=1, stride=stride, groups=inplanes, 26 | bias=False) 27 | self.bn_dw = nn.BatchNorm2d(inplanes) 28 | self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False) 29 | self.bn_sep = nn.BatchNorm2d(planes) 30 | if prelu: 31 | self.relu = nn.PReLU() 32 | else: 33 | self.relu = nn.ReLU(inplace=True) 34 | 35 | def forward(self, x): 36 | out = self.conv_dw(x) 37 | out = self.bn_dw(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv_sep(out) 41 | out = self.bn_sep(out) 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class MobileNet(nn.Module): 48 | def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3): 49 | """ Constructor 50 | Args: 51 | widen_factor: config of widen_factor 52 | num_classes: number of classes 53 | """ 54 | super(MobileNet, self).__init__() 55 | 56 | block = DepthWiseBlock 57 | self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1, 58 | bias=False) 59 | 60 | self.bn1 = nn.BatchNorm2d(int(32 * widen_factor)) 61 | if prelu: 62 | self.relu = nn.PReLU() 63 | else: 64 | self.relu = nn.ReLU(inplace=True) 65 | 66 | self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu) 67 | self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu) 68 | 69 | self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu) 70 | self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu) 71 | 72 | self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu) 73 | self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu) 74 | 75 | self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 76 | self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 77 | self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 78 | self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 79 | self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 80 | self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu) 81 | 82 | self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu) 83 | 84 | self.avgpool = nn.AdaptiveAvgPool2d(1) 85 | #self.fc = nn.Linear(int(1024 * widen_factor), num_classes) 86 | 87 | self.num_ori = 12 88 | self.num_shape = 40 89 | self.num_exp = 10 90 | self.num_texture = 40 91 | 92 | #self.fc = nn.Linear(int(1024 * widen_factor), num_classes) 93 | 94 | ### Multi-decoder output 95 | self.fc_ori = nn.Linear(int(1024 * widen_factor), self.num_ori) 96 | self.fc_shape = nn.Linear(int(1024 * widen_factor), self.num_shape) 97 | self.fc_exp = nn.Linear(int(1024 * widen_factor), self.num_exp) 98 | self.fc_tex = nn.Linear(int(1024 * widen_factor), self.num_texture) 99 | 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | 108 | def forward(self, x): 109 | x = self.conv1(x) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | 113 | x = self.dw2_1(x) 114 | x = self.dw2_2(x) 115 | x = self.dw3_1(x) 116 | x = self.dw3_2(x) 117 | x = self.dw4_1(x) 118 | x = self.dw4_2(x) 119 | x = self.dw5_1(x) 120 | x = self.dw5_2(x) 121 | x = self.dw5_3(x) 122 | x = self.dw5_4(x) 123 | x = self.dw5_5(x) 124 | x = self.dw5_6(x) 125 | x = self.dw6(x) 126 | 127 | x = self.avgpool(x) 128 | x = x.view(x.size(0), -1) 129 | #x = self.fc(x) 130 | 131 | ### Multi-decoder output 132 | x_ori = self.fc_ori(x) 133 | x_shp = self.fc_shape(x) 134 | x_exp = self.fc_exp(x) 135 | #x = torch.cat((x_ori, x_shp, x_exp), dim=1) 136 | 137 | x_tex = self.fc_tex(x) 138 | x = torch.cat((x_ori, x_shp, x_exp, x_tex), dim=1) 139 | 140 | return x 141 | 142 | class MobileNet_ori(nn.Module): 143 | def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3): 144 | """ Constructor 145 | Args: 146 | widen_factor: config of widen_factor 147 | num_classes: number of classes 148 | """ 149 | super(MobileNet_ori, self).__init__() 150 | 151 | block = DepthWiseBlock 152 | self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1, 153 | bias=False) 154 | 155 | self.bn1 = nn.BatchNorm2d(int(32 * widen_factor)) 156 | if prelu: 157 | self.relu = nn.PReLU() 158 | else: 159 | self.relu = nn.ReLU(inplace=True) 160 | 161 | self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu) 162 | self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu) 163 | 164 | self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu) 165 | self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu) 166 | 167 | self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu) 168 | self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu) 169 | 170 | self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 171 | self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 172 | self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 173 | self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 174 | self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 175 | self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu) 176 | 177 | self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu) 178 | 179 | self.avgpool = nn.AdaptiveAvgPool2d(1) 180 | self.fc = nn.Linear(int(1024 * widen_factor), num_classes) 181 | 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 185 | m.weight.data.normal_(0, math.sqrt(2. / n)) 186 | elif isinstance(m, nn.BatchNorm2d): 187 | m.weight.data.fill_(1) 188 | m.bias.data.zero_() 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | 195 | x = self.dw2_1(x) 196 | x = self.dw2_2(x) 197 | x = self.dw3_1(x) 198 | x = self.dw3_2(x) 199 | x = self.dw4_1(x) 200 | x = self.dw4_2(x) 201 | x = self.dw5_1(x) 202 | x = self.dw5_2(x) 203 | x = self.dw5_3(x) 204 | x = self.dw5_4(x) 205 | x = self.dw5_5(x) 206 | x = self.dw5_6(x) 207 | x = self.dw6(x) 208 | 209 | x = self.avgpool(x) 210 | x = x.view(x.size(0), -1) 211 | x = self.fc(x) 212 | return x 213 | 214 | 215 | def mobilenet(widen_factor=1.0, num_classes=1000): 216 | """ 217 | Construct MobileNet. 218 | widen_factor=1.0 for mobilenet_1 219 | widen_factor=0.75 for mobilenet_075 220 | widen_factor=0.5 for mobilenet_05 221 | widen_factor=0.25 for mobilenet_025 222 | """ 223 | model = MobileNet(widen_factor=widen_factor, num_classes=num_classes) 224 | return model 225 | 226 | 227 | def mobilenet_2(num_classes=62, input_channel=3): 228 | model = MobileNet(widen_factor=2.0, num_classes=num_classes, input_channel=input_channel) 229 | return model 230 | 231 | 232 | def mobilenet_1(num_classes=62, input_channel=3): 233 | #model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel) 234 | model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel) 235 | return model 236 | 237 | 238 | def mobilenet_075(num_classes=62, input_channel=3): 239 | model = MobileNet(widen_factor=0.75, num_classes=num_classes, input_channel=input_channel) 240 | return model 241 | 242 | 243 | def mobilenet_05(num_classes=62, input_channel=3): 244 | model = MobileNet(widen_factor=0.5, num_classes=num_classes, input_channel=input_channel) 245 | return model 246 | 247 | 248 | def mobilenet_025(num_classes=62, input_channel=3): 249 | model = MobileNet(widen_factor=0.25, num_classes=num_classes, input_channel=input_channel) 250 | return model 251 | -------------------------------------------------------------------------------- /backbone_nets/mobilenetv2_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 35 | padding = (kernel_size - 1) // 2 36 | if norm_layer is None: 37 | norm_layer = nn.BatchNorm2d 38 | super(ConvBNReLU, self).__init__( 39 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 40 | norm_layer(out_planes), 41 | nn.ReLU6(inplace=True) 42 | ) 43 | 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 47 | super(InvertedResidual, self).__init__() 48 | self.stride = stride 49 | assert stride in [1, 2] 50 | 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | 54 | hidden_dim = int(round(inp * expand_ratio)) 55 | self.use_res_connect = self.stride == 1 and inp == oup 56 | 57 | layers = [] 58 | if expand_ratio != 1: 59 | # pw 60 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 61 | layers.extend([ 62 | # dw 63 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 64 | # pw-linear 65 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 66 | norm_layer(oup), 67 | ]) 68 | self.conv = nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | if self.use_res_connect: 72 | return x + self.conv(x) 73 | else: 74 | return self.conv(x) 75 | 76 | 77 | class MobileNetV2(nn.Module): 78 | def __init__(self, 79 | num_classes=1000, 80 | width_mult=1.0, 81 | inverted_residual_setting=None, 82 | round_nearest=8, 83 | block=None, 84 | norm_layer=None): 85 | """ 86 | MobileNet V2 main class 87 | Args: 88 | num_classes (int): Number of classes 89 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 90 | inverted_residual_setting: Network structure 91 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 92 | Set to 1 to turn off rounding 93 | block: Module specifying inverted residual building block for mobilenet 94 | norm_layer: Module specifying the normalization layer to use 95 | """ 96 | super(MobileNetV2, self).__init__() 97 | 98 | if block is None: 99 | block = InvertedResidual 100 | 101 | if norm_layer is None: 102 | norm_layer = nn.BatchNorm2d 103 | 104 | input_channel = 32 105 | last_channel = 1280 106 | 107 | if inverted_residual_setting is None: 108 | inverted_residual_setting = [ 109 | # t, c, n, s 110 | [1, 16, 1, 1], 111 | [6, 24, 2, 2], 112 | [6, 32, 3, 2], 113 | [6, 64, 4, 2], 114 | [6, 96, 3, 1], 115 | [6, 160, 3, 2], 116 | [6, 320, 1, 1], 117 | ] 118 | 119 | # only check the first element, assuming user knows t,c,n,s are required 120 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 121 | raise ValueError("inverted_residual_setting should be non-empty " 122 | "or a 4-element list, got {}".format(inverted_residual_setting)) 123 | 124 | # building first layer 125 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 126 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 127 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 128 | # building inverted residual blocks 129 | for t, c, n, s in inverted_residual_setting: 130 | output_channel = _make_divisible(c * width_mult, round_nearest) 131 | for i in range(n): 132 | stride = s if i == 0 else 1 133 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 134 | input_channel = output_channel 135 | # building last several layers 136 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 137 | # make it nn.Sequential 138 | self.features = nn.Sequential(*features) 139 | 140 | # building classifier 141 | 142 | self.num_ori = 12 143 | self.num_shape = 40 144 | self.num_exp = 10 145 | 146 | 147 | self.classifier_ori = nn.Sequential( 148 | nn.Dropout(0.2), 149 | nn.Linear(self.last_channel, self.num_ori), 150 | ) 151 | self.classifier_shape = nn.Sequential( 152 | nn.Dropout(0.2), 153 | nn.Linear(self.last_channel, self.num_shape), 154 | ) 155 | self.classifier_exp = nn.Sequential( 156 | nn.Dropout(0.2), 157 | nn.Linear(self.last_channel, self.num_exp), 158 | ) 159 | 160 | # weight initialization 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 164 | if m.bias is not None: 165 | nn.init.zeros_(m.bias) 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.ones_(m.weight) 168 | nn.init.zeros_(m.bias) 169 | elif isinstance(m, nn.Linear): 170 | nn.init.normal_(m.weight, 0, 0.01) 171 | nn.init.zeros_(m.bias) 172 | 173 | def _forward_impl(self, x): 174 | # This exists since TorchScript doesn't support inheritance, so the superclass method 175 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 176 | 177 | x = self.features(x) 178 | 179 | x = nn.functional.adaptive_avg_pool2d(x, 1) 180 | x = x.reshape(x.shape[0], -1) 181 | 182 | pool_x = x.clone() 183 | 184 | x_ori = self.classifier_ori(x) 185 | x_shape = self.classifier_shape(x) 186 | x_exp = self.classifier_exp(x) 187 | 188 | x = torch.cat((x_ori, x_shape, x_exp), dim=1) 189 | return x, pool_x 190 | 191 | def forward(self, x): 192 | return self._forward_impl(x) 193 | 194 | 195 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 196 | """ 197 | Constructs a MobileNetV2 architecture from 198 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | progress (bool): If True, displays a progress bar of the download to stderr 202 | """ 203 | model = MobileNetV2(**kwargs) 204 | if pretrained: 205 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 206 | progress=progress) 207 | model.load_state_dict(state_dict, strict=False) 208 | return model -------------------------------------------------------------------------------- /backbone_nets/pointnet_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['MLP_for', 'MLP_rev'] 6 | 7 | class MLP_for(nn.Module): 8 | def __init__(self, num_pts): 9 | super(MLP_for,self).__init__() 10 | self.conv1 = torch.nn.Conv1d(3,64,1) 11 | self.conv2 = torch.nn.Conv1d(64,64,1) 12 | self.conv3 = torch.nn.Conv1d(64,64,1) 13 | self.conv4 = torch.nn.Conv1d(64,128,1) 14 | self.conv5 = torch.nn.Conv1d(128,1024,1) 15 | self.conv6 = nn.Conv1d(2418, 512, 1) # 1024 + 64 + 1280 = 2368 16 | self.conv7 = nn.Conv1d(512, 256, 1) 17 | self.conv8 = nn.Conv1d(256, 128, 1) 18 | self.conv9 = nn.Conv1d(128, 3, 1) 19 | self.bn1 = nn.BatchNorm1d(64) 20 | self.bn2 = nn.BatchNorm1d(64) 21 | self.bn3 = nn.BatchNorm1d(64) 22 | self.bn4 = nn.BatchNorm1d(128) 23 | self.bn5 = nn.BatchNorm1d(1024) 24 | self.bn6 = nn.BatchNorm1d(512) 25 | self.bn7 = nn.BatchNorm1d(256) 26 | self.bn8 = nn.BatchNorm1d(128) 27 | self.bn9 = nn.BatchNorm1d(3) 28 | self.num_pts = num_pts 29 | self.max_pool = nn.MaxPool1d(num_pts) 30 | 31 | def forward(self,x, other_input1=None, other_input2=None, other_input3=None): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | point_features = out 35 | out = F.relu(self.bn3(self.conv3(out))) 36 | out = F.relu(self.bn4(self.conv4(out))) 37 | out = F.relu(self.bn5(self.conv5(out))) 38 | global_features = self.max_pool(out) 39 | global_features_repeated = global_features.repeat(1,1, self.num_pts) 40 | 41 | #out = F.relu(self.bn6(self.conv6(torch.cat([point_features, global_features_repeated],1)))) 42 | 43 | # Avg_pool 44 | # avgpool = other_input1 45 | # avgpool = avgpool.unsqueeze(2).repeat(1,1,self.num_pts) 46 | # out = F.relu(self.bn6(self.conv6(torch.cat([point_features, global_features_repeated, avgpool],1)))) 47 | 48 | #3DMMImg 49 | avgpool = other_input1 50 | avgpool = avgpool.unsqueeze(2).repeat(1,1,self.num_pts) 51 | 52 | shape_code = other_input2 53 | shape_code = shape_code.unsqueeze(2).repeat(1,1,self.num_pts) 54 | 55 | expr_code = other_input3 56 | expr_code = expr_code.unsqueeze(2).repeat(1,1,self.num_pts) 57 | 58 | out = F.relu(self.bn6(self.conv6(torch.cat([point_features, global_features_repeated, avgpool, shape_code, expr_code],1)))) 59 | 60 | 61 | out = F.relu(self.bn7(self.conv7(out))) 62 | out = F.relu(self.bn8(self.conv8(out))) 63 | out = F.relu(self.bn9(self.conv9(out))) 64 | return out 65 | 66 | 67 | class MLP_rev(nn.Module): 68 | def __init__(self, num_pts): 69 | super(MLP_rev,self).__init__() 70 | self.conv1 = torch.nn.Conv1d(3,64,1) 71 | self.conv2 = torch.nn.Conv1d(64,64,1) 72 | self.conv3 = torch.nn.Conv1d(64,64,1) 73 | self.conv4 = torch.nn.Conv1d(64,128,1) 74 | self.conv5 = torch.nn.Conv1d(128,1024,1) 75 | self.conv6_1 = nn.Conv1d(1024, 12, 1) 76 | self.conv6_2 = nn.Conv1d(1024, 40, 1) 77 | self.conv6_3 = nn.Conv1d(1024, 10, 1) 78 | 79 | self.bn1 = nn.BatchNorm1d(64) 80 | self.bn2 = nn.BatchNorm1d(64) 81 | self.bn3 = nn.BatchNorm1d(64) 82 | self.bn4 = nn.BatchNorm1d(128) 83 | self.bn5 = nn.BatchNorm1d(1024) 84 | self.bn6_1 = nn.BatchNorm1d(12) 85 | self.bn6_2 = nn.BatchNorm1d(40) 86 | self.bn6_3 = nn.BatchNorm1d(10) 87 | self.num_pts = num_pts 88 | self.max_pool = nn.MaxPool1d(num_pts) 89 | 90 | def forward(self,x, other_input1=None, other_input2=None, other_input3=None): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = F.relu(self.bn2(self.conv2(out))) 93 | out = F.relu(self.bn3(self.conv3(out))) 94 | out = F.relu(self.bn4(self.conv4(out))) 95 | out = F.relu(self.bn5(self.conv5(out))) 96 | global_features = self.max_pool(out) 97 | 98 | # Global point feature 99 | out_rot = F.relu(self.bn6_1(self.conv6_1(global_features))) 100 | out_shape = F.relu(self.bn6_2(self.conv6_2(global_features))) 101 | out_expr = F.relu(self.bn6_3(self.conv6_3(global_features))) 102 | 103 | 104 | out = torch.cat([out_rot, out_shape, out_expr], 1).squeeze(2) 105 | 106 | return out -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import torch.backends.cudnn as cudnn 6 | import time 7 | import numpy as np 8 | 9 | from utils.ddfa import ToTensor, Normalize, CenterCrop, DDFATestDataset 10 | from model_building import SynergyNet 11 | from benchmark_aflw2000 import calc_nme as calc_nme_alfw2000 12 | from benchmark_aflw2000 import ana_msg as ana_alfw2000 13 | 14 | import argparse 15 | import os 16 | import glob 17 | import math 18 | from math import cos, atan2, asin 19 | import cv2 20 | 21 | from utils.params import ParamsPack 22 | param_pack = ParamsPack() 23 | 24 | def parse_pose(param): 25 | '''parse parameters into pose''' 26 | if len(param)==62: 27 | param = param * param_pack.param_std[:62] + param_pack.param_mean[:62] 28 | else: 29 | param = param * param_pack.param_std + param_pack.param_mean 30 | Ps = param[:12].reshape(3, -1) # camera matrix 31 | s, R, t3d = P2sRt(Ps) 32 | P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale 33 | pose = matrix2angle(R) # yaw, pitch, roll 34 | return P, pose 35 | 36 | def P2sRt(P): 37 | '''decomposing camera matrix P''' 38 | t3d = P[:, 3] 39 | R1 = P[0:1, :3] 40 | R2 = P[1:2, :3] 41 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 42 | r1 = R1 / np.linalg.norm(R1) 43 | r2 = R2 / np.linalg.norm(R2) 44 | r3 = np.cross(r1, r2) 45 | R = np.concatenate((r1, r2, r3), 0) 46 | return s, R, t3d 47 | 48 | def matrix2angle(R): 49 | '''convert matrix to angle''' 50 | if R[2, 0] != 1 and R[2, 0] != -1: 51 | x = asin(R[2, 0]) 52 | y = atan2(R[1, 2] / cos(x), R[2, 2] / cos(x)) 53 | z = atan2(R[0, 1] / cos(x), R[0, 0] / cos(x)) 54 | 55 | else: # Gimbal lock 56 | z = 0 # can be anything 57 | if R[2, 0] == -1: 58 | x = np.pi / 2 59 | y = z + atan2(R[0, 1], R[0, 2]) 60 | else: 61 | x = -np.pi / 2 62 | y = -z + atan2(-R[0, 1], -R[0, 2]) 63 | 64 | rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi 65 | 66 | return [rx, ry, rz] 67 | 68 | def parsing(param): 69 | p_ = param[:, :12].reshape(-1, 3, 4) 70 | p = p_[:, :, :3] 71 | offset = p_[:, :, -1].reshape(-1, 3, 1) 72 | alpha_shp = param[:, 12:52].reshape(-1, 40, 1) 73 | alpha_exp = param[:, 52:62].reshape(-1, 10, 1) 74 | return p, offset, alpha_shp, alpha_exp 75 | 76 | def reconstruct_vertex(param, data_param, whitening=True, transform=True, lmk_pts=68): 77 | """ 78 | This function includes parameter de-whitening, reconstruction of landmarks, and transform from coordinate space (x,y) to image space (u,v) 79 | """ 80 | param_mean, param_std, w_shp_base, u_base, w_exp_base = data_param 81 | 82 | if whitening: 83 | if param.shape[1] == 62: 84 | param = param * param_std[:62] + param_mean[:62] 85 | else: 86 | raise NotImplementedError("Parameter length must be 62") 87 | 88 | if param.shape[1] == 62: 89 | p, offset, alpha_shp, alpha_exp = parsing(param) 90 | else: 91 | raise NotImplementedError("Parameter length must be 62") 92 | 93 | vertex = p @ (u_base + w_shp_base @ alpha_shp + w_exp_base @ alpha_exp).contiguous().view(-1, lmk_pts, 3).transpose(1,2) + offset 94 | if transform: 95 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 96 | 97 | return vertex 98 | 99 | def extract_param(checkpoint_fp, root='', args=None, filelists=None, device_ids=[0], 100 | batch_size=128, num_workers=4): 101 | map_location = {'cuda:{}'.format(i): 'cuda:0' for i in range(8)} 102 | checkpoint = torch.load(checkpoint_fp, map_location=map_location)['state_dict'] 103 | 104 | # Need to take off these for different numbers of base landmark points 105 | # del checkpoint['module.u_base'] 106 | # del checkpoint['module.w_shp_base'] 107 | # del checkpoint['module.w_exp_base'] 108 | 109 | torch.cuda.set_device(device_ids[0]) 110 | 111 | model = SynergyNet(args) 112 | model = nn.DataParallel(model, device_ids=device_ids).cuda() 113 | model.load_state_dict(checkpoint, strict=False) 114 | 115 | dataset = DDFATestDataset(filelists=filelists, root=root, 116 | transform=transforms.Compose([ToTensor(), CenterCrop(5, mode='test') , Normalize(mean=127.5, std=128) ])) 117 | data_loader = data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) 118 | 119 | cudnn.benchmark = True 120 | model.eval() 121 | 122 | end = time.time() 123 | outputs = [] 124 | with torch.no_grad(): 125 | for _, inputs in enumerate(data_loader): 126 | inputs = inputs.cuda() 127 | output = model.module.forward_test(inputs) 128 | 129 | for i in range(output.shape[0]): 130 | param_prediction = output[i].cpu().numpy().flatten() 131 | 132 | outputs.append(param_prediction) 133 | outputs = np.array(outputs, dtype=np.float32) 134 | 135 | print('Extracting params take {: .3f}s'.format(time.time() - end)) 136 | return outputs, model.module.data_param 137 | 138 | 139 | def _benchmark_aflw2000(outputs): 140 | '''Calculate the error statistics.''' 141 | return ana_alfw2000(calc_nme_alfw2000(outputs,option='ori')) 142 | 143 | # AFLW2000 facial alignment 144 | img_list = sorted(glob.glob('./aflw2000_data/AFLW2000-3D_crop/*.jpg')) 145 | def benchmark_aflw2000_params(params, data_param): 146 | '''Reconstruct the landmark points and calculate the statistics''' 147 | outputs = [] 148 | params = torch.Tensor(params).cuda() 149 | 150 | batch_size = 50 151 | num_samples = params.shape[0] 152 | iter_num = math.floor(num_samples / batch_size) 153 | residual = num_samples % batch_size 154 | for i in range(iter_num+1): 155 | if i == iter_num: 156 | if residual == 0: 157 | break 158 | batch_data = params[i*batch_size: i*batch_size + residual] 159 | lm = reconstruct_vertex(batch_data, data_param, lmk_pts=68) 160 | lm = lm.cpu().numpy() 161 | for j in range(residual): 162 | outputs.append(lm[j, :2, :]) 163 | else: 164 | batch_data = params[i*batch_size: (i+1)*batch_size] 165 | lm = reconstruct_vertex(batch_data, data_param, lmk_pts=68) 166 | lm = lm.cpu().numpy() 167 | for j in range(batch_size): 168 | if i == 0: 169 | #plot the first 50 samples for validation 170 | bkg = cv2.imread(img_list[i*batch_size+j],-1) 171 | lm_sample = lm[j] 172 | c0 = np.clip((lm_sample[1,:]).astype(np.int64), 0, 119) 173 | c1 = np.clip((lm_sample[0,:]).astype(np.int64), 0, 119) 174 | for y, x, in zip([c0,c0,c0-1,c0-1],[c1,c1-1,c1,c1-1]): 175 | bkg[y, x, :] = np.array([233,193,133]) 176 | cv2.imwrite(f'./results/{i*batch_size+j}.png', bkg) 177 | 178 | outputs.append(lm[j, :2, :]) 179 | return _benchmark_aflw2000(outputs) 180 | 181 | 182 | # AFLW2000 face orientation estimation 183 | def benchmark_FOE(params): 184 | """ 185 | FOE benchmark validation. Only calculate the groundtruth of angles within [-99, 99] (following FSA-Net https://github.com/shamangary/FSA-Net) 186 | """ 187 | 188 | # AFLW200 groundturh and indices for skipping, whose yaw angle lies outside [-99, 99] 189 | exclude_aflw2000 = './aflw2000_data/eval/ALFW2000-3D_pose_3ANG_excl.npy' 190 | skip_aflw2000 = './aflw2000_data/eval/ALFW2000-3D_pose_3ANG_skip.npy' 191 | 192 | if not os.path.isfile(exclude_aflw2000) or not os.path.isfile(skip_aflw2000): 193 | raise RuntimeError('Missing data') 194 | 195 | pose_GT = np.load(exclude_aflw2000) 196 | skip_indices = np.load(skip_aflw2000) 197 | pose_mat = np.ones((pose_GT.shape[0],3)) 198 | 199 | idx = 0 200 | for i in range(params.shape[0]): 201 | if i in skip_indices: 202 | continue 203 | P, angles = parse_pose(params[i]) 204 | angles[0], angles[1], angles[2] = angles[1], angles[0], angles[2] # we decode raw-ptich-yaw order 205 | pose_mat[idx,:] = np.array(angles) 206 | idx += 1 207 | 208 | pose_analyis = np.mean(np.abs(pose_mat-pose_GT),axis=0) # pose GT uses [pitch-yaw-roll] order 209 | MAE = np.mean(pose_analyis) 210 | yaw = pose_analyis[1] 211 | pitch = pose_analyis[0] 212 | roll = pose_analyis[2] 213 | msg = 'Mean MAE = %3.3f (in deg), [yaw,pitch,roll] = [%3.3f, %3.3f, %3.3f]'%(MAE, yaw, pitch, roll) 214 | print('\nFace orientation estimation:') 215 | print(msg) 216 | return msg 217 | 218 | def benchmark(checkpoint_fp, args): 219 | '''benchmark validation pipeline''' 220 | device_ids = [0] 221 | 222 | def aflw2000(): 223 | root = './aflw2000_data/AFLW2000-3D_crop' 224 | filelists = './aflw2000_data/AFLW2000-3D_crop.list' 225 | 226 | if not os.path.isdir(root) or not os.path.isfile(filelists): 227 | raise RuntimeError('check if the testing data exist') 228 | 229 | params, data_param = extract_param( 230 | checkpoint_fp=checkpoint_fp, 231 | root=root, 232 | args= args, 233 | filelists=filelists, 234 | device_ids=device_ids, 235 | batch_size=128) 236 | 237 | info_out_fal = benchmark_aflw2000_params(params, data_param) 238 | print(info_out_fal) 239 | info_out_foe = benchmark_FOE(params) 240 | 241 | aflw2000() 242 | 243 | def main(): 244 | parser = argparse.ArgumentParser(description='SynergyNet benchmark on AFLW2000-3D') 245 | parser.add_argument('-a', '--arch', default='mobilenet_v2', type=str) 246 | parser.add_argument('-w', '--weights', default='models/best.pth.tar', type=str) 247 | parser.add_argument('-d', '--device', default='0', type=str) 248 | parser.add_argument('--img_size', default='120', type=int) 249 | args = parser.parse_args() 250 | args.device = [int(d) for d in args.device.split(',')] 251 | 252 | benchmark(args.weights, args) 253 | 254 | 255 | if __name__ == '__main__': 256 | main() 257 | -------------------------------------------------------------------------------- /benchmark_aflw2000.py: -------------------------------------------------------------------------------- 1 | """ 2 | This evaluation script follows 3DDFA and 3DDFA_V2 3 | https://github.com/cleardusk/3DDFA 4 | https://github.com/cleardusk/3DDFA_V2 5 | """ 6 | 7 | import os.path as osp 8 | import numpy as np 9 | from math import sqrt 10 | from utils.io import _load 11 | 12 | 13 | d = './aflw2000_data/eval' 14 | yaws_list = _load(osp.join(d, 'AFLW2000-3D.pose.npy')) 15 | # origin 16 | pts68_all_ori = _load(osp.join(d, 'AFLW2000-3D.pts68.npy')) 17 | # reannonated 18 | pts68_all_re = _load(osp.join(d, 'AFLW2000-3D-Reannotated.pts68.npy')) 19 | roi_boxs = _load(osp.join(d, 'AFLW2000-3D_crop.roi_box.npy')) 20 | 21 | 22 | def ana(nme_list): 23 | yaw_list_abs = np.abs(yaws_list) 24 | ind_yaw_1 = yaw_list_abs <= 30 25 | ind_yaw_2 = np.bitwise_and(yaw_list_abs > 30, yaw_list_abs <= 60) 26 | ind_yaw_3 = yaw_list_abs > 60 27 | 28 | nme_1 = nme_list[ind_yaw_1] 29 | nme_2 = nme_list[ind_yaw_2] 30 | nme_3 = nme_list[ind_yaw_3] 31 | 32 | mean_nme_1 = np.mean(nme_1) * 100 33 | mean_nme_2 = np.mean(nme_2) * 100 34 | mean_nme_3 = np.mean(nme_3) * 100 35 | 36 | std_nme_1 = np.std(nme_1) * 100 37 | std_nme_2 = np.std(nme_2) * 100 38 | std_nme_3 = np.std(nme_3) * 100 39 | 40 | mean_all = [mean_nme_1, mean_nme_2, mean_nme_3] 41 | mean = np.mean(mean_all) 42 | std = np.std(mean_all) 43 | 44 | s1 = '[ 0, 30]\tMean: \x1b[32m{:.3f}\x1b[0m, Std: {:.3f}'.format(mean_nme_1, std_nme_1) 45 | s2 = '[30, 60]\tMean: \x1b[32m{:.3f}\x1b[0m, Std: {:.3f}'.format(mean_nme_2, std_nme_2) 46 | s3 = '[60, 90]\tMean: \x1b[32m{:.3f}\x1b[0m, Std: {:.3f}'.format(mean_nme_3, std_nme_3) 47 | s5 = '[ 0, 90]\tMean: \x1b[31m{:.3f}\x1b[0m, Std: \x1b[31m{:.3f}\x1b[0m'.format(mean, std) 48 | 49 | s = '\n'.join([s1, s2, s3, s5]) 50 | 51 | print(s) 52 | 53 | return mean_nme_1, mean_nme_2, mean_nme_3, mean, std 54 | 55 | def ana_msg(nme_list): 56 | 57 | leng = nme_list.shape[0] 58 | yaw_list_abs = np.abs(yaws_list)[:leng] 59 | ind_yaw_1 = yaw_list_abs <= 30 60 | ind_yaw_2 = np.bitwise_and(yaw_list_abs > 30, yaw_list_abs <= 60) 61 | ind_yaw_3 = yaw_list_abs > 60 62 | 63 | nme_1 = nme_list[ind_yaw_1] 64 | nme_2 = nme_list[ind_yaw_2] 65 | nme_3 = nme_list[ind_yaw_3] 66 | 67 | mean_nme_1 = np.mean(nme_1) * 100 68 | mean_nme_2 = np.mean(nme_2) * 100 69 | mean_nme_3 = np.mean(nme_3) * 100 70 | 71 | std_nme_1 = np.std(nme_1) * 100 72 | std_nme_2 = np.std(nme_2) * 100 73 | std_nme_3 = np.std(nme_3) * 100 74 | 75 | mean_all = [mean_nme_1, mean_nme_2, mean_nme_3] 76 | mean = np.mean(mean_all) 77 | std = np.std(mean_all) 78 | 79 | s0 = '\nFacial Alignment on AFLW2000-3D (NME):' 80 | s1 = '[ 0, 30]\tMean: {:.3f}, Std: {:.3f}'.format(mean_nme_1, std_nme_1) 81 | s2 = '[30, 60]\tMean: {:.3f}, Std: {:.3f}'.format(mean_nme_2, std_nme_2) 82 | s3 = '[60, 90]\tMean: {:.3f}, Std: {:.3f}'.format(mean_nme_3, std_nme_3) 83 | s4 = '[ 0, 90]\tMean: {:.3f}, Std: {:.3f}'.format(mean, std) 84 | 85 | s = '\n'.join([s0, s1, s2, s3, s4]) 86 | 87 | return s 88 | 89 | def convert_to_ori(lms, i): 90 | std_size = 120 91 | sx, sy, ex, ey = roi_boxs[i] 92 | scale_x = (ex - sx) / std_size 93 | scale_y = (ey - sy) / std_size 94 | lms[0, :] = lms[0, :] * scale_x + sx 95 | lms[1, :] = lms[1, :] * scale_y + sy 96 | return lms 97 | 98 | def convert_to_crop(lms, i): 99 | std_size = 120 100 | sx, sy, ex, ey = roi_boxs[i] 101 | scale_x = (ex - sx) / std_size 102 | scale_y = (ey - sy) / std_size 103 | lms[0, :] = (lms[0, :] - sx)/ scale_x 104 | lms[1, :] = (lms[1, :] - sy)/ scale_y 105 | return lms 106 | 107 | def calc_nme(pts68_fit_all, option='ori'): 108 | if option == 'ori': 109 | pts68_all = pts68_all_ori 110 | elif option == 're': 111 | pts68_all = pts68_all_re 112 | std_size = 120 113 | 114 | nme_list = [] 115 | length_list = [] 116 | for i in range(len(roi_boxs)): 117 | pts68_fit = pts68_fit_all[i] 118 | pts68_gt = pts68_all[i] 119 | 120 | sx, sy, ex, ey = roi_boxs[i] 121 | scale_x = (ex - sx) / std_size 122 | scale_y = (ey - sy) / std_size 123 | pts68_fit[0, :] = pts68_fit[0, :] * scale_x + sx 124 | pts68_fit[1, :] = pts68_fit[1, :] * scale_y + sy 125 | 126 | # build bbox 127 | minx, maxx = np.min(pts68_gt[0, :]), np.max(pts68_gt[0, :]) 128 | miny, maxy = np.min(pts68_gt[1, :]), np.max(pts68_gt[1, :]) 129 | llength = sqrt((maxx - minx) * (maxy - miny)) 130 | length_list.append(llength) 131 | 132 | dis = pts68_fit - pts68_gt[:2, :] 133 | dis = np.sqrt(np.sum(np.power(dis, 2), 0)) 134 | dis = np.mean(dis) 135 | nme = dis / llength 136 | nme_list.append(nme) 137 | 138 | nme_list = np.array(nme_list, dtype=np.float32) 139 | return nme_list 140 | 141 | 142 | if __name__ == '__main__': 143 | pass 144 | -------------------------------------------------------------------------------- /benchmark_validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | import torch.backends.cudnn as cudnn 9 | import time 10 | import numpy as np 11 | 12 | from benchmark_aflw2000 import calc_nme as calc_nme_alfw2000 13 | from benchmark_aflw2000 import ana_msg as ana_alfw2000 14 | 15 | from utils.ddfa import ToTensor, Normalize, DDFATestDataset, CenterCrop 16 | import argparse 17 | 18 | import logging 19 | import os 20 | from utils.params import ParamsPack 21 | param_pack = ParamsPack() 22 | import glob 23 | import scipy.io as sio 24 | import math 25 | from math import cos, sin, atan2, asin, sqrt 26 | 27 | # Only work with numpy without batch 28 | def parse_pose(param): 29 | """ 30 | Parse the parameters into 3x4 affine matrix and pose angles 31 | """ 32 | param = param * param_pack.param_std[:62] + param_pack.param_mean[:62] 33 | Ps = param[:12].reshape(3, -1) # camera matrix 34 | s, R, t3d = P2sRt(Ps) 35 | P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale 36 | pose = matrix2angle_corr(R) # yaw, pitch, roll 37 | return P, pose 38 | 39 | def P2sRt(P): 40 | ''' 41 | Decompositing camera matrix P. 42 | ''' 43 | t3d = P[:, 3] 44 | R1 = P[0:1, :3] 45 | R2 = P[1:2, :3] 46 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 47 | r1 = R1 / np.linalg.norm(R1) 48 | r2 = R2 / np.linalg.norm(R2) 49 | r3 = np.cross(r1, r2) 50 | 51 | R = np.concatenate((r1, r2, r3), 0) 52 | return s, R, t3d 53 | 54 | # def matrix2angle(R): 55 | # ''' 56 | # Compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf 57 | # ''' 58 | 59 | # if R[2, 0] != 1 and R[2, 0] != -1: 60 | # x = asin(R[2, 0]) 61 | # y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) 62 | # z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) 63 | 64 | # else: # Gimbal lock 65 | # z = 0 # can be anything 66 | # if R[2, 0] == -1: 67 | # x = np.pi / 2 68 | # y = z + atan2(R[0, 1], R[0, 2]) 69 | # else: 70 | # x = -np.pi / 2 71 | # y = -z + atan2(-R[0, 1], -R[0, 2]) 72 | 73 | # rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi 74 | 75 | # return [rx, ry, rz] 76 | 77 | #numpy 78 | def matrix2angle_corr(R): 79 | ''' 80 | Compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf 81 | ''' 82 | 83 | if R[2, 0] != 1 and R[2, 0] != -1: 84 | x = asin(R[2, 0]) 85 | y = atan2(R[1, 2] / cos(x), R[2, 2] / cos(x)) 86 | z = atan2(R[0, 1] / cos(x), R[0, 0] / cos(x)) 87 | 88 | else: # Gimbal lock 89 | z = 0 # can be anything 90 | if R[2, 0] == -1: 91 | x = np.pi / 2 92 | y = z + atan2(R[0, 1], R[0, 2]) 93 | else: 94 | x = -np.pi / 2 95 | y = -z + atan2(-R[0, 1], -R[0, 2]) 96 | 97 | rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi 98 | 99 | return [rx, ry, rz] 100 | 101 | 102 | 103 | def parse_param_62_batch(param): 104 | """batch styler""" 105 | p_ = param[:, :12].reshape(-1, 3, 4) 106 | p = p_[:, :, :3] 107 | offset = p_[:, :, -1].reshape(-1, 3, 1) 108 | alpha_shp = param[:, 12:52].reshape(-1, 40, 1) 109 | alpha_exp = param[:, 52:62].reshape(-1, 10, 1) 110 | return p, offset, alpha_shp, alpha_exp 111 | 112 | 113 | # 62-with-false-rot 114 | def reconstruct_vertex(param, data_param, whitening=True, dense=False, transform=True): 115 | """ 116 | Whitening param -> 3d vertex, based on the 3dmm param: u_base, w_shp, w_exp 117 | dense: if True, return dense vertex, else return 68 sparse landmarks. All dense or sparse vertex is transformed to 118 | image coordinate space, but without alignment caused by face cropping. 119 | transform: whether transform to image space 120 | Working with Tensor with batch. Using Fortan-type reshape. 121 | """ 122 | param_mean, param_std, w_shp_base, u_base, w_exp_base = data_param 123 | 124 | if whitening: 125 | if param.shape[1] == 62: 126 | param = param * param_std[:62] + param_mean[:62] 127 | 128 | p, offset, alpha_shp, alpha_exp = parse_param_62_batch(param) 129 | 130 | """For 68 pts""" 131 | vertex = p @ (u_base + w_shp_base @ alpha_shp + w_exp_base @ alpha_exp).contiguous().view(-1, 68, 3).transpose(1,2) + offset 132 | 133 | if transform: 134 | # transform to image coordinate space 135 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] ## corrected 136 | 137 | 138 | return vertex 139 | 140 | 141 | def extract_param(model, root='', filelists=None, 142 | batch_size=128, num_workers=4): 143 | 144 | dataset = DDFATestDataset(filelists=filelists, root=root, 145 | transform=transforms.Compose([ToTensor(), CenterCrop(5, mode='test'), Normalize(mean=127.5, std=130)])) 146 | data_loader = data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) 147 | 148 | cudnn.benchmark = True 149 | model.eval() 150 | 151 | end = time.time() 152 | outputs = [] 153 | with torch.no_grad(): 154 | for _, inputs in enumerate(data_loader): 155 | inputs = inputs.cuda() 156 | output = model.module.forward_test(inputs) 157 | 158 | for i in range(output.shape[0]): 159 | param_prediction = output[i].cpu().numpy().flatten() 160 | 161 | outputs.append(param_prediction) 162 | outputs = np.array(outputs, dtype=np.float32) 163 | 164 | logging.info('Extracting params take {: .3f}s\n'.format(time.time() - end)) 165 | return outputs 166 | 167 | 168 | def _benchmark_aflw2000(outputs): 169 | """ 170 | Calculate the error statistics. 171 | """ 172 | return ana_alfw2000(calc_nme_alfw2000(outputs, option='ori')) 173 | 174 | 175 | def benchmark_aflw2000_params(params, data_param): 176 | """ 177 | Reconstruct the landmark points and calculate the statistics 178 | """ 179 | outputs = [] 180 | params = torch.Tensor(params).cuda() 181 | 182 | batch_size = 50 183 | num_samples = params.shape[0] 184 | iter_num = math.floor(num_samples / batch_size) 185 | residual = num_samples % batch_size 186 | for i in range(iter_num+1): 187 | if i == iter_num: 188 | if residual == 0: 189 | break 190 | batch_data = params[i*batch_size: i*batch_size + residual] 191 | lm = reconstruct_vertex(batch_data, data_param) 192 | lm = lm.cpu().numpy() 193 | for j in range(residual): 194 | outputs.append(lm[j, :2, :]) 195 | else: 196 | batch_data = params[i*batch_size: (i+1)*batch_size] 197 | lm = reconstruct_vertex(batch_data, data_param) 198 | lm = lm.cpu().numpy() 199 | for j in range(batch_size): 200 | outputs.append(lm[j, :2, :]) 201 | return _benchmark_aflw2000(outputs) 202 | 203 | 204 | def benchmark_FOE(params): 205 | """ 206 | FOE benchmark validation. Only calculate the groundtruth of angles within [-99, 99] 207 | """ 208 | 209 | # Define the data path for AFLW200 groundturh and skip indices, where the data and structure lie on S3 buckets (fixed structure) 210 | groundtruth_excl = './aflw2000_data/eval/ALFW2000-3D_pose_3ANG_excl.npy' 211 | skip_aflw2000 = './aflw2000_data/eval/ALFW2000-3D_pose_3ANG_skip.npy' 212 | 213 | if not os.path.isfile(groundtruth_excl) or not os.path.isfile(skip_aflw2000): 214 | raise RuntimeError('The data is not properly downloaded from the S3 bucket. Please check your S3 bucket access permission') 215 | 216 | 217 | pose_GT = np.load(groundtruth_excl) # groundtruth load 218 | skip_indices = np.load(skip_aflw2000) # load the skip indices in AFLW2000 219 | pose_mat = np.ones((pose_GT.shape[0],3)) 220 | 221 | total = 0 222 | for i in range(params.shape[0]): 223 | if i in skip_indices: 224 | continue 225 | P, angles = parse_pose(params[i]) # original per-sample decode 226 | angles[0], angles[1], angles[2] = angles[1], angles[0], angles[2] 227 | pose_mat[total,:] = np.array(angles) 228 | total += 1 229 | 230 | pose_analy = np.mean(np.abs(pose_mat-pose_GT),axis=0) 231 | MAE = np.mean(pose_analy) 232 | yaw = pose_analy[1] 233 | pitch = pose_analy[0] 234 | roll = pose_analy[2] 235 | msg = 'MAE = %3.3f, [yaw,pitch,roll] = [%3.3f, %3.3f, %3.3f]'%(MAE, yaw, pitch, roll) 236 | print('\n--------------------------------------------------------------------------------') 237 | print(msg) 238 | print('--------------------------------------------------------------------------------') 239 | return msg 240 | 241 | 242 | # 102 243 | def benchmark_pipeline(model): 244 | """ 245 | Run the benchmark validation pipeline for Facial Alignments: AFLW and AFLW2000, FOE: AFLW2000. 246 | """ 247 | 248 | def aflw2000(data_param): 249 | root = './aflw2000_data/AFLW2000-3D_crop' 250 | filelists = './aflw2000_data/AFLW2000-3D_crop.list' 251 | 252 | if not os.path.isdir(root) or not os.path.isfile(filelists): 253 | raise RuntimeError('The data is not properly downloaded from the S3 bucket. Please check your S3 bucket access permission') 254 | 255 | params = extract_param( 256 | model=model, 257 | root=root, 258 | filelists=filelists, 259 | batch_size=128) 260 | 261 | s2 = benchmark_aflw2000_params(params, data_param) 262 | logging.info(s2) 263 | # s3 = benchmark_FOE(params) 264 | # logging.info(s3) 265 | 266 | aflw2000(model.module.data_param) 267 | 268 | 269 | def main(): 270 | parser = argparse.ArgumentParser(description='3DDFA Benchmark') 271 | parser.add_argument('--arch', default='mobilenet_1', type=str) 272 | parser.add_argument('-c', '--checkpoint-fp', default='models/phase1_wpdc.pth.tar', type=str) 273 | args = parser.parse_args() 274 | 275 | benchmark_pipeline(args.arch, args.checkpoint_fp) 276 | 277 | 278 | if __name__ == '__main__': 279 | main() 280 | -------------------------------------------------------------------------------- /demo/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/0.png -------------------------------------------------------------------------------- /demo/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/1.png -------------------------------------------------------------------------------- /demo/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/10.png -------------------------------------------------------------------------------- /demo/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/11.png -------------------------------------------------------------------------------- /demo/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/12.png -------------------------------------------------------------------------------- /demo/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/2.png -------------------------------------------------------------------------------- /demo/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/3.png -------------------------------------------------------------------------------- /demo/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/4.png -------------------------------------------------------------------------------- /demo/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/5.png -------------------------------------------------------------------------------- /demo/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/6.png -------------------------------------------------------------------------------- /demo/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/7.png -------------------------------------------------------------------------------- /demo/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/8.png -------------------------------------------------------------------------------- /demo/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/9.png -------------------------------------------------------------------------------- /demo/AF-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/AF-1.png -------------------------------------------------------------------------------- /demo/AF-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/AF-2.png -------------------------------------------------------------------------------- /demo/alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/alignment.png -------------------------------------------------------------------------------- /demo/comparison-deca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/comparison-deca.png -------------------------------------------------------------------------------- /demo/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/demo.gif -------------------------------------------------------------------------------- /demo/multiple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/multiple.png -------------------------------------------------------------------------------- /demo/orientation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/orientation.png -------------------------------------------------------------------------------- /demo/single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/single.png -------------------------------------------------------------------------------- /demo/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/demo/teaser.png -------------------------------------------------------------------------------- /img/sample_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/img/sample_1.jpg -------------------------------------------------------------------------------- /img/sample_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/img/sample_2.jpg -------------------------------------------------------------------------------- /img/sample_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/img/sample_3.jpg -------------------------------------------------------------------------------- /img/sample_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/img/sample_4.jpg -------------------------------------------------------------------------------- /loss_definition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.params import ParamsPack 4 | param_pack = ParamsPack() 5 | import math 6 | 7 | 8 | class WingLoss(nn.Module): 9 | def __init__(self, omega=10, epsilon=2): 10 | super(WingLoss, self).__init__() 11 | self.omega = omega 12 | self.epsilon = epsilon 13 | self.log_term = math.log(1 + self.omega / self.epsilon) 14 | 15 | def forward(self, pred, target, kp=False): 16 | n_points = pred.shape[2] 17 | pred = pred.transpose(1,2).contiguous().view(-1, 3*n_points) 18 | target = target.transpose(1,2).contiguous().view(-1, 3*n_points) 19 | y = target 20 | y_hat = pred 21 | delta_y = (y - y_hat).abs() 22 | delta_y1 = delta_y[delta_y < self.omega] 23 | delta_y2 = delta_y[delta_y >= self.omega] 24 | loss1 = self.omega * torch.log(1 + delta_y1 / self.epsilon) 25 | C = self.omega - self.omega * self.log_term 26 | loss2 = delta_y2 - C 27 | return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2)) 28 | 29 | class ParamLoss(nn.Module): 30 | """Input and target are all 62-d param""" 31 | def __init__(self): 32 | super(ParamLoss, self).__init__() 33 | self.criterion = nn.MSELoss(reduction="none") 34 | 35 | def forward(self, input, target, mode = 'normal'): 36 | if mode == 'normal': 37 | loss = self.criterion(input[:,:12], target[:,:12]).mean(1) + self.criterion(input[:,12:], target[:,12:]).mean(1) 38 | return torch.sqrt(loss) 39 | elif mode == 'only_3dmm': 40 | loss = self.criterion(input[:,:50], target[:,12:62]).mean(1) 41 | return torch.sqrt(loss) 42 | return torch.sqrt(loss.mean(1)) 43 | 44 | 45 | if __name__ == "__main__": 46 | pass 47 | 48 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | import os 4 | import os.path as osp 5 | from pathlib import Path 6 | import numpy as np 7 | import argparse 8 | import time 9 | import logging 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.transforms as transforms 14 | from torch.utils.data import DataLoader 15 | import torch.backends.cudnn as cudnn 16 | cudnn.benchmark=True 17 | 18 | from utils.ddfa import DDFADataset, ToTensor, Normalize, SGD_NanHandler, CenterCrop, Compose_GT, ColorJitter 19 | from utils.ddfa import str2bool, AverageMeter 20 | from utils.io import mkdir 21 | from model_building import SynergyNet as SynergyNet 22 | from benchmark_validate import benchmark_pipeline 23 | 24 | 25 | # global args (configuration) 26 | args = None # define the static training setting, which wouldn't and shouldn't be changed over the whole experiements. 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='3DMM Fitting') 30 | parser.add_argument('-j', '--workers', default=6, type=int) 31 | parser.add_argument('--epochs', default=40, type=int) 32 | parser.add_argument('--start-epoch', default=1, type=int) 33 | parser.add_argument('-b', '--batch-size', default=128, type=int) 34 | parser.add_argument('-vb', '--val-batch-size', default=32, type=int) 35 | parser.add_argument('--base-lr', '--learning-rate', default=0.001, type=float) 36 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 37 | help='momentum') 38 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float) 39 | parser.add_argument('--print-freq', '-p', default=20, type=int) 40 | parser.add_argument('--resume', default='', type=str, metavar='PATH') 41 | parser.add_argument('--resume_pose', default='', type=str, metavar='PATH') 42 | parser.add_argument('--devices-id', default='0,1', type=str) 43 | parser.add_argument('--filelists-train',default='', type=str) 44 | parser.add_argument('--root', default='') 45 | parser.add_argument('--snapshot', default='', type=str) 46 | parser.add_argument('--log-file', default='output.log', type=str) 47 | parser.add_argument('--log-mode', default='w', type=str) 48 | parser.add_argument('--arch', default='mobilenet_v2', type=str, help="Please choose [mobilenet_v2, mobilenet_1, resnet50, resnet101, or ghostnet]") 49 | parser.add_argument('--milestones', default='15,25,30', type=str) 50 | parser.add_argument('--task', default='all', type=str) 51 | parser.add_argument('--test_initial', default='false', type=str2bool) 52 | parser.add_argument('--warmup', default=-1, type=int) 53 | parser.add_argument('--param-fp-train',default='',type=str) 54 | parser.add_argument('--img_size', default=120, type=int) 55 | parser.add_argument('--save_val_freq', default=10, type=int) 56 | 57 | global args 58 | args = parser.parse_args() 59 | 60 | # some other operations 61 | args.devices_id = [int(d) for d in args.devices_id.split(',')] 62 | args.milestones = [int(m) for m in args.milestones.split(',')] 63 | 64 | snapshot_dir = osp.split(args.snapshot)[0] 65 | mkdir(snapshot_dir) 66 | 67 | 68 | def print_args(args): 69 | for arg in vars(args): 70 | s = arg + ': ' + str(getattr(args, arg)) 71 | logging.info(s) 72 | 73 | 74 | def adjust_learning_rate(optimizer, epoch, milestones=None): 75 | """Sets the learning rate: milestone is a list/tuple""" 76 | 77 | def to(epoch): 78 | if epoch <= args.warmup: 79 | return 1 80 | elif args.warmup < epoch <= milestones[0]: 81 | return 0 82 | for i in range(1, len(milestones)): 83 | if milestones[i - 1] < epoch <= milestones[i]: 84 | return i 85 | return len(milestones) 86 | 87 | n = to(epoch) 88 | 89 | #global lr 90 | lr = args.base_lr * (0.2 ** n) 91 | for param_group in optimizer.param_groups: 92 | param_group['lr'] = lr 93 | 94 | return lr 95 | 96 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 97 | torch.save(state, filename) 98 | logging.info(f'Save checkpoint to {filename}') 99 | 100 | def count_parameters(model): 101 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 102 | 103 | def train(train_loader, model, optimizer, epoch, lr): 104 | """Network training, loss updates, and backward calculation""" 105 | 106 | # AverageMeter for statistics 107 | batch_time = AverageMeter() 108 | data_time = AverageMeter() 109 | losses_name = list(model.module.get_losses()) 110 | losses_name.append('loss_total') 111 | losses_meter = [AverageMeter() for i in range(len(losses_name))] 112 | 113 | model.train() 114 | 115 | end = time.time() 116 | for i, (input, target) in enumerate(train_loader): 117 | 118 | input = input.cuda(non_blocking=True) 119 | 120 | target = target[:,:62] 121 | target.requires_grad = False 122 | target = target.float().cuda(non_blocking=True) 123 | 124 | losses = model(input, target) 125 | 126 | data_time.update(time.time() - end) 127 | 128 | loss_total = 0 129 | for j, name in enumerate(losses): 130 | mean_loss = losses[name].mean() 131 | losses_meter[j].update(mean_loss, input.size(0)) 132 | loss_total += mean_loss 133 | 134 | losses_meter[j+1].update(loss_total, input.size(0)) 135 | 136 | ### compute gradient and do SGD step 137 | optimizer.zero_grad() 138 | loss_total.backward() 139 | flag, _ = optimizer.step_handleNan() 140 | 141 | if flag: 142 | print("Nan encounter! Backward gradient error. Not updating the associated gradients.") 143 | 144 | batch_time.update(time.time() - end) 145 | end = time.time() 146 | 147 | if i % args.print_freq == 0: 148 | msg = 'Epoch: [{}][{}/{}]\t'.format(epoch, i, len(train_loader)) + \ 149 | 'LR: {:.8f}\t'.format(lr) + \ 150 | 'Time: {:.3f} ({:.3f})\t'.format(batch_time.val, batch_time.avg) 151 | for k in range(len(losses_meter)): 152 | msg += '{}: {:.4f} ({:.4f})\t'.format(losses_name[k], losses_meter[k].val, losses_meter[k].avg) 153 | logging.info(msg) 154 | 155 | 156 | def main(): 157 | """ Main funtion for the training process""" 158 | parse_args() # parse global argsl 159 | 160 | # logging setup 161 | logging.basicConfig( 162 | format='[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s', 163 | level=logging.INFO, 164 | handlers=[ 165 | logging.FileHandler(args.log_file, mode=args.log_mode), 166 | logging.StreamHandler() 167 | ] 168 | ) 169 | 170 | print_args(args) # print args 171 | 172 | # step1: define the model structure 173 | model = SynergyNet(args) 174 | torch.cuda.set_device(args.devices_id[0]) 175 | 176 | model = nn.DataParallel(model, device_ids=args.devices_id).cuda() # -> GPU 177 | 178 | # step2: optimization: loss and optimization method 179 | 180 | optimizer = SGD_NanHandler(model.parameters(), 181 | lr=args.base_lr, 182 | momentum=args.momentum, 183 | weight_decay=args.weight_decay, 184 | nesterov=True) 185 | 186 | # step 2.1 resume 187 | if args.resume: 188 | if Path(args.resume).is_file(): 189 | logging.info(f'=> loading checkpoint {args.resume}') 190 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)['state_dict'] 191 | model.load_state_dict(checkpoint, strict=False) 192 | 193 | else: 194 | logging.info(f'=> no checkpoint found at {args.resume}') 195 | 196 | # step3: data 197 | normalize = Normalize(mean=127.5, std=128) 198 | 199 | train_dataset = DDFADataset( 200 | root=args.root, 201 | filelists=args.filelists_train, 202 | param_fp=args.param_fp_train, 203 | gt_transform=True, 204 | transform=Compose_GT([ColorJitter(0.4,0.4,0.4), ToTensor(), CenterCrop(5), normalize]) # 205 | ) 206 | 207 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, 208 | shuffle=True, pin_memory=True, drop_last=True) 209 | 210 | 211 | # step4: run 212 | cudnn.benchmark = True 213 | if args.test_initial: # if testing the performance from the initial 214 | logging.info('Testing from initial') 215 | benchmark_pipeline(model) 216 | 217 | 218 | for epoch in range(args.start_epoch, args.epochs + 1): 219 | # adjust learning rate 220 | lr = adjust_learning_rate(optimizer, epoch, args.milestones) 221 | 222 | # train for one epoch 223 | train(train_loader, model, optimizer, epoch, lr) 224 | 225 | filename = f'{args.snapshot}_checkpoint_epoch_{epoch}.pth.tar' 226 | # save checkpoints and current model validation 227 | if (epoch % args.save_val_freq == 0) or (epoch==args.epochs): 228 | save_checkpoint( 229 | { 230 | 'epoch': epoch, 231 | 'state_dict': model.state_dict(), 232 | }, 233 | filename 234 | ) 235 | logging.info('\nVal[{}]'.format(epoch)) 236 | benchmark_pipeline(model) 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /model_building.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torchvision import transforms as T 5 | import scipy.io as sio 6 | 7 | # All data parameters import 8 | from utils.params import ParamsPack 9 | param_pack = ParamsPack() 10 | 11 | from backbone_nets import resnet_backbone 12 | from backbone_nets import mobilenetv1_backbone 13 | from backbone_nets import mobilenetv2_backbone 14 | from backbone_nets import ghostnet_backbone 15 | from backbone_nets.pointnet_backbone import MLP_for, MLP_rev 16 | from loss_definition import ParamLoss, WingLoss 17 | 18 | from backbone_nets.ResNeSt import resnest50, resnest101 19 | import time 20 | from utils.inference import predict_sparseVert, predict_denseVert, predict_pose, crop_img 21 | from FaceBoxes import FaceBoxes 22 | import cv2 23 | import types 24 | 25 | def parse_param_62(param): 26 | """Work for only tensor""" 27 | p_ = param[:, :12].reshape(-1, 3, 4) 28 | p = p_[:, :, :3] 29 | offset = p_[:, :, -1].reshape(-1, 3, 1) 30 | alpha_shp = param[:, 12:52].reshape(-1, 40, 1) 31 | alpha_exp = param[:, 52:62].reshape(-1, 10, 1) 32 | return p, offset, alpha_shp, alpha_exp 33 | 34 | # Image-to-parameter 35 | class I2P(nn.Module): 36 | def __init__(self, args): 37 | super(I2P, self).__init__() 38 | self.args = args 39 | # backbone definition 40 | if 'mobilenet_v2' in self.args.arch: 41 | self.backbone = getattr(mobilenetv2_backbone, args.arch)(pretrained=False) 42 | elif 'mobilenet' in self.args.arch: 43 | self.backbone = getattr(mobilenetv1_backbone, args.arch)() 44 | elif 'resnet' in self.args.arch: 45 | self.backbone = getattr(resnet_backbone, args.arch)(pretrained=False) 46 | elif 'ghostnet' in self.args.arch: 47 | self.backbone = getattr(ghostnet_backbone, args.arch)() 48 | elif 'resnest' in self.args.arch: 49 | self.backbone = resnest50() 50 | else: 51 | raise RuntimeError("Please choose [mobilenet_v2, mobilenet_1, resnet50, or ghostnet]") 52 | 53 | def forward(self,input, target): 54 | """Training time forward""" 55 | _3D_attr, avgpool = self.backbone(input) 56 | _3D_attr_GT = target.type(torch.cuda.FloatTensor) 57 | return _3D_attr, _3D_attr_GT, avgpool 58 | 59 | def forward_test(self, input): 60 | """ Testing time forward.""" 61 | _3D_attr, avgpool = self.backbone(input) 62 | return _3D_attr, avgpool 63 | 64 | # Main model SynergyNet definition 65 | class SynergyNet(nn.Module): 66 | def __init__(self, args): 67 | super(SynergyNet, self).__init__() 68 | self.triangles = sio.loadmat('./3dmm_data/tri.mat')['tri'] -1 69 | self.triangles = torch.Tensor(self.triangles.astype(np.int64)).long().cuda() 70 | self.img_size = args.img_size 71 | # Image-to-parameter 72 | self.I2P = I2P(args) 73 | # Forward 74 | self.forwardDirection = MLP_for(68) 75 | # Reverse 76 | self.reverseDirection = MLP_rev(68) 77 | self.LMKLoss_3D = WingLoss() 78 | self.ParamLoss = ParamLoss() 79 | 80 | self.loss = {'loss_LMK_f0':0.0, 81 | 'loss_LMK_pointNet': 0.0, 82 | 'loss_Param_In':0.0, 83 | 'loss_Param_S2': 0.0, 84 | 'loss_Param_S1S2': 0.0, 85 | } 86 | 87 | self.register_buffer('param_mean', torch.Tensor(param_pack.param_mean).cuda(non_blocking=True)) 88 | self.register_buffer('param_std', torch.Tensor(param_pack.param_std).cuda(non_blocking=True)) 89 | self.register_buffer('w_shp', torch.Tensor(param_pack.w_shp).cuda(non_blocking=True)) 90 | self.register_buffer('u', torch.Tensor(param_pack.u).cuda(non_blocking=True)) 91 | self.register_buffer('w_exp', torch.Tensor(param_pack.w_exp).cuda(non_blocking=True)) 92 | 93 | # If doing only offline evaluation, use these 94 | # self.u_base = torch.Tensor(param_pack.u_base).cuda(non_blocking=True) 95 | # self.w_shp_base = torch.Tensor(param_pack.w_shp_base).cuda(non_blocking=True) 96 | # self.w_exp_base = torch.Tensor(param_pack.w_exp_base).cuda(non_blocking=True) 97 | 98 | # Online training needs these to parallel 99 | self.register_buffer('u_base', torch.Tensor(param_pack.u_base).cuda(non_blocking=True)) 100 | self.register_buffer('w_shp_base', torch.Tensor(param_pack.w_shp_base).cuda(non_blocking=True)) 101 | self.register_buffer('w_exp_base', torch.Tensor(param_pack.w_exp_base).cuda(non_blocking=True)) 102 | self.keypoints = torch.Tensor(param_pack.keypoints).long() 103 | 104 | self.data_param = [self.param_mean, self.param_std, self.w_shp_base, self.u_base, self.w_exp_base] 105 | 106 | def reconstruct_vertex_62(self, param, whitening=True, dense=False, transform=True, lmk_pts=68): 107 | """ 108 | Whitening param -> 3d vertex, based on the 3dmm param: u_base, w_shp, w_exp 109 | dense: if True, return dense vertex, else return 68 sparse landmarks. All dense or sparse vertex is transformed to 110 | image coordinate space, but without alignment caused by face cropping. 111 | transform: whether transform to image space 112 | Working with batched tensors. Using Fortan-type reshape. 113 | """ 114 | 115 | if whitening: 116 | if param.shape[1] == 62: 117 | param_ = param * self.param_std[:62] + self.param_mean[:62] 118 | else: 119 | raise RuntimeError('length of params mismatch') 120 | 121 | p, offset, alpha_shp, alpha_exp = parse_param_62(param_) 122 | 123 | if dense: 124 | 125 | vertex = p @ (self.u + self.w_shp @ alpha_shp + self.w_exp @ alpha_exp).contiguous().view(-1, 53215, 3).transpose(1,2) + offset 126 | 127 | if transform: 128 | # transform to image coordinate space 129 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 130 | 131 | else: 132 | """For 68 pts""" 133 | vertex = p @ (self.u_base + self.w_shp_base @ alpha_shp + self.w_exp_base @ alpha_exp).contiguous().view(-1, lmk_pts, 3).transpose(1,2) + offset 134 | 135 | if transform: 136 | # transform to image coordinate space 137 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 138 | 139 | return vertex 140 | 141 | def forward(self, input, target): 142 | _3D_attr, _3D_attr_GT, avgpool = self.I2P(input, target) 143 | 144 | vertex_lmk = self.reconstruct_vertex_62(_3D_attr, dense=False) 145 | vertex_GT_lmk = self.reconstruct_vertex_62(_3D_attr_GT, dense=False) 146 | self.loss['loss_LMK_f0'] = 0.05 *self.LMKLoss_3D(vertex_lmk, vertex_GT_lmk, kp=True) 147 | self.loss['loss_Param_In'] = 0.02 * self.ParamLoss(_3D_attr, _3D_attr_GT) 148 | 149 | point_residual = self.forwardDirection(vertex_lmk, avgpool, _3D_attr[:,12:52], _3D_attr[:,52:62]) 150 | vertex_lmk = vertex_lmk + 0.05 * point_residual 151 | self.loss['loss_LMK_pointNet'] = 0.05 * self.LMKLoss_3D(vertex_lmk, vertex_GT_lmk, kp=True) 152 | 153 | _3D_attr_S2 = self.reverseDirection(vertex_lmk) 154 | self.loss['loss_Param_S2'] = 0.02 * self.ParamLoss(_3D_attr_S2, _3D_attr_GT, mode='only_3dmm') 155 | self.loss['loss_Param_S1S2'] = 0.001 * self.ParamLoss(_3D_attr_S2, _3D_attr, mode='only_3dmm') 156 | 157 | return self.loss 158 | 159 | def forward_test(self, input): 160 | """test time forward""" 161 | _3D_attr, _ = self.I2P.forward_test(input) 162 | return _3D_attr 163 | 164 | def get_losses(self): 165 | return self.loss.keys() 166 | 167 | 168 | # Main model SynergyNet definition 169 | class WrapUpSynergyNet(nn.Module): 170 | def __init__(self): 171 | super(WrapUpSynergyNet, self).__init__() 172 | self.triangles = sio.loadmat('./3dmm_data/tri.mat')['tri'] -1 173 | self.triangles = torch.Tensor(self.triangles.astype(np.int64)).long() 174 | args = types.SimpleNamespace() 175 | args.arch = 'mobilenet_v2' 176 | args.checkpoint_fp = 'pretrained/best.pth.tar' 177 | 178 | # Image-to-parameter 179 | self.I2P = I2P(args) 180 | # Forward 181 | self.forwardDirection = MLP_for(68) 182 | # Reverse 183 | self.reverseDirection = MLP_rev(68) 184 | self.LMKLoss_3D = WingLoss() 185 | self.ParamLoss = ParamLoss() 186 | 187 | self.loss = {'loss_LMK_f0':0.0, 188 | 'loss_LMK_pointNet': 0.0, 189 | 'loss_Param_In':0.0, 190 | 'loss_Param_S2': 0.0, 191 | 'loss_Param_S1S2': 0.0, 192 | } 193 | 194 | self.register_buffer('param_mean', torch.Tensor(param_pack.param_mean)) 195 | self.register_buffer('param_std', torch.Tensor(param_pack.param_std)) 196 | self.register_buffer('w_shp', torch.Tensor(param_pack.w_shp)) 197 | self.register_buffer('u', torch.Tensor(param_pack.u)) 198 | self.register_buffer('w_exp', torch.Tensor(param_pack.w_exp)) 199 | 200 | # Online training needs these to parallel 201 | self.register_buffer('u_base', torch.Tensor(param_pack.u_base)) 202 | self.register_buffer('w_shp_base', torch.Tensor(param_pack.w_shp_base)) 203 | self.register_buffer('w_exp_base', torch.Tensor(param_pack.w_exp_base)) 204 | self.keypoints = torch.Tensor(param_pack.keypoints).long() 205 | 206 | self.data_param = [self.param_mean, self.param_std, self.w_shp_base, self.u_base, self.w_exp_base] 207 | 208 | try: 209 | print("loading weights from ", args.checkpoint_fp) 210 | self.load_weights(args.checkpoint_fp) 211 | except: 212 | pass 213 | self.eval() 214 | 215 | def reconstruct_vertex_62(self, param, whitening=True, dense=False, transform=True, lmk_pts=68): 216 | """ 217 | Whitening param -> 3d vertex, based on the 3dmm param: u_base, w_shp, w_exp 218 | dense: if True, return dense vertex, else return 68 sparse landmarks. All dense or sparse vertex is transformed to 219 | image coordinate space, but without alignment caused by face cropping. 220 | transform: whether transform to image space 221 | Working with batched tensors. Using Fortan-type reshape. 222 | """ 223 | 224 | if whitening: 225 | if param.shape[1] == 62: 226 | param_ = param * self.param_std[:62] + self.param_mean[:62] 227 | else: 228 | raise RuntimeError('length of params mismatch') 229 | 230 | p, offset, alpha_shp, alpha_exp = parse_param_62(param_) 231 | 232 | if dense: 233 | 234 | vertex = p @ (self.u + self.w_shp @ alpha_shp + self.w_exp @ alpha_exp).contiguous().view(-1, 53215, 3).transpose(1,2) + offset 235 | 236 | if transform: 237 | # transform to image coordinate space 238 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 239 | 240 | else: 241 | """For 68 pts""" 242 | vertex = p @ (self.u_base + self.w_shp_base @ alpha_shp + self.w_exp_base @ alpha_exp).contiguous().view(-1, lmk_pts, 3).transpose(1,2) + offset 243 | 244 | if transform: 245 | # transform to image coordinate space 246 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 247 | 248 | return vertex 249 | 250 | def forward_test(self, input): 251 | """test time forward""" 252 | _3D_attr, _ = self.I2P.forward_test(input) 253 | return _3D_attr 254 | 255 | def load_weights(self, path): 256 | model_dict = self.state_dict() 257 | checkpoint = torch.load(path, map_location=lambda storage, loc: storage)['state_dict'] 258 | 259 | # because the model is trained by multiple gpus, prefix 'module' should be removed 260 | for k in checkpoint.keys(): 261 | model_dict[k.replace('module.', '')] = checkpoint[k] 262 | 263 | self.load_state_dict(model_dict, strict=False) 264 | 265 | 266 | def get_all_outputs(self, input): 267 | """convenient api to get 3d landmarks, face pose, 3d faces""" 268 | 269 | face_boxes = FaceBoxes() 270 | rects = face_boxes(input) 271 | 272 | # storage 273 | pts_res = [] 274 | poses = [] 275 | vertices_lst = [] 276 | for idx, rect in enumerate(rects): 277 | roi_box = rect 278 | 279 | # enlarge the bbox a little and do a square crop 280 | HCenter = (rect[1] + rect[3])/2 281 | WCenter = (rect[0] + rect[2])/2 282 | side_len = roi_box[3]-roi_box[1] 283 | margin = side_len * 1.2 // 2 284 | roi_box[0], roi_box[1], roi_box[2], roi_box[3] = WCenter-margin, HCenter-margin, WCenter+margin, HCenter+margin 285 | 286 | img = crop_img(input, roi_box) 287 | img = cv2.resize(img, dsize=(120, 120), interpolation=cv2.INTER_LANCZOS4) 288 | img = torch.from_numpy(img) 289 | img = img.permute(2,0,1) 290 | img = img.unsqueeze(0) 291 | img = (img - 127.5)/ 128.0 292 | 293 | with torch.no_grad(): 294 | param = self.forward_test(img) 295 | 296 | param = param.squeeze().cpu().numpy().flatten().astype(np.float32) 297 | 298 | lmks = predict_sparseVert(param, roi_box, transform=True) 299 | vertices = predict_denseVert(param, roi_box, transform=True) 300 | angles, translation = predict_pose(param, roi_box) 301 | 302 | pts_res.append(lmks) 303 | vertices_lst.append(vertices) 304 | poses.append([angles, translation]) 305 | 306 | return pts_res, vertices_lst, poses 307 | 308 | if __name__ == '__main__': 309 | pass 310 | -------------------------------------------------------------------------------- /pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/pretrained/__init__.py -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='synergy-3dmm', 5 | version='0.0.1', 6 | description='Library for accurate and fast 3d landmarks, face mesh, and face pose prediction', 7 | packages=find_packages(exclude=('artistic*', 'benchmark*', 'loss_definition*', 'demo*', 'img*', 'pretrained*', 'Sim3DR', 'main_train*', 'singleImage*', 'model_building*', 'uv_*', 'FaceBoxes*', '3dmm_data*')), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | ], 12 | ) -------------------------------------------------------------------------------- /singleImage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import cv2 5 | from utils.ddfa import ToTensor, Normalize 6 | from model_building import SynergyNet 7 | from utils.inference import crop_img, predict_sparseVert, draw_landmarks, predict_denseVert, predict_pose, draw_axis 8 | import argparse 9 | import torch.backends.cudnn as cudnn 10 | cudnn.benchmark = True 11 | import os 12 | import os.path as osp 13 | import glob 14 | from FaceBoxes import FaceBoxes 15 | from utils.render import render 16 | 17 | # Following 3DDFA-V2, we also use 120x120 resolution 18 | IMG_SIZE = 120 19 | 20 | def main(args): 21 | # load pre-tained model 22 | checkpoint_fp = 'pretrained/best.pth.tar' 23 | args.arch = 'mobilenet_v2' 24 | args.devices_id = [0] 25 | 26 | checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict'] 27 | 28 | model = SynergyNet(args) 29 | model_dict = model.state_dict() 30 | 31 | # because the model is trained by multiple gpus, prefix 'module' should be removed 32 | for k in checkpoint.keys(): 33 | model_dict[k.replace('module.', '')] = checkpoint[k] 34 | 35 | model.load_state_dict(model_dict, strict=False) 36 | model = model.cuda() 37 | model.eval() 38 | 39 | # face detector 40 | face_boxes = FaceBoxes() 41 | 42 | # preparation 43 | transform = transforms.Compose([ToTensor(), Normalize(mean=127.5, std=128)]) 44 | if osp.isdir(args.files): 45 | if not args.files[-1] == '/': 46 | args.files = args.files + '/' 47 | if not args.png: 48 | files = sorted(glob.glob(args.files+'*.jpg')) 49 | else: 50 | files = sorted(glob.glob(args.files+'*.png')) 51 | else: 52 | files = [args.files] 53 | 54 | for img_fp in files: 55 | print("Process the image: ", img_fp) 56 | 57 | img_ori = cv2.imread(img_fp) 58 | 59 | # crop faces 60 | rects = face_boxes(img_ori) 61 | 62 | # storage 63 | pts_res = [] 64 | poses = [] 65 | vertices_lst = [] 66 | for idx, rect in enumerate(rects): 67 | roi_box = rect 68 | 69 | # enlarge the bbox a little and do a square crop 70 | HCenter = (rect[1] + rect[3])/2 71 | WCenter = (rect[0] + rect[2])/2 72 | side_len = roi_box[3]-roi_box[1] 73 | margin = side_len * 1.2 // 2 74 | roi_box[0], roi_box[1], roi_box[2], roi_box[3] = WCenter-margin, HCenter-margin, WCenter+margin, HCenter+margin 75 | 76 | img = crop_img(img_ori, roi_box) 77 | img = cv2.resize(img, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) 78 | # cv2.imwrite(f'validate_{idx}.png', img) 79 | 80 | input = transform(img).unsqueeze(0) 81 | with torch.no_grad(): 82 | input = input.cuda() 83 | param = model.forward_test(input) 84 | param = param.squeeze().cpu().numpy().flatten().astype(np.float32) 85 | 86 | # inferences 87 | lmks = predict_sparseVert(param, roi_box, transform=True) 88 | vertices = predict_denseVert(param, roi_box, transform=True) 89 | angles, translation = predict_pose(param, roi_box) 90 | 91 | pts_res.append(lmks) 92 | vertices_lst.append(vertices) 93 | poses.append([angles, translation, lmks]) 94 | 95 | if not osp.exists(f'inference_output/rendering_overlay/'): 96 | os.makedirs(f'inference_output/rendering_overlay/') 97 | if not osp.exists(f'inference_output/landmarks/'): 98 | os.makedirs(f'inference_output/landmarks/') 99 | if not osp.exists(f'inference_output/poses/'): 100 | os.makedirs(f'inference_output/poses/') 101 | 102 | name = img_fp.rsplit('/',1)[-1][:-4] 103 | img_ori_copy = img_ori.copy() 104 | 105 | # mesh 106 | render(img_ori, vertices_lst, alpha=0.6, wfp=f'inference_output/rendering_overlay/{name}.jpg') 107 | 108 | # landmarks 109 | draw_landmarks(img_ori_copy, pts_res, wfp=f'inference_output/landmarks/{name}.jpg') 110 | 111 | # face orientation 112 | img_axis_plot = img_ori_copy 113 | for angles, translation, lmks in poses: 114 | img_axis_plot = draw_axis(img_axis_plot, angles[0], angles[1], 115 | angles[2], translation[0], translation[1], size = 50, pts68=lmks) 116 | wfp = f'inference_output/poses/{name}.jpg' 117 | cv2.imwrite(wfp, img_axis_plot) 118 | print(f'Save pose result to {wfp}') 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('-f', '--files', default='', help='path to a single image or path to a folder containing multiple images') 124 | parser.add_argument("--png", action="store_true", help="if images are with .png extension") 125 | parser.add_argument('--img_size', default=120, type=int) 126 | parser.add_argument('-b', '--batch-size', default=1, type=int) 127 | 128 | args = parser.parse_args() 129 | main(args) -------------------------------------------------------------------------------- /singleImage_simple.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from synergy3DMM import SynergyNet 3 | 4 | 5 | if __name__ == '__main__': 6 | model = SynergyNet() 7 | I=cv2.imread('img/sample_2.jpg', -1) 8 | # get landmark [[y, x, z], 68 (points)], mesh [[y, x, z], 53215 (points)], and face pose (Euler angles [yaw, pitch, roll] and translation [y, x, z]) 9 | lmk3d, mesh, pose = model.get_all_outputs(I) 10 | print(lmk3d[0].shape) 11 | print(mesh[0].shape) 12 | print(pose[0]) 13 | -------------------------------------------------------------------------------- /synergy3DMM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torchvision import transforms as T 5 | import scipy.io as sio 6 | 7 | # All data parameters import 8 | from utils.params import ParamsPack 9 | param_pack = ParamsPack() 10 | 11 | from backbone_nets import resnet_backbone 12 | from backbone_nets import mobilenetv1_backbone 13 | from backbone_nets import mobilenetv2_backbone 14 | from backbone_nets import ghostnet_backbone 15 | from backbone_nets.pointnet_backbone import MLP_for, MLP_rev 16 | import loss_definition 17 | from loss_definition import ParamLoss, WingLoss 18 | 19 | from backbone_nets.ResNeSt import resnest50, resnest101 20 | import time 21 | from utils.inference import predict_sparseVert, predict_denseVert, predict_pose, crop_img 22 | from FaceBoxes import FaceBoxes 23 | import cv2 24 | import types 25 | import os 26 | 27 | prefix_path = os.path.abspath(loss_definition.__file__).rsplit('/',1)[0] 28 | print(prefix_path) 29 | 30 | def parse_param_62(param): 31 | """Work for only tensor""" 32 | p_ = param[:, :12].reshape(-1, 3, 4) 33 | p = p_[:, :, :3] 34 | offset = p_[:, :, -1].reshape(-1, 3, 1) 35 | alpha_shp = param[:, 12:52].reshape(-1, 40, 1) 36 | alpha_exp = param[:, 52:62].reshape(-1, 10, 1) 37 | return p, offset, alpha_shp, alpha_exp 38 | 39 | # Image-to-parameter 40 | class I2P(nn.Module): 41 | def __init__(self, args): 42 | super(I2P, self).__init__() 43 | self.args = args 44 | # backbone definition 45 | if 'mobilenet_v2' in self.args.arch: 46 | self.backbone = getattr(mobilenetv2_backbone, args.arch)(pretrained=False) 47 | elif 'mobilenet' in self.args.arch: 48 | self.backbone = getattr(mobilenetv1_backbone, args.arch)() 49 | elif 'resnet' in self.args.arch: 50 | self.backbone = getattr(resnet_backbone, args.arch)(pretrained=False) 51 | elif 'ghostnet' in self.args.arch: 52 | self.backbone = getattr(ghostnet_backbone, args.arch)() 53 | elif 'resnest' in self.args.arch: 54 | self.backbone = resnest50() 55 | else: 56 | raise RuntimeError("Please choose [mobilenet_v2, mobilenet_1, resnet50, or ghostnet]") 57 | 58 | def forward(self,input, target): 59 | """Training time forward""" 60 | _3D_attr, avgpool = self.backbone(input) 61 | _3D_attr_GT = target.type(torch.cuda.FloatTensor) 62 | return _3D_attr, _3D_attr_GT, avgpool 63 | 64 | def forward_test(self, input): 65 | """ Testing time forward.""" 66 | _3D_attr, avgpool = self.backbone(input) 67 | return _3D_attr, avgpool 68 | 69 | # Main model SynergyNet definition 70 | class SynergyNet(nn.Module): 71 | def __init__(self): 72 | super(SynergyNet, self).__init__() 73 | self.triangles = sio.loadmat(os.path.join(prefix_path, '3dmm_data/tri.mat'))['tri'] -1 74 | self.triangles = torch.Tensor(self.triangles.astype(np.int64)).long() 75 | args = types.SimpleNamespace() 76 | args.arch = 'mobilenet_v2' 77 | args.checkpoint_fp = os.path.join(prefix_path, 'pretrained/best.pth.tar') 78 | 79 | # Image-to-parameter 80 | self.I2P = I2P(args) 81 | # Forward 82 | self.forwardDirection = MLP_for(68) 83 | # Reverse 84 | self.reverseDirection = MLP_rev(68) 85 | self.LMKLoss_3D = WingLoss() 86 | self.ParamLoss = ParamLoss() 87 | 88 | self.loss = {'loss_LMK_f0':0.0, 89 | 'loss_LMK_pointNet': 0.0, 90 | 'loss_Param_In':0.0, 91 | 'loss_Param_S2': 0.0, 92 | 'loss_Param_S1S2': 0.0, 93 | } 94 | 95 | self.register_buffer('param_mean', torch.Tensor(param_pack.param_mean)) 96 | self.register_buffer('param_std', torch.Tensor(param_pack.param_std)) 97 | self.register_buffer('w_shp', torch.Tensor(param_pack.w_shp)) 98 | self.register_buffer('u', torch.Tensor(param_pack.u)) 99 | self.register_buffer('w_exp', torch.Tensor(param_pack.w_exp)) 100 | 101 | # Online training needs these to parallel 102 | self.register_buffer('u_base', torch.Tensor(param_pack.u_base)) 103 | self.register_buffer('w_shp_base', torch.Tensor(param_pack.w_shp_base)) 104 | self.register_buffer('w_exp_base', torch.Tensor(param_pack.w_exp_base)) 105 | self.keypoints = torch.Tensor(param_pack.keypoints).long() 106 | 107 | self.data_param = [self.param_mean, self.param_std, self.w_shp_base, self.u_base, self.w_exp_base] 108 | 109 | try: 110 | print("loading weights from ", args.checkpoint_fp) 111 | self.load_weights(args.checkpoint_fp) 112 | except: 113 | pass 114 | self.eval() 115 | 116 | def reconstruct_vertex_62(self, param, whitening=True, dense=False, transform=True, lmk_pts=68): 117 | """ 118 | Whitening param -> 3d vertex, based on the 3dmm param: u_base, w_shp, w_exp 119 | dense: if True, return dense vertex, else return 68 sparse landmarks. All dense or sparse vertex is transformed to 120 | image coordinate space, but without alignment caused by face cropping. 121 | transform: whether transform to image space 122 | Working with batched tensors. Using Fortan-type reshape. 123 | """ 124 | 125 | if whitening: 126 | if param.shape[1] == 62: 127 | param_ = param * self.param_std[:62] + self.param_mean[:62] 128 | else: 129 | raise RuntimeError('length of params mismatch') 130 | 131 | p, offset, alpha_shp, alpha_exp = parse_param_62(param_) 132 | 133 | if dense: 134 | 135 | vertex = p @ (self.u + self.w_shp @ alpha_shp + self.w_exp @ alpha_exp).contiguous().view(-1, 53215, 3).transpose(1,2) + offset 136 | 137 | if transform: 138 | # transform to image coordinate space 139 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 140 | 141 | else: 142 | """For 68 pts""" 143 | vertex = p @ (self.u_base + self.w_shp_base @ alpha_shp + self.w_exp_base @ alpha_exp).contiguous().view(-1, lmk_pts, 3).transpose(1,2) + offset 144 | 145 | if transform: 146 | # transform to image coordinate space 147 | vertex[:, 1, :] = param_pack.std_size + 1 - vertex[:, 1, :] 148 | 149 | return vertex 150 | 151 | def forward_test(self, input): 152 | """test time forward""" 153 | _3D_attr, _ = self.I2P.forward_test(input) 154 | return _3D_attr 155 | 156 | def load_weights(self, path): 157 | model_dict = self.state_dict() 158 | checkpoint = torch.load(path, map_location=lambda storage, loc: storage)['state_dict'] 159 | 160 | # because the model is trained by multiple gpus, prefix 'module' should be removed 161 | for k in checkpoint.keys(): 162 | model_dict[k.replace('module.', '')] = checkpoint[k] 163 | 164 | self.load_state_dict(model_dict, strict=False) 165 | 166 | 167 | def get_all_outputs(self, input): 168 | """convenient api to get 3d landmarks, face pose, 3d faces""" 169 | 170 | face_boxes = FaceBoxes() 171 | rects = face_boxes(input) 172 | 173 | # storage 174 | pts_res = [] 175 | poses = [] 176 | vertices_lst = [] 177 | for idx, rect in enumerate(rects): 178 | roi_box = rect 179 | 180 | # enlarge the bbox a little and do a square crop 181 | HCenter = (rect[1] + rect[3])/2 182 | WCenter = (rect[0] + rect[2])/2 183 | side_len = roi_box[3]-roi_box[1] 184 | margin = side_len * 1.2 // 2 185 | roi_box[0], roi_box[1], roi_box[2], roi_box[3] = WCenter-margin, HCenter-margin, WCenter+margin, HCenter+margin 186 | 187 | img = crop_img(input, roi_box) 188 | img = cv2.resize(img, dsize=(120, 120), interpolation=cv2.INTER_LANCZOS4) 189 | img = torch.from_numpy(img) 190 | img = img.permute(2,0,1) 191 | img = img.unsqueeze(0) 192 | img = (img - 127.5)/ 128.0 193 | 194 | with torch.no_grad(): 195 | param = self.forward_test(img) 196 | 197 | param = param.squeeze().cpu().numpy().flatten().astype(np.float32) 198 | 199 | lmks = predict_sparseVert(param, roi_box, transform=True) 200 | vertices = predict_denseVert(param, roi_box, transform=True) 201 | angles, translation = predict_pose(param, roi_box) 202 | 203 | pts_res.append(lmks) 204 | vertices_lst.append(vertices) 205 | poses.append([angles, translation]) 206 | 207 | return pts_res, vertices_lst, poses 208 | 209 | if __name__ == '__main__': 210 | pass 211 | -------------------------------------------------------------------------------- /train_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | LOG_ALIAS=$1 4 | LOG_DIR="ckpts/logs" 5 | mkdir -p ${LOG_DIR} 6 | 7 | LOG_FILE="${LOG_DIR}/`date +'%Y-%m-%d_%H:%M.%S'`.log" 8 | 9 | python3 main_train.py --arch="mobilenet_v2" \ 10 | --start-epoch=1 \ 11 | --snapshot="ckpts/SynergyNet" \ 12 | --param-fp-train='./3dmm_data/param_all_norm_v201.pkl' \ 13 | --warmup=5 \ 14 | --batch-size=1024 \ 15 | --base-lr=0.08 \ 16 | --epochs=80 \ 17 | --milestones=48,64 \ 18 | --print-freq=50 \ 19 | --devices-id=0 \ 20 | --workers=8 \ 21 | --filelists-train="./3dmm_data/train_aug_120x120.list.train" \ 22 | --root="./train_aug_120x120" \ 23 | --log-file="${LOG_FILE}" \ 24 | --test_initial=True \ 25 | --save_val_freq=5 \ 26 | --resume="" \ 27 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/choyingw/SynergyNet/1352b871e58a3169ecf50312aa6185a6412b5e08/utils/__init__.py -------------------------------------------------------------------------------- /utils/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from utils.params import ParamsPack 4 | param_pack = ParamsPack() 5 | from math import cos, sin, atan2, asin, sqrt 6 | import cv2 7 | 8 | def write_obj(obj_name, vertices, triangles): 9 | triangles = triangles.copy() # meshlab start with 1 10 | 11 | if obj_name.split('.')[-1] != 'obj': 12 | obj_name = obj_name + '.obj' 13 | 14 | # write obj 15 | with open(obj_name, 'w') as f: 16 | # write vertices & colors 17 | for i in range(vertices.shape[1]): 18 | s = 'v {:.4f} {:.4f} {:.4f}\n'.format(vertices[0, i], vertices[1, i], vertices[2, i]) 19 | f.write(s) 20 | # write f: ver ind/ uv ind 21 | for i in range(triangles.shape[1]): 22 | s = 'f {} {} {}\n'.format(triangles[2, i], triangles[1, i], triangles[0, i]) 23 | f.write(s) 24 | 25 | def parse_param(param): 26 | p_ = param[:12].reshape(3, 4) 27 | p = p_[:, :3] 28 | offset = p_[:, -1].reshape(3, 1) 29 | alpha_shp = param[12:52].reshape(40, 1) 30 | alpha_exp = param[52:62].reshape(10, 1) 31 | return p, offset, alpha_shp, alpha_exp 32 | 33 | def P2sRt(P): 34 | t3d = P[:, 3] 35 | R1 = P[0:1, :3] 36 | R2 = P[1:2, :3] 37 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 38 | r1 = R1 / np.linalg.norm(R1) 39 | r2 = R2 / np.linalg.norm(R2) 40 | r3 = np.cross(r1, r2) 41 | 42 | R = np.concatenate((r1, r2, r3), 0) 43 | return s, R, t3d 44 | 45 | def matrix2angle_corr(R): 46 | if R[2, 0] != 1 and R[2, 0] != -1: 47 | x = asin(R[2, 0]) 48 | y = atan2(R[1, 2] / cos(x), R[2, 2] / cos(x)) 49 | z = atan2(R[0, 1] / cos(x), R[0, 0] / cos(x)) 50 | 51 | else: # Gimbal lock 52 | z = 0 53 | if R[2, 0] == -1: 54 | x = np.pi / 2 55 | y = z + atan2(R[0, 1], R[0, 2]) 56 | else: 57 | x = -np.pi / 2 58 | y = -z + atan2(-R[0, 1], -R[0, 2]) 59 | 60 | rx, ry, rz = x*180/np.pi, y*180/np.pi, z*180/np.pi 61 | 62 | return [rx, ry, rz] 63 | 64 | def param2vert(param, dense=False, transform=True): 65 | if param.shape[0] == 62: 66 | param_ = param * param_pack.param_std[:62] + param_pack.param_mean[:62] 67 | else: 68 | raise RuntimeError('length of params mismatch') 69 | 70 | p, offset, alpha_shp, alpha_exp = parse_param(param_) 71 | 72 | if dense: 73 | vertex = p @ (param_pack.u + param_pack.w_shp @ alpha_shp + param_pack.w_exp @ alpha_exp).reshape(3, -1, order='F') + offset 74 | if transform: 75 | # transform to image coordinate space 76 | vertex[1, :] = param_pack.std_size + 1 - vertex[1, :] 77 | 78 | else: 79 | vertex = p @ (param_pack.u_base + param_pack.w_shp_base @ alpha_shp + param_pack.w_exp_base @ alpha_exp).reshape(3, -1, order='F') + offset 80 | if transform: 81 | # transform to image coordinate space 82 | vertex[1, :] = param_pack.std_size + 1 - vertex[1, :] 83 | 84 | return vertex 85 | 86 | def parse_pose(param): 87 | param = param * param_pack.param_std[:62] + param_pack.param_mean[:62] 88 | Ps = param[:12].reshape(3, -1) # camera matrix 89 | s, R, t3d = P2sRt(Ps) 90 | P = np.concatenate((R, t3d.reshape(3, -1)), axis=1) # without scale 91 | pose = matrix2angle_corr(R) # yaw, pitch, roll 92 | return P, pose, t3d 93 | 94 | 95 | def crop_img(img, roi_box): 96 | h, w = img.shape[:2] 97 | 98 | sx, sy, ex, ey, _ = [int(round(_)) for _ in roi_box] 99 | dh, dw = ey - sy, ex - sx 100 | if len(img.shape) == 3: 101 | res = np.zeros((dh, dw, 3), dtype=np.uint8) 102 | else: 103 | res = np.zeros((dh, dw), dtype=np.uint8) 104 | if sx < 0: 105 | sx, dsx = 0, -sx 106 | else: 107 | dsx = 0 108 | 109 | if ex > w: 110 | ex, dex = w, dw - (ex - w) 111 | else: 112 | dex = dw 113 | 114 | if sy < 0: 115 | sy, dsy = 0, -sy 116 | else: 117 | dsy = 0 118 | 119 | if ey > h: 120 | ey, dey = h, dh - (ey - h) 121 | else: 122 | dey = dh 123 | 124 | res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex] 125 | return res 126 | 127 | def _predict_vertices(param, roi_bbox, dense, transform=True): 128 | vertex = param2vert(param, dense=dense, transform=transform) 129 | sx, sy, ex, ey, _ = roi_bbox 130 | scale_x = (ex - sx) / 120 131 | scale_y = (ey - sy) / 120 132 | vertex[0, :] = vertex[0, :] * scale_x + sx 133 | vertex[1, :] = vertex[1, :] * scale_y + sy 134 | 135 | s = (scale_x + scale_y) / 2 136 | vertex[2, :] *= s 137 | 138 | return vertex 139 | 140 | def predict_sparseVert(param, roi_box, transform=False): 141 | return _predict_vertices(param, roi_box, dense=False, transform=transform) 142 | 143 | def predict_denseVert(param, roi_box, transform=False): 144 | return _predict_vertices(param, roi_box, dense=True, transform=transform) 145 | 146 | def predict_pose(param, roi_bbox, ret_mat=False): 147 | P, angles, t3d = parse_pose(param) 148 | 149 | sx, sy, ex, ey, _ = roi_bbox 150 | scale_x = (ex - sx) / 120 151 | scale_y = (ey - sy) / 120 152 | t3d[0] = t3d[0] * scale_x + sx 153 | t3d[1] = t3d[1] * scale_y + sy 154 | 155 | if ret_mat: 156 | return P 157 | return angles, t3d 158 | 159 | def draw_landmarks(img, pts, wfp): 160 | height, width = img.shape[:2] 161 | base = 6.4 162 | plt.figure(figsize=(base, height / width * base)) 163 | plt.imshow(img[:, :, ::-1]) 164 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 165 | plt.axis('off') 166 | 167 | if not type(pts) in [tuple, list]: 168 | pts = [pts] 169 | for i in range(len(pts)): 170 | alpha = 0.8 171 | markersize = 1.5 172 | lw = 0.7 173 | color = 'g' 174 | markeredgecolor = 'green' 175 | 176 | nums = [0, 17, 22, 27, 31, 36, 42, 48, 60, 68] 177 | 178 | # close eyes and mouths 179 | plot_close = lambda i1, i2: plt.plot([pts[i][0, i1], pts[i][0, i2]], [pts[i][1, i1], pts[i][1, i2]], 180 | color=color, lw=lw, alpha=alpha - 0.1) 181 | plot_close(41, 36) 182 | plot_close(47, 42) 183 | plot_close(59, 48) 184 | plot_close(67, 60) 185 | 186 | for ind in range(len(nums) - 1): 187 | l, r = nums[ind], nums[ind + 1] 188 | plt.plot(pts[i][0, l:r], pts[i][1, l:r], color=color, lw=lw, alpha=alpha - 0.1) 189 | 190 | plt.plot(pts[i][0, l:r], pts[i][1, l:r], marker='o', linestyle='None', markersize=markersize, 191 | color=color, 192 | markeredgecolor=markeredgecolor, alpha=alpha) 193 | 194 | plt.savefig(wfp, dpi=200) 195 | print('Save landmark result to {}'.format(wfp)) 196 | plt.close() 197 | 198 | 199 | def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size = 100, pts68=None): 200 | pitch = pitch * np.pi / 180 201 | yaw = -(yaw * np.pi / 180) 202 | roll = roll * np.pi / 180 203 | 204 | if tdx != None and tdy != None: 205 | tdx = tdx 206 | tdy = tdy 207 | else: 208 | height, width = img.shape[:2] 209 | tdx = width / 2 210 | tdy = height / 2 211 | 212 | tdx = pts68[0,30] 213 | tdy = pts68[1,30] 214 | 215 | 216 | minx, maxx = np.min(pts68[0, :]), np.max(pts68[0, :]) 217 | miny, maxy = np.min(pts68[1, :]), np.max(pts68[1, :]) 218 | llength = sqrt((maxx - minx) * (maxy - miny)) 219 | size = llength * 0.5 220 | 221 | 222 | # if pts8 != None: 223 | # tdx = 224 | 225 | # X-Axis pointing to right. drawn in red 226 | x1 = size * (cos(yaw) * cos(roll)) + tdx 227 | y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy 228 | 229 | # Y-Axis | drawn in green 230 | # v 231 | x2 = size * (-cos(yaw) * sin(roll)) + tdx 232 | y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy 233 | 234 | # Z-Axis (out of the screen) drawn in blue 235 | x3 = size * (sin(yaw)) + tdx 236 | y3 = size * (-cos(yaw) * sin(pitch)) + tdy 237 | 238 | minus=0 239 | 240 | cv2.line(img, (int(tdx), int(tdy)-minus), (int(x1),int(y1)),(0,0,255),4) 241 | cv2.line(img, (int(tdx), int(tdy)-minus), (int(x2),int(y2)),(0,255,0),4) 242 | cv2.line(img, (int(tdx), int(tdy)-minus), (int(x3),int(y3)),(255,0,0),4) 243 | 244 | return img 245 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import pickle 5 | import scipy.io as sio 6 | 7 | 8 | def mkdir(d): 9 | """only works on *nix system""" 10 | if not os.path.isdir(d) and not os.path.exists(d): 11 | os.system('mkdir -p {}'.format(d)) 12 | 13 | 14 | def _get_suffix(filename): 15 | """a.jpg -> jpg""" 16 | pos = filename.rfind('.') 17 | if pos == -1: 18 | return '' 19 | return filename[pos + 1:] 20 | 21 | 22 | def _load(fp): 23 | suffix = _get_suffix(fp) 24 | if suffix == 'npy': 25 | return np.load(fp) 26 | elif suffix == 'pkl': 27 | return pickle.load(open(fp, 'rb')) 28 | 29 | 30 | def _dump(wfp, obj): 31 | suffix = _get_suffix(wfp) 32 | if suffix == 'npy': 33 | np.save(wfp, obj) 34 | elif suffix == 'pkl': 35 | pickle.dump(obj, open(wfp, 'wb')) 36 | else: 37 | raise Exception('Unknown Type: {}'.format(suffix)) 38 | 39 | 40 | def _load_tensor(fp, mode='cpu'): 41 | if mode.lower() == 'cpu': 42 | return torch.from_numpy(_load(fp)) 43 | elif mode.lower() == 'gpu': 44 | return torch.from_numpy(_load(fp)).cuda() 45 | 46 | 47 | def _tensor_to_cuda(x): 48 | if x.is_cuda: 49 | return x 50 | else: 51 | return x.cuda() 52 | 53 | 54 | def _load_gpu(fp): 55 | return torch.from_numpy(_load(fp)).cuda() 56 | 57 | 58 | def load_bfm(model_path): 59 | suffix = _get_suffix(model_path) 60 | if suffix == 'mat': 61 | C = sio.loadmat(model_path) 62 | model = C['model_refine'] 63 | model = model[0, 0] 64 | 65 | model_new = {} 66 | w_shp = model['w'].astype(np.float32) 67 | model_new['w_shp_sim'] = w_shp[:, :40] 68 | w_exp = model['w_exp'].astype(np.float32) 69 | model_new['w_exp_sim'] = w_exp[:, :10] 70 | 71 | u_shp = model['mu_shape'] 72 | u_exp = model['mu_exp'] 73 | u = (u_shp + u_exp).astype(np.float32) 74 | model_new['mu'] = u 75 | model_new['tri'] = model['tri'].astype(np.int32) - 1 76 | 77 | # flatten it, pay attention to index value 78 | keypoints = model['keypoints'].astype(np.int32) - 1 79 | keypoints = np.concatenate((3 * keypoints, 3 * keypoints + 1, 3 * keypoints + 2), axis=0) 80 | 81 | model_new['keypoints'] = keypoints.T.flatten() 82 | 83 | # 84 | w = np.concatenate((w_shp, w_exp), axis=1) 85 | w_base = w[keypoints] 86 | w_norm = np.linalg.norm(w, axis=0) 87 | w_base_norm = np.linalg.norm(w_base, axis=0) 88 | 89 | dim = w_shp.shape[0] // 3 90 | u_base = u[keypoints].reshape(-1, 1) 91 | w_shp_base = w_shp[keypoints] 92 | w_exp_base = w_exp[keypoints] 93 | 94 | model_new['w_norm'] = w_norm 95 | model_new['w_base_norm'] = w_base_norm 96 | model_new['dim'] = dim 97 | model_new['u_base'] = u_base 98 | model_new['w_shp_base'] = w_shp_base 99 | model_new['w_exp_base'] = w_exp_base 100 | 101 | _dump(model_path.replace('.mat', '.pkl'), model_new) 102 | return model_new 103 | else: 104 | return _load(model_path) 105 | 106 | 107 | _load_cpu = _load 108 | _numpy_to_tensor = lambda x: torch.from_numpy(x) 109 | _tensor_to_numpy = lambda x: x.cpu() 110 | _numpy_to_cuda = lambda x: _tensor_to_cuda(torch.from_numpy(x)) 111 | _cuda_to_tensor = lambda x: x.cpu() 112 | _cuda_to_numpy = lambda x: x.cpu().numpy() 113 | -------------------------------------------------------------------------------- /utils/params.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | from .io import _load 4 | 5 | def make_abs_path(d): 6 | return osp.join(osp.dirname(osp.realpath(__file__)), d) 7 | 8 | class ParamsPack(): 9 | """Parameter package""" 10 | def __init__(self): 11 | try: 12 | d = make_abs_path('../3dmm_data') 13 | self.keypoints = _load(osp.join(d, 'keypoints_sim.npy')) 14 | 15 | # PCA basis for shape, expression, texture 16 | self.w_shp = _load(osp.join(d, 'w_shp_sim.npy')) 17 | self.w_exp = _load(osp.join(d, 'w_exp_sim.npy')) 18 | # param_mean and param_std are used for re-whitening 19 | meta = _load(osp.join(d, 'param_whitening.pkl')) 20 | self.param_mean = meta.get('param_mean') 21 | self.param_std = meta.get('param_std') 22 | # mean values 23 | self.u_shp = _load(osp.join(d, 'u_shp.npy')) 24 | self.u_exp = _load(osp.join(d, 'u_exp.npy')) 25 | self.u = self.u_shp + self.u_exp 26 | self.w = np.concatenate((self.w_shp, self.w_exp), axis=1) 27 | # base vector for landmarks 28 | self.w_base = self.w[self.keypoints] 29 | self.w_norm = np.linalg.norm(self.w, axis=0) 30 | self.w_base_norm = np.linalg.norm(self.w_base, axis=0) 31 | self.u_base = self.u[self.keypoints].reshape(-1, 1) 32 | self.w_shp_base = self.w_shp[self.keypoints] 33 | self.w_exp_base = self.w_exp[self.keypoints] 34 | self.std_size = 120 35 | self.dim = self.w_shp.shape[0] // 3 36 | except: 37 | raise RuntimeError('Missing data') -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | # modified from 3DDFA-V2 2 | 3 | import sys 4 | 5 | sys.path.append('..') 6 | 7 | import cv2 8 | import numpy as np 9 | import scipy.io as sio 10 | 11 | from Sim3DR import RenderPipeline 12 | 13 | def _to_ctype(arr): 14 | if not arr.flags.c_contiguous: 15 | return arr.copy(order='C') 16 | return arr 17 | 18 | cfg = { 19 | 'intensity_ambient': 0.75, 20 | 'color_ambient': (1, 1, 1), 21 | 'intensity_directional': 0.7, 22 | 'color_directional': (1, 1, 1), 23 | 'intensity_specular': 0.2, 24 | 'specular_exp': 5, 25 | 'light_pos': (0, 0, 5), 26 | 'view_pos': (0, 0, 5) 27 | } 28 | 29 | render_app = RenderPipeline(**cfg) 30 | 31 | def render(img, ver_lst, alpha=0.6, wfp=None, tex=None, connectivity=None): 32 | tri = sio.loadmat('./3dmm_data/tri.mat')['tri'] - 1 33 | tri = _to_ctype(tri.T).astype(np.int32) 34 | # save solid mesh rendering and alpha overlaying on images 35 | if not connectivity is None: 36 | tri = _to_ctype(connectivity.T).astype(np.int32) 37 | 38 | overlap = img.copy() 39 | for ver_ in ver_lst: 40 | ver_ = ver_.astype(np.float32) 41 | ver = _to_ctype(ver_.T) # transpose 42 | overlap = render_app(ver, tri, overlap, texture=tex) 43 | cv2.imwrite(wfp[:-4]+'_solid'+'.png', overlap) 44 | 45 | res = cv2.addWeighted(img, 1 - alpha, overlap, alpha, 0) 46 | if wfp is not None: 47 | cv2.imwrite(wfp, res) 48 | print(f'Save mesh result to {wfp}') 49 | 50 | return res 51 | -------------------------------------------------------------------------------- /uv_texture_realFaces.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import cv2 5 | from utils.ddfa import ToTensor, Normalize 6 | from model_building import SynergyNet 7 | from utils.inference import crop_img, predict_denseVert 8 | import argparse 9 | import torch.backends.cudnn as cudnn 10 | cudnn.benchmark = True 11 | import os 12 | import os.path as osp 13 | import glob 14 | from FaceBoxes import FaceBoxes 15 | from utils.render import render 16 | 17 | 18 | # Following 3DDFA-V2, we also use 120x120 resolution 19 | IMG_SIZE = 120 20 | 21 | def write_obj_with_colors(obj_name, vertices, triangles, colors): 22 | triangles = triangles.copy() 23 | 24 | if obj_name.split('.')[-1] != 'obj': 25 | obj_name = obj_name + '.obj' 26 | with open(obj_name, 'w') as f: 27 | for i in range(vertices.shape[1]): 28 | s = 'v {:.4f} {:.4f} {:.4f} {} {} {}\n'.format(vertices[0, i], vertices[1, i], vertices[2, i], colors[i, 2], 29 | colors[i, 1], colors[i, 0]) 30 | f.write(s) 31 | for i in range(triangles.shape[1]): 32 | s = 'f {} {} {}\n'.format(triangles[0, i], triangles[1, i], triangles[2, i]) 33 | f.write(s) 34 | 35 | def main(args): 36 | # load pre-tained model 37 | checkpoint_fp = 'pretrained/best.pth.tar' 38 | args.arch = 'mobilenet_v2' 39 | args.devices_id = [0] 40 | 41 | checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict'] 42 | 43 | model = SynergyNet(args) 44 | model_dict = model.state_dict() 45 | 46 | # load BFM_UV mapping and kept indicies and deleted triangles 47 | uv_vert=np.load('3dmm_data/BFM_UV.npy') 48 | coord_u = (uv_vert[:,1]*255.0).astype(np.int32) 49 | coord_v = (uv_vert[:,0]*255.0).astype(np.int32) 50 | keep_ind = np.load('3dmm_data/keptInd.npy') 51 | tri_deletion = np.load('3dmm_data/deletedTri.npy') 52 | 53 | # because the model is trained by multiple gpus, prefix 'module' should be removed 54 | for k in checkpoint.keys(): 55 | model_dict[k.replace('module.', '')] = checkpoint[k] 56 | 57 | model.load_state_dict(model_dict, strict=False) 58 | model = model.cuda() 59 | model.eval() 60 | 61 | # face detector 62 | face_boxes = FaceBoxes() 63 | 64 | # preparation 65 | transform = transforms.Compose([ToTensor(), Normalize(mean=127.5, std=128)]) 66 | if osp.isdir(args.files): 67 | if not args.files[-1] == '/': 68 | args.files = args.files + '/' 69 | if not args.png: 70 | files = sorted(glob.glob(args.files+'*.jpg')) 71 | else: 72 | files = sorted(glob.glob(args.files+'*.png')) 73 | else: 74 | files = [args.files] 75 | 76 | for img_fp in files: 77 | print("Process the image: ", img_fp) 78 | 79 | img_ori = cv2.imread(img_fp) 80 | 81 | # crop faces 82 | rect = [0,0,256,256,1.0] # pre-cropped faces 83 | 84 | # storage 85 | vertices_lst = [] 86 | roi_box = rect 87 | img = crop_img(img_ori, roi_box) 88 | img = cv2.resize(img, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) 89 | 90 | input = transform(img).unsqueeze(0) 91 | with torch.no_grad(): 92 | input = input.cuda() 93 | param = model.forward_test(input) 94 | param = param.squeeze().cpu().numpy().flatten().astype(np.float32) 95 | 96 | # dense pts 97 | vertices = predict_denseVert(param, roi_box, transform=True) 98 | vertices = vertices[:,keep_ind] 99 | vertices_lst.append(vertices) 100 | 101 | # textured obj file output 102 | if not osp.exists(f'inference_output/obj/'): 103 | os.makedirs(f'inference_output/obj/') 104 | if not osp.exists(f'inference_output/rendering_overlay/'): 105 | os.makedirs(f'inference_output/rendering_overlay/') 106 | 107 | name = img_fp.rsplit('/',1)[-1][:-11] # drop off the postfix 108 | colors = cv2.imread(f'texture_data/uv_real/{name}_fake_B.png',-1) 109 | colors = np.flip(colors,axis=0) 110 | colors_uv = (colors[coord_u, coord_v,:]) 111 | 112 | wfp = f'inference_output/obj/{name}.obj' 113 | write_obj_with_colors(wfp, vertices, tri_deletion, colors_uv[keep_ind,:].astype(np.float32)) 114 | 115 | tex = colors_uv[keep_ind,:].astype(np.float32)/255.0 116 | render(img_ori, vertices_lst, alpha=0.6, wfp=f'inference_output/rendering_overlay/{name}.jpg', tex=tex, connectivity=tri_deletion-1) 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('-f', '--files', default='', help='path to a single image or path to a folder containing multiple images') 121 | parser.add_argument("--png", action="store_true", help="if images are with .png extension") 122 | parser.add_argument('--img_size', default=120, type=int) 123 | parser.add_argument('-b', '--batch-size', default=1, type=int) 124 | 125 | args = parser.parse_args() 126 | main(args) --------------------------------------------------------------------------------