├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── environment.yml ├── examples ├── 3dprint.jpg ├── jupyter_gif.gif └── sample.jpg ├── pix2vertex ├── __init__.py ├── constants.py ├── detector.py ├── models │ ├── __init__.py │ └── pix2pix.py ├── reconstructor.py └── utils.py ├── reconstruct_pipeline.ipynb ├── requirements.txt ├── setup.py ├── tests └── test_pix2vertex.py └── weights └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | build 3 | dist 4 | *__pycache__* 5 | *.egg-info* 6 | 7 | *.bz2 8 | *.dat 9 | *.pth 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson and Matan Sela 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation - Official PyTorch Implementation 2 | 3 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/eladrich/pix2vertex.pytorch/mybinder?filepath=reconstruct_pipeline.ipynb) 4 | [![PyPI version](https://badge.fury.io/py/pix2vertex.svg)](https://badge.fury.io/py/pix2vertex) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | 8 | [[Arxiv]](https://arxiv.org/pdf/1703.10131.pdf) [[Video]](https://www.youtube.com/watch?v=6lUdSVcBB-k) 9 | 10 | 11 | Evaluation code for Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation. Finally ported to PyTorch! 12 |

13 | 14 |

15 | 16 | 17 | ## Recent Updates 18 | 19 | **`2020.10.27`**: Added STL support 20 | 21 | **`2020.05.07`**: Added a wheel package! 22 | 23 | **`2020.05.06`**: Added [myBinder](https://mybinder.org/v2/gh/eladrich/pix2vertex.pytorch/mybinder?filepath=reconstruct_pipeline.ipynb) version for quick testing of the model 24 | 25 | **`2020.04.30`**: Initial pyTorch release 26 | 27 | # What's in this release? 28 | 29 | The [original pix2vertex repo](https://github.com/matansel/pix2vertex) was composed of three parts 30 | - A network to perform the image to depth + correspondence maps trained on synthetic facial data 31 | - A non-rigid ICP scheme for converting the output maps to a full 3D Mesh 32 | - A shape-from-shading scheme for adding fine mesoscopic details 33 | 34 | 35 | This repo currently contains our image-to-image network with weights and model to `PyTorch` and a simple `python` postprocessing scheme. 36 | - The released network was trained on a combination of synthetic images and unlabeled real images for some extra robustness :) 37 | 38 | ## Installation 39 | Installation from PyPi 40 | ```bash 41 | $ pip install pix2vertex 42 | ``` 43 | Installation from source 44 | ```bash 45 | $ git clone https://github.com/eladrich/pix2vertex.pytorch.git 46 | $ cd pix2vertex.pytorch 47 | $ python setup.py install 48 | ``` 49 | ## Usage 50 | The quickest way to try `p2v` is using the `reconstruct` method over an input image, followed by visualization or STL creation. 51 | ```python 52 | import pix2vertex as p2v 53 | from imageio import imread 54 | 55 | image = imread() 56 | result, crop = p2v.reconstruct(image) 57 | 58 | # Interactive visualization in a notebook 59 | p2v.vis_depth_interactive(result['Z_surface']) 60 | 61 | # Static visualization using matplotlib 62 | p2v.vis_depth_matplotlib(crop, result['Z_surface']) 63 | 64 | # Export to STL 65 | p2v.save2stl(result['Z_surface'], 'res.stl') 66 | ``` 67 | For a more complete example see the `reconstruct_pipeline` notebook. You can give it a try without any installations using our [binder port](https://mybinder.org/v2/gh/eladrich/pix2vertex.pytorch/mybinder?filepath=reconstruct_pipeline.ipynb). 68 | 69 | ### Pretrained Model 70 | Models can be downloaded from these links: 71 | - [pix2vertex model](https://drive.google.com/open?id=1op5_zyH4CWm_JFDdCUPZM4X-A045ETex) 72 | - [dlib landmark predictor](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) - note that the dlib model has its own license. 73 | 74 | If no model path is specified the package automagically downloads the required models. 75 | 76 | 77 | ## TODOs 78 | - [x] Port Torch model to PyTorch 79 | - [x] Release an inference notebook (using [K3D](https://github.com/K3D-tools/K3D-jupyter)) 80 | - [x] Add requirements 81 | - [x] Pack as wheel 82 | - [x] Ported to MyBinder 83 | - [x] Add a simple method to export a stl file for printing 84 | - [ ] Port the Shape-from-Shading method used in our matlab paper 85 | - [ ] Write a short blog about the revised training scheme 86 | 87 | ## Citation 88 | If you use this code for your research, please cite our paper Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation: 89 | 90 | ``` 91 | @article{sela2017unrestricted, 92 | title={Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation}, 93 | author={Sela, Matan and Richardson, Elad and Kimmel, Ron}, 94 | journal={arxiv}, 95 | year={2017} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: py36 2 | channels: 3 | - menpo 4 | - defaults 5 | dependencies: 6 | - python=3.6 7 | - dlib 8 | - pip 9 | - pip: 10 | - matplotlib 11 | - scikit-image 12 | - torch 13 | - colormap 14 | - easydev 15 | - k3d==2.7.4 16 | - tqdm 17 | - six 18 | - requests -------------------------------------------------------------------------------- /examples/3dprint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/examples/3dprint.jpg -------------------------------------------------------------------------------- /examples/jupyter_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/examples/jupyter_gif.gif -------------------------------------------------------------------------------- /examples/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/examples/sample.jpg -------------------------------------------------------------------------------- /pix2vertex/__init__.py: -------------------------------------------------------------------------------- 1 | from .reconstructor import Reconstructor 2 | from .detector import Detector 3 | from .utils import vis_net_result, vis_depth_interactive, vis_pcloud_interactive, vis_depth_matplotlib, save2stl 4 | 5 | reconstructor = None 6 | def reconstruct(image=None, verbose=False): 7 | global reconstructor 8 | if reconstructor is None: 9 | reconstructor = Reconstructor() 10 | if image is None: 11 | import os 12 | from .constants import sample_image 13 | image = os.path.join(os.path.dirname(__file__), sample_image) 14 | print('No image specified, using {} as default input image'.format(image)) 15 | return reconstructor.run(image, verbose) 16 | -------------------------------------------------------------------------------- /pix2vertex/constants.py: -------------------------------------------------------------------------------- 1 | predictor_url = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' 2 | predictor_file = '../weights/shape_predictor_68_face_landmarks.dat' 3 | p2v_model_gdrive_id = '1op5_zyH4CWm_JFDdCUPZM4X-A045ETex' 4 | sample_image = '../examples/sample.jpg' 5 | -------------------------------------------------------------------------------- /pix2vertex/detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dlib 3 | import numpy as np 4 | import math 5 | from skimage.transform import resize 6 | from .utils import download_url, extract_file 7 | 8 | 9 | class Detector: 10 | def __init__(self, predictor_path=None): 11 | self.detector = dlib.get_frontal_face_detector() 12 | self.set_predictor(predictor_path) 13 | 14 | def set_predictor(self, predictor_path): 15 | if predictor_path is None: 16 | from .constants import predictor_file 17 | predictor_path = os.path.join(os.path.dirname(__file__), predictor_file) 18 | print('Loading default detector weights from {}'.format(predictor_path)) 19 | if not os.path.exists(predictor_path): 20 | from .constants import predictor_url 21 | os.makedirs(os.path.dirname(predictor_path), exist_ok=True) 22 | print('\tDownloading weights from {}...'.format(predictor_url)) 23 | download_url(predictor_url, save_path=os.path.dirname(predictor_path)) 24 | print('\tExtracting weights...') 25 | extract_file(predictor_path + '.bz2', os.path.dirname(predictor_path)) 26 | print('\tDone!') 27 | self.predictor = dlib.shape_predictor(predictor_path) 28 | 29 | def detect_and_crop(self, img, img_size=512): 30 | dets = self.detector(img, 1) # Take a single detection 31 | for k, d in enumerate(dets): 32 | print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format( 33 | k, d.left(), d.top(), d.right(), d.bottom())) 34 | dets = dets[0] 35 | shape = self.predictor(img, dets) 36 | points = shape.parts() 37 | 38 | pts = np.array([[p.x, p.y] for p in points]) 39 | min_x = np.min(pts[:, 0]) 40 | min_y = np.min(pts[:, 1]) 41 | max_x = np.max(pts[:, 0]) 42 | max_y = np.max(pts[:, 1]) 43 | box_width = (max_x - min_x) * 1.2 44 | box_height = (max_y - min_y) * 1.2 45 | bbox = np.array([min_y - box_height * 0.3, min_x, box_height, box_width]).astype(np.int) 46 | 47 | img_crop = Detector.adjust_box_and_crop(img, bbox, crop_percent=150, img_size=img_size) 48 | # img_crop = img[bbox[0]:bbox[0]+bbox[2], bbox[1]:bbox[1]+bbox[3], :] 49 | return img_crop 50 | 51 | @staticmethod 52 | def adjust_box_and_crop(img, bbox, crop_percent=100, img_size=None): 53 | w_ext = math.floor(bbox[2]) 54 | h_ext = math.floor(bbox[3]) 55 | bbox_center = np.round(np.array([bbox[0] + 0.5 * bbox[2], bbox[1] + 0.5 * bbox[3]])) 56 | max_ext = np.round(crop_percent / 100 * max(w_ext, h_ext) / 2) 57 | top = max(1, bbox_center[0] - max_ext) 58 | left = max(1, bbox_center[1] - max_ext) 59 | bottom = min(img.shape[0], bbox_center[0] + max_ext) 60 | right = min(img.shape[1], bbox_center[1] + max_ext) 61 | height = bottom - top 62 | width = right - left 63 | # make the frame as square as possible 64 | if height < width: 65 | diff = width - height 66 | top_pad = int(max(0, np.floor(diff / 2) - top + 1)) 67 | top = max(1, top - np.floor(diff / 2)) 68 | bottom_pad = int(max(0, bottom + np.ceil(diff / 2) - img.shape[0])) 69 | bottom = min(img.shape[0], bottom + np.ceil(diff / 2)) 70 | left_pad = 0 71 | right_pad = 0 72 | else: 73 | diff = height - width 74 | left_pad = int(max(0, np.floor(diff / 2) - left + 1)) 75 | left = max(1, left - np.floor(diff / 2)) 76 | right_pad = int(max(0, right + np.ceil(diff / 2) - img.shape[1])) 77 | right = min(img.shape[1], right + np.ceil(diff / 2)) 78 | top_pad = 0 79 | bottom_pad = 0 80 | 81 | # crop the image 82 | img_crop = img[int(top):int(bottom), int(left):int(right), :] 83 | # pad the image 84 | img_crop = np.pad(img_crop, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), 'constant') 85 | if img_size is not None: 86 | img_crop = resize(img_crop, (img_size, img_size))*255 87 | img_crop = img_crop.astype(np.uint8) 88 | return img_crop 89 | -------------------------------------------------------------------------------- /pix2vertex/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/pix2vertex/models/__init__.py -------------------------------------------------------------------------------- /pix2vertex/models/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv_block(in_channels, out_channels): 5 | return nn.Sequential( 6 | nn.LeakyReLU(0.2,inplace=True), 7 | nn.Conv2d(in_channels, out_channels, 4, stride=2,padding=1), 8 | nn.BatchNorm2d(out_channels) 9 | ) 10 | 11 | 12 | def deconv_block(in_channels, out_channels,use_dropout=False): 13 | layers = [ 14 | nn.ReLU(inplace=True), 15 | nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2,padding=1), 16 | nn.BatchNorm2d(out_channels) 17 | ] 18 | if use_dropout: 19 | layers.append(nn.Dropout(0.5)) 20 | return nn.Sequential(*layers) 21 | 22 | class UNet(nn.Module): 23 | 24 | def __init__(self): 25 | super().__init__() 26 | 27 | self.conv_down1 = nn.Conv2d(3,64,4,stride=2,padding=1) 28 | 29 | self.conv_down2 = conv_block(64,128) 30 | 31 | self.conv_down3 = conv_block(128,256) 32 | 33 | self.conv_down4 = conv_block(256,512) 34 | 35 | self.conv_down5 = conv_block(512,512) 36 | 37 | self.conv_down6 = conv_block(512,512) 38 | 39 | self.conv_down7 = conv_block(512,512) 40 | 41 | self.conv_down8 = conv_block(512,512) 42 | 43 | self.conv_up1 = deconv_block(512,512,use_dropout=True) 44 | 45 | self.conv_up2 = deconv_block(1024,512,use_dropout=True) 46 | 47 | self.conv_up3 = deconv_block(1024,512,use_dropout=True) 48 | 49 | self.conv_up4 = deconv_block(1024,512) 50 | 51 | self.conv_up5 = deconv_block(1024,256) 52 | 53 | self.conv_up6 = deconv_block(512,128) 54 | 55 | self.conv_up7 = deconv_block(256,64) 56 | 57 | self.conv_up8 = deconv_block(128,64) 58 | 59 | self.conv_up9 = nn.Sequential( 60 | nn.ReLU(inplace=True), 61 | nn.ConvTranspose2d(64, 64, 3, stride=1,padding=1), 62 | nn.BatchNorm2d(64) 63 | ) 64 | 65 | self.conv_up9 = nn.Sequential( 66 | nn.ReLU(inplace=True), 67 | nn.ConvTranspose2d(64, 64, 3, stride=1,padding=1), 68 | nn.BatchNorm2d(64) 69 | ) 70 | 71 | self.conv_up10 = nn.Sequential( 72 | nn.ReLU(inplace=True), 73 | nn.ConvTranspose2d(64, 32, 3, stride=1,padding=1), 74 | nn.BatchNorm2d(32) 75 | ) 76 | 77 | self.conv_up11 = nn.Sequential( 78 | nn.ReLU(inplace=True), 79 | nn.ConvTranspose2d(32, 7, 3, stride=1,padding=1) 80 | ) 81 | 82 | # TODO: rewrite nicely 83 | 84 | def forward(self, x): 85 | down1 = self.conv_down1(x) 86 | down2 = self.conv_down2(down1) 87 | down3 = self.conv_down3(down2) 88 | down4 = self.conv_down4(down3) 89 | down5 = self.conv_down5(down4) 90 | down6 = self.conv_down6(down5) 91 | down7 = self.conv_down7(down6) 92 | down8 = self.conv_down8(down7) 93 | 94 | up1 = self.conv_up1(down8) 95 | up2 = self.conv_up2(torch.cat([up1, down7], dim=1)) 96 | up3 = self.conv_up3(torch.cat([up2, down6], dim=1)) 97 | up4 = self.conv_up4(torch.cat([up3, down5], dim=1)) 98 | up5 = self.conv_up5(torch.cat([up4, down4], dim=1)) 99 | up6 = self.conv_up6(torch.cat([up5, down3], dim=1)) 100 | up7 = self.conv_up7(torch.cat([up6, down2], dim=1)) 101 | up8 = self.conv_up8(torch.cat([up7, down1], dim=1)) 102 | up9 = self.conv_up9(up8) 103 | up10 = self.conv_up10(up9) 104 | up11 = self.conv_up11(up10) 105 | 106 | return up11 107 | -------------------------------------------------------------------------------- /pix2vertex/reconstructor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import torch 5 | from imageio import imread 6 | 7 | from .models import pix2pix 8 | from .detector import Detector 9 | 10 | class Reconstructor: 11 | def __init__(self, weights_path=None, detector=None): 12 | if detector is None: 13 | detector = Detector() 14 | self.detector = detector 15 | self.unet = pix2pix.UNet() 16 | self.set_initial_weights(weights_path) 17 | self.unet.train() # As in the original pix2pix, works as InstanceNormalization 18 | 19 | def set_initial_weights(self, weights_path): 20 | if weights_path is None: 21 | weights_path = os.path.join(os.path.dirname(__file__), 22 | '../weights/faces_hybrid_and_rotated_2.pth') 23 | print('loading default reconstructor weights from {}'.format(weights_path)) 24 | if not os.path.exists(weights_path): 25 | from .utils import download_from_gdrive 26 | from .constants import p2v_model_gdrive_id 27 | os.makedirs(os.path.dirname(weights_path), exist_ok=True) 28 | print('\tDownloading weights...') 29 | download_from_gdrive(p2v_model_gdrive_id, weights_path) 30 | print('\tDone!') 31 | self.initial_weights = torch.load(weights_path) 32 | 33 | def run(self, image, verbose=False): 34 | if type(image) is str: 35 | image = imread(image) 36 | image_cropped = self.detector.detect_and_crop(image) 37 | net_res = self.run_net(image_cropped) 38 | final_res = self.post_process(net_res) 39 | if verbose: 40 | from . import vis_depth_interactive 41 | vis_depth_interactive(final_res['Z_surface']) 42 | return final_res, image_cropped 43 | 44 | def run_net(self, img): 45 | # Because is actually instance normalization need to copy weights each time 46 | self.unet.load_state_dict(copy.deepcopy(self.initial_weights), strict=True) 47 | 48 | # Forward 49 | input = torch.from_numpy(img.transpose()).float() 50 | input = input.unsqueeze(0) 51 | input = input.transpose(2, 3) 52 | input = input.div(255.0).mul(2).add(-1) 53 | output = self.unet(input) 54 | output = output.add(1).div(2).mul(255) 55 | 56 | # Post Processing 57 | im_both = output.squeeze(0).detach().numpy().transpose().swapaxes(0, 1).copy() 58 | im_pncc = im_both[:, :, 0:3] 59 | im_depth = im_both[:, :, 3:6] 60 | im_depth[np.logical_and(im_depth < 10, im_depth > -10)] = 0 61 | im_pncc[np.logical_and(im_pncc < 10, im_pncc > -10)] = 0 62 | 63 | return {'pnnc': im_pncc, 'depth': im_depth} 64 | 65 | def post_process(self, net_res): 66 | im_pncc = net_res['pnnc'].astype(np.float64) 67 | im_depth = net_res['depth'].astype(np.float64) 68 | net_X = im_depth[:, :, 0] * (1.3674) / 255 - 0.6852 69 | net_Y = im_depth[:, :, 1] * (1.8401) / 255 - 0.9035 70 | net_Z = im_depth[:, :, 2] * (0.7542) / 255 - 0.2997 71 | mask = np.any(im_depth, axis=2) * np.all(im_pncc, axis=2) 72 | 73 | X = np.tile(np.linspace(-1, 1, im_depth.shape[1]), (im_depth.shape[0], 1)) 74 | Y = np.tile(np.linspace(1, -1, im_depth.shape[0]).reshape(-1, 1), (1, im_depth.shape[1])) 75 | 76 | # Normalize fixed grid according to the network result, as X,Y are actually redundant 77 | X = (X - np.mean(X[mask])) / np.std(X[mask]) * np.std(net_X[mask]) + np.mean(net_X[mask]) 78 | Y = (Y - np.mean(Y[mask])) / np.std(Y[mask]) * np.std(net_Y[mask]) + np.mean(net_Y[mask]) 79 | 80 | Z = net_Z * 2 # Due to image resizing 81 | 82 | f = 1 / (X[0, 1] - X[0, 0]) 83 | 84 | Z_surface = Z * f 85 | Z_surface[mask == False] = np.nan 86 | Z[mask == False] = np.nan 87 | 88 | return {'Z': Z, 'X': X, 'Y': Y, 'Z_surface': Z_surface} 89 | -------------------------------------------------------------------------------- /pix2vertex/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from tqdm import tqdm 5 | from itertools import product 6 | import struct 7 | 8 | ASCII_FACET = """ facet normal {face[0]:e} {face[1]:e} {face[2]:e} 9 | outer loop 10 | vertex {face[3]:e} {face[4]:e} {face[5]:e} 11 | vertex {face[6]:e} {face[7]:e} {face[8]:e} 12 | vertex {face[9]:e} {face[10]:e} {face[11]:e} 13 | endloop 14 | endfacet""" 15 | 16 | BINARY_HEADER = "80sI" 17 | BINARY_FACET = "12fH" 18 | 19 | 20 | # Saving to STL is based on https://github.com/thearn/stl_tools/ 21 | 22 | def vis_depth_matplotlib(img, Z, elevation=60, azimuth=45, stride=5): 23 | from mpl_toolkits.mplot3d import Axes3D 24 | from matplotlib import cm 25 | import matplotlib.pyplot as plt 26 | from matplotlib.colors import LightSource 27 | 28 | fig = plt.figure() 29 | ax = fig.gca(projection='3d') 30 | 31 | # Create X and Y data 32 | x = np.arange(0, 512, 1) 33 | y = np.arange(0, 512, 1) 34 | X, Y = np.meshgrid(x, y) 35 | 36 | ls = LightSource(azdeg=0, altdeg=90) 37 | # shade data, creating an rgb array. 38 | img_c = np.concatenate((img.astype('float') / 255, np.ones((512, 512, 1))), axis=2) 39 | rgb = ls.shade_rgb(img.astype('float') / 255, Z, blend_mode='overlay') 40 | 41 | surf = ax.plot_surface(X, Y, Z, cstride=stride, rstride=stride, linewidth=0, antialiased=False, facecolors=rgb) 42 | ax.view_init(elev=elevation, azim=azimuth) 43 | # ax.view_init(elev=90., azim=90)` 44 | # Show the plot 45 | plt.show() 46 | 47 | 48 | def vis_depth_interactive(Z): 49 | import k3d 50 | Nx, Ny = 1, 1 51 | xmin, xmax = 0, Z.shape[0] 52 | ymin, ymax = 0, Z.shape[1] 53 | 54 | x = np.linspace(xmin, xmax, Nx) 55 | y = np.linspace(ymin, ymax, Ny) 56 | x, y = np.meshgrid(x, y) 57 | plot = k3d.plot(grid_auto_fit=True, camera_auto_fit=False) 58 | plt_surface = k3d.surface(-Z.astype(np.float32), color=0xb2ccff, 59 | bounds=[xmin, xmax, ymin, ymax]) # -Z for mirroring 60 | plot += plt_surface 61 | plot.display() 62 | plot.camera = [242.57934019166004, 267.50948550191197, -406.62328311352337, 256, 256, -8.300323486328125, 63 | -0.13796270478729053, -0.987256298362836, -0.07931767413815752] 64 | return plot 65 | 66 | 67 | def vis_pcloud_interactive(res, img): 68 | import k3d 69 | from colormap import rgb2hex 70 | color_vals = np.zeros((img.shape[0], img.shape[1])) 71 | for i in range(img.shape[0]): 72 | for j in range(img.shape[1]): 73 | color_vals[i, j] = int(rgb2hex(img[i, j, 0], img[i, j, 1], img[i, j, 2]).replace('#', '0x'), 0) 74 | colors = color_vals.flatten() 75 | 76 | points = np.stack((res['X'].flatten(), res['Y'].flatten(), res['Z'].flatten()), axis=1) 77 | 78 | invalid_inds = np.any(np.isnan(points), axis=1) 79 | points_valid = points[invalid_inds == False] 80 | colors_valid = colors[invalid_inds == False] 81 | plot = k3d.plot(grid_auto_fit=True, camera_auto_fit=False) 82 | plot += k3d.points(points_valid, colors_valid, point_size=0.01, compression_level=9, shader='flat') 83 | plot.display() 84 | plot.camera = [-0.3568942548181382, -0.12775125650240726, 3.5390732533009452, 0.33508163690567017, 85 | 0.3904658555984497, -0.0499117374420166, 0.11033077266672488, 0.9696364582197756, 0.2182481603445357] 86 | return plot 87 | 88 | 89 | def vis_net_result(img, net_result): 90 | plt.figure() 91 | plt.subplot(1, 3, 1) 92 | plt.imshow(img) 93 | plt.title('Input Image') 94 | plt.subplot(1, 3, 2) 95 | plt.imshow(net_result['pnnc'].astype(np.uint8)) 96 | plt.title('PNCC Visualization') 97 | plt.subplot(1, 3, 3) 98 | plt.imshow(net_result['depth'][:, :, 2].astype(np.uint8), cmap='gray') 99 | plt.title('Depth Visualization') 100 | plt.show() 101 | 102 | 103 | class TqdmUpTo(tqdm): 104 | """Provides `update_to(n)` which uses `tqdm.update(delta_n)`.""" 105 | 106 | def update_to(self, b=1, bsize=1, tsize=None): 107 | """ 108 | b : int, optional 109 | Number of blocks transferred so far [default: 1]. 110 | bsize : int, optional 111 | Size of each block (in tqdm units) [default: 1]. 112 | tsize : int, optional 113 | Total size (in tqdm units). If [default: None] remains unchanged. 114 | """ 115 | if tsize is not None: 116 | self.total = tsize 117 | self.update(b * bsize - self.n) # will also set self.n = b * bsize 118 | 119 | 120 | def download_url(url, save_path): 121 | from six.moves import urllib 122 | save_path = os.path.expanduser(save_path) 123 | if not os.path.exists(save_path): 124 | makedir(save_path) 125 | 126 | filename = url.rpartition('/')[2] 127 | filepath = os.path.join(save_path, filename) 128 | 129 | try: 130 | with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, 131 | desc=url.split('/')[-1]) as t: # all optional kwargs 132 | urllib.request.urlretrieve(url, filepath, reporthook=t.update_to) 133 | t.total = t.n 134 | except ValueError: 135 | raise Exception('Failed to download! Check URL: ' + url + 136 | ' and local path: ' + save_path) 137 | 138 | 139 | def extract_file(path, to_directory=None): 140 | path = os.path.expanduser(path) 141 | if path.endswith('.zip'): 142 | opener, mode = zipfile.ZipFile, 'r' 143 | elif path.endswith(('.tar.gz', '.tgz')): 144 | opener, mode = tarfile.open, 'r:gz' 145 | elif path.endswith(('tar.bz2', '.tbz')): 146 | opener, mode = tarfile.open, 'r:bz2' 147 | elif path.endswith('.bz2'): 148 | import bz2 149 | opener, mode = bz2.BZ2File, 'rb' 150 | with open(path[:-4], 'wb') as fp_out, opener(path, 'rb') as fp_in: 151 | for data in iter(lambda: fp_in.read(100 * 1024), b''): 152 | fp_out.write(data) 153 | return 154 | else: 155 | raise (ValueError, 156 | "Could not extract `{}` as no extractor is found!".format(path)) 157 | 158 | if to_directory is None: 159 | to_directory = os.path.abspath(os.path.join(path, os.path.pardir)) 160 | cwd = os.getcwd() 161 | os.chdir(to_directory) 162 | 163 | try: 164 | file = opener(path, mode) 165 | try: 166 | file.extractall() 167 | finally: 168 | file.close() 169 | finally: 170 | os.chdir(cwd) 171 | 172 | 173 | def download_from_gdrive(id, destination): 174 | import requests 175 | URL = "https://docs.google.com/uc?export=download" 176 | 177 | session = requests.Session() 178 | 179 | response = session.get(URL, params={'id': id}, stream=True) 180 | token = get_confirm_token(response) 181 | 182 | if token: 183 | params = {'id': id, 'confirm': token} 184 | response = session.get(URL, params=params, stream=True) 185 | 186 | save_response_content(response, destination) 187 | 188 | 189 | def get_confirm_token(response): 190 | for key, value in response.cookies.items(): 191 | if key.startswith('download_warning'): 192 | return value 193 | 194 | return None 195 | 196 | 197 | def save_response_content(response, destination): 198 | CHUNK_SIZE = 32768 199 | t = tqdm(unit='B', unit_scale=True, miniters=1) 200 | 201 | with open(destination, "wb") as f: 202 | for chunk in response.iter_content(CHUNK_SIZE): 203 | if chunk: # filter out keep-alive new chunks 204 | t.update(len(chunk)) 205 | f.write(chunk) 206 | 207 | 208 | def _build_binary_stl(facets): 209 | """returns a string of binary binary data for the stl file""" 210 | 211 | lines = [struct.pack(BINARY_HEADER, b'Binary STL Writer', len(facets)), ] 212 | for facet in facets: 213 | facet = list(facet) 214 | facet.append(0) # need to pad the end with a unsigned short byte 215 | lines.append(struct.pack(BINARY_FACET, *facet)) 216 | return lines 217 | 218 | 219 | def _build_ascii_stl(facets): 220 | """returns a list of ascii lines for the stl file """ 221 | 222 | lines = ['solid ffd_geom', ] 223 | for facet in facets: 224 | lines.append(ASCII_FACET.format(face=facet)) 225 | lines.append('endsolid ffd_geom') 226 | return lines 227 | 228 | 229 | def writeSTL(facets, file_name, ascii=False): 230 | """writes an ASCII or binary STL file""" 231 | 232 | f = open(file_name, 'wb') 233 | if ascii: 234 | lines = _build_ascii_stl(facets) 235 | lines_ = "\n".join(lines).encode("UTF-8") 236 | f.write(lines_) 237 | else: 238 | data = _build_binary_stl(facets) 239 | data = b"".join(data) 240 | f.write(data) 241 | 242 | f.close() 243 | 244 | 245 | def save2stl(A, fn, scale=1, mask_val=None, ascii=False, 246 | max_width=235., 247 | max_depth=140., 248 | max_height=150., 249 | solid=False, 250 | rotate=True, 251 | min_thickness_percent=0.1): 252 | """ 253 | Reads a numpy array, and outputs an STL file 254 | Inputs: 255 | A (ndarray) - an 'm' by 'n' 2D numpy array 256 | fn (string) - filename to use for STL file 257 | Optional input: 258 | scale (float) - scales the height (surface) of the 259 | resulting STL mesh. Tune to match needs 260 | mask_val (float) - any element of the inputted array that is less 261 | than this value will not be included in the mesh. 262 | default renders all vertices (x > -inf for all float x) 263 | ascii (bool) - sets the STL format to ascii or binary (default) 264 | max_width, max_depth, max_height (floats) - maximum size of the stl 265 | object (in mm). Match this to 266 | the dimensions of a 3D printer 267 | platform 268 | solid (bool): sets whether to create a solid geometry (with sides and 269 | a bottom) or not. 270 | min_thickness_percent (float) : when creating the solid bottom face, this 271 | multiplier sets the minimum thickness in 272 | the final geometry (shallowest interior 273 | point to bottom face), as a percentage of 274 | the thickness of the model computed up to 275 | that point. 276 | Returns: (None) 277 | """ 278 | 279 | # Remove Nans, set their values as the minimal one 280 | A = A.copy() 281 | A[np.isnan(A)] = A[~np.isnan(A)].min() 282 | 283 | m, n = A.shape 284 | if n >= m and rotate: 285 | # rotate to best fit a printing platform 286 | A = np.rot90(A, k=3) 287 | m, n = n, m 288 | A = scale * (A - A.min()) 289 | 290 | if not mask_val: 291 | mask_val = A.min() # - 1. 292 | 293 | facets = [] 294 | mask = np.zeros((m, n)) 295 | print("Creating top mesh...") 296 | for i, k in product(range(m - 1), range(n - 1)): 297 | this_pt = np.array([i - m / 2., k - n / 2., A[i, k]]) 298 | top_right = np.array([i - m / 2., k + 1 - n / 2., A[i, k + 1]]) 299 | bottom_left = np.array([i + 1. - m / 2., k - n / 2., A[i + 1, k]]) 300 | bottom_right = np.array( 301 | [i + 1. - m / 2., k + 1 - n / 2., A[i + 1, k + 1]]) 302 | 303 | n1, n2 = np.zeros(3), np.zeros(3) 304 | is_a, is_b = False, False 305 | if (this_pt[-1] > mask_val and top_right[-1] > mask_val and 306 | bottom_left[-1] > mask_val): 307 | facet = np.concatenate([n1, top_right, this_pt, bottom_right]) 308 | mask[i, k] = 1 309 | mask[i, k + 1] = 1 310 | mask[i + 1, k] = 1 311 | facets.append(facet) 312 | 313 | if (this_pt[-1] > mask_val and bottom_right[-1] > mask_val and 314 | bottom_left[-1] > mask_val): 315 | facet = np.concatenate( 316 | [n2, bottom_right, this_pt, bottom_left]) 317 | facets.append(facet) 318 | mask[i, k] = 1 319 | mask[i + 1, k + 1] = 1 320 | mask[i + 1, k] = 1 321 | 322 | print('\t', len(facets), 'facets') 323 | facets = np.array(facets) 324 | 325 | if solid: 326 | print("Computed edges...") 327 | edge_mask = np.sum([roll2d(mask, (i, k)) 328 | for i, k in product([-1, 0, 1], repeat=2)], 329 | axis=0) 330 | edge_mask[np.where(edge_mask == 9.)] = 0. 331 | edge_mask[np.where(edge_mask != 0.)] = 1. 332 | edge_mask[0::m - 1, :] = 1. 333 | edge_mask[:, 0::n - 1] = 1. 334 | X, Y = np.where(edge_mask == 1.) 335 | locs = zip(X - m / 2., Y - n / 2.) 336 | 337 | zvals = facets[:, 5::3] 338 | zmin, zthickness = zvals.min(), zvals.ptp() 339 | 340 | minval = zmin - min_thickness_percent * zthickness 341 | 342 | bottom = [] 343 | print("Extending edges, creating bottom...") 344 | for i, facet in enumerate(facets): 345 | if (facet[3], facet[4]) in locs: 346 | facets[i][5] = minval 347 | if (facet[6], facet[7]) in locs: 348 | facets[i][8] = minval 349 | if (facet[9], facet[10]) in locs: 350 | facets[i][11] = minval 351 | this_bottom = np.concatenate( 352 | [facet[:3], facet[6:8], [minval], facet[3:5], [minval], 353 | facet[9:11], [minval]]) 354 | bottom.append(this_bottom) 355 | 356 | facets = np.concatenate([facets, bottom]) 357 | 358 | xsize = facets[:, 3::3].ptp() 359 | if xsize > max_width: 360 | facets = facets * float(max_width) / xsize 361 | 362 | ysize = facets[:, 4::3].ptp() 363 | if ysize > max_depth: 364 | facets = facets * float(max_depth) / ysize 365 | 366 | zsize = facets[:, 5::3].ptp() 367 | if zsize > max_height: 368 | facets = facets * float(max_height) / zsize 369 | 370 | print('Writing STL...') 371 | writeSTL(facets, fn, ascii=ascii) 372 | print('Done!') 373 | -------------------------------------------------------------------------------- /reconstruct_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import matplotlib\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import numpy as np\n", 15 | "from skimage import io\n", 16 | "import pix2vertex as p2v\n", 17 | "\n", 18 | "matplotlib.rcParams['figure.figsize'] = (13,7)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Initializations" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "detector = p2v.Detector()\n", 35 | "reconstructor = p2v.Reconstructor(detector=detector)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "im_path = 'examples/sample.jpg' # im_path can be a URL as well!\n", 45 | "img = io.imread(im_path) \n", 46 | "fig = plt.figure()\n", 47 | "plt.imshow(img)\n", 48 | "plt.show()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Inference" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "img_crop = detector.detect_and_crop(img)\n", 65 | "fig = plt.figure()\n", 66 | "plt.imshow(img_crop)\n", 67 | "plt.show()" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "net_res = reconstructor.run_net(img_crop)\n", 77 | "p2v.vis_net_result(img_crop,net_res)\n", 78 | "final_res = reconstructor.post_process(net_res)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "## Interactive Visualizations" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "plot = p2v.vis_depth_interactive(final_res['Z_surface'])" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "plot = p2v.vis_pcloud_interactive(final_res,img_crop)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# Fallback matplotlib visualization\n", 113 | "p2v.vis_depth_matplotlib(img_crop,final_res['Z_surface'])" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## Saving Result" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "p2v.save2stl(final_res['Z_surface'],'res.stl')" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "Create link to make accessible from notebook" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "from IPython.display import FileLink\n", 146 | "FileLink('res.stl')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.7.6" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 4 178 | } 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | k3d==2.7.4 2 | easydev==0.9.38 3 | colormap==1.0.3 4 | torch 5 | imageio 6 | dlib 7 | scikit-image 8 | matplotlib 9 | tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # MIT License 4 | # 5 | # Copyright (c) 2020 Elad Richardson and Matan Sela 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import os 26 | import setuptools 27 | 28 | with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as fh: 29 | long_description = fh.read() 30 | 31 | with open(os.path.join(os.path.dirname(__file__), "requirements.txt")) as f: 32 | required = f.read().splitlines() 33 | 34 | setuptools.setup( 35 | name="pix2vertex", # Replace with your own username 36 | version="1.0.4", 37 | author="Elad Richardson, Matan Sela", 38 | author_email="elad.richardson@gmail.com, matansel@gmail.com", 39 | description="3D face reconstruction from a single image", 40 | long_description=long_description, 41 | long_description_content_type="text/markdown", 42 | url="https://github.com/eladrich/pix2vertex.pytorch", 43 | packages=setuptools.find_packages(exclude=["tests.*", "tests"]), 44 | classifiers=[ 45 | "Environment :: Console", 46 | "Intended Audience :: Developers", 47 | "Intended Audience :: Science/Research", 48 | "Intended Audience :: Education", 49 | "Programming Language :: Python :: 3", 50 | "License :: OSI Approved :: MIT License", 51 | "Operating System :: OS Independent", 52 | ], 53 | install_requires=required, 54 | include_package_data=True, 55 | keywords="pix2vertex face reconstruction 3d pytorch pip package", 56 | python_requires='>=3.4', 57 | zip_safe=False 58 | ) 59 | -------------------------------------------------------------------------------- /tests/test_pix2vertex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from imageio import imread 3 | import pix2vertex as p2v 4 | 5 | from unittest import TestCase 6 | 7 | class TestPix2Vertex(TestCase): 8 | def test_reconstruct(self): 9 | image = imread('examples/sample.jpg') 10 | results = p2v.reconstruct(image) 11 | self.assertEqual(len(results), 2) 12 | -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | Place the dlib and pix2vertex models here, or simply run download.sh 2 | --------------------------------------------------------------------------------