├── LICENSE ├── README.md ├── data ├── imgs │ ├── front.jpg │ ├── left.jpg │ └── right.jpg └── sigma_exp.mat ├── model.py ├── test_img.py └── tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Fanzi Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MVF-Net: Multi-View 3D Face Morphable Model Regression 2 | Testing code for the paper. 3 | > [MVF-Net: Multi-View 3D Face Morphable Model Regression](https://arxiv.org/abs/1904.04473). 4 | > Fanzi Wu*, Linchao Bao*, Yajing Chen, Yonggen Ling, Yibing Song, Songnan Li, King Ngi Ngan, Wei Liu. 5 | > CVPR 2019. 6 | 7 | ## Installation 8 | 1. Python 2.7 (Numpy, PIL, scipy) 9 | 2. Pytorch 0.4.0, torchvision 10 | 3. face-alignment package from [https://github.com/1adrianb/face-alignment](https://github.com/1adrianb/face-alignment). This code is used for face cropping and will be replaced by face detection algorithm in the future. 11 | 12 | 4. `Model_shape.mat` and `Model_Expression.mat` from [3DDFA](http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm). 13 | ## Test 14 | You can download the CNN model from [here](https://www.dropbox.com/s/7ds3aesjjmybjh9/net.pth?dl=0) and copy it into `data` folder. 15 | Then you can test the model by: 16 | ``` 17 | python test_img.py --image_path ./data/imgs --save_dir ./result 18 | ``` 19 | If you are testing the code with your own images, please organize multiview images as: 20 | ``` 21 | folder 22 | +--front.jpg 23 | +--left.jpg 24 | +--right.jpg 25 | ``` 26 | and change `line 15` in `test_img.py` as: 27 | ``` 28 | crop_opt = True 29 | ``` 30 | ## Citation 31 | If you find this work useful in your research, please cite: 32 | ``` 33 | @inproceedings{wu2019mvf, 34 | title={MVF-Net: Multi-View 3D Face Morphable Model Regression}, 35 | author={Wu, Fanzi and Bao, Linchao and Chen, Yajing and Ling, Yonggen and Song, Yibing and Li, Songnan and Ngan, King Ngi and Liu, Wei}, 36 | booktitle={CVPR}, 37 | year={2019} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /data/imgs/front.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanziapril/mvfnet/bb7ce9b53368cd213489e41f29bf6ce6ae200f9d/data/imgs/front.jpg -------------------------------------------------------------------------------- /data/imgs/left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanziapril/mvfnet/bb7ce9b53368cd213489e41f29bf6ce6ae200f9d/data/imgs/left.jpg -------------------------------------------------------------------------------- /data/imgs/right.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanziapril/mvfnet/bb7ce9b53368cd213489e41f29bf6ce6ae200f9d/data/imgs/right.jpg -------------------------------------------------------------------------------- /data/sigma_exp.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fanziapril/mvfnet/bb7ce9b53368cd213489e41f29bf6ce6ae200f9d/data/sigma_exp.mat -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from collections import OrderedDict 5 | import torchvision.models as tvmodel 6 | 7 | def reset_params(net): 8 | for m in net.modules(): 9 | if isinstance(m, nn.Conv2d): 10 | nn.init.normal(m.weight, 0.0, 0.02) 11 | if m.bias is not None: 12 | nn.init.constant(m.bias, 0) 13 | elif isinstance(m, nn.Linear): 14 | nn.init.normal(m.weight, 0.0, 0.0001) 15 | if m.bias is not None: 16 | nn.init.constant(m.bias, 0) 17 | elif isinstance(m, nn.BatchNorm2d): 18 | nn.init.constant(m.weight, 1) 19 | nn.init.normal(m.weight, 1.0, 0.02) 20 | nn.init.constant(m.bias, 0) 21 | 22 | 23 | class VggEncoder(nn.Module): 24 | def __init__(self): 25 | super(VggEncoder, self).__init__() 26 | 27 | self.featChannel = 512 28 | self.layer1 = tvmodel.vgg16_bn(pretrained=True).features 29 | self.layer1 = nn.Sequential(OrderedDict([ 30 | ('conv1', nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1))), 31 | ('bn1', nn.BatchNorm2d(64)), 32 | ('relu1', nn.ReLU(True)), 33 | ('pool1', nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True)), 34 | 35 | ('conv2', nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1))), 36 | ('bn2', nn.BatchNorm2d(128)), 37 | ('relu2', nn.ReLU(True)), 38 | ('pool2', nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True)), 39 | 40 | ('conv3', nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1))), 41 | ('bn3', nn.BatchNorm2d(256)), 42 | ('relu3', nn.ReLU(True)), 43 | 44 | ('conv4', nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1))), 45 | ('bn4', nn.BatchNorm2d(256)), 46 | ('relu4', nn.ReLU(True)), 47 | ('pool3', nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True)), 48 | 49 | ('conv5', nn.Conv2d(256, 512, (3, 3), (1, 1), 1)), 50 | ('bn5', nn.BatchNorm2d(512)), 51 | ('relu5', nn.ReLU(True)), 52 | ('pool4', nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True)), 53 | 54 | ('conv6', nn.Conv2d(512, 512, (3, 3), stride=1, padding=1)), 55 | ('bn6', nn.BatchNorm2d(512)), 56 | ('relu6', nn.ReLU(True)), 57 | 58 | ('conv7', nn.Conv2d(512, 512, (3, 3), (1, 1), 1)), 59 | ('bn7', nn.BatchNorm2d(512)), 60 | ('relu7', nn.ReLU(True)), 61 | ('pool5', nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True)), 62 | ])) 63 | 64 | 65 | 66 | self.fc_3dmm = nn.Sequential(OrderedDict([ 67 | ('fc1', nn.Linear(self.featChannel*3, 256*3)), 68 | ('relu1', nn.ReLU(True)), 69 | ('fc2', nn.Linear(256*3, 228))])) 70 | 71 | self.fc_pose = nn.Sequential(OrderedDict([ 72 | ('fc3', nn.Linear(512, 256)), 73 | ('relu2', nn.ReLU(True)), 74 | ('fc4', nn.Linear(256, 7))])) 75 | reset_params(self.fc_3dmm) 76 | reset_params(self.fc_pose) 77 | 78 | def forward(self, x): 79 | imga = x[:, 0:3, :, :] 80 | feata = self.layer1(imga) 81 | feata = F.avg_pool2d(feata, feata.size()[2:]).view(feata.size(0), feata.size(1)) 82 | posea = self.fc_pose(feata) 83 | imgb = x[:, 3:6, :, :] 84 | featb = self.layer1(imgb) 85 | featb = F.avg_pool2d(featb, featb.size()[2:]).view(featb.size(0), featb.size(1)) 86 | poseb = self.fc_pose(featb) 87 | imgc = x[:, 6:9, :, :] 88 | featc = self.layer1(imgc) 89 | featc = F.avg_pool2d(featc, featc.size()[2:]).view(featc.size(0), featc.size(1)) 90 | posec = self.fc_pose(featc) 91 | para = self.fc_3dmm(torch.cat([feata, featb, featc], dim=1)) 92 | out = torch.cat([para, posea, poseb, posec], dim=1) 93 | return out 94 | 95 | 96 | -------------------------------------------------------------------------------- /test_img.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | from PIL import Image 5 | import tools 6 | import torchvision.transforms as transforms 7 | from model import VggEncoder 8 | import time 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--image_path', type=str, default=None, help='path to load images. It should include image name with: front|left|right') 12 | parser.add_argument('--save_dir', type=str, default='./result', help='path to save 3D face shapes') 13 | 14 | options = parser.parse_args() 15 | crop_opt = True # change to True if you want to crop the image 16 | imgA = Image.open(os.path.join(options.image_path, 'front.jpg')).convert('RGB') 17 | imgB = Image.open(os.path.join(options.image_path, 'left.jpg')).convert('RGB') 18 | imgC = Image.open(os.path.join(options.image_path, 'right.jpg')).convert('RGB') 19 | if crop_opt: 20 | imgA = tools.crop_image(imgA) 21 | imgB = tools.crop_image(imgB) 22 | imgC = tools.crop_image(imgC) 23 | imgA = transforms.functional.to_tensor(imgA) 24 | imgB = transforms.functional.to_tensor(imgB) 25 | imgC = transforms.functional.to_tensor(imgC) 26 | model = VggEncoder() 27 | model = torch.nn.DataParallel(model).cuda() 28 | ckpt = torch.load('data/net.pth') 29 | model.load_state_dict(ckpt) 30 | #print model 31 | input_tensor = torch.cat([imgA, imgB, imgC], 0).view(1, 9, 224, 224).cuda() 32 | start = time.time() 33 | preds = model(input_tensor) 34 | print(time.time() -start) 35 | faces3d = tools.preds_to_shape(preds[0].detach().cpu().numpy()) 36 | tools.write_ply(os.path.join(options.save_dir, 'shape.ply'), faces3d[0], faces3d[1]) 37 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import face_alignment 2 | import numpy as np 3 | import scipy.io as io 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import pandas as pd 7 | import sys 8 | import math 9 | 10 | model_shape = io.loadmat('data/Model_Shape.mat') 11 | kpt_index = np.reshape(model_shape['keypoints'], 68).astype(np.int32) - 1 12 | model_exp = io.loadmat('data/Model_Expression.mat') 13 | data = io.loadmat('data/sigma_exp.mat') 14 | pose_mean = np.array([0,0,0,112,112,0,0]).astype(np.float32) 15 | pose_std = np.array([math.pi/2.0,math.pi/2.0,math.pi/2.0,56,56,1,224.0 / (2 * 180000.0)]).astype(np.float32) 16 | 17 | def angle_to_rotation(angles): 18 | phi = angles[0] 19 | gamma = angles[1] 20 | theta = angles[2] 21 | 22 | R_x = np.eye(3) 23 | R_x[1, 1] = math.cos(phi) 24 | R_x[1, 2] = math.sin(phi) 25 | R_x[2, 1] = - math.sin(phi) 26 | R_x[2, 2] = math.cos(phi) 27 | 28 | R_y = np.eye(3) 29 | R_y[0, 0] = math.cos(gamma) 30 | R_y[0, 2] = - math.sin(gamma) 31 | R_y[2, 0] = math.sin(gamma) 32 | R_y[2, 2] = math.cos(gamma) 33 | 34 | R_z = np.eye(3) 35 | R_z[0, 0] = math.cos(theta) 36 | R_z[0, 1] = math.sin(theta) 37 | R_z[1, 0] = - math.sin(theta) 38 | R_z[1, 1] = math.cos(theta) 39 | 40 | return np.matmul(np.matmul(R_x, R_y), R_z) 41 | 42 | def preds_to_pose(preds): 43 | pose = preds * pose_std + pose_mean 44 | R = angle_to_rotation(pose[:3]) 45 | t2d = pose[3:5] 46 | s = pose[6] 47 | return R, t2d, s 48 | 49 | def preds_to_shape(preds): 50 | # paras = torch.mul(preds[:228, :], label_std[:199+29, :]) 51 | alpha = np.reshape(preds[:199], [199,1]) * np.reshape(model_shape['sigma'], [199,1]) 52 | beta = np.reshape(preds[199:228], [29, 1]) * 1.0/(1000.0 * np.reshape(data['sigma_exp'], [29, 1])) 53 | face_shape = np.matmul(model_shape['w'], alpha) + np.matmul(model_exp['w_exp'], beta) + model_shape['mu_shape'] 54 | face_shape = face_shape.reshape(-1, 3) 55 | 56 | R, t, s = preds_to_pose(preds[228:228+7]) 57 | kptA = np.matmul(face_shape[kpt_index], s*R[:2].transpose()) + np.repeat(np.reshape(t,[1,2]), 68, axis=0) 58 | kptA[:, 1] = 224 - kptA[:, 1] 59 | R, t, s = preds_to_pose(preds[228+7:228+14]) 60 | kptB = np.matmul(face_shape[kpt_index], s*R[:2].transpose()) + np.repeat(np.reshape(t,[1,2]), 68, axis=0) 61 | kptB[:, 1] = 224 - kptB[:, 1] 62 | 63 | R, t, s = preds_to_pose(preds[228+14:]) 64 | kptC = np.matmul(face_shape[kpt_index], s*R[:2].transpose()) + np.repeat(np.reshape(t,[1,2]), 68, axis=0) 65 | kptC[:, 1] = 224 - kptC[:, 1] 66 | return [face_shape, model_shape['tri'].astype(np.int64).transpose() - 1, kptA, kptB, kptC] 67 | 68 | 69 | def crop_image(image, res=224): 70 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False) 71 | pts = fa.get_landmarks(np.array(image)) 72 | if len(pts) < 1: 73 | assert "No face detected!" 74 | pts = np.array(pts[0]).astype(np.int32) 75 | 76 | h = image.size[1] 77 | w = image.size[0] 78 | # x-width-pts[0,:], y-height-pts[1,:] 79 | x_max = np.max(pts[:68, 0]) 80 | x_min = np.min(pts[:68, 0]) 81 | y_max = np.max(pts[:68, 1]) 82 | y_min = np.min(pts[:68, 1]) 83 | bbox = [y_min, x_min, y_max, x_max] 84 | # c (cy, cx) 85 | c = [bbox[2] - (bbox[2] - bbox[0]) / 2, bbox[3] - (bbox[3] - bbox[1]) / 2.0] 86 | c[0] = c[0] - (bbox[2] - bbox[0]) * 0.12 87 | s = (max(bbox[2] - bbox[0], bbox[3] - bbox[1]) * 1.5).astype(np.int32) 88 | old_bb = np.array([c[0] - s / 2, c[1] - s / 2, c[0] + s / 2, c[1] + s / 2]).astype(np.int32) 89 | crop_img = Image.new('RGB', (s, s)) 90 | #crop_img = torch.zeros(image.shape[0], s, s, dtype=torch.float32) 91 | 92 | shift_x = 0 - old_bb[1] 93 | shift_y = 0 - old_bb[0] 94 | old_bb = np.array([max(0, old_bb[0]), max(0, old_bb[1]), 95 | min(h, old_bb[2]), min(w, old_bb[3])]).astype(np.int32) 96 | hb = old_bb[2] - old_bb[0] 97 | wb = old_bb[3] - old_bb[1] 98 | new_bb = np.array([max(0, shift_y), max(0, shift_x), max(0, shift_y) + hb, max(0, shift_x) + wb]).astype(np.int32) 99 | cache = image.crop((old_bb[1], old_bb[0], old_bb[3], old_bb[2])) 100 | crop_img.paste(cache, (new_bb[1], new_bb[0], new_bb[3], new_bb[2])) 101 | crop_img = crop_img.resize((res, res), Image.BICUBIC) 102 | return crop_img 103 | 104 | def write_ply(filename, points=None, mesh=None, colors=None, as_text=True): 105 | points = pd.DataFrame(points, columns=["x", "y", "z"]) 106 | mesh = pd.DataFrame(mesh, columns=["v1", "v2", "v3"]) 107 | if colors is not None: 108 | colors = pd.DataFrame(colors, columns=["red", "green", "blue"]) 109 | points = pd.concat([points, colors], axis=1) 110 | """ 111 | 112 | Parameters 113 | ---------- 114 | filename: str 115 | The created file will be named with this 116 | points: ndarray 117 | mesh: ndarray 118 | as_text: boolean 119 | Set the write mode of the file. Default: binary 120 | 121 | Returns 122 | ------- 123 | boolean 124 | True if no problems 125 | 126 | """ 127 | if not filename.endswith('ply'): 128 | filename += '.ply' 129 | 130 | # open in text mode to write the header 131 | with open(filename, 'w') as ply: 132 | header = ['ply'] 133 | 134 | if as_text: 135 | header.append('format ascii 1.0') 136 | else: 137 | header.append('format binary_' + sys.byteorder + '_endian 1.0') 138 | 139 | if points is not None: 140 | header.extend(describe_element('vertex', points)) 141 | if mesh is not None: 142 | mesh = mesh.copy() 143 | mesh.insert(loc=0, column="n_points", value=3) 144 | mesh["n_points"] = mesh["n_points"].astype("u1") 145 | header.extend(describe_element('face', mesh)) 146 | 147 | header.append('end_header') 148 | 149 | for line in header: 150 | ply.write("%s\n" % line) 151 | 152 | if as_text: 153 | if points is not None: 154 | points.to_csv(filename, sep=" ", index=False, header=False, mode='a', 155 | encoding='ascii') 156 | if mesh is not None: 157 | mesh.to_csv(filename, sep=" ", index=False, header=False, mode='a', 158 | encoding='ascii') 159 | 160 | else: 161 | # open in binary/append to use tofile 162 | with open(filename, 'ab') as ply: 163 | if points is not None: 164 | points.to_records(index=False).tofile(ply) 165 | if mesh is not None: 166 | mesh.to_records(index=False).tofile(ply) 167 | 168 | return True 169 | 170 | def describe_element(name, df): 171 | """ Takes the columns of the dataframe and builds a ply-like description 172 | Parameters 173 | ---------- 174 | name: str 175 | df: pandas DataFrame 176 | Returns 177 | ------- 178 | element: list[str] 179 | """ 180 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'} 181 | element = ['element ' + name + ' ' + str(len(df))] 182 | 183 | if name == 'face': 184 | element.append("property list uchar int vertex_indices") 185 | 186 | else: 187 | for i in range(len(df.columns)): 188 | # get first letter of dtype to infer format 189 | f = property_formats[str(df.dtypes[i])[0]] 190 | element.append('property ' + f + ' ' + str(df.columns.values[i])) 191 | 192 | return element 193 | --------------------------------------------------------------------------------